#include <sys/sendfile.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/time.h>

#include <fcntl.h>
#include <netdb.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <errno.h>
#include <unistd.h>
#include <signal.h>

#include <linux/socket.h>
#include <linux/types.h>
#include <net/zerocopy.h>

#define TYPE_SENDFILE		0
#define TYPE_RW			1

#define ulog_err(f, a...) fprintf(stderr, f ": %s [%d].\n", ##a, strerror(errno), errno)
#define ulog(f, a...) fprintf(stderr, f, ##a)

static int need_exit;

void SIGINT_h(int signo)
{
	need_exit = signo;
}

static inline uint32_t num2ip(uint8_t a, uint8_t b, uint8_t c, uint8_t d)
{
	uint32_t ret = 0;

	ret |= a;
	ret <<= 8;
	ret |= b;
	ret <<= 8;
	ret |= c;
	ret <<= 8;
	ret |= d;

	return ret;
}

static int setup_zerocopy(__u32 saddr, __u16 sport, __u32 daddr, __u16 dport, int in, int s, int pnum, int op)
{
	unsigned char buf[sizeof(struct sock_zc_setup_data) + sizeof(struct tcp_udp_v4_priv)];
	struct sock_zc_setup_data *setup_data;
	struct tcp_udp_v4_priv *priv;
	
	setup_data = (struct sock_zc_setup_data *)buf;
	priv = (struct tcp_udp_v4_priv *)(setup_data + 1);

	setup_data->op = op;
	setup_data->type = IPPROTO_TCP;
	setup_data->size = htonl(sizeof(struct tcp_udp_v4_priv));
	priv->dst = daddr;
	priv->dport = dport;
	priv->src = saddr;
	priv->sport = sport;
	priv->fd = in;
	priv->pnum = pnum;

	ulog("%08x -> %08x.\n", priv->src, priv->dst);
	
	err = setsockopt(s, SOL_SOCKET, SO_ZEROCOPY, 
			setup_data, sizeof(struct sock_zc_setup_data) + sizeof(struct tcp_udp_v4_priv));
	if (err) {
		ulog_err("Failed to setup zero-copy socket");
		return err;
	}

	return 0;
}

int create_socket(char *addr, unsigned short port, char *bind_addr, unsigned short bind_port, int in, int pnum, int type)
{
	int s, err;
	struct hostent *h, *bh;
	struct sockaddr_in sa, bsa;

	ulog("%s:%u -> %s:%u.\n", bind_addr, bind_port, addr, port);
	
	h = gethostbyname(addr);
	if (!h) {
		ulog_err("gethostbyname %s", addr);
		return -1;
	}

	sa.sin_family = AF_INET;
	sa.sin_port = htons(port);
	memcpy(&sa.sin_addr.s_addr, h->h_addr_list[0], 4);
	
	bh = gethostbyname(bind_addr);
	if (!bh) {
		ulog_err("gethostbyname %s", bind_addr);
		return -1;
	}
	
	bsa.sin_family = AF_INET;
	bsa.sin_port = htons(bind_port);
	memcpy(&bsa.sin_addr.s_addr, bh->h_addr_list[0], 4);

	s = socket(AF_INET, SOCK_STREAM, 0);
	if (s == -1) {
		ulog_err("Failed to create a socket");
		return -1;
	}

	if (bind(s, (struct sockaddr *)&bsa, sizeof(bsa)) == -1) {
		ulog_err("bind");
		close(s);
		return -1;
	}

	if (type == TYPE_SENDFILE) {
		err = setup_zerocopy(bsa.sin_addr.s_addr, bsa.sin_port, sa.sin_addr.s_addr, sa.sin_port, 
				in, s, pnum, ZC_OP_SETUP);
		if (err) {
			close(s);
			return -1;
		}
	}

	if (connect(s, (struct sockaddr *)&sa, sizeof(sa)) == -1) {
		ulog_err("connect");
		setup_zerocopy(bsa.sin_addr.s_addr, bsa.sin_port, sa.sin_addr.s_addr, sa.sin_port, 
				in, s, pnum, ZC_OP_CLEANUP);
		close(s);
		return -1;
	}

	return s;
}

static ssize_t test_write(int in, int s, int count)
{
	char *buf, *ptr;
	int err, err1;
	int sz = 4096;
	ssize_t bytes = 0;

	buf = malloc(sz);
	if (!buf) {
		ulog("Failed to allocate buffer of %d bytes.\n", count);
		return -ENOMEM;
	}
	
	while (!need_exit) {
		err = recv(s, buf, sz, 0);
		if (err <= 0) {
			ulog_err("recv");
			break;
		}

		count -= err;
		bytes += err;

		ptr = buf;
		while (err) {
			err1 = write(in, ptr, err);
			if (err1 <= 0)
				break;
			err -= err1;
			ptr += err1;
		}
	}

	free(buf);

	return bytes;
}

static ssize_t test_sendfile(int in, int s, int count)
{
	int err;
	ssize_t bytes = 0;

	while (!need_exit) {
		err = sendfile(in, s, NULL, count);
		if (err <= 0) {
			ulog_err("sendfile");
			break;
		}

		bytes += err;
	}

	return bytes;
}

static void usage(const char *p)
{
	ulog("Usage: %s\n", p);
}

int main(int argc, char *argv[])
{
	int s, in, ch;
	size_t count = 1024*1024*1000;
	int type, pnum;
	ssize_t bytes;
	char *addr, *bind_addr, *file;
	unsigned short port, bind_port;
	long utime;
	double speed;
	struct timeval tm1, tm2;

	addr = bind_addr = file = NULL;
	port = bind_port = 0;
	type = -1;
	pnum = 32;

	while ((ch = getopt(argc, argv, "a:p:f:t:b:B:n:h")) != -1) {
		switch (ch) {
			case 'a':
				addr = optarg;
				break;
			case 'p':
				port = atoi(optarg);
				break;
			case 'b':
				bind_addr = optarg;
				break;
			case 'B':
				bind_port = atoi(optarg);
				break;
			case 'f':
				file = optarg;
				break;
			case 't':
				type = atoi(optarg);
				break;
			case 'n':
				pnum = atoi(optarg);
				break;
			case 'h':
			default:
				usage(argv[0]);
				return -1;
		}
	}

	if (!file || !addr || !bind_addr || !port || !bind_port || type == -1) {
		usage(argv[0]);
		return -1;
	}

	signal(SIGINT, SIGINT_h);
	
	in = open(file, O_RDWR | O_TRUNC | O_CREAT, 0644);
	if (in == -1) {
		ulog_err("open");
		return -1;
	}

	s = create_socket(addr, port, bind_addr, bind_port, in, pnum, type);
	if (s < 0)
		return s;

	gettimeofday(&tm1, NULL);

	switch (type) {
		case TYPE_SENDFILE:
			bytes = test_sendfile(in, s, count);
			break;
		default:
		case TYPE_RW:
			bytes = test_write(in, s, count);
			break;
	}
	
	gettimeofday(&tm2, NULL);

	utime = (tm2.tv_sec - tm1.tv_sec) * 1000000 + tm2.tv_usec - tm1.tv_usec;
	speed = (double)((double)bytes * 1000000) / (double)((double)utime * 1024 * 1024);
	
	ulog("transferred:%zd bytes, speed:%f Mb/sec, time:%ld usec.\n", 
			bytes, speed, utime);
	
	close(s);
	close(in);
	
	return 0;
}
