#include <string.h>
#include <stdlib.h>

#include "sankowski.h"
#include "matrix.h"
#include "mod.h"
#include "sm.h"

bool sherman_morrison_full(sankowski_t *s) {
	int64_t inv_denom;
	uint32_t i, j;
	uint32_t n = s->n;

	matrix_t *C = s->C;
	matrix_t *RV = s->rv;
	matrix_t *CV = s->cv;

	int64_t *c = C->data,
			*rv = RV->data,
			*cv = CV->data;

	// Used for storing A.u
	int64_t *au = malloc(n * sizeof(int64_t));
	memset(au, 0, n*sizeof(int64_t));

	// A^{-1} * u
	for (i = 0; i < n; i++) {
		for (j = 0; j < n; j++) {
			au[i] = addmod(au[i], mulmod(c[i*n + j], cv[j]));
		}
	}

	int64_t denom = 0;
	for (i = 0; i < n; i++) {
		// No modulo here, as $n$ would have to be insanely big for this to
		// overflow. SAVE ALL THE INSTRUCTIONS!
		denom += mulmod(au[i], rv[i]);
	}

	denom = mod(denom, p);

	denom = addmod(1, denom);

#ifndef NODET
	if (denom == 0) {
		free(au);
		return false;
	}
#endif

	inv_denom = divmod(1, denom);

	uint32_t counter = 0;

	// v^T * A^{-1}
	memset(cv, 0, n * sizeof(int64_t));

	for (i = 0; i < n; i++) {
		for (j = 0; j < n; j++) {
			// Create v^T*A
			cv[j] = addmod(cv[j], mulmod(c[i*n + j], rv[i]));
		}
	}

	for (i = 0; i < n; i++) {
		for (j = 0; j < n; j++) {
			int64_t r = submod(
				c[i*n + j],
				mulmod(
					mulmod(
						au[i],
						cv[j]
					),
					inv_denom
				)
			);

			c[i*n + j] = r;

			if (c[i*n + j] != 0) {
				counter++;
			}
		}
	}

	s->tc = counter;
	free(au);

	return true;
}

bool sherman_morrison(sankowski_t *s, uint32_t u, uint32_t v, int32_t delta) {
	int64_t inv_denom;
	uint32_t i, j;
	uint32_t n = s->n;

	matrix_t *C = s->C;
	matrix_t *RV = s->rv;
	matrix_t *CV = s->cv;

	int64_t *c = C->data,
			*rv = RV->data,
			*cv = CV->data;

	// A^{-1} * u
	for (i = 0; i < n; i++) {
		cv[i] = mulmod(c[i*n + u], delta);
	}

	int64_t denom = addmod(1, cv[v]);

#ifndef NODET
	if (denom == 0) {
		return false;
	}
#endif

	inv_denom = divmod(1, denom);

	uint32_t counter = 0;

	// v^T * A^{-1}
	memcpy(rv, &c[v*n], n * sizeof(int64_t));

	// A^{-1} - ((A^{-1} * u) * (v^T * A^{-1}) / denominator)

	for (i = 0; i < n; i++) {
		for (j = 0; j < n; j++) {
			int64_t r = submod(
				c[i*n + j],
				mulmod(
					mulmod(
						cv[i],
						rv[j]
					),
					inv_denom
				)
			);

			assert(r >= 0);
			assert(r < 2147483647);
			c[i*n + j] = r;

			if (c[i*n + j] != 0) {
				counter++;
			}
		}
	}

	s->tc = counter;

	return true;
}
