#!/usr/bin/python3

import sys
import HMM
import fasta
import util

from collections import deque


def trainByCounting(hmm, train):
    try:
        dna = fasta.read(train[0])
        ann = fasta.read(train[1])
    except Exception as e:
        print("Error reading trainingdata: ", e)
        sys.exit(1)

    seq = deque() # Sequence of states, for backtracking
    bt  = deque() # Backtracking queue
    state = {
            's': hmm['start'][ann[0]],  # Current state
            'b': []
            }

    btlength = 0
    i = 1

    while i < len(dna):
        a = ann[i]  # Next annotation
        d = dna[i]  # Next {A, C, G, T}

        #print("State:", state, "i:", i, "d:", d, "a:", a)

        prev = hmm['states'][state['s']]
        nextState = ""

        for v in prev['t']:             # Check all states reachable from prev
            n  = hmm['states'][v]       # Possible next

            if n['a'] == a:             # Check if next has correct annotation
                if v not in state['b']: # Check the state is not marked bad
                    if not v in hmm['static'] or d in hmm['static'][v]: # or emit next symbol
                        nextState = v
                        break

        if nextState == "":   # This means we can reach no states, backtrack :(
            nextState = state['s']

            if len(seq) == 0:
                print("Error: Seq is empty, we cannot backtrack further")
                sys.exit(1)

            state = seq.pop()

            #print("Just popped:", state)
            #print("{0} == {1}".format(i, len(seq)))
            state['b'].append(nextState)
            nextState = state['s']
            #i = len(seq);
            i -= 1

            btlength += 1

            #print("We could get no further after", i, "changes")
        else:
            seq.append(state)
            #print("Just appended:", state)
            state = {
                    's': nextState,
                    'b': [],
                    }
            i += 1;
            btlength = 0

        if btlength >= len(hmm['states']) + 1:
            print("We backtracked too much, stopping")
            sys.exit(1)

    seq.append(state)

    prev = seq.popleft()
    symb = dna[0]
    i = 1

    index = hmm['states'][prev['s']]['i']
    hmm['phi'][index][hmm['aIndex'][symb]] += 1

    while i < len(dna):
        state = seq.popleft()
        symb  = dna[i]

        previndex = hmm['states'][prev['s']]['i']
        curindex = hmm['states'][state['s']]['i']

        # Add to state changes
        hmm['A'][previndex][curindex] += 1

        # Add to emissions
        hmm['phi'][curindex][hmm['aIndex'][symb]] += 1


        i += 1
        prev = state

def normalize(hmm):
    for i in range(len(hmm['A'])):
        s = 0
        for c in hmm['A'][i]:
            s += c
        for j in range(len(hmm['A'][i])):
            if hmm['A'][i][j] != 0:
                hmm['A'][i][j] /= s

    for i in range(len(hmm['phi'])):
        s = 0
        for c in hmm['phi'][i]:
            s += c

        for j in range(len(hmm['phi'][i])):
            if hmm['phi'][i][j] != 0:
                hmm['phi'][i][j] /= s
