#include <stdio.h>
#include <inttypes.h>
#include <math.h>

#include "popcount.h"
#include "bitvector.h"
#include "xutil.h"

void succinct_popcount_preprocess(bitvector_t *restrict B) {
	uint64_t i, j, k = 0;
	uint64_t size;
	double logSize = log2(B->bits);
	uint64_t index = 0;
	uint64_t tmp = 0; uint64_t bits = B->bits;
	uint64_t sp, smodbit;
	uint64_t read;

	// Create table structure to construct O(1) lookups
	popcount_t *popcount = xmalloc(sizeof(popcount_t));
	popcount->b = floor((double)logSize/2);
	popcount->s = popcount->b*floor(logSize);
	popcount->logS = floor(log2(popcount->s)+1);
	popcount->bits = floor(log2(bits)+1);

	// Cache calculations
	smodbit = popcount->s % WORD;
	read = popcount->b;

	// (log_2(n)+1)*(n/s) = n/(0.5*(log_2(n)+1))
	size = ceil((double)((bits/popcount->s+1)*popcount->bits)/WORD)
			*sizeof(bitvector);
	popcount->Rs = xmalloc(size);
	xmemset(popcount->Rs, size);

	// n/b*log_2(s)
	size = ceil((((double)bits/popcount->b+1)*popcount->logS)/WORD)
			*sizeof(bitvector);
	popcount->Rb = xmalloc(size);
	xmemset(popcount->Rb, size);

	// Calculating Rs
	bitvector_set_bits(popcount->Rs, index, popcount->bits, 0);
	index += popcount->bits;
	for (i = 1; i <= floor(bits/popcount->s); i++) {
		sp = 0;

		for (j = 0; j < floor(popcount->s/WORD); j++) {
			sp += __builtin_popcountll(
				bitvector_get_bits(
						B->B, (i-1)*popcount->s+j*WORD, WORD	
				)
			);
		}

		if (smodbit != 0) {
			sp += __builtin_popcountll(
				bitvector_get_bits(B->B, (i-1)*popcount->s+j*WORD, smodbit)
			);
		}

		bitvector_set_bits(
				popcount->Rs, 
				index, 
				popcount->bits, 
				bitvector_get_bits(
					popcount->Rs, index-popcount->bits, popcount->bits
				) + sp
		);

		index += popcount->bits;
	}


	// Reset values to be used later
	index = 0;

	// Calculating Rb
	bitvector_set_bits(popcount->Rb, index, popcount->logS, 0);
	index += popcount->logS;

	for (i = 1; i <= floor(bits/popcount->b); i++) {
		sp = 0;
		j = floor(i/floor(logSize));

		if (tmp != j) {
			tmp = j;
			read = popcount->b;
			bitvector_set_bits(popcount->Rb, index, popcount->logS, 0);
		} else {
			smodbit = read % WORD;

			for (k = 0; k < floor(read/WORD); k++) {
				sp += __builtin_popcountll(
					bitvector_get_bits(B->B, j*popcount->s + k*WORD, WORD)
				);
			}

			if (smodbit) {
				sp += __builtin_popcountll(
					bitvector_get_bits(B->B, j*popcount->s + k*WORD, smodbit)
				);
			}

			bitvector_set_bits(popcount->Rb, index, popcount->logS, sp);

			read += popcount->b;
		}

		// Every number can at most fill log2(s) bits
		index += popcount->logS;
	}

	B->table = popcount;
}

void succinct_popcount_postprocess(bitvector_t *restrict B) {
	popcount_t *popcount = (popcount_t*) B->table;

	if (NULL != popcount) {
		if (NULL != popcount->Rs) {
			free(popcount->Rs);
		}
		
		if (NULL != popcount->Rb) {
			free(popcount->Rb);
		}

		free(popcount);
	}
}

/**
 * Returns amount of the 1's up until the i'th offset in the bitvector B
 */
uint64_t succinct_popcount_rank(struct succinct_t *restrict succ) {
	uint64_t i = succ->i;
	bitvector_t *restrict B = succ->B;
	popcount_t *restrict popcount = (popcount_t *)B->table;
	uint64_t s = popcount->s;
	uint64_t b = popcount->b;
	uint64_t logS = popcount->logS;
	uint64_t bits = popcount->bits;
	uint64_t idivb = i/b;
	uint64_t read = (idivb*b+b <= B->bits) ? b : B->bits-idivb*b;
	uint64_t S = (read > 0 && idivb*b+read <= B->bits) ? 
		bitvector_get_bits(B->B, idivb*b, read) << (b-read) : 0;

	if ( unlikely(i == 0) ) {
		return 0;
	}

	return bitvector_get_bits(popcount->Rs, (i/s)*bits, bits) + 
		   bitvector_get_bits(popcount->Rb, idivb*logS, logS) +
		   __builtin_popcountll((S >> (b-(i%b))));
}


/**
 * Returns offset of the j'th accurance of 1 in the bitvector B
 */
uint64_t succinct_popcount_select(struct succinct_t *restrict succ) {
	uint64_t j = succ->i;
	bitvector_t *restrict B = succ->B;
	popcount_t *restrict popcount = (popcount_t *)B->table;
	uint64_t s = popcount->s;
	uint64_t b = popcount->b;
	uint64_t bits = popcount->bits;
	uint64_t tmp1 = 0, tmp2 = 0;
	uint64_t read;
	uint64_t mid;
	int64_t l = 0;
	int64_t r = floor(B->bits/s);

	if ( unlikely(j == 0) ) {
		return 0;
	}

	if ( unlikely(j >= B->bits) ) {
		return B->bits;
	}

	// Binary search Rs
	while ( likely(r-l > 1) ) {
		mid = l + ((r-l) >> 1);
		tmp1 = bitvector_get_bits(popcount->Rs, mid*bits, bits); 

 		if (tmp1 < j) {
			l = mid;
		} else {
			r = mid;
		}
	}

	// Check r
	tmp1 = bitvector_get_bits(popcount->Rs, r*bits, bits);
	if (tmp1 < j) {
		l = r;
	} else {
		tmp1 = bitvector_get_bits(popcount->Rs, l*bits, bits);
	}
	mid = l*s;
	tmp2 = 0;
	read = (B->bits-mid > b) ? b : B->bits-mid;

	// Popcount each miniblock in Rs
	while ( likely(read > 0 && tmp1 + tmp2 < j) ) {
		read = (B->bits-mid > b) ? b : B->bits-mid;
		if ( unlikely(read == 0) ) {
			break;
		}
		tmp1 += tmp2;
		tmp2 = __builtin_popcountll(bitvector_get_bits(B->B, mid, read));
		mid += read;
	} 

	// If we have search to the end, return amount of bits in bitvector
	mid -= read;
	if ( unlikely(mid == B->bits) ) {
		return B->bits;
	}

	// Sequentially search miniblock containing result
	tmp2 = 0;
	while ( likely(tmp2 < b) ) {
		tmp1 += BIT(B, mid+1);
		if ( unlikely(tmp1 == j) ) {
			return mid+1;
		}
		tmp2++;
		mid++;
	}

	return B->bits;
}
