/*************************************************************************/
/*                                                                       */
/*                Centre for Speech Technology Research                  */
/*                     University of Edinburgh, UK                       */
/*                      Copyright (c) 1995,1996                          */
/*                        All Rights Reserved.                           */
/*                                                                       */
/*  Permission to use, copy, modify, distribute this software and its    */
/*  documentation for research, educational and individual use only, is  */
/*  hereby granted without fee, subject to the following conditions:     */
/*   1. The code must retain the above copyright notice, this list of    */
/*      conditions and the following disclaimer.                         */
/*   2. Any modifications must be clearly marked as such.                */
/*   3. Original authors' names are not deleted.                         */
/*  This software may not be used for commercial purposes without        */
/*  specific prior written permission from the authors.                  */
/*                                                                       */
/*  THE UNIVERSITY OF EDINBURGH AND THE CONTRIBUTORS TO THIS WORK        */
/*  DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING      */
/*  ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT   */
/*  SHALL THE UNIVERSITY OF EDINBURGH NOR THE CONTRIBUTORS BE LIABLE     */
/*  FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES    */
/*  WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN   */
/*  AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION,          */
/*  ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF       */
/*  THIS SOFTWARE.                                                       */
/*                                                                       */
/*************************************************************************/
/*                 Authors:  Alan W Black and Simon King                 */
/*                 Date   :  January 1997                                */
/*-----------------------------------------------------------------------*/
/*  A simple use of the Viterbi decoder                                  */
/*                                                                       */
/*=======================================================================*/

#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include "EST.h"

EST_read_status load_TList_of_StrVector(EST_TList<EST_StrVector> &w,
					const EST_String &filename,
					const int vec_len);

static void print_results(EST_Stream &wstream);
static void do_search(EST_Stream &wstream);
static EST_VTPath *vit_npath(EST_VTPath *p,EST_VTCandidate *c);
static EST_VTCandidate *vit_candlist(EST_Stream_Item &s);
static void load_vocab(const EST_String &vfile);

static int add_word(EST_Stream &w, const EST_String &word, int pos);

static void load_wstream(const EST_String &filename,
			 const EST_String &vfile,
			 EST_Stream &w,
			 EST_Track &obs);

static void load_given(const EST_String &filename,
		       const int ngram_order);
		       
static double find_gram_prob(EST_VTPath *p,int *state);

// special stuff for non-sliding window ngrams
static double find_extra_gram_prob(EST_VTPath *p,int *state, int time);
static void get_history(EST_StrVector &history, EST_VTPath *p);
static void fill_window(EST_StrVector &window,EST_StrVector &history,
			EST_VTPath *p,const int time);
static int is_a_special(EST_String &s, int &val);
static int max_history=0;

static EST_Ngrammar ngram;
static EST_String pstring = SENTENCE_START_MARKER;
static EST_String ppstring = SENTENCE_END_MARKER;
static float lm_scale = 1.0;
static float ob_scale = 1.0;
static float ob_scale2 = 1.0;

// always logs
static double ob_log_prob_floor = SAFE_LOG_ZERO;  
static double ob_log_prob_floor2 = SAFE_LOG_ZERO;  
static double lm_log_prob_floor = SAFE_LOG_ZERO;  

int btest_debug = FALSE;
static EST_String outfile = "";
static EST_StrList vocab;
static EST_Track observations;  
static EST_Track observations2;  
static EST_TList<EST_StrVector> given; // to do : convert to array for speed
int using_given=FALSE;

// default is that obs are already logs
int take_logs = FALSE;
int num_obs = 1;

