#include <sys/sendfile.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/time.h>

#include <fcntl.h>
#include <netdb.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <errno.h>
#include <unistd.h>
#include <signal.h>

#include <linux/types.h>
#include <linux/unistd.h>
#include <linux/kevent.h>

_syscall2(int, kevent_ctl, int, arg1, void *, argv2);
_syscall5(int, aio_sendfile, int, ctl_fd, int, data_fd, int, s, size_t, size, unsigned, flags);

#define TYPE_SENDFILE		0
#define TYPE_AIO		1

#define MAX_CLIENT_NUM		100

#define ulog_err(f, a...) fprintf(stderr, f ": %s [%d].\n", ##a, strerror(errno), errno)
#define ulog(f, a...) fprintf(stderr, f, ##a)

static int need_exit;

void SIGINT_h(int signo)
{
	need_exit = signo;
}

static inline uint32_t num2ip(uint8_t a, uint8_t b, uint8_t c, uint8_t d)
{
	uint32_t ret = 0;

	ret |= a;
	ret <<= 8;
	ret |= b;
	ret <<= 8;
	ret |= c;
	ret <<= 8;
	ret |= d;

	return ret;
}

int create_socket(char *addr, unsigned short port, char *bind_addr, unsigned short bind_port)
{
	int s, err;
	struct hostent *h, *bh;
	struct sockaddr_in sa, bsa;

	ulog("%s:%u -> %s:%u.\n", bind_addr, bind_port, addr, port);
	
	h = gethostbyname(addr);
	if (!h) {
		ulog_err("gethostbyname %s", addr);
		return -1;
	}

	sa.sin_family = AF_INET;
	sa.sin_port = htons(port);
	memcpy(&sa.sin_addr.s_addr, h->h_addr_list[0], 4);

	bh = gethostbyname(bind_addr);
	if (!bh) {
		ulog_err("gethostbyname %s", bind_addr);
		return -1;
	}
	
	bsa.sin_family = AF_INET;
	bsa.sin_port = htons(bind_port);
	memcpy(&bsa.sin_addr.s_addr, bh->h_addr_list[0], 4);

	s = socket(AF_INET, SOCK_STREAM, 0);
	if (s == -1) {
		ulog_err("Failed to create a socket");
		return -1;
	}
#if 0
	if (bind(s, (struct sockaddr *)&bsa, sizeof(bsa)) == -1) {
		ulog_err("bind");
		close(s);
		return -1;
	}
#endif
	if (connect(s, (struct sockaddr *)&sa, sizeof(sa)) == -1) {
		ulog_err("connect");
		close(s);
		return -1;
	}

	return s;
}

static int aio_sendfile_wait(int ctl_fd, unsigned int timeout, unsigned int wait_num)
{
	int num, err = 0;
	struct kevent_user_control *ctl;
	struct ukevent *uk;
	int i;
	char buf[4096];

	ctl = (struct kevent_user_control *)buf;
	uk = (struct ukevent *)(ctl + 1);

	ctl->num = wait_num;
	ctl->timeout = timeout;
	ctl->cmd = KEVENT_CTL_WAIT;

	err = kevent_ctl(ctl_fd, buf);
	if (err < 0) {
		ulog_err("Failed to perform control operation");
		return err;
	}
	num = ctl->num;

	ulog("Wait: num=%d, ctl->num=%u.\n", num, ctl->num);
	uk = (struct ukevent *)(ctl+1);
	for (i=0; i<num; ++i) {
		ulog("%3u - %3u: id:%08x.%08x, ret_data:%08x.%08x, req_flags:%08x, ret_flags:%08x.\n", 
			i, ctl->num,
			uk[i].id.raw[0], uk[i].id.raw[1],
			uk[i].ret_data[0], uk[i].ret_data[1],
			uk[i].req_flags, uk[i].ret_flags);
	}

	return ctl->num;
}

