/*
 * 2007+ Copyright (c) Evgeniy Polyakov <johnpol@2ka.mipt.ru>
 * All rights reserved.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 */

#include <sys/types.h>
#include <sys/stat.h>
#include <sys/mman.h>

#include <fcntl.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <unistd.h>
#include <time.h>
#include <errno.h>

//#define CODE_DEBUG

#ifdef CODE_DEBUG
#define ulog(f, a...) fprintf(stdout, f, ##a)
#else
#define ulog(f, a...) 	do {} while (0)
#endif

#define ulog_err(f, a...) fprintf(stderr, f ": %s [%d].\n", ##a, strerror(errno), errno)

#define MAX_SPARSE_ORDER	4
#define MAX_SPARSE_WEIGHT	(1<<4)

static inline void ldpc_set_empty_pos(unsigned int *x, int empty_pos)
{
	unsigned int tmp = *x;

	tmp &= ((1<<MAX_SPARSE_WEIGHT) - 1);
	tmp |= ((empty_pos)) << MAX_SPARSE_WEIGHT;

	if (empty_pos == -1)
		tmp |= 1 << (MAX_SPARSE_ORDER + MAX_SPARSE_WEIGHT);
	else
		tmp &= ~(1 << (MAX_SPARSE_ORDER + MAX_SPARSE_WEIGHT));

	*x = tmp;
}

static inline int ldpc_get_empty_pos(unsigned int *x)
{
	int empty_pos = ((*x) >> MAX_SPARSE_WEIGHT) & ~(1<<MAX_SPARSE_ORDER);

	if ((*x) & (1 << (MAX_SPARSE_ORDER + MAX_SPARSE_WEIGHT)))
		return -1;
	return empty_pos;
}

static inline int ldpc_get_response_bit(unsigned int *x)
{
	int empty_pos = ldpc_get_empty_pos(x);
	int bit;

	if (empty_pos < 0)
		return empty_pos;

	bit = (((*x) >> empty_pos) & 1);
	ldpc_set_empty_pos(x, empty_pos - 1);
	return bit;
}

static inline int ldpc_add_response_bit(unsigned int *x, int bit)
{
	int empty_pos = ldpc_get_empty_pos(x);
	unsigned int tmp = *x;
	
	empty_pos++;
	if (bit) 
		tmp |= 1<<empty_pos;
	else 
		tmp &= ~(1<<empty_pos);

	ldpc_set_empty_pos(&tmp, empty_pos);

	*x = tmp;

	return empty_pos;
}

struct sparse_matrix
{
	int		word_bits, weight, checks;
	/*
	 * Each entry is decoded this way:
	 *  [0 			: @MAX_SPARSE_WEIGHT] 			- response bits from check nodes of the Tanner's graph
	 *  [@MAX_SPARSE_WEIGHT : @MAX_SPARSE_WEIGHT+@MAX_SPARSE_ORDER]	- index of the first empty slot for response
	 */
	unsigned int	*tmp; 
	int		**bits;
};

static struct sparse_matrix *sparse_matrix_alloc(int word_bits, int weight)
{
	struct sparse_matrix *m;
	int i;

	if (weight > MAX_SPARSE_WEIGHT) {
		fprintf(stderr, "Too big weight %d, max: %d.\n",
				weight, MAX_SPARSE_WEIGHT);
		return NULL;
	}

	m = malloc(sizeof(struct sparse_matrix) + word_bits*sizeof(unsigned int));
	if (!m)
		return NULL;

	m->tmp = (unsigned int *)(m+1);
	m->word_bits = word_bits;
	m->weight = weight;
	
	memset(m->tmp, 0, word_bits*sizeof(unsigned int));
	
	m->checks = m->word_bits;
	//m->checks = m->word_bits;

	m->bits = malloc(m->checks * sizeof(void *));
	if (!m->bits)
		goto err_out_free;

	for (i=0; i<m->checks; ++i) {
		m->bits[i] = malloc(weight*sizeof(int));
		if (!m->bits[i])
			goto err_out_free_bits;
		memset(m->bits[i], 0, weight*sizeof(int));
	}
	
	return m;

err_out_free_bits:
	while (--i >= 0)
		free(m->bits[i]);
	free(m->bits);
err_out_free:
	free(m);
	return NULL;
}

static void sparse_matrix_free(struct sparse_matrix *m)
{
	int i;

	for (i=0; i<m->checks; ++i)
		free(m->bits[i]);
	free(m->bits);
	free(m);
}

