/* ---------------------------------------------------------- 
%   (C)1995 Institute for New Generation Computer Technology 
%       (Read COPYRIGHT for detailed information.) 
----------------------------------------------------------- */
/*-----------------------------------------------------------------

   Discrete HMM: Viterbi $B$K$h$k;D4pJ,G[(B
         Ver 1.0 1994.10.13 by H.Tanaka

   work.log.output[N][N][A]
   sample.seq[S][L]         --> sample.beststates[S][T]
   hmnet.arc[N][N].link         hmnet.arc[N][N].profile[A]
                           

-----------------------------------------------------------------*/
#include <stdio.h>
#include <string.h>
#include <math.h>

#include "defs.h"
#include "e_struct.h"

/* #define DEBUG_trelis */
/* #define DEBUG_profile */
/* #define DEBUG_work */
/* #define DEBUG_align */
/* #define DURATION */

int trelis_log[SAMPLE_LENGTH][NODEMAX];
int trelis_path[SAMPLE_LENGTH][NODEMAX];

int disc_viterbi()
{
  int s,n;
  int lasteval,endflag;

  for(s=0;s<sample.number;s++) {
    clear_trelis(s);
    make_trelis(s);
    get_bestpath(s);
  }
  clear_statistics();
  take_statistics();

  calc_trans();

  lasteval=sumeval;
  if(tied_switch==OFF) sumeval=calc_output_and_diff();
  else sumeval=calc_tied_output_and_diff();

  if(lasteval <= sumeval && split_after > 0) {
    endflag = V_END;
    split_after = 0;
  } else {
    endflag = V_CONTINUE;
    split_after ++;
  }
  return(endflag);

}

make_trelis(s)
int s;
{
  int i,n;

  for(n=0;n<sample.length[s];n++){ /* $B;~4V(B($BD9$5(B) */
    if(n==0) initial_proc(s);
    else if(n == sample.length[s]-1) end_proc(s);
    else routine_proc(s,n);
  }

#ifdef DEBUG_trelis
  for(i=0;i<hmnet.nodenum;i++){
    printf("DBTR");
    for(n=0;n<sample.length[s];n++){
      if(trelis_log[n][i]==mRANGE) printf("      --");
      else printf(" %7d",trelis_log[n][i]);
    }
    printf("\n");
  }
#endif
}

initial_proc(s) /* start -> i=0 */
int s;
{
  int is,i,j;
  int candidate;

  for(i=0;i<hmnet.nodenum;i++) trelis_path[0][i] = TERMINALSTATE;

  for(is=0;is<hmnet.initStateNumber;is++) {/* from (Start) */
    i = hmnet.initialState[is];
    for(j=0;j<hmnet.nodenum;j++) {        /* to */
      if (j==TERMINALSTATE) continue;
#ifdef DURATION
      if(i == j && sample.seq[s][0]==sample.seq[s][1]){ /* Duration Model */
	candidate = plus(log_near1((double)hmnet.initialProb[i]),
                         work_log.output[i][j][sample.seq[s][0]]);
	if(trelis_log[0][j] < candidate) {
	  trelis_log[0][j] = candidate;
	  trelis_path[0][j] = i;
	} 
      } else 
#endif
	if(hmnet.arc[i][j].link == LINK){
	candidate = plus(log_near1((double)hmnet.initialProb[i]),
                    plus(work_log.output[i][j][sample.seq[s][0]],
			 work_log.trans[i][j]));
	if(trelis_log[0][j] < candidate) {
	  trelis_log[0][j] = candidate;
	  trelis_path[0][j] = i;
	}
      } /* LINK */
    } /* to */
  } /* from */
}

routine_proc(s,n) /* 1 < n < sample.length-1 */
int s,n;
{
  int i,j;
  int candidate;

  for(i=0;i<hmnet.nodenum;i++){   /* $B>uBV(B from */
    if (i==TERMINALSTATE) continue;
    for(j=0;j<hmnet.nodenum;j++){ /* $B>uBV(B to */
      if (j==TERMINALSTATE) continue;
#ifdef DURATION
      if(i == j && sample.seq[s][n]==sample.seq[s][n+1]){ /* Duration Model */
	candidate = plus(trelis_log[n-1][i],
                         work_log.output[i][j][sample.seq[s][n]]);
	if(trelis_log[n][j] < candidate) {
	  trelis_log[n][j] = candidate;
	  trelis_path[n][j] = i;
	}
      } else 
#endif
	if(hmnet.arc[i][j].link == LINK){
	candidate = plus(trelis_log[n-1][i],
                    plus(work_log.output[i][j][sample.seq[s][n]],
			 work_log.trans[i][j]));
	if(trelis_log[n][j] < candidate) {
	  trelis_log[n][j] = candidate;
	  trelis_path[n][j] = i;
	}
      } /* LINK */
    } /* to */
  } /* from */
}

