/*
 * 2007+ Copyright (c) Evgeniy Polyakov <johnpol@2ka.mipt.ru>
 * All rights reserved.
 * 
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 */

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/resource.h>
#include <sys/wait.h>
#include <sys/poll.h>

#include <linux/netlink.h>
#include <linux/rtnetlink.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <errno.h>
#include <string.h>
#include <fcntl.h>
#include <time.h>
#include <ctype.h>
#include <netdb.h>

#include <asm/byteorder.h>
#include <fs/pohmelfs/netfs.h>
#include <linux/connector.h>

#include "fserver.h"

#define SOL_NETLINK    270

static unsigned int pohmelfs_seq;

static void pohmelfs_display_function(struct pohmelfs_ctl *ctl)
{
	static int display_banner;
	struct sockaddr_in *sa;
	struct sockaddr_in6 *sa6;
	struct in6_addr *in6;
	struct sockaddr *saddr;
	char inet6_addr[128];

	if (!display_banner) {
		printf("Config Index = %d\n", ctl->idx);
		printf("Family    Server IP                                            Port     \n");
		display_banner = 1;
	}
	if (!ctl->proto)        /*Return if nothing valid is found*/
		return;
	memset(&inet6_addr, 0, sizeof(inet6_addr));
	saddr = (struct sockaddr *)&ctl->addr;
	if (ctl->addr.sa_family == AF_INET) {
		sa = (struct sockaddr_in *)saddr;
		printf("AF_INET   %s%46d\n",inet_ntoa(sa->sin_addr), ntohs(sa->sin_port));
	} else if (ctl->addr.sa_family == AF_INET6) {
		sa6 = (struct sockaddr_in6 *)saddr;
		in6 = &sa6->sin6_addr;
		sprintf(inet6_addr,  NIP6_FMT , NIP6(*in6));
		printf("AF_INET6  %s%18d\n",inet6_addr, ntohs(sa6->sin6_port));
	} else
		printf("Unknown   %s%18d\n",ctl->addr.addr, ctl->proto);
}

static int pohmelfs_recv_ack(int s, unsigned int flags)
{
	struct pollfd pfd;
	char buf[4096];
	struct pohmelfs_cn_ack *ack;
	struct nlmsghdr *nlh;
	int err;

	pfd.fd = s;
	pfd.events = POLLIN;
	pfd.revents = 0;

	while (1) {
		switch (poll(&pfd, 1, 1000)) {
			case 0:
				ulog("Timed out polling for ack\n");
				return -1;
			case -1:
				ulog_err("Error polling for ack\n");
				return -1;
		}

		memset(buf, 0, sizeof(buf));

		err = recv(s, buf, sizeof(buf), 0);
		if (err == -1) {
			ulog_err("recv from cn failed\n");
			return -1;
		}

		nlh = (struct nlmsghdr *)buf;

		switch (nlh->nlmsg_type) {
			case NLMSG_ERROR:
				ulog("Received error message rather than ack.\n");
				return -1;
			case NLMSG_DONE:
				ack = (struct pohmelfs_cn_ack *)NLMSG_DATA(nlh);

				/*
				 * XXX: worry about matching acks to the right request
				 * and resending if we don't get an ack.
				 */
				if (ack->msg.seq != pohmelfs_seq-1) {
					ulog("Uh oh... received ack for wrong seqnum (got %d, expected %d)"
						 " - bail for now\n", ack->msg.seq, pohmelfs_seq-1);
					return -1;
				}

				if (ack->msg.ack != 1) {
					ulog("Uh oh... received wrong ack (got %d, expected %d)"
						" for right seqnum - bail for now\n", ack->msg.ack, 1);
					return -1;
				}

				errno = ack->error;
				if(errno)
					return errno;
				if (flags == POHMELFS_FLAGS_SHOW) {
					 pohmelfs_display_function(&ack->ctl);
				}
				if (!ack->msg_num)
					return 0;
				break;
			default:
				ulog("Received unrecognised message type %d rather than ack\n", nlh->nlmsg_type);
				return -1;
		}
	}
	return -1;
}