static void ldpc_dump_word(unsigned char *data, int size, char *name)
{
	int fd;

	fd = open(name, O_RDWR | O_TRUNC | O_CREAT, 0644);
	if (fd == -1)
		return;

	while (size) {
		int err = write(fd, data, size);
		if (err <= 0)
			break;

		data += err;
		size -= err;
	}

	close(fd);
}

struct gen_row
{
	int		num, total;
	int		*w;
};

static int ldcp_generate_sparse_matrix(struct sparse_matrix *sm, int m, int n, int rw, int cw)
{
	int i, j, *gc, err = -1;
	struct gen_row *g, *r;

	if (!cw)
		cw = 1;
	
	printf("m: %d, n: %d, rw: %d, cw: %d.\n", m, n, rw, cw);

	while (cw * n < rw * m) {
		rw--;
		sm->weight--;
		printf("Invalid parameters, reducing row weight to %d.\n", rw);
		if (rw < 0)
			return -1;

		cw = rw*m/n;
	}


	g = malloc(sizeof(struct gen_row) * m);
	if (!g)
		goto err_out_exit;
	memset(g, 0, sizeof(struct gen_row) * m);

	for (i=0; i<m; ++i) {
		g[i].w = malloc(sizeof(int) * n);
		if (!g[i].w)
			goto err_out_free_g;
		for (j=0; j<n; ++j)
			g[i].w[j] = j;
		g[i].num = 0;
		g[i].total = n;
	}

	gc = malloc(sizeof(int) * n);
	if (!gc)
		goto err_out_free_g;
	memset(gc, 0, sizeof(int) * n);

	for (i=0; i<m; ++i) {
		for (j=0; j<rw; ++j) {
			int tmp, k, pos, tmp_pos;

			r = &g[i];

			pos = -1;
			if (i >= m/2 && 0) {
				for (k=0; k<r->total; ++k) {
					if (gc[r->w[k]] == 0) {
						pos = k;
						break;
					}
				}
			}
			
			if (pos == -1)
				pos = (int) (((double)(r->total)) * rand() / (RAND_MAX + 1.0));

			ulog("%d/%d: ", i, m);
			for (k=0; k<r->total; ++k)
				ulog("%d ", r->w[k]);
			ulog("\n");

			ulog("%d/%d: %d/%d, total: %d, pos: %d, %d <-> %d, gc[%d]: %d.\n",
					i, m, j, rw, r->total, pos, 
					r->w[r->total - 1], r->w[pos],
					r->w[pos], gc[r->w[pos]]);

			tmp = r->w[pos];
			tmp_pos = r->total - 1;
			r->w[pos] = r->w[tmp_pos];
			r->total--;

			sm->bits[i][j] = tmp;
			for (k=i+1; k<m; ++k) {
				int tmp1, t;

				r = &g[k];

				tmp1 = r->w[pos];
				r->w[pos] = r->w[tmp_pos];
				r->w[tmp_pos] = tmp1;
				
				ulog("  %d/%d, total: %d, %d <- %d: ", 
						k, m, r->total, r->w[pos], r->w[tmp_pos]);
				for (t=0; t<r->total; ++t)
					ulog("%d ", r->w[t]);
				ulog("\n");
			}

			if (++gc[tmp] == cw) {
				ulog("removing column %d [%d].\n", pos, tmp);
				for (k=i+1; k<m; ++k) {
					int t;
					r = &g[k];

					r->w[tmp_pos] = r->w[r->total - 1];
					r->total--;
				
					ulog("  %d/%d, total: %d, %d <- %d: ", 
							k, m, r->total, r->w[pos], r->w[r->total]);
					for (t=0; t<r->total; ++t)
						ulog("%d ", r->w[t]);
					ulog("\n");
				}
			}
		}
		ulog("%d/%d, total: %d.\n\n", i, m, g[i].total);
	}

	err = 0;

	i = m;
err_out_free_g:
	while (--i >= 0)
		free(g[i].w);
	free(g);
err_out_exit:
	return err;
}

static struct sparse_matrix *ldpc_generate_matrix(int ldpc_init, int word_bits, int weight)
{
	struct sparse_matrix *m;
	int i, j, k;

	m = sparse_matrix_alloc(word_bits, weight);
	if (!m)
		return NULL;

	if (!ldpc_init)
		ldpc_init = time(NULL);
	srand(ldpc_init);
	printf("Initialization constant: %u, checks: %d, weight: %d.\n", 
			ldpc_init, m->checks, m->weight);