end_proc(s) /* n == sample.length -1 */
{
  int i,j,n;
  int candidate,optim;

  n = sample.length[s]-1;
  j = TERMINALSTATE;

  for(i=0;i<hmnet.nodenum;i++){   /* $B>uBV(B from */
    if (i==TERMINALSTATE) continue;
    if(hmnet.arc[i][j].link == LINK){
      candidate = plus(trelis_log[n-1][i],
                  plus(work_log.output[i][j][sample.seq[s][n]],
                       work_log.trans[i][j]));
      if(trelis_log[n][j] < candidate) {
	trelis_log[n][j] = candidate;
	trelis_path[n][j] = i;
      }
    } /* LINK */
  } /* from */
}

clear_trelis(s)
int s;
{
  int i,n;

  for(n=0;n<sample.length[s];n++)
    for(i=0;i<hmnet.nodenum;i++)
      trelis_log[n][i] = mRANGE;

}

get_bestpath(s)
int s;
{
  int i,j,m,n;
  int state;

  state = TERMINALSTATE;
  for(n=sample.length[s]-1;n>=0;n--){
    sample.beststates[s][n] = trelis_path[n][state];
    state = sample.beststates[s][n];
  }

#ifdef DEBUG_align
  printf("DBXAM ");
  for(n=0;n<sample.length[s];n++)
    printf("%2c",sample.seq[s][n]+'a');
  printf("\nDBXST");
  for(n=0;n<sample.length[s];n++)
    printf("%2d",sample.beststates[s][n]);
  printf(" E\n");
#endif
}

take_statistics()
{
  int i,j,s,n,a;

  for(s=0;s<sample.number;s++) {
    for(n=0;n<sample.length[s]-1;n++){
      i = sample.beststates[s][n];
      j = sample.beststates[s][n+1];
      a = sample.seq[s][n];

      hmnet.arc[i][j].profile[a]++;
      hmnet.arc[i][j].usage++;
      hmnet.node[i].profile[a]++;
      hmnet.node[i].usage++;
    }

    i = sample.beststates[s][sample.length[s]-1];
    j = TERMINALSTATE;
    a = sample.seq[s][sample.length[s]-1];

    hmnet.arc[i][j].profile[a]++;
    hmnet.arc[i][j].usage++;
    hmnet.node[i].profile[a]++;
    hmnet.node[i].usage++;
  }

#ifdef DEBUG_profile
  for(i=0;i<hmnet.nodenum;i++) {
    printf("DBP");
    for(j=0;j<hmnet.nodenum;j++) {
      if(hmnet.arc[i][j].link != LINK) continue;
      printf(" (%d,%d)",i,j); 
      for(a=0;a<AMINOS;a++)
	if(hmnet.arc[i][j].profile[a]>0)
	  printf(" %c%d",a+'a',hmnet.arc[i][j].profile[a]);
    }
    printf("\n");
  }

  for(i=0;i<hmnet.nodenum;i++) {
    printf("DBA");
    for(j=0;j<hmnet.nodenum;j++)
	if(hmnet.arc[i][j].usage>0)
	  printf(" (%d,%d): %d",i,j,hmnet.arc[i][j].usage);
    printf("\n");
  }

  printf("DBU");
  for(i=0;i<hmnet.nodenum;i++)
    if(hmnet.node[i].usage>0)
	  printf(" (%d,--): %d",i,hmnet.node[i].usage);
  printf("\n");
#endif
}