int main(int argc, char **argv)
{
    EST_StrList files;
    EST_Option al;
    EST_Stream wstream;
    double floor; // a temporary

    parse_command_line2(argc, argv, 
       EST_String("Usage:\n")+
       "viterbi  <options>\n"+
       "-ngram <string>     Grammar file, required\n"+
       "-given <string>     ngram left contexts, per frame\n"+
       "-observes <string>  Observations file, Ascii Track, required\n"+
       "                    One line per frame, probabilities in\n"+
       "                    order of vocab.\n"+
       "-vocab <string>     File with names of vocabulary, this\n"+
       "                    must be same number as width of observations, required\n"+
       "-output <string>    File to save output to, if unspecified\n"+
       "                    output goes to stdout\n"+
       "-ob_type <string>   Observation type : \"probs\" or \"logs\" (default is \"logs\")\n"+
       "\nFloor values and scaling (scaling is applied after floor value)\n"+
       "-lm_floor <float>   LM floor probability\n"+
       "-lm_scale <float>   LM scale factor factor (applied to log prob)\n"+
       "-ob_floor <float>   Observations floor probability\n"+
       "-ob_scale <float>   Observation scale factor (applied to prob or log prob, depending on -ob_type)\n\n"+
       "-prev_tag <string>\n"+
       "                 tag before sentence start\n"+
       "-prev_prev_tag <string>\n"+
       "                 all words before 'prev_tag'\n"+
       "-last_tag <string>\n"+
       "                 after sentence end\n"+
       "-default_tags    use default tags of "+SENTENCE_START_MARKER+","
			SENTENCE_END_MARKER+" and "+SENTENCE_END_MARKER+"\n"+
       "                 respectively\n"+

       "-observes2 <string> second observations (overlays first, ob_type must be same)\n"+
       "-ob_floor2 <float>  \n"+
       "-ob_scale2 <float>  \n",
			files, al);

    if (al.present("-ngram"))
    {
	ngram.load(al.val("-ngram"));
    }
    else
    {
	cerr << argv[0] << ": no ngram specified" << endl;
	exit(-1);
    }

    if (al.present("-observes") &&
	al.present("-vocab"))
    {
	load_wstream(al.val("-observes"),al.val("-vocab"),wstream,observations);
	if (al.present("-observes2"))
	{
	    load_wstream(al.val("-observes2"),al.val("-vocab"),wstream,observations2);
	    num_obs = 2;
	}
    }
    else
    {
	cerr << argv[0] << ": no observations or vocab file specified" << endl;
	exit(-1);
    }

    if (al.present("-given"))
    {
	load_given(al.val("-given"),ngram.order());
	using_given=TRUE;
    }

    if (al.present("-output"))
	outfile = al.val("-output");
    else
	outfile = "";

    if (al.present("-lm_scale"))
	lm_scale = al.fval("-lm_scale");
    else
	lm_scale = 1.0;

    if (al.present("-ob_scale"))
	ob_scale = al.fval("-ob_scale");
    else
	ob_scale = 1.0;

    if (al.present("-ob_scale2"))
	ob_scale2 = al.fval("-ob_scale2");
    else
	ob_scale2 = 1.0;

    if (al.present("-prev_tag"))
	pstring = al.val("-prev_tag");
    if (al.present("-prev_prev_tag"))
	ppstring = al.val("-prev_prev_tag");


    // language model floor
    if (al.present("-lm_floor"))
    {
	floor = al.fval("-lm_floor");
	if(floor < 0)
	{
	    cerr << "Error : LM floor probability is negative !" << endl;
	    exit(-1);
	}
	else if(floor > 1)
	{
	    cerr << "Error : LM floor probability > 1 " << endl;
	    exit(-1);
	}
	lm_log_prob_floor = safe_log(floor);
    }

    // observations floor
    if (al.present("-ob_floor"))
    {
	floor = al.fval("-ob_floor");
	if(floor < 0)
	{
	    cerr << "Error : Observation floor probability is negative !" << endl;
	    exit(-1);
	}
	else if(floor > 1)
	{
	    cerr << "Error : Observation floor probability > 1 " << endl;
	    exit(-1);
	}
	ob_log_prob_floor = safe_log(floor);
    }

    if (al.present("-ob_floor2"))
    {
	floor = al.fval("-ob_floor2");
	if(floor < 0)
	{
	    cerr << "Error : Observation2 floor probability is negative !" << endl;
	    exit(-1);
	}
	else if(floor > 1)
	{
	    cerr << "Error : Observation2 floor probability > 1 " << endl;
	    exit(-1);
	}
	ob_log_prob_floor2 = safe_log(floor);
    }
    

    if (al.present("-ob_type"))
    {
	if(al.val("-ob_type") == "logs")
	    take_logs = false;
	else if(al.val("-ob_type") == "probs")
	    take_logs = true;
	else
	{
	    cerr << "\"" << al.val("-ob_type") 
		<< "\" is not a valid ob_type : try \"logs\" or \"probs\"" << endl;
	    exit(-1);
	}
    }

    do_search(wstream);
    print_results(wstream);
    return 0;
}

static void print_results(EST_Stream &wstream)
{
    EST_Stream_Item *s;
    float pscore;
    EST_String predict;
    FILE *fd;

    if (outfile == "")
	fd = stdout;
    else if ((fd = fopen(outfile,"wb")) == NULL)
    {
	cerr << "can't open \"" << outfile << "\" for output" << endl;
	exit(-1);
    }

    for (s=wstream.head(); s != 0 ; s=next(s))
    {
	predict = s->feature("best").string();
	pscore = s->feature("best_score");
	fprintf(fd,"%s %f\n",(const char *)predict,pscore);
    }

    if (outfile != "")
	fclose(fd);
}

