# -*- coding: utf-8 -*-
"""
Created on Fri Aug 26 10:58:19 2016

@author: fl207
"""

# -*- coding: utf-8 -*-
"""
    Created in Thu March  22 10:47:00 2016
    
    @author: Baseline: Remi Eyraud & Sicco Verwer
    Added the smoothig techniques by Farhana Ferdousi Liza
    
    Usage: python 3gram_baseline.py train_file prefixes_file output_file
    Role: learn a 3-gram on the whole sequences of train_file, then generates a ranking of the 5 most probable symbols for each prefix of prefixes_file, stores these ranking in output_file (one ranking per line, in the same order than in the prefix file)
    Example: python 3gram_baseline.py ../train/0.spice.train ../prefixes/0.spice.prefix.public 0.spice.ranking
"""




from numpy import *
from decimal import *
from sys import *
import math

import collections
import itertools


train_file = argv[1] #whole sequences, PAutomaC/SPiCe format
prefixes_file = argv[2] #all prefixes, PAutomaC/SPiCe format
output_file = argv[3] #to store the rankings on the prefixes


def number(arg):
    return Decimal(arg)

#implement a unigram
def unigramdict(sett):
 #DPdict = {}
 #list2d = [[1,2,3],[4,5,6], [7], [8,9]]
 merged = list(itertools.chain.from_iterable(sett))
 counter= collections.Counter(merged)
 #print "Unigram"
 #print(counter)
 # Counter({1: 4, 2: 4, 3: 2, 5: 2, 4: 1})
 #print(counter.values()) # [4, 4, 2, 1, 2]
 #print(counter.keys()) # [1, 2, 3, 4, 5]
 #print(counter.most_common(3))
 #total = sum(counter.values())
 #print total
 #print "End unigram"
 return counter
 
#implement a bigram

def bigramdict(sett):
 DPdict = {}
 total = 0
 for sequence in sett:
     ngramseq = [-1] + sequence + [-2]
     #print ngramseq
     for start in range(len(ngramseq)-1):
         total
         end = start + 1
         if tuple([ngramseq[start]]) in DPdict:
             table = DPdict[tuple([ngramseq[start]])]
             if ngramseq[end] in table:
                 table[ngramseq[end]] = table[ngramseq[end]] + 1
             else:
                 table[ngramseq[end]] = 1
             table[-1] = table[-1] + 1
         else:
             table = {}
             table[ngramseq[end]] = 1
             table[-1] = 1
             DPdict[tuple([ngramseq[start]])] = table
 #print "DPDict" + str(DPdict)
 return DPdict


def threegramdict(sett):
 DPdict = {}
 total = 0
 for sequence in sett:
     ngramseq = [-1,-1] + sequence + [-2]
     #print ngramseq
     for start in range(len(ngramseq)-2):
         total
         end = start + 2
         if tuple(ngramseq[start:end]) in DPdict:
             table = DPdict[tuple(ngramseq[start:end])]
             if ngramseq[end] in table:
                 table[ngramseq[end]] = table[ngramseq[end]] + 1
             else:
                 table[ngramseq[end]] = 1
             table[-1] = table[-1] + 1
         else:
             table = {}
             table[ngramseq[end]] = 1
             table[-1] = 1
             DPdict[tuple(ngramseq[start:end])] = table
 #print "DPDict" + str(DPdict)
 return DPdict

