// Copyright (C) 2002 Samy Bengio (bengio@idiap.ch)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA


#include "PhonemeSeqDataSet.h"
#include "string_utils.h"

namespace Torch {
#ifdef USEDOUBLE
#define REAL_FORMAT "%lf"
#else
#define REAL_FORMAT "%f"
#endif

PhonemeSeqDataSet::PhonemeSeqDataSet(SeqDataSet* data,int phoneme, Dictionary* dict, HMM** models)
{
  n_observations = data->n_observations;
  n_file_names = 0;
  file_names = NULL;

  // first we need to compute the number of examples 
  n_real_examples = 0;
  for (int i=0;i<data->n_examples;i++) {
    data->setExample(i);
    SeqExample* ex = (SeqExample*)data->inputs->ptr;
    if (ex->n_alignments > 0) {
      for (int j=0;j<ex->n_alignments;j++) {
        if (ex->alignment_phoneme[j] == phoneme)
          n_real_examples++;
      }
    } else {
      for (int j=0;j<ex->n_seqtargets;j++) {
        int word = (int)ex->seqtargets[j][0];
        for (int k=0;k<dict->word_length[word];k++) {
          if (dict->words[word][k] == phoneme)
            n_real_examples++;
        }
      }
    }
  }

  n_examples = 0;
  examples = (SeqExample*) xalloc(sizeof(SeqExample)*n_real_examples);
  for(int i=0;i<data->n_examples;i++){
    data->setExample(i);
    SeqExample* ex = (SeqExample*)data->inputs->ptr;
    if (ex->n_alignments > 0) {
      for (int j=0;j<ex->n_alignments;j++) {
        if (ex->alignment_phoneme[j] == phoneme) {
          int start_align = j == 0 ? 0 : ex->alignment[j-1];
          int end_align = ex->alignment[j];
          addExample(ex,start_align,end_align);
        }
      }
    } else {
      int current_n_frames = 0;
      int n_frames_per_state;
      int n_states = 0;
      for (int j=0;j<ex->n_seqtargets;j++) {
        int word = (int)ex->seqtargets[j][0];
        for (int k=0;k<dict->word_length[word];k++) {
          n_states += models[dict->words[word][k]]->n_states - 2;
        }
      }
      n_frames_per_state = ex->n_frames / n_states;
      for (int j=0;j<ex->n_seqtargets;j++) {
        int word = (int)ex->seqtargets[j][0];
        for (int k=0;k<dict->word_length[word];k++) {
          int n_frames_for_phoneme = n_frames_per_state * 
            (models[dict->words[word][k]]->n_states - 2);
          if (dict->words[word][k] == phoneme) {
            addExample(ex,current_n_frames,current_n_frames+n_frames_for_phoneme);
          } 
          current_n_frames += n_frames_for_phoneme;
        }
      }
    }
  }
}

void PhonemeSeqDataSet::addExample(SeqExample* ex, int start, int end)
{
  SeqExample* ex_i = &examples[n_examples];
  int n_f = end - start;
  ex_i->n_real_frames = n_f;
  ex_i->n_frames = n_f;
  ex_i->n_seqtargets = 0;
  ex_i->seqtargets = NULL;
  ex_i->name = NULL;
  ex_i->selected_frames = NULL;
  ex_i->current_frame = 0;
  if (n_observations>0) {
    ex_i->observations = (real**)xalloc(sizeof(real*)*n_f);
    for (int i=0;i<n_f;i++) {
      ex_i->observations[i] = ex->observations[start+i];
    }
  } else {
    ex_i->observations = NULL;
  }
  if (n_inputs>0) {
    ex_i->inputs = (real**)xalloc(sizeof(real*)*n_f);
    for (int i=0;i<n_f;i++) {
      ex_i->inputs[i] = ex->inputs[start+i];
    }
  } else {
    ex_i->inputs = NULL;
  }
  n_examples++;
}

PhonemeSeqDataSet::~PhonemeSeqDataSet()
{
  freeMemory();
}


void PhonemeSeqDataSet::freeMemory()
{
  
  for(int example = 0; example < n_real_examples; example++)
  {
    for (int i=0;i<examples[example].n_seqtargets;i++)
      free(examples[example].seqtargets[i]);
    if (examples[example].inputs) {
      free(examples[example].inputs);
      examples[example].inputs = NULL;
    }
    if (examples[example].observations) {
      free(examples[example].observations);
      examples[example].observations = NULL;
    }
    if(examples[example].selected_frames) {
      free(examples[example].selected_frames);
      examples[example].selected_frames = NULL;
    }
  }
  free(examples);
  for (int i = 0;i<n_file_names;i++){
     if(file_names[i] != NULL){
        free(file_names[i]);
        file_names[i] = NULL;
     }
  }
  free(file_names);
}

}