static ssize_t test_aio_sendfile(int ctl, int *in, int *s, int num, int count)
{
	int err, i;
	ssize_t bytes = 0;
	unsigned flags = 0;

	for (i=0; i<num; ++i) {
		err = aio_sendfile(ctl, in[i], s[i], count, flags);
		if (err < 0) {
			ulog_err("sendfile");
			return err;
		}
	}

	aio_sendfile_wait(ctl, -1, num);

	return bytes;
}

static ssize_t test_sendfile(int *in, int *s, int num, int count)
{
	int err, i;
	ssize_t bytes = 0;

	while (!need_exit) {
		for (i=0; i<num; ++i) {
			err = sendfile(in[i], s[i], NULL, count);
			if (err <= 0) {
				ulog_err("sendfile");
				break;
			}
			bytes += err;
		}
	}

	return bytes;
}

static int kevent_init(void)
{
	struct kevent_user_control ctl;
	int ctl_fd;
	
	ctl.cmd = KEVENT_CTL_INIT;
	ctl.num = ctl.timeout = 0;
	
	ctl_fd = kevent_ctl(0, &ctl);
	if (ctl_fd < 0) {
		ulog_err("Failed to obtain kevent control descriptor");
		return -1;
	}

	return ctl_fd;
}

static void usage(const char *p)
{
	ulog("Usage: %s -a addr -p port -b bind_addr -B bind_port -f file -t test_type -n number_of_users\n", p);
}

int main(int argc, char *argv[])
{
	int s[MAX_CLIENT_NUM], in[MAX_CLIENT_NUM], ctl;
	size_t count = 1024*1024*1000;
	int type, ch, num, i;
	ssize_t bytes;
	char *addr, *bind_addr, *file;
	unsigned short port, bind_port;
	long utime;
	double speed;
	struct timeval tm1, tm2;

	addr = bind_addr = file = NULL;
	port = bind_port = 0;
	type = -1;
	num = 1;

	while ((ch = getopt(argc, argv, "n:a:p:f:t:b:B:h")) != -1) {
		switch (ch) {
			case 'a':
				addr = optarg;
				break;
			case 'p':
				port = atoi(optarg);
				break;
			case 'b':
				bind_addr = optarg;
				break;
			case 'B':
				bind_port = atoi(optarg);
				break;
			case 'f':
				file = optarg;
				break;
			case 't':
				type = atoi(optarg);
				break;
			case 'n':
				num = atoi(optarg);
				break;
			case 'h':
			default:
				usage(argv[0]);
				return -1;
		}
	}

	if (num > MAX_CLIENT_NUM) {
		ulog("Max number of client is %u.\n", MAX_CLIENT_NUM);
		usage(argv[0]);
		return -1;
	}

	if (!file || !addr || !bind_addr || !port || !bind_port || type == -1) {
		ulog("You need at least -f -a -b -B -t parameters.\n");
		usage(argv[0]);
		return -1;
	}

	signal(SIGINT, SIGINT_h);

	ctl = kevent_init();
	if (ctl < 0)
		return -1;

	for (i=0; i<num; ++i) {
		in[i] = open(file, O_RDONLY);
		if (in[i] == -1) {
			ulog_err("open");
			return -1;
		}

		s[i] = create_socket(addr, port, bind_addr, bind_port);
		if (s[i] < 0)
			return -1;
	}

	gettimeofday(&tm1, NULL);

	switch (type) {
		case TYPE_SENDFILE:
			bytes = test_sendfile(in, s, num, count);
			break;
		default:
		case TYPE_AIO:
			bytes = test_aio_sendfile(ctl, in, s, num, count);
			break;
	}

	gettimeofday(&tm2, NULL);

	utime = (tm2.tv_sec - tm1.tv_sec) * 1000000 + tm2.tv_usec - tm1.tv_usec;
	speed = (double)((double)bytes * 1000000) / (double)((double)utime * 1024 * 1024);

	ulog("transferred:%zd bytes, speed:%f Mb/sec, time:%ld usec.\n", 
			bytes, speed, utime);

	for (i=0; i<num; ++i) {
		close(s[i]);
		close(in[i]);
	}
	
	return 0;
}