def threegramrank(prefix, alphabet, DPdict_3g, DPdict_2g, DPdict_1g):
    probs=[]
    # Compute the probability for prefix to be a whole sequence
    prob = number('1.0')
    
    term1 = term2 = term3 =term4 = number(0.0);    
    
    lamda_3g = number('0.0')
    lamda_2g = number('0.0')
    p_continuation = number('0.0')
    D_1 = D_2 = number('0.75')
    
    ngramseq = [-1,-1] + prefix + [-2]
    
    for start in range(len(ngramseq)-2):
        end = start + 2
        
        if tuple(ngramseq[start:end]) in DPdict_3g and ngramseq[end] in DPdict_3g[tuple(ngramseq[start:end])]:
            term1 = max((number(DPdict_3g[tuple(ngramseq[start:end])][ngramseq[end]]))- D_1, 0 ) / number(DPdict_3g[tuple(ngramseq[start:end])][-1])
            for dic in DPdict_3g.keys():
                if dic == tuple(ngramseq[start:end]):
                    count_3g = len(DPdict_3g[dic].values())
            lamda_3g = (D_1/number(DPdict_3g[tuple(ngramseq[start:end])][-1])) * count_3g
            
            #print "term1 = "+ str(term1)
        if tuple(ngramseq[start+1:end]) in DPdict_2g and ngramseq[end] in DPdict_2g[tuple(ngramseq[start+1:end])]:
            term2 = lamda_3g * (max((number(DPdict_2g[tuple(ngramseq[start+1:end])][ngramseq[end]])) - D_2,0) / number(DPdict_2g[tuple(ngramseq[start+1:end])][-1]))
            
            count_n = 0

            for dic in DPdict_2g.values():
                for w in dic.keys():
                    if w == ngramseq[end]:
                        count_n = count_n + 1


            count_d = len(DPdict_2g.keys())
            #print count_n
            #print count_d

            p_continuation = count_n/count_d
            
            for dic in DPdict_2g.keys():
                if dic == tuple(ngramseq[start+1:end]):
                    count_2g = len(DPdict_2g[dic].values())
                    
            lamda_2g = (D_2/number(DPdict_2g[tuple(ngramseq[start+1:end])][-1])) * count_2g
             
            term3 = Decimal(lamda_3g) * Decimal(lamda_2g) * Decimal(p_continuation)
            
            
            #print "term2= " + str(term2)
        #if ngramseq[end] in DPdict_1g.keys():
            #term3 = lamda_3g * lamda_2g * P_continuation(ngramseq[end])
            #term3 = lamda_3g * lamda_2g * DPdict_1g[end]/sum(DPdict_1g.values())
             #print "term3 = " + str(term3)
         #term4 = lamda4 / alphabet
         #print "term4 = " + str(term4)
        prob = term1 + term2 + term3
        #print "prob=" + str(prob)
        term1 = term2 = term3 = number(0.0);
            #prob = prob * (number(DPdict[tuple(ngramseq[start:end])][ngramseq[end]]) / number(DPdict[tuple(ngramseq[start:end])][-1]))
        #else:
            # Subsequence not in the dictionnary
            #prob = number(0)
    probs.append((-1,prob))
    for x in range(alphabet):
        prob = number('1.0')
        term1 = term2 = term3 =term4 = number(0.0);
        ngramseq = [-1,-1] + prefix + [x]
        for start in range(len(ngramseq)-2):
            end = start + 2
            if tuple(ngramseq[start:end]) in DPdict_3g and ngramseq[end] in DPdict_3g[tuple(ngramseq[start:end])]:
                term1 = max((number(DPdict_3g[tuple(ngramseq[start:end])][ngramseq[end]]))- D_1, 0 ) / number(DPdict_3g[tuple(ngramseq[start:end])][-1])
                for dic in DPdict_3g.keys():
                    if dic == tuple(ngramseq[start:end]):
                        count_3g = len(DPdict_3g[dic].values())
                lamda_3g = (D_1/number(DPdict_3g[tuple(ngramseq[start:end])][-1])) * count_3g
            
                #print "term1 = "+ str(term1)
            if tuple(ngramseq[start+1:end]) in DPdict_2g and ngramseq[end] in DPdict_2g[tuple(ngramseq[start+1:end])]:
                term2 = lamda_3g * (max((number(DPdict_2g[tuple(ngramseq[start+1:end])][ngramseq[end]])) - D_2,0) / number(DPdict_2g[tuple(ngramseq[start+1:end])][-1]))
            
                count_n = 0

                for dic in DPdict_2g.values():
                    for w in dic.keys():
                        if w == ngramseq[end]:
                            count_n = count_n + 1


                count_d = len(DPdict_2g.keys())
                #print count_n
                #print count_d

                p_continuation = count_n/count_d
            
                for dic in DPdict_2g.keys():
                    if dic == tuple(ngramseq[start+1:end]):
                        count_2g = len(DPdict_2g[dic].values())
                    
                lamda_2g = (D_2/number(DPdict_2g[tuple(ngramseq[start+1:end])][-1])) * count_2g
             
                term3 = Decimal(lamda_3g) * Decimal(lamda_2g) * Decimal(p_continuation)
            
            
                #print "term2= " + str(term2)
        #if ngramseq[end] in DPdict_1g.keys():
            #term3 = lamda_3g * lamda_2g * P_continuation(ngramseq[end])
            #term3 = lamda_3g * lamda_2g * DPdict_1g[end]/sum(DPdict_1g.values())
             #print "term3 = " + str(term3)
         #term4 = lamda4 / alphabet
         #print "term4 = " + str(term4)
            prob = term1 + term2 + term3
            #print "prob=" + str(prob)
            term1 = term2 = term3 = number(0.0);
            #prob = prob * (number(DPdict[tuple(ngramseq[start:end])][ngramseq[end]]) / number(DPdict[tuple(ngramseq[start:end])][-1]))
        #else:
            # Subsequence not in the dictionnary
            #prob = number(0)
        probs.append((x,prob))
    probs=sorted(probs, key=lambda x: -x[1])
    #print "Prob" + str(probs)
    return [x[0] for x in probs]



def readset(f):
 sett = []
 line = f.readline()
 l = line.strip().split(" ")
 num_strings = int(l[0])
 alphabet_size = int(l[1])
 for n in range(num_strings):
     line = f.readline()
     l = line.strip().split(" ")
     sett = sett + [[int(i) for i in l[1:len(l)]]]
 return alphabet_size, sett


def list_to_string(l):
    s=str(l[0])
    for x in l[1:]:
        s+= " " + str(x)
    return(s)

print("Getting training sample")
alphabet, train = readset(open(train_file,"r"))
print ("Start Learning")
dict_1g = unigramdict(train)
dict_2g = bigramdict(train)
dict_3g = threegramdict(train)
#dict=threegramdict(train)
print ("Learning Ended")

print("Start rankings computation")
#open prefixes
p = open(prefixes_file,"r")
o = open(output_file, "w")
#get rid of first line of prefixes_file (needed since it contains nb of example, size of the alphabet)
p.readline()

for prefix in p.readlines():
    pl = prefix.split()
    pi = [int(i) for i in pl[1:len(pl)]]
    ranking = threegramrank(pi, alphabet, dict_3g, dict_2g, dict_1g)[:5]
    for i in range(len(ranking)):
        o.write(str(ranking[i])+ ' ')
    o.write('\n')

print("End of rankings computation")
p.close()
o.close()