clear_statistics()
{
  int i,j,a;

  for(i=0;i<hmnet.nodenum;i++)
    for(j=0;j<hmnet.nodenum;j++)
      for(a=0;a<AMINOS;a++)
	hmnet.arc[i][j].profile[a]=0;

  for(i=0;i<hmnet.nodenum;i++)
    for(j=0;j<hmnet.nodenum;j++)
	hmnet.arc[i][j].usage=0;

  for(i=0;i<hmnet.nodenum;i++)
    for(a=0;a<AMINOS;a++)
      hmnet.node[i].profile[a]=0;

  for(i=0;i<hmnet.nodenum;i++)
	hmnet.node[i].usage=0;
}

calc_trans()
{
  int s,i,j;
  int temp;

  for(i=0;i<hmnet.nodenum;i++) {
    if(i==TERMINALSTATE) continue;
    for(j=0;j<hmnet.nodenum;j++){
      if(work_log.trans[i][j]==mRANGE) continue;
      if(hmnet.node[i].usage==0) continue;
      work_log.trans[i][j]=logplus(work_log.trans[i][j],
	minus_II(hmnet.arc[i][j].usage,hmnet.node[i].usage));
      work_log.trans[i][j]=minus(work_log.trans[i][j],
				 log_near1((double)2));
    }
  }

#ifdef DEBUG_work
  for(i=0;i<hmnet.nodenum;i++) {
    if(i==TERMINALSTATE) continue;
    for(j=0;j<hmnet.nodenum;j++)
      if(work_log.trans[i][j]>mRANGE)
	printf("DBWTR (%d,%d): %d\n",i,j,work_log.trans[i][j]);
  }
#endif
}

int calc_output_and_diff()
{
  int s,i,j,a;
  int temp;
  int eval;
  int sumeval;

  for(sumeval=mRANGE,i=0;i<hmnet.nodenum;i++) {
    if(i==TERMINALSTATE) continue;
    for(j=0;j<hmnet.nodenum;j++){
      for(a=0;a<AMINOS;a++) {
	if(work_log.output[i][j][a]==mRANGE) continue;
	if(hmnet.arc[i][j].usage==0) continue;
	temp=minus_II(hmnet.arc[i][j].profile[a],hmnet.arc[i][j].usage);
	eval=logminus(work_log.output[i][j][a],temp);
	sumeval=logplus(sumeval,eval);

	work_log.output[i][j][a]=logplus(work_log.output[i][j][a],temp);
	work_log.output[i][j][a]=minus(work_log.output[i][j][a],
				       log_near1((double)2));
      }
    }
  }

  sumeval = sumeval/hmnet.nodenum;

#ifdef DEBUG_work
  for(i=0;i<hmnet.nodenum;i++) {
    if(i==TERMINALSTATE) continue;
    for(j=0;j<hmnet.nodenum;j++)
      for(a=0;a<AMINOS;a++)
	if(work_log.output[i][j][a]>mRANGE)
	  printf("DBWOP (%d,%d) %c: %d\n",i,j,a+'a',work_log.output[i][j][a]);
  }
  printf("SumEval=%d\n",sumeval);
#endif

  return(sumeval);
}

int calc_tied_output_and_diff()
{
  int s,i,j,a;
  int temp;
  int eval;
  int sumeval;

  for(sumeval=mRANGE,i=0;i<hmnet.nodenum;i++) {
    if(i==TERMINALSTATE) continue;
    for(j=0;j<hmnet.nodenum;j++){
      for(a=0;a<AMINOS;a++) {
	if(work_log.output[i][j][a]==mRANGE) continue;
	if(hmnet.node[i].usage==0) continue;
	temp=minus_II(hmnet.node[i].profile[a],hmnet.node[i].usage);

	work_log.output[i][j][a]=logplus(work_log.output[i][j][a],temp);
	work_log.output[i][j][a]=minus(work_log.output[i][j][a],
				       log_near1((double)2));
      }
    }
    eval=logminus(work_log.output[i][i][a],temp);
    sumeval=logplus(sumeval,eval);
  }

  sumeval = sumeval/hmnet.nodenum;

#ifdef DEBUG_work
  for(i=0;i<hmnet.nodenum;i++) {
    if(i==TERMINALSTATE) continue;
    for(j=0;j<hmnet.nodenum;j++)
      for(a=0;a<AMINOS;a++)
	if(work_log.output[i][j][a]>mRANGE)
	  printf("DBWOP (%d,%d) %c: %d\n",i,j,a+'a',work_log.output[i][j][a]);
  }
  printf("SumEval=%d\n",sumeval);
#endif

  return(sumeval);
}

/* end of file */