	ulog("Generation matrix.\n");
	if (ldcp_generate_sparse_matrix(m, m->checks, m->word_bits, m->weight, m->weight*m->checks/m->word_bits)) {
		sparse_matrix_free(m);
		return NULL;
	}
#if 0
	for (i=0; i<m->checks; ++i) {
		for (k=0; k<m->word_bits; ++k) {
			int found = 0;
			for (j=0; j<m->weight; ++j) {
				if (m->bits[i][j] == k)
					found = 1;
			}
			
			printf("%d ", found);
		}
		printf("\n");
	}
#endif
	return m;
}

static unsigned char *ldpc_generate_checkword(unsigned char *data, struct sparse_matrix *m)
{
	int idx, off, i, j, bit, bytes;
	unsigned char *check;

	printf("Generating checksum...\n");

	bytes = ((m->checks + 7) & ~7)/8;

	check = malloc(bytes);
	if (!check)
		return NULL;
	memset(check, 0, bytes);

	for (i=0; i<m->checks; ++i) {
		bit = 0;
		for (j=0; j<m->weight; ++j) {
			idx = m->bits[i][j] / 8;
			off = m->bits[i][j] % 8;
			bit ^= (data[idx] >> off) & 1;
#if 0
			ulog("%d.%d: bits: %d, idx: %d, off: %d, data: %x, data_bit: %d, bit: %d.\n",
					i, j, m->bits[i][j], idx, off,
					data[idx], (data[idx] >> off) & 1, bit);
#endif
		}

		if (bit) {
			idx = i / 8;
			off = i % 8;

			check[idx] |= 1 << off;
		}
	}

	ulog("Checksum: 0x");
	for (i=0; i<bytes; ++i)
		ulog("%02x", check[i]);
	ulog("\n");

	return check;
}

static void ldpc_generate_response(int *f, int check_bit, struct sparse_matrix *m)
{
	int i, j;
	int resp[m->weight];

	for (i=0; i<m->weight; ++i) {
		resp[i] = check_bit;

		for (j=0; j<m->weight; ++j) {
			if (i != j) {
				resp[i] ^= f[j];
			}
		}
	}

	for (i=0; i<m->weight; ++i)
		f[i] = resp[i];
}

static int ldpc_correct_data(unsigned char *data, struct sparse_matrix *m)
{
	int i, j, idx, off, changed = 0, old_bit, new_bit;
	int num_one, num_zero;
	
	for (i=0; i<m->word_bits; ++i) {
		idx = i / 8;
		off = i % 8;

		num_one = num_zero = 0;;
		new_bit = old_bit = (data[idx] >> off) & 1;
		if (old_bit)
			num_one++;
		else
			num_zero++;

		ulog("%s: %3d/%3d: %d ", __func__, i, m->word_bits, old_bit);
		for (j=0; j<m->weight; ++j) {
			int bit = ldpc_get_response_bit(&m->tmp[i]);
			if (bit < 0) {
				ulog("  ");
				continue;
			}

			if (bit)
				num_one++;
			else
				num_zero++;
			ulog("%d ", bit);
		}

		if (num_one > num_zero) {
			new_bit = 1;
			data[idx] |= 1 << off;
		} else if (num_one < num_zero) {
			new_bit = 0;
			data[idx] &= ~(1 << off);
		}

		ulog("| %d => %d ", old_bit, new_bit);

		if (old_bit != new_bit) {
			changed++;
			ulog("+");
		}
		ulog("\n");
	}
	
	return changed;
}

static int ldpc_decode_hard(unsigned char *data, unsigned char *check, struct sparse_matrix *m)
{
	int i, j, check_bit, idx, off;
	int f[m->weight];

	for (i=0; i<m->word_bits; ++i) {
		m->tmp[i] = 0;
		ldpc_set_empty_pos(&m->tmp[i], -1);
	}

	for (i=0; i<m->checks; ++i) {
		ulog("%s: %2d/%2d: ", __func__, i, m->checks);
		for (j=0; j<m->weight; ++j) {
			idx = m->bits[i][j] / 8;
			off = m->bits[i][j] % 8;

			f[j] = (data[idx] >> off) & 1;
			ulog("%d ", f[j]);
		}
		ulog("| ");

		idx = i / 8;
		off = i % 8;

		check_bit = (check[idx] >> off) & 1;

		ldpc_generate_response(f, check_bit, m);

		for (j=0; j<m->weight; ++j) {
			idx = m->bits[i][j];
			ldpc_add_response_bit(&m->tmp[idx], f[j]);
			ulog("%d.%x.%d ", idx, m->tmp[idx], f[j]);
		}
		ulog("\n");
	}
	
	return ldpc_correct_data(data, m);
}