static int pohmelfs_netlink_send(int s, unsigned int flags, void *data, unsigned int len)
{
	struct nlmsghdr *nlh;
	unsigned int size;
	int err;
	char buf[4096];
	struct cn_msg *m;

	size = NLMSG_SPACE(sizeof(struct cn_msg) + len);

	nlh = (struct nlmsghdr *)buf;
	nlh->nlmsg_seq = pohmelfs_seq++;
	nlh->nlmsg_pid = getpid();
	nlh->nlmsg_type = NLMSG_DONE;
	nlh->nlmsg_len = NLMSG_LENGTH(size - sizeof(*nlh));
	nlh->nlmsg_flags = 0;

	m = NLMSG_DATA(nlh);

	m->id.idx = POHMELFS_CN_IDX;
	m->id.val = POHMELFS_CN_VAL;

	m->seq = nlh->nlmsg_seq;
	m->ack = 0;
	m->len = len;
	m->flags = flags;

	memcpy(m->data, data, len);

	err = send(s, nlh, size, 0);
	if (err == -1) {
		ulog("Failed to send: %s [%d].\n",
			strerror(errno), errno);
		return err;
	}

	return pohmelfs_recv_ack(s, flags);
}

static int pohmelfs_sock_init(struct  pohmelfs_ctl *ctl, char *addr, unsigned short port)
{
	struct addrinfo *h, hints;
	struct sockaddr_in sa;
	struct sockaddr_in6 sa6;

	memset(&hints, '\0', sizeof(hints));
	hints.ai_protocol = IPPROTO_TCP;

	if (getaddrinfo(addr, "", &hints, &h)) {
		ulog_err("%s: Failed to get address of '%s'.\n", __func__, addr);
		return -1;
	}

	if (h->ai_family == AF_INET) {
		memcpy(&sa, (struct sockaddr_in *)h->ai_addr, sizeof(struct sockaddr_in));
		sa.sin_port = htons(port);
		sa.sin_family = AF_INET;
		memcpy(&(ctl->addr), &sa, sizeof(sa));
		ctl-> addrlen = sizeof(sa);
	} else {
		memcpy(&sa6, (struct sockaddr_in6 *)h->ai_addr, sizeof(struct sockaddr_in6));
		sa6.sin6_port = htons(port);
		sa6.sin6_family = AF_INET6;
		memcpy(&(ctl->addr), &sa6, sizeof(sa6));
		ctl-> addrlen = sizeof(sa6);
	}

	freeaddrinfo(h);
	return 0;
}

static int pohmelfs_setup_ctl(struct pohmelfs_ctl *ctl, char *addr, int port, unsigned int idx)
{
	int err;

	err = pohmelfs_sock_init(ctl, addr, port);
	if (err)
		return err;

	ctl->type = SOCK_STREAM;
	ctl->proto = IPPROTO_TCP;
	ctl->idx = idx;

	return 0;
}

static int pohmelfs_show_remote(int s, unsigned int idx)
{
	struct pohmelfs_ctl ctl;
	int err;
	memset(&ctl, 0, sizeof(struct pohmelfs_ctl));
	ctl.idx = idx;
	err =  pohmelfs_netlink_send(s, POHMELFS_FLAGS_SHOW,
			&ctl, sizeof(struct pohmelfs_ctl));

	return err;
}

static int pohmelfs_add_remote(int s, char *addr, int port, unsigned int idx, int action)
{
	int err;
	struct pohmelfs_ctl ctl;

	err = pohmelfs_setup_ctl(&ctl, addr, port, idx);
	if (err)
		return err;
	if (action == POHMELFS_FLAGS_ADD)
		return pohmelfs_netlink_send(s, POHMELFS_FLAGS_ADD,
				&ctl, sizeof(struct pohmelfs_ctl));
	else
		return pohmelfs_netlink_send(s, POHMELFS_FLAGS_DEL,
				&ctl, sizeof(struct pohmelfs_ctl));
}

static int pohmelfs_add_crypto(int s, char *algo, char *key_file,
		unsigned int type, unsigned int idx)
{
	unsigned char buf[4096], *ptr;
	struct pohmelfs_crypto *c = (struct pohmelfs_crypto *)buf;
	int fd, err;

	c->idx = idx;
	c->type = type;

	ptr = c->data;
	c->strlen = sprintf((char *)ptr, "%s", algo) + 1;
	ptr += c->strlen;

	fd = open(key_file, O_RDONLY);
	if (fd == -1) {
		ulog_err("%s: failed to open key file '%s' for algo '%s'",
				__func__, key_file, algo);
		return -1;
	}

	err = read(fd, ptr, 128);
	if (err <= 0) {
		ulog_err("%s: failed to read key from file '%s' for algo '%s'",
				__func__, key_file, algo);
		close(fd);
		return -1;
	}

	close(fd);

	c->keysize = err;
	
	return pohmelfs_netlink_send(s, POHMELFS_FLAGS_CRYPTO, c,
			sizeof(struct pohmelfs_crypto) + c->keysize + c->strlen);
}

