#include <linux/string.h>
#include <linux/sched.h>
#include <linux/sched/clock.h>
#include <linux/av_thread_priorities.h>

#define DEBUG_ENABLED 0
#define USE_FTRACE 0

#if DEBUG_ENABLED
#if USE_FTRACE
#define PRINT(fmt, ...) trace_printk(fmt, ##__VA_ARGS__)
#else
#define PRINT(fmt, ...) printk(fmt, ##__VA_ARGS__)
#endif
#else
#define PRINT(fmt, ...)
#endif

typedef struct {
	char name[TASK_COMM_LEN];
	u32 hash;
	int policy;
	int priority;
} THREAD_INFO;

#define AAMP_HASH		(('a' << 24) | ('a' << 16) | ('m' << 8) | ('p'))
#define AUDHAL_HASH		(('a' << 24) | ('m' << 16) | ('l' << 8) | ('A'))
#define AUDPROCFRM_HASH	(('a' << 24) | ('u' << 16) | ('d' << 8) | ('P'))
#define AUDSRC_HASH		(('a' << 24) | ('u' << 16) | ('d' << 8) | ('s'))
#define AQUEUE_HASH		(('a' << 24) | ('q' << 16) | ('u' << 8) | ('e'))
#define APPSRC_HASH		(('a' << 24) | ('p' << 16) | ('p' << 8) | ('s'))
#define AV_APOLL_HASH	(('a' << 24) | ('v' << 16) | ('_' << 8) | ('a'))
#define BASE_HASH		(('B' << 24) | ('a' << 16) | ('s' << 8) | ('e'))
#define DOLBY_HASH		(('D' << 24) | ('O' << 16) | ('L' << 8) | ('B'))
#define DRMSYSTEM_HASH	(('D' << 24) | ('R' << 16) | ('M' << 8) | ('S'))
#define FOG_HASH		(('f' << 24) | ('o' << 16) | ('g' << 8) | ('c'))
#define GRPCPP_HASH		(('g' << 24) | ('r' << 16) | ('p' << 8) | ('c'))
#define MAINWEB_HASH	(('M' << 24) | ('a' << 16) | ('i' << 8) | ('n'))
#define MEDIA_P_HASH	(('m' << 24) | ('e' << 16) | ('d' << 8) | ('i'))
#define MULTI_Q_HASH	(('m' << 24) | ('u' << 16) | ('l' << 8) | ('t'))
#define OCDM_HASH		(('O' << 24) | ('C' << 16) | ('D' << 8) | ('M'))
#define PLAYBACK_HASH	(('p' << 24) | ('l' << 16) | ('a' << 8) | ('y'))
#define WPE_POOL_HASH	(('p' << 24) | ('o' << 16) | ('o' << 8) | ('l'))
#define QUEUE_HASH		(('q' << 24) | ('u' << 16) | ('e' << 8) | ('u'))
#define SOURCE_HASH		(('s' << 24) | ('o' << 16) | ('u' << 8) | ('r'))
#define VIDSRC_HASH		(('v' << 24) | ('i' << 16) | ('d' << 8) | ('s'))
#define VIDPROC_HASH	(('v' << 24) | ('i' << 16) | ('d' << 8) | ('P'))
#define VQUEUE_HASH		(('v' << 24) | ('q' << 16) | ('u' << 8) | ('e'))
#define V8_DEF_HASH		(('V' << 24) | ('8' << 16) | ('_' << 8) | ('D'))
#define WST_SINK_HASH	(('w' << 24) | ('s' << 16) | ('t' << 8) | ('S'))
#define WST_V_HASH		(('w' << 24) | ('s' << 16) | ('t' << 8) | ('V'))

