/*
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * Authors: Waiman Long <waiman.long@hp.com>
 *
 * This kernel module enables us to evaluate the contention behavior of locks.
 *
 * The following sysfs variables will be created:
 * 1) /sys/kernel/locktest/c_count
 * 2) /sys/kernel/locktest/i_count
 * 3) /sys/kernel/locktest/l_count
 * 4) /sys/kernel/locktest/p_count
 * 5) /sys/kernel/locktest/locktype
 * 6) /sys/kernel/locktest/rw_ratio
 * 7) /sys/kernel/locktest/load		# load type
 * 8) /sys/kernel/locktest/etime	# elapsed time in ticks
 * 9) /sys/kernel/locktest/status	# current test status
 */
#include <linux/module.h>	// included for all kernel modules
#include <linux/kernel.h>	// included for KERN_INFO
#include <linux/init.h>		// included for __init and __exit macros
#include <linux/kobject.h>
#include <linux/sysfs.h>
#include <linux/string.h>
#include <linux/atomic.h>
#include <linux/delay.h>
#include <linux/spinlock.h>
#include <linux/sched.h>

static int c_count ;	/* Contention count	*/
static int i_count ;	/* Iteration count	*/
static int l_count ;	/* Load count		*/
static int p_count ;	/* Pause count		*/
static int locktype;	/* Lock type		*/
static int loadtype;	/* Load type		*/
static int rw_ratio;	/* RW (n:1) ratio	*/

int read_dummy;

/*
 * Show, store and attribute macro
 */