static void do_search(EST_Stream &wstream)
{
    // Apply Ngram to matrix of probs 
    int states;

    states = ngram.states();
    EST_Viterbi_Decoder vc(vit_candlist,vit_npath,states);

    vc.initialise(wstream);
    vc.search();

    vc.result("best");  // adds fields to w with best values 

}

static void load_wstream(const EST_String &filename,
			 const EST_String &vfile, 
			 EST_Stream &w,
			 EST_Track &obs)
{
    // Load in vocab and probs into Stream (this isn't general)
    EST_String word, pos;
    int i=-1;

    if(vocab.empty())
	load_vocab(vfile);

    if (obs.load(filename,0.10) != 0)
    {
	cerr << "can't find observations file \"" << filename << "\"" << endl;
	exit(-1);
    }

    if (vocab.length() != obs.num_channels())
    {
	cerr << "Number in vocab (" << vocab.length() << 
	    ") not equal to observation's width (" <<
		obs.num_channels() << ")" << endl;
	exit(-1);
    }
	
    if(w.empty())
	for (i=0; i < obs.num_frames(); i++)
	    add_word(w,itoString(i),i);
}


static void load_given(const EST_String &filename,
		       const int ngram_order)
{

    EST_String word, pos;
    EST_TBI *p;
    int i,j;

    if (load_TList_of_StrVector(given,filename,ngram_order-1) != 0)
    {
	cerr << "can't load given file \"" << filename << "\"" << endl;
	exit(-1);
    }

    // set max history
    for (p = given.head(); p; p = next(p))
    {
	for(i=0;i<given(p).num_points();i++)
	    if(	is_a_special( given(p)(i), j) && (-j > max_history))
		max_history = -j;
	
    }
    
}

static void load_vocab(const EST_String &vfile)
{
    // Load vocabulary (strings)
    EST_TokenStream ts;

    if (ts.open(vfile) == -1)
    {
	cerr << "can't find vocab file \"" << vfile << "\"" << endl;
	exit(-1);
    }

    while (!ts.eof())
	if (ts.peek() != "")
	    vocab.append(ts.get().string());

    ts.close();
}

static int add_word(EST_Stream &w, const EST_String &word, int pos)
{
    static int addr=0;
    EST_Stream_Item item;
    
    item.init("Word");
    item.set_name(word);
    item.set_feature("pos",pos);
    item.set_addr(addr);
    w.append(item);
    return addr++;
} 

static EST_VTCandidate *vit_candlist(EST_Stream_Item &s)
{
    // Return a list of new candidates from this point 
    double prob=1.0,prob2=1.0;
    int i;
    EST_TBI *p;
    int observe;
    EST_VTCandidate *all_c = 0;
    EST_VTCandidate *c;

    observe = s.feature("pos");  // index for observations TRACK
    for (i=0,p=vocab.head(); i < observations.num_channels(); i++,p=next(p))
    {
	c = new EST_VTCandidate;
	c->name = vocab(p);  // to be more efficient this could be the index
	prob = observations.a(observe,i);
	if(num_obs == 2)
	    prob2 = observations2.a(observe,i);

	if(take_logs)
	{
	    prob = safe_log10(prob);
	    if (prob < ob_log_prob_floor)
		prob = ob_log_prob_floor;

	    if(num_obs == 2)
	    {
		prob2 = safe_log10(prob2);
		if (prob2 < ob_log_prob_floor2)
		    prob2 = ob_log_prob_floor2;
	    }
	}
	else // already in logs
	{
	    if (prob < ob_log_prob_floor)
		prob = ob_log_prob_floor;
	    if ((num_obs == 2) && (prob2 < ob_log_prob_floor2))
		prob2 = ob_log_prob_floor2;
	}

	prob *= ob_scale;
	prob2 *= ob_scale2;

	if(num_obs == 2)
	    c->score = prob + prob2;
	else
	    c->score = prob;

	c->next = all_c;
	c->s = &s;
	all_c = c;
    }
    return all_c;
}