static THREAD_INFO threads[] = {
        {"aampBuffHealth", AAMP_HASH, SCHED_RR, 3},
        {"aampfMP4DRM", AAMP_HASH, SCHED_RR, 3},
        {"aampHLSFetcher", AAMP_HASH, SCHED_RR, 10},
        {"aampInjector", AAMP_HASH, SCHED_RR, 10},
        {"aampLatencyMon", AAMP_HASH, SCHED_RR, 3},
        {"aampMPDFetcher", AAMP_HASH, SCHED_RR, 10},
        {"aampPSFetcher", AAMP_HASH, SCHED_RR, 10},
        {"amlAudioMixer16", AUDHAL_HASH, SCHED_RR, 40},
        {"amlAudOut_patch", AUDHAL_HASH, SCHED_RR, 40},
        {"amlAudOutMmap", AUDHAL_HASH, SCHED_RR, 40},
        {"audProcessFrame", AUDPROCFRM_HASH, SCHED_RR, 30},			/* netflix DRM */
        {"audsrc", AUDSRC_HASH, SCHED_RR, 10},						/* youtube WriteSample() (i.e. where frame time is applied) */
        {"aqueue", AQUEUE_HASH, SCHED_RR, 35},
        {"appsrc", APPSRC_HASH, SCHED_RR, 25},						/* netflix equivalent to aamp's multiqueue */
        {"avs_apoll", AV_APOLL_HASH, SCHED_RR, 40},
        {"BaseDIWorker", BASE_HASH, SCHED_RR, 6},
        {"BaseDIFrag", BASE_HASH, SCHED_RR, 7},
        {"DOLBY_MS12", DOLBY_HASH, SCHED_RR, 45},
        {"DRMSYSTEM", DRMSYSTEM_HASH, SCHED_RR, 30},				/* netflix DRM */
        {"fogcli", FOG_HASH, SCHED_RR, 5},
        {"grpcpp_sync_ser", GRPCPP_HASH, SCHED_RR, 40},
        {"MainWebModule", MAINWEB_HASH, SCHED_RR, 8},				/* youtube WriteSample() (i.e. where frame time is applied) */
        {"media_pipeline", MEDIA_P_HASH, SCHED_RR, 8},				/* youtube WriteSample() (i.e. where frame time is applied) */
        {"multiqueue", MULTI_Q_HASH, SCHED_RR, 25},
        {"OCDM_decrypt", OCDM_HASH, SCHED_RR, 30},
        {"OCDM_SaThread", OCDM_HASH, SCHED_RR, 30},
        {"playback_thread", PLAYBACK_HASH, SCHED_RR, 10},				/* youtube WriteSample() (i.e. where frame time is applied) */
        {"pool-WPEFramewo", WPE_POOL_HASH, SCHED_RR, 10},
        {"queue", QUEUE_HASH, SCHED_RR, 35},
        {"source:src", SOURCE_HASH, SCHED_RR, 10},
        {"vidsrc", VIDSRC_HASH, SCHED_RR, 10},						/* youtube WriteSample() (i.e. where frame time is applied) */
        {"vidProcessFrame", VIDPROC_HASH, SCHED_RR, 30},			/* netflix DRM */
        {"vqueue", VQUEUE_HASH, SCHED_RR, 35},
        {"V8 Default", V8_DEF_HASH, SCHED_RR, 8},					/* youtube AV1/VP9 */
        {"wstSinkVidDSP", WST_SINK_HASH, SCHED_RR, 10},
        {"wstSinkVidEOS", WST_SINK_HASH, SCHED_RR, 10},
        {"wstSinkVidOut", WST_SINK_HASH, SCHED_RR, 55},
        {"wstVOffload", WST_V_HASH, SCHED_RR, 75},
        {"wstVRefresh", WST_V_HASH, SCHED_RR, 70},
        {"wstVServerConn", WST_V_HASH, SCHED_RR, 65},
        {"wstVServer", WST_V_HASH, SCHED_RR, 60},
};

#define MAX_NUM_AV_THREADS	(sizeof(threads) / sizeof(THREAD_INFO))

int av_thread_priority_set(struct task_struct *task)
{
	int i;
	u32 hash;
#if DEBUG_ENABLED
	u64 startTime;
	/* declare and init metrics counters */
	static u32 call_count = 0;			/* incremented each time function is called */
	static u32 hit_miss_count = 0;		/* incremented when get a hash match, but name compare fail */
	static u32 hit_hit_count = 0;		/* incremented when get a hash match and a name match */
#endif
	/* check for valid task pointer */
	if (task == NULL) return -EINVAL;
#if DEBUG_ENABLED
	startTime = local_clock();
	call_count++;
#endif
	/* convert first 4 chars of task name into an unsigned int */
	hash = (task->comm[0] << 24) | (task->comm[1] << 16) | (task->comm[2] << 8) | (task->comm[3]);
	/* fast check if task hash is in list */
	switch (hash) {
		case AAMP_HASH:
		case AUDHAL_HASH:
		case AUDPROCFRM_HASH:
		case AUDSRC_HASH:
		case AQUEUE_HASH:
		case APPSRC_HASH:
		case AV_APOLL_HASH:
		case BASE_HASH:
		case DOLBY_HASH:
		case DRMSYSTEM_HASH:
		case FOG_HASH:
		case GRPCPP_HASH:
		case MAINWEB_HASH:
		case MEDIA_P_HASH:
		case MULTI_Q_HASH:
		case OCDM_HASH:
		case PLAYBACK_HASH:
		case WPE_POOL_HASH:
		case QUEUE_HASH:
		case SOURCE_HASH:
		case VIDSRC_HASH:
		case VIDPROC_HASH:
		case VQUEUE_HASH:
		case V8_DEF_HASH:
		case WST_V_HASH:
		case WST_SINK_HASH:
			for (i = 0; i < MAX_NUM_AV_THREADS; i++) {
				if (hash == threads[i].hash) {
					int j;
					for (j = i; j < MAX_NUM_AV_THREADS; j++) {
						if (strncmp(task->comm, threads[j].name, strlen(threads[j].name)) == 0) {
							struct sched_param param = { .sched_priority = threads[j].priority };
							/* got a match, so set thread policy/priority */
							sched_setscheduler_nocheck(task, threads[j].policy, &param);
							PRINT("thread %s, 0x%x, pid=%d, set policy=%d, priority=%d, runtime=%lld, call_count=%d, hit_hit_count=%d, hit_miss_count=%d\n", task->comm, hash, task->pid, threads[j].policy, threads[j].priority, local_clock() - startTime, call_count, ++hit_hit_count, hit_miss_count);
							return 0;
						}
					}
				}
			}
			PRINT("thread %s 0x%x, pid=%d hash hit, name miss, runtime=%lld, call_count=%d, hit_hit_count=%d, hit_miss_count=%d\n", task->comm, hash, task->pid, local_clock() - startTime, call_count, hit_hit_count, ++hit_miss_count);
		return ENODEV;
		default:
		break;
	}
	PRINT("thread %s 0x%x, pid=%d not in list, runtime=%lld, call_count=%d, hit_hit_count=%d, hit_miss_count=%d\n", task->comm, hash, task->pid, local_clock() - startTime, call_count, hit_hit_count, hit_miss_count);
	return ENODEV;
}
