// Copyright (C) 2002 Samy Bengio (bengio@idiap.ch)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

#include "ViterbiTrainer.h"
#include "log_add.h"

namespace Torch {

ViterbiTrainer::ViterbiTrainer(Distribution *distribution_, SeqDataSet *data_) : EMTrainer(distribution_,data_)
{

}

void ViterbiTrainer::train(List *measurers)
{
  int iter = 0;
  real prev_nll = INF;
  real nll = INF;
  int n_train = sdata->n_examples;

  DataSet **datas;
  Measurer ***mes;
  int *n_mes;
  int n_datas;

  message("ViterbiTrainer: training");

  extractMeasurers(measurers, sdata, &datas, &mes, &n_mes, &n_datas);

  // first compute tot_n_frames
  sdata->totNFrames();

  while (1) {
    distribution->eMIterInitialize();
    nll = 0;
    for (int t=0;t<n_train;t++) {
      data->setExample(t);
      distribution->viterbiForward(data->inputs);
      nll -= distribution->log_probability;
      distribution->viterbiAccPosteriors(data->inputs,LOG_ONE);

      for(int i = 0; i < n_mes[0]; i++)
        mes[0][i]->measureEx();
    }
    nll /= sdata->tot_n_frames;
    distribution->eMUpdate();

    for(int i = 0; i < n_mes[0]; i++)
      mes[0][i]->measureIter();

    for(int julie = 1; julie < n_datas; julie++) {
      SeqDataSet *dataset = (SeqDataSet*)datas[julie];

      for(int t=0;t<dataset->n_examples;t++) {
        dataset->setExample(t);
        distribution->viterbiForward(dataset->inputs);

        for(int i = 0; i < n_mes[julie]; i++)
          mes[julie][i]->measureEx();
      }
      for(int i = 0; i < n_mes[julie]; i++)
        mes[julie][i]->measureIter();
    }


    // stopping criterion
    if (fabs(prev_nll - nll)/prev_nll < end_accuracy) {
      print("\n");
      break;
    }
    prev_nll = nll;
    print(".");
    iter++;
    if ((iter >= max_iter) && (max_iter > 0)) {
      print("\n");
      warning("ViterbiTrainer: you have reached the maximum number of iterations");
      break;
    }
  }
  for(int i=0;i<n_datas;i++) {
    for(int j=0;j<n_mes[i];j++)
      mes[i][j]->measureEnd();
  }
  deleteExtractedMeasurers(datas, mes, n_mes, n_datas);
}

void ViterbiTrainer::test(List *measurers)
{
  DataSet **datas;
  Measurer ***mes;
  int *n_mes;
  int n_datas;

  message("ViterbiTrainer: testing");

  extractMeasurers(measurers, NULL, &datas, &mes, &n_mes, &n_datas);

  for(int andrea = 0; andrea < n_datas; andrea++)
  {
    DataSet *dataset = datas[andrea];

    for(int i = 0; i < n_mes[andrea]; i++)
      mes[andrea][i]->reset();

    distribution->eMIterInitialize();
    for(int t = 0; t < dataset->n_examples; t++)
    {
      dataset->setExample(t);
      distribution->viterbiForward(dataset->inputs);

      for(int i = 0; i < n_mes[andrea]; i++)
        mes[andrea][i]->measureEx();
    }

    for(int i = 0; i < n_mes[andrea]; i++)
      mes[andrea][i]->measureIter();
  }

  deleteExtractedMeasurers(datas, mes, n_mes, n_datas);
}

ViterbiTrainer::~ViterbiTrainer()
{
}

}