static void pohmelfs_usage(char *p)
{
	ulog("Usage: %s -A[ction]{add/del/show} -a[ddress]{ipv4/ipv6} -p[ort] -i[ndex]"
		" -k[cipher_key_file] -K [hash_key_file] -C[ipher] -H[ash] -h\n", p);

}

int main(int argc, char *argv[])
{
	int ch, port, err, s;
	unsigned int idx;
	char *addr, *action, *cipher_key, *hash_key, *cipher, *hash;
	struct sockaddr_nl l_local;

	addr = NULL;
	port = -1;
	idx = 0;
	action = cipher_key = hash_key = cipher = hash = NULL;

	while ((ch = getopt(argc, argv, "A:a:p:i:k:K:C:H:h")) > 0) {
		switch (ch) {
			case 'i':
				idx = atoi(optarg);
				break;
			case 'C':
				cipher = optarg;
				break;
			case 'H':
				hash = optarg;
				break;
			case 'K':
				hash_key = optarg;
				break;
			case 'k':
				cipher_key = optarg;
				break;
			case 'a':
				addr = optarg;
				break;
			case 'p':
				port = atoi(optarg);
				break;
			case 'A':
				action = optarg;
				break;
			default:
				pohmelfs_usage(argv[0]);
				return -1;
		}
	}

	s = socket(PF_NETLINK, SOCK_DGRAM, NETLINK_CONNECTOR);
	if (s == -1) {
		perror("socket");
		return -1;
	}

	l_local.nl_family = AF_NETLINK;
	l_local.nl_groups = 1<<POHMELFS_CN_IDX; /* bitmask of requested groups */
	l_local.nl_pid = getpid();

	if (bind(s, (struct sockaddr *)&l_local, sizeof(struct sockaddr_nl)) == -1) {
		perror("bind");
		close(s);
		return -1;
	}

	l_local.nl_groups = POHMELFS_CN_IDX;
	if (setsockopt(s, SOL_NETLINK, NETLINK_ADD_MEMBERSHIP,
		&l_local.nl_groups, sizeof(l_local.nl_groups))) {
		perror("setsockopt");
		close(s);
		return -1;
	}

	err = -EINVAL;

	if (action) {
		if (!strncmp(action, "add", 3)) {
			if (addr && port != -1)
				err = pohmelfs_add_remote(s, addr, port, idx, POHMELFS_FLAGS_ADD);
			if ((hash && hash_key))
				err = pohmelfs_add_crypto(s, hash, hash_key, POHMELFS_CRYPTO_HASH, idx);
			if ((cipher && cipher_key))
				err = pohmelfs_add_crypto(s, cipher, cipher_key, POHMELFS_CRYPTO_CIPHER, idx);
			if (err)
				goto out;
		} else if (!strncmp(action, "del", 3)) {
			if (cipher_key || cipher || hash || hash_key != NULL) {
				ulog("cipher and hash parameters are needed for deletion\n");
				pohmelfs_usage(argv[0]);
			} else if (addr && port != -1) {
				err = pohmelfs_add_remote(s, addr, port, idx, POHMELFS_FLAGS_DEL);
				if (err)
					goto out;
			} else {
				pohmelfs_usage(argv[0]);
			}
		} else if (!strncmp(action, "show", 4)) {
			if (addr || cipher_key || cipher || hash || hash_key != NULL) {
				ulog("address, cipher and hash parameters are needed for show\n");
				pohmelfs_usage(argv[0]);
			} else {
				err = pohmelfs_show_remote(s, idx);
				if (err)
					goto out;
			}
		} else {
			pohmelfs_usage(argv[0]);
		}
	} else {
		pohmelfs_usage(argv[0]);
	}

out:
	close(s);

	if (err) {
		ulog("%s: err: %d.\n", __func__, err);
	}

	return err;
}