static EST_VTPath *vit_npath(EST_VTPath *p,EST_VTCandidate *c)
{
    // Build a (potential) new path link from this previous path and 
    // This candidate 
    EST_VTPath *np = new EST_VTPath;
    double lprob,prob;
    EST_String prev,ttt;

    np->c = c;
    np->from = p;

    // are we using extra info ?
    if(using_given)
	// time of candidate is
	// c->s->feature("pos");
	prob = find_extra_gram_prob(np,&np->state,c->s->feature("pos"));
    else
	prob = find_gram_prob(np,&np->state);


    lprob = safe_log10(prob);
    if (lprob < lm_log_prob_floor)
	lprob =	lm_log_prob_floor;

    lprob *= lm_scale;

    np->set_feature("lscore",(c->score+lprob)); // simonk : changed prob to lprob
    if (p==0)
	np->score = (c->score+lprob);
    else
	np->score = (c->score+lprob) + p->score;

    return np;
}

static double find_gram_prob(EST_VTPath *p,int *state)
{
    // Look up transition probability from *state for name.
    // Return probability and update state
    double prob=0.0,nprob;
    int i,f=FALSE;
    EST_VTPath *pp;
    
    EST_StrVector window(ngram.order());
    for (pp=p->from,i=ngram.order()-2; i >= 0; i--)
    {
	if (pp != 0)
	{
	    window(i) = pp->c->name.string();
	    pp = pp->from;
	}
	else if (f)
	    window(i) = ppstring;
	else
	{
	    window(i) = pstring;
	    f = TRUE;
	}
    }
    window(ngram.order()-1) = p->c->name.string();
    const EST_DiscreteProbDistribution &pd = ngram.prob_dist(window);
    if (pd.samples() == 0)
	prob = 0;
    else
	prob = (double)pd.probability(p->c->name.string());
    
    for (i=0; i < ngram.order()-1; i++)
	window(i) = window(i+1);
    ngram.predict(window,&nprob,state);

    return prob;
}


static double find_extra_gram_prob(EST_VTPath *p,int *state,int time)
{

    int i;
    double prob=0.0,nprob;
    EST_StrVector window(ngram.order());
    EST_StrVector history(max_history);

    get_history(history,p);

    fill_window(window,history,p,time);

    /*
    cerr << "Looking up ngram ";
    for(i=0;i<window.num_points();i++)
	cerr << window(i) << " ";
    cerr << endl;
    */

    const EST_DiscreteProbDistribution &pd = ngram.prob_dist(window);
    if (pd.samples() == 0)
	prob = 0;
    else
	prob = (double)pd.probability(p->c->name.string());

    // shift history, adding latest item at 'end' (0)
    if(max_history>0)
    {
	for(i=history.num_points()-1;i>0;i--)
	    history(i) = history(i-1);
	history(0) = p->c->name.string();
    }

    fill_window(window,history,p,time+1);
    ngram.predict(window,&nprob,state);

    //cerr << endl << endl;

    return prob;

}

static void get_history(EST_StrVector &history, EST_VTPath *p)
{

    EST_VTPath *pp;
    int i,f=FALSE;
    for (pp=p->from,i=0; i < history.num_points(); i++)
    {
	
	if (pp != 0)
	{
	    history(i) = pp->c->name.string();
	    pp = pp->from;
	}
	else if (f)
	    history(i) = ppstring;
	else
	{
	    history(i) = pstring;
	    f = TRUE;
	}
    }

}

static void fill_window(EST_StrVector &window,EST_StrVector &history,
			EST_VTPath *p,const int time)
{
    // Look up transition probability from *state for name.
    // Return probability and update state
    int i,j;
    EST_String s;

    // can we even do it?
    if( time >= given.length() )
	return;

    // format should be run-time defined, but try this for now
    // first n-1 things in window come from 'given'
    // last one is predictee

    // also want vocab and grammar mismatch allowed !!!!!!

    // predictee
    window(ngram.order()-1) = p->c->name.string();

    // given info for this time
    EST_StrVector *this_g = &(given.item(time)); // inefficient to count down a list


    //cerr << "window : ";
    for(i=0;i<ngram.order()-1;i++)
    {

	if( is_a_special( (*this_g)(i), j))
	{
	    window(i) = history(-1-j); // j=-1 -> (0)   j=-2 -> (1)   etc.
	    //cerr << "<" << j << "=" << window(i) << "> ";
	}
	else
	{
	    window(i) = (*this_g)(i);
	    //cerr << window(i) << " ";
	}
	
    }
    //cerr << endl;
}



static int is_a_special(EST_String &s, int &val)
{

    // special is "<int>"

    EST_String tmp;
    if(s.contains("<") && s.contains(">"))
    {
	tmp = s.after("<");
	tmp = tmp.before(">");
	val = atoi(tmp);
	//cerr << "special " << tmp << "=" << val << endl;
	return TRUE;
    }
    return FALSE;
}