#define	SHOW_STORE_ATTR(v)						\
	static ssize_t v ## _show(struct kobject *kobj,			\
				struct kobj_attribute *attr, char *buf)	\
	{								\
		return sprintf(buf, "%d\n", v);				\
	}								\
	static ssize_t v ## _store(struct kobject *kobj,		\
				struct kobj_attribute *attr,		\
				const char *buf, size_t count)		\
	{								\
		sscanf(buf, "%du", &v);					\
		return count;						\
	}								\
	static struct kobj_attribute v ## _attribute =			\
	__ATTR(v, 0664, v ## _show, v ## _store);

/*
 * Lock type
 */
#define	LOCK_SPINLOCK	0
#define	LOCK_RWLOCK	1
#define LOCK_MUTEX	2
#define LOCK_RWSEM	3
#define LOCK_OTHER	4
/*
 * Timing tests
 */
#define TIME_PAUSE	10	/* pause timing test  */
#define TIME_RMB	11	/* lfence timing test */
#define TIME_MB		12	/* mfence timing test */

/*
 * Load type
 * 1) Standalone - lock & protected data in separate cachelines
 * 2) Embedded   - lock & protected data in the same cacheline
 */
#define	LOAD_STANDALONE	0
#define	LOAD_EMBEDDED	1

/*
 * Scheduling check period mask
 */
#define RESCHED_MASK	0xff

/*
 * External functions
 */
extern void *__kmalloc_node(size_t size, gfp_t flags, int node);

/*
 * Locks
 */
struct load {
	int c1, c2;
};

struct s_lock {
	struct {
		spinlock_t  lock;
		struct load data;
	} ____cacheline_aligned_in_smp;
};

struct rw_lock {
	struct {
		rwlock_t    lock;
		struct load data;
	} ____cacheline_aligned_in_smp;
};

struct mutex_lock {
	struct {
		struct mutex lock;
		struct load  data;
	} ____cacheline_aligned_in_smp;
};

struct rwsem_lock {
	struct {
		struct rw_semaphore lock;
		struct load	    data;
	} ____cacheline_aligned_in_smp;
};

static struct s_lock     *slock;
static struct rw_lock    *rwlock;
static struct mutex_lock *mutex;
static struct rwsem_lock *rwsem;

/*
 * Show & store methods and attributes
 */
SHOW_STORE_ATTR(c_count) ;
SHOW_STORE_ATTR(i_count) ;
SHOW_STORE_ATTR(l_count) ;
SHOW_STORE_ATTR(p_count) ;
SHOW_STORE_ATTR(locktype);
SHOW_STORE_ATTR(rw_ratio);
SHOW_STORE_ATTR(loadtype);

static ssize_t etime_store(struct kobject *kobj, struct kobj_attribute *attr,
			   const char *buf, size_t count)
{
	/* Dummy method */
	return count;
}

static ssize_t status_store(struct kobject *kobj, struct kobj_attribute *attr,
			   const char *buf, size_t count)
{
	/* Dummy method */
	return count;
}

/*
 * Compute time difference in microsecond (us)
 */
static unsigned long compute_us(struct timespec *start, struct timespec *stop)
{
	return ((stop->tv_sec  - start->tv_sec ) * NSEC_PER_SEC +
		(stop->tv_nsec - start->tv_nsec) + 500)/1000;
}

static noinline int load(struct load *data, int ltype, int write, int lcnt)
{
	/*
	 * Negative lcnt means sleep for lcnt us
	 */
	if (lcnt < 0) {
		usleep_range(-lcnt, -lcnt+1);
	} else if (ltype == LOAD_STANDALONE) {
		for (; lcnt > 0; lcnt--)
			cpu_relax();
	} else if (write) {
		for (; lcnt > 0; lcnt--) {
			ACCESS_ONCE(data->c1) += lcnt  ;
			ACCESS_ONCE(data->c2) += 2*lcnt;
		}
	} else {
		int sum;
		for (sum = 0; lcnt > 0; lcnt--)
			sum += ACCESS_ONCE(data->c1) + ACCESS_ONCE(data->c2);
		return sum;
	}
	return 0;
}

static noinline void test_spinlock(void)
{
	int i = i_count ;
	int p = p_count ;
	int c = l_count ;
	int l = loadtype;
	int j		;

	for ( ; i > 0 ; i--) {
		spin_lock(&slock->lock);
		load(&slock->data, l, 1, c);
		spin_unlock(&slock->lock);
		for (j = p ; j > 0 ; j--)
			cpu_relax();
		if ((i & RESCHED_MASK) == 0)
			cond_resched();
	}
}

static noinline void test_rwlock(void)
{
	int i = i_count ;
	int p = p_count ;
	int c = l_count ;
	int l = loadtype;
	int r = rw_ratio;
	int j, k	;

	for (j = 0 ; i > 0 ; i--, j++) {
		if (j >= r) {
			/* Write lock */
			j = -1;
			write_lock(&rwlock->lock);
			load(&slock->data, l, 1, c);
			write_unlock(&rwlock->lock);
		} else {
			/* Read lock */
			read_lock(&rwlock->lock);
			load(&slock->data, l, 0, c);
			read_unlock(&rwlock->lock);
		}
		for (k = p ; k > 0 ; k--)
			cpu_relax();
		if ((i & RESCHED_MASK) == 0)
			cond_resched();
	}
}

static noinline void test_mutex(void)
{
	int i = i_count ;
	int p = p_count ;
	int c = l_count ;
	int l = loadtype;
	int j		;

	for ( ; i > 0 ; i--) {
		mutex_lock(&mutex->lock);
		load(&slock->data, l, 1, c);
		mutex_unlock(&mutex->lock);
		for (j = p ; j > 0 ; j--)
			cpu_relax();
		if ((i & RESCHED_MASK) == 0)
			cond_resched();
	}
}

static noinline void test_rwsem(void)
{
	int i = i_count ;
	int p = p_count ;
	int c = l_count ;
	int l = loadtype;
	int r = rw_ratio;
	int j, k	;

	for (j = 0 ; i > 0 ; i--, j++) {
		if (j >= r) {
			/* Write lock */
			j = -1;
			down_write(&rwsem->lock);
			load(&slock->data, l, 1, c);
			up_write(&rwsem->lock);
		} else {
			/* Read lock */
			down_read(&rwsem->lock);
			load(&slock->data, l, 0, c);
			up_read(&rwsem->lock);
		}
		for (k = p ; k > 0 ; k--)
			cpu_relax();
		if ((i & RESCHED_MASK) == 0)
			cond_resched();
	}
}


static noinline int test_pause(char *buf)
{
	struct timespec start, stop;
	unsigned long us;
	int i;

	/*
	 * Measure time with a 1-pause loop, then with a 11-pause loop.
	 * Get the time difference & divided by 10.
	 */
	getnstimeofday(&start);
	for (i = i_count; i && !ACCESS_ONCE(read_dummy); i--) {
		cpu_relax();
		if ((i & RESCHED_MASK) == 0)
			cond_resched();
	}
	getnstimeofday(&stop);
	us = compute_us(&start, &stop);
	getnstimeofday(&start);
	for (i = i_count; i && !ACCESS_ONCE(read_dummy); i--) {
		cpu_relax();	/* 01 */
		cpu_relax();	/* 02 */
		cpu_relax();	/* 03 */
		cpu_relax();	/* 04 */
		cpu_relax();	/* 05 */
		cpu_relax();	/* 06 */
		cpu_relax();	/* 07 */
		cpu_relax();	/* 08 */
		cpu_relax();	/* 09 */
		cpu_relax();	/* 10 */
		cpu_relax();	/* 11 */
		if ((i & RESCHED_MASK) == 0)
			cond_resched();
	}
	getnstimeofday(&stop);

	us = (compute_us(&start, &stop) - us + 5)/10;
	return sprintf(buf, "%ld\n", us);
}

static noinline int test_rmb(char *buf)
{
	struct timespec start, stop;
	unsigned long us;
	int i;

	/*
	 * Measure time with a 1-rmb loop, then with a 11-rmb loop.
	 * Get the time difference & divided by 10.
	 */
	getnstimeofday(&start);
	for (i = i_count; i && !ACCESS_ONCE(read_dummy); i--) {
		rmb();
		if ((i & RESCHED_MASK) == 0)
			cond_resched();
	}
	getnstimeofday(&stop);
	us = compute_us(&start, &stop);
	getnstimeofday(&start);
	for (i = i_count; i && !ACCESS_ONCE(read_dummy); i--) {
		rmb();	/* 01 */
		rmb();	/* 02 */
		rmb();	/* 03 */
		rmb();	/* 04 */
		rmb();	/* 05 */
		rmb();	/* 06 */
		rmb();	/* 07 */
		rmb();	/* 08 */
		rmb();	/* 09 */
		rmb();	/* 10 */
		rmb();	/* 11 */
		if ((i & RESCHED_MASK) == 0)
			cond_resched();
	}
	getnstimeofday(&stop);

	us = (compute_us(&start, &stop) - us + 5)/10;
	return sprintf(buf, "%ld\n", us);
}

static noinline int test_mb(char *buf)
{
	struct timespec start, stop;
	unsigned long us;
	int i;

	/*
	 * Measure time with a 1-mb loop, then with a 11-mb loop.
	 * Get the time difference & divided by 10.
	 */
	getnstimeofday(&start);
	for (i = i_count; i && !ACCESS_ONCE(read_dummy); i--) {
		mb();
		if ((i & RESCHED_MASK) == 0)
			cond_resched();
	}
	getnstimeofday(&stop);
	us = compute_us(&start, &stop);
	getnstimeofday(&start);
	for (i = i_count; i && !ACCESS_ONCE(read_dummy); i--) {
		mb();	/* 01 */
		mb();	/* 02 */
		mb();	/* 03 */
		mb();	/* 04 */
		mb();	/* 05 */
		mb();	/* 06 */
		mb();	/* 07 */
		mb();	/* 08 */
		mb();	/* 09 */
		mb();	/* 10 */
		mb();	/* 11 */
		if ((i & RESCHED_MASK) == 0)
			cond_resched();
	}
	getnstimeofday(&stop);

	us = (compute_us(&start, &stop) - us + 5)/10;
	return sprintf(buf, "%ld\n", us);
}

static noinline void test_other(void)
{
	printk(KERN_INFO "Info: run recursive read_lock test\n");
	/*
	 * Take recursive read_lock three times
	 */
	read_lock(&rwlock->lock);
	cpu_relax();
	read_lock(&rwlock->lock);
	cpu_relax();
	read_lock(&rwlock->lock);
	cpu_relax();
	read_unlock(&rwlock->lock);
	cpu_relax();
	read_unlock(&rwlock->lock);
	cpu_relax();
	read_unlock(&rwlock->lock);
}

static ssize_t etime_show(struct kobject *kobj, struct kobj_attribute *attr,
			  char *buf)
{
	struct timespec start, stop;

	if (locktype == TIME_PAUSE)
		return test_pause(buf);	/* pasue instruction timing test */
	else if (locktype == TIME_RMB)
		return test_rmb(buf);	/* lfence instruction timing test */
	else if (locktype == TIME_MB)
		return test_mb(buf);	/* mfence instruction timing test */

	atomic_dec((atomic_t *)&c_count);

	/* Wait until the count reaches 0 */
	while (ACCESS_ONCE(c_count))
		cpu_relax();

	getnstimeofday(&start);
	if (locktype == LOCK_SPINLOCK)
		test_spinlock();
	else if (locktype == LOCK_RWLOCK)
		test_rwlock();
	else if (locktype == LOCK_MUTEX)
		test_mutex();
	else if (locktype == LOCK_RWSEM)
		test_rwsem();
	else if (locktype == LOCK_OTHER)
		test_other();
	getnstimeofday(&stop);
	return sprintf(buf, "%ld\n", compute_us(&start, &stop));
}

static struct kobj_attribute etime_attribute =
	__ATTR(etime, 0664, etime_show, etime_store);

static ssize_t status_show(struct kobject *kobj, struct kobj_attribute *attr,
			  char *buf)
{
	char *end = buf;

	if (locktype == LOCK_RWSEM) {
		/*
		 * Print out status of the rwsem
		 */
		end += sprintf(end, "rwsem.count     = 0x%lx\n",
			rwsem->lock.count);
		end += sprintf(end, "rwsem.wait_list = %sempty\n",
			list_empty(&rwsem->lock.wait_list) ? "" : "not ");
#ifdef raw_spin_is_locked
		end += sprintf(end, "rwsem.wait_lock = %slocked\n",
			raw_spin_is_locked(&rwsem->lock.wait_lock)
			? "" : "not ");
#endif
#ifdef CONFIG_RWSEM_SPIN_ON_OWNER
		end += sprintf(end, "rwsem.osq       = %d\n",
			atomic_read(&rwsem->lock.osq.tail));
		end += sprintf(end, "rwsem.owner     = 0x%lx\n",
			(unsigned long)rwsem->lock.owner);
#endif
	}
	return end - buf;
}

static struct kobj_attribute status_attribute =
	__ATTR(status, 0664, status_show, status_store);

static struct attribute *attrs[] = {
	&c_count_attribute.attr ,
	&i_count_attribute.attr ,
	&l_count_attribute.attr ,
	&p_count_attribute.attr ,
	&locktype_attribute.attr,
	&loadtype_attribute.attr,
	&rw_ratio_attribute.attr,
	&etime_attribute.attr   ,
	&status_attribute.attr  ,
	NULL
};

static struct attribute_group attr_group = {
        .attrs = attrs,
};

static struct kobject *locktest_kobj;

/*
 * Module init function
 */
static int __init locktest_init(void)
{
	int retval;

	locktest_kobj = kobject_create_and_add("locktest", kernel_kobj);
	if (!locktest_kobj)
		return -ENOMEM;

	retval = sysfs_create_group(locktest_kobj, &attr_group);
	if (retval) {
		kobject_put(locktest_kobj);
		return retval;
	}

	/*
	 * Allocate node 0 memory for locks
	 */
	slock  = __kmalloc_node(sizeof(*slock ), GFP_KERNEL, 0);
	rwlock = __kmalloc_node(sizeof(*rwlock), GFP_KERNEL, 0);
	mutex  = __kmalloc_node(sizeof(*mutex ), GFP_KERNEL, 0);
	rwsem  = __kmalloc_node(sizeof(*rwsem ), GFP_KERNEL, 0);

	if (!slock || !rwlock || !mutex || !rwsem) {
		printk(KERN_WARNING "locktest: __kmalloc_node failed!\n");
		return -ENOMEM;
	}
	spin_lock_init(&slock->lock);
	rwlock_init(&rwlock->lock);
	mutex_init(&mutex->lock);
	init_rwsem(&rwsem->lock);

	printk(KERN_INFO "locktest module loaded!\n");

	// Non-zero return means that the module couldn't be loaded.
	return retval;
}

static void __exit locktest_cleanup(void)
{
	/*
	kfree(slock );
	kfree(rwlock);
	kfree(mutex );
	kfree(rwsem );
	*/
	printk(KERN_INFO "locktest module unloaded.\n");
	kobject_put(locktest_kobj);
}

module_init(locktest_init);
module_exit(locktest_cleanup);

MODULE_LICENSE("GPL");
MODULE_AUTHOR("Waiman Long");
MODULE_DESCRIPTION("Lock testing module");