static void ldpc_usage(char *p)
{
	fprintf(stderr, "Usage: %s -i seed -n runs -j weight -b word_bits -l noise_level -f file -h\n", p);
}

int main(int argc, char *argv[])
{
	int word_bits, weight, init, changed, total_runs, fd, noise;
	char *file;
	unsigned char *check, *data;
	long long data_size;
	int ch;
	struct sparse_matrix *m;

	total_runs = 3;
	init = 0;
	word_bits = 0;
	file = NULL;
	noise = 10;
	weight = 3;

	while ((ch = getopt(argc, argv, "l:f:n:i:j:b:h")) != -1) {
		switch (ch) {
			case 'l':
				noise = atoi(optarg);
				break;
			case 'f':
				file = optarg;
				break;
			case 'n':
				total_runs = atoi(optarg);
				break;
			case 'i':
				init = atoi(optarg);
				break;
			case 'j':
				weight = atoi(optarg);
				break;
			case 'b':
				word_bits = atoi(optarg);
				break;
			case 'h':
			default:
				ldpc_usage(argv[0]);
				return -1;
		}
	}

	if (noise > 100 || noise < 0) {
		fprintf(stderr, "Wrong parameters, noise must be more than zero and less than 100.\n");
		ldpc_usage(argv[0]);
		return -1;
	}

	if (!file) {
		fprintf(stderr, "Wrong parameters, you need to specify filename.\n");
		ldpc_usage(argv[0]);
		return -1;
	}
#if 1
	fd = open(file, O_RDWR);
	if (fd == -1)
		return -1;

	data_size = lseek(fd, 0, SEEK_END);
	if (data_size == -1) {
		fprintf(stderr, "Failed to determine file size.\n");
		return -1;
	}

	if (!word_bits)
		word_bits = data_size*8;

	data_size = (data_size + 4096)&~0xfffULL;
	
	data = mmap(NULL, data_size, PROT_READ|PROT_WRITE, MAP_SHARED, fd, 0);
	if (data == MAP_FAILED) {
		ulog_err("Failed to map %lld bytes of file '%s'", data_size, file);
		return -1;
	}
#else
	data_size = 1024/8;

	data = malloc(data_size);
	if (!data)
		return -1;

	memset(data, 0, data_size);
	
	word_bits = data_size*8;
#endif
	printf("LDPC: weight: %u, word_bits: %u, file: %s, noise_level: %d.\n",
			weight, word_bits, file, noise);
	
	{
		int i, j, nonzero;

		nonzero = 0;
		for (i=0; i<data_size; ++i) {
			for (j=0; j<8; ++j) {
				if ((data[i] >> j) & 1) {
					nonzero++;
				}
			}
		}
		printf("Nonzero bits: %d.\n", nonzero);
	}


	noise = word_bits * noise / 100;

	m = ldpc_generate_matrix(init, word_bits, weight);
	if (!m)
		return -1;

	check = ldpc_generate_checkword(data, m);
	if (!check)
		return -1;

	printf("Generating noise (%d bits changed)... ", noise);
	while (noise-- > 0) {
		int pos = (int) ((m->word_bits + m->checks) * (rand() / (RAND_MAX + 1.0)));
		int idx, off;
		char prefix;

		if (pos >= m->word_bits) {
			pos -= m->word_bits;
			idx = pos / 8;
			off = pos % 8;

			check[idx] ^= 1 << off;
			prefix='C';
		} else {
			idx = pos / 8;
			off = pos % 8;

			data[idx] ^= 1 << off;
			prefix='D';
		}
		ulog("%c%d ", prefix, pos);
	}
	printf("\n");
	
	ldpc_dump_word(data, m->word_bits/8, "modified.save");

	printf("Decoding has been started.\n");

	while (--total_runs >= 0) {
		char name[32];

		changed = ldpc_decode_hard(data, check, m);

		snprintf(name, sizeof(name), "%d.save", total_runs);

		{
			int i, j, nonzero;

			nonzero = 0;
			for (i=0; i<data_size; ++i) {
				for (j=0; j<8; ++j) {
					if ((data[i] >> j) & 1)
						nonzero++;
				}
			}
			printf("Changed %d bits, saving to %s, nonzero bits: %d.\n", changed, name, nonzero);
		}

		if (changed == 0)
			break;

		ldpc_dump_word(data, m->word_bits/8, name);

	}
	ldpc_dump_word(data, m->word_bits/8, "last.save");

	return 0;
}
