/*****************************************************************************
  FILE           : $Source: /usr/local/bv/SNNS/SNNSv4.1/tools/sources/RCS/analyze.c,v $
  SHORTNAME      : analyze.c
  SNNS VERSION   : 4.1

  PURPOSE        : Network Analyzation Tool
  NOTES          :

  AUTHOR         : Stefan Broeckel, Tobias Soyez 
  DATE           : 30.07.92

  CHANGED BY     : Michael Vogt
  IDENTIFICATION : $State: Exp $ $Locker:  $
  RCS VERSION    : $Revision: 2.10 $
  LAST CHANGE    : $Date: 1995/11/16 07:19:48 $

             Copyright (c) 1990-1995  SNNS Group, IPVR, Univ. Stuttgart, FRG

******************************************************************************/


/*****************************************************************************/
/* included headers                                                          */
/*****************************************************************************/

#include <stdio.h>


/*****************************************************************************/
/* constants                                                                 */
/*****************************************************************************/

#define  WRONG   1
#define  RIGHT   2
#define  UNKNOWN 4

#define  ON  1
#define  OFF 2

#define  R402040  1
#define  WTA      2
#define  band     3


/*****************************************************************************/
/* type definitions                                                          */
/*****************************************************************************/

typedef struct 
{
  int  no_of_patterns           ;
  int  no_of_input_units        ;
  int  no_of_output_units       ;
  int  startpattern             ;
  int  endpattern               ;
  int  input_pattern_included   ;
  int  teaching_output_included ;
  int  sub_pattern_present      ;
} FileHeaderInfo ;


/*****************************************************************************/
/* global variables                                                          */
/*****************************************************************************/

FILE  *in_file  ;
FILE  *out_file ;



/*****************************************************************************/
/* function Error                                                            */
/*****************************************************************************/

float Error (output, teaching_output, no_of_units)

  float  *output          ;
  float  *teaching_output ;
  int    no_of_units      ;

  {   
    int    i ;
    float  e, diff ;

    e = 0.0 ;
    for (i = 0 ; i < no_of_units ; i++)
    { diff = teaching_output[i] - output[i]  ;
      e    = e + diff * diff                 ;
    }
    
    return e;
  }

/*****************************************************************************/
/* function ClassNo                                                          */
/*****************************************************************************/

int ClassNo (teaching_output, no_of_units)

  float *teaching_output;
  int no_of_units;

{
  int i;
  int no;

  no = 0;

  for (i = 0 ; i < no_of_units ; i++)
  {
    no = no + (int)(teaching_output[i]*i);
  }
 
  return no;
}

/*****************************************************************************/
/* The function F_402040 returns :                                           */
/*                                                                           */
/*    RIGHT   : the output of exactly one unit is >= high               and  */
/*              this unit has the greatest teaching output              and  */
/*              the output of all the other units is <= low                  */
/*                                                                           */
/*    WRONG   : the output of exactly one unit is >= high               and  */
/*              this unit has NOT the greatest teaching output          and  */
/*              the output of all the other units is <= low                  */
/*                                                                           */
/*    UNKNOWN : in any other case                                            */
/*                                                                           */
/* high, low  : parameters of F_402040                                       */
/*              default values: low = 0.4, high = 0.6                        */
/*****************************************************************************/

int F_402040 (output, teaching_output, no_of_units, high, low)

  float  *output          ;
  float  *teaching_output ;
  int    no_of_units      ;
  float  high             ;
  float  low              ;


  {
    int    o_pos, t_pos ;
    int    on, off, i   ;
    float  t_max        ;

    on    = 0 ;
    off   = 0 ;
    o_pos = 0 ;
    t_pos = 0 ;
    t_max = teaching_output[t_pos] ;

    for (i = 0 ; i < no_of_units ; i++)
    {
      if      (output[i] >= high) { on++  ; o_pos = i ; } 
      else if (output[i] <= low)    off++ ;
    
      if (teaching_output[i] > t_max)
      {  t_max = teaching_output[i] ;
         t_pos = i ; 
      } 
    }
 
    if ((on + off == no_of_units) && (on == 1))
    {  if (o_pos == t_pos)
            return (RIGHT) ;
       else return (WRONG) ;
    }
    else return (UNKNOWN) ;
  }



/*****************************************************************************/
/* The function F_WTA returns :                                              */
/*                                                                           */
/*    RIGHT   :  there is exactly one unit j with maximal output a       and */
/*               a > high                                                and */
/*               the output of all the other units is < a - low          and */
/*               the unit j has the greatest teaching output                 */
/*                                                                           */
/*    WRONG   :  there is exactly one unit j with maximal output a       and */
/*               a > high                                                and */
/*               the output of all the other units is < a - low          and */
/*               the unit j has NOT the greatest teaching output             */
/*                                                                           */
/*    UNKNOWN :  in any other case                                           */
/*                                                                           */
/* high, low  : parameters of F_WTA                                          */
/*              default values: low = 0.0, high = 0.0                        */
/*****************************************************************************/

int F_WTA (output, teaching_output, no_of_units, high, low)

  float  *output          ;
  float  *teaching_output ;
  int    no_of_units      ;
  float  high             ;
  float  low              ;


  {
    int    i, o_pos, t_pos, no_of_max ;
    float  min, max, max2 , t_max     ;


    o_pos     = 0                      ;
    t_pos     = 0                      ;
    t_max     = teaching_output[t_pos] ;
    max       = output[o_pos]          ;
    min       = max                    ;
    no_of_max = 1                      ;

    for (i = 1 ; i < no_of_units ; i++)
    { if (output[i] > max)
      {  max       = output[i] ; 
         o_pos     = i         ;
         no_of_max = 1         ;
      }
      else if (output[i] == max) no_of_max++     ;
      else if (output[i] <  min) min = output[i] ;

      if (teaching_output[i] > t_max)
      {  t_max = teaching_output[i] ;
         t_pos = i ; 
      } 
    }
   
    max2 = min  ;
    for (i = 1 ; i < no_of_units ; i++) 
      if ((output[i] > max2) && (output[i] < max)) max2 = output[i] ;

    if ((no_of_max == 1) && (output[o_pos] > high) && (max2  < max - low))
    {  if (o_pos == t_pos)
            return (RIGHT) ; 
       else return (WRONG) ;
    }
    else return (UNKNOWN) ;
  }



/*****************************************************************************/
/* The function F_band returns :                                             */
/*                                                                           */
/*    RIGHT   : for all units :                                              */
/*              the output is >= the teaching output - low              and  */
/*              the output is <= the teaching output + high                  */
/*                                                                           */
/*    WRONG   : for all units :                                              */
/*              the output is < the teaching output - low               or   */
/*              the output is > the teaching output + high                   */
/*                                                                           */
/*    UNKNOWN : in any other case                                            */
/*                                                                           */
/* high, low  : parameters of F_band                                         */
/*              default values: low = 0.1, high = 0.1                        */
/*****************************************************************************/

int F_band (output, teaching_output, no_of_units, high, low)

  float  *output          ;
  float  *teaching_output ;
  int    no_of_units      ;
  float  high             ;
  float  low              ;

{
    int found_right = 0   ;
    int found_wrong = 0   ;
    int i                 ;

    for (i = 0 ; i < no_of_units ; i++)
	{   
	if ((output[i] <= teaching_output[i] + high) &&
	    (output[i] >= teaching_output[i] - low))
	    {   
	    if (found_wrong)
		return (UNKNOWN);
	    else
		found_right = 1;
	}else if ((output[i] >= teaching_output[i] + high) ||
		 (output[i] <= teaching_output[i] - low))
	    {
	    if (found_right)
		return (UNKNOWN);
	    else
		found_wrong = 1;
	}
    }
    if(found_right)
	return (RIGHT);
    else
	return (WRONG);
}



/*****************************************************************************/
/*  function get_options                                                     */
/*****************************************************************************/

int get_options (argc, argv, function, sel_cond, output_text, high, low, 
		 statistics, class_statistics)

  int    argc         ;
  char   *argv[]      ;
  int    *function    ;
  int    *sel_cond    ;
  int    *output_text ;
  float  *high        ;
  float  *low         ;
  int    *statistics  ;
  int    *class_statistics  ;

  {
    int          c       ;
    extern char  *optarg ;
    int          error   ;
    int          hl_flag ;

    *function    = R402040       ;
    *sel_cond    = 0             ;
    *output_text = OFF           ;
    *statistics  = OFF           ;
    *class_statistics = OFF      ;
    in_file      = (FILE *) NULL ;
    out_file     = (FILE *) NULL ;
    hl_flag      = 0             ;

    error = 0 ;
   
    while ((c = getopt (argc, argv, "awruvce:i:o:h:l:s")) != -1)
      switch (c)
      {  
        case 'l' :
          sscanf (optarg, "%f", low) ;
          hl_flag = hl_flag | 1 ;
          break ;
        case 'h' :
          sscanf (optarg, "%f", high) ;
          hl_flag = hl_flag | 2 ;
          break ;
        case 'e' :
          if      (strcmp (optarg, "402040") == 0) 
                  *function = R402040 ;
          else if (strcmp (optarg, "WTA"   ) == 0)
                  *function = WTA    ;
          else if (strcmp (optarg, "band") == 0)
                  *function = band    ;
          else error++ ; 
          break ;
        case 'r' :
          *sel_cond = *sel_cond | RIGHT   ;
          break ;
        case 'w' :
          *sel_cond = *sel_cond | WRONG   ;
          break ;
        case 'u' :
          *sel_cond = *sel_cond | UNKNOWN ;
          break ;
        case 'a' :
          *sel_cond = WRONG | RIGHT | UNKNOWN ;
          break ;
        case 'v' :
          *output_text = ON ;
          break ;
        case 's' :
          *statistics  = ON ;
          break ;
        case 'c' :                  
          *class_statistics = ON ;
          break ;
        case 'i' :
          if ((in_file = fopen(optarg, "r")) == (FILE *) NULL)
          {  
             fprintf (stderr, "error:  can't read file %s \n", optarg) ;
             error++ ;
          } 
          break ;
        case 'o' :
          if ((out_file = fopen(optarg, "w")) == (FILE *) NULL)
          {  
             fprintf (stderr, "error:  can't create file %s\n", optarg) ;
             error++ ;
          } 
          break ;
        default  : error++ ;
      }

    if (*sel_cond == 0) *sel_cond = WRONG  ;

    if ((hl_flag & 1) == 0) 
    {  switch (*function)
       {
         case R402040 : *low  = 0.4 ; break ;
         case WTA     : *low  = 0.0 ; break ;
         case band    : *low  = 0.1 ; break ;
       }
    }

    if ((hl_flag & 2) == 0)
    {  switch (*function)
       {
         case R402040 : *high = 0.6 ; break ;
         case WTA     : *high = 0.0 ; break ;
         case band    : *high = 0.1 ; break ;
       }
    }
       
    if (in_file  == (FILE *) NULL) in_file  = stdin  ;
    if (out_file == (FILE *) NULL) out_file = stdout ;

    return (error) ;  
  }



/*****************************************************************************/
/* function read_file_header                                                 */
/*                                                                           */
/* reads from the input file :                                               */
/*     no. of patterns                                                       */
/*     no. of input units                                                    */
/*     no. of output units                                                   */
/*     startpattern                                                          */
/*     endpattern                                                            */
/*****************************************************************************/

int read_file_header (file_header_info)

  FileHeaderInfo  *file_header_info ;

  {
    char  str1[80], str2[80], str3[80] ;
    int   error ;

    error = 0 ;

    fscanf (in_file, "%s %s %s", str1, str2, str3) ;
    if ((strcmp (str1, "SNNS"  ) != 0) ||   
        (strcmp (str2, "result") != 0) ||  
        (strcmp (str3, "file"  ) != 0))   error++ ;   
      
    if (error == 0)
    {
       fscanf (in_file, "%*s %*s %*s %*s %*s %*s %*s %*s") ;

       fscanf (in_file, "%*s %*s %*s %*s     %d", 
	       &(file_header_info -> no_of_patterns)    ) ;
       fscanf (in_file, "%*s %*s %*s %*s %*s %d", 
	       &(file_header_info -> no_of_input_units) ) ;
       fscanf (in_file, "%*s %*s %*s %*s %*s %d", 
	       &(file_header_info -> no_of_output_units)) ;
       fscanf (in_file, "%*s %*s             %d", 
	       &(file_header_info -> startpattern      )) ;
       fscanf (in_file, "%*s %*s             %d", 
	       &(file_header_info -> endpattern        )) ;
          
       file_header_info -> sub_pattern_present = 
	   (file_header_info -> endpattern) 
	   - (file_header_info -> startpattern)
	   + 1 
	       != (file_header_info -> no_of_patterns);
       
       fscanf (in_file, "%s", str1) ;

       if (strcmp (str1, "input") == 0)
       {   
         fscanf (in_file, "%s %s %s", str2, str3, str1) ;
         if ((strcmp (str2, "patterns") == 0) &&  
             (strcmp (str3, "included") == 0)) 
              file_header_info -> input_pattern_included = 1 ;   
         else file_header_info -> input_pattern_included = 0 ;
       }
       else file_header_info -> input_pattern_included = 0 ;

       if (strcmp (str1, "teaching") == 0) 
       {
         fscanf (in_file, "%s %s %s", str2, str3, str1) ; 
         if ((strcmp (str2, "output"  ) == 0) &&  
             (strcmp (str3, "included") == 0)) 
              file_header_info -> teaching_output_included = 1 ;
         else
         {    file_header_info -> teaching_output_included = 0 ;
              fprintf (stderr, "error:  missing teaching_output \n") ;
              error++ ;
         }   
       }
       else
       {    file_header_info -> teaching_output_included = 0 ;
            fprintf (stderr, "error:  missing teaching_output \n") ;
            error++ ;
       }   

     }
     else 
     {
       fprintf (stderr, "error:  no SNNS result file\n") ;
     }
     return (error) ;
  }



/*****************************************************************************/
/* main program                                                              */
/*****************************************************************************/

main (argc, argv)

int   argc   ;
char  *argv[] ;

{
  int             pat_no              ;
  int             i, result           ;
  int             function            ;
  int             sel_cond            ;
  int             output_text         ;
  float           low, high           ;
  int             statistics          ;
  int             class_statistics    ; 
  int             class_no            ;
  int             right, wrong, 
                  unknown             ;
  float           error               ;
  float           *output             ;
  float           *teaching_output    ;
  int             *class_stat_wrong   ;
  int             *class_stat_right   ;
  int             *class_stat_unknown ;
  FileHeaderInfo  file_header_info    ;


  if (get_options (argc, argv, &function, &sel_cond, &output_text,
                   &high, &low, &statistics, &class_statistics) != 0) 
  {
    fprintf (stderr, "usage: %s [options]        \n", argv[0]) ;
    fprintf (stderr, "analyzes result files which are generated by SNNS\n");
    fprintf (stderr, "options are:\n");
    fprintf (stderr, 
      "\t-w               : report wrong classified patterns (default)\n") ;
    fprintf (stderr, 
      "\t-r               : report right classified patterns\n") ;
    fprintf (stderr, 
      "\t-u               : report unclassified patterns\n") ;
    fprintf (stderr, 
      "\t-a               : same as -w -r -u\n") ;
    fprintf (stderr, 
      "\t-s               : show statistic information\n") ;
    fprintf (stderr, 
      "\t-c               : show class statistic information\n") ; 
    fprintf (stderr, 
      "\t-v               : verbous mode\n");
    fprintf (stderr, 
      "\t-e <function>    : select error function \n") ;
    fprintf (stderr, 
      "\t                   <function> = [402040 | WTA | band]\n") ;
    fprintf (stderr,
      "\t                   default = 402040\n") ;
    fprintf (stderr, 
      "\t-l <float>       : lower bound level (see documentation) \n") ;
    fprintf (stderr, 
      "\t                   default: 0.4 for 402040\n") ;
    fprintf (stderr, 
      "\t                   default: 0.0 for WTA\n") ;
    fprintf (stderr, 
      "\t                   default: 0.1 for band\n") ;
    fprintf (stderr, 
      "\t-h <float>       : upper bound level (see documentation) \n") ;
    fprintf (stderr, 
      "\t                   default: 0.6 for 402040\n") ;
    fprintf (stderr, 
      "\t                   default: 0.0 for WTA\n") ;
    fprintf (stderr, 
      "\t                   default: 0.1 for band\n") ;
    fprintf (stderr, 
      "\t-i <input file>  : input result file (default stdin)\n");
    fprintf (stderr, 
      "\t-o <output file> : output file (default stdout)\n") ;
  }
  else
  {
    if (read_file_header (&file_header_info) == 0) 
    {
      output          = (float *) malloc (file_header_info.no_of_output_units 
					  * sizeof(float)) ;
      teaching_output = (float *) malloc (file_header_info.no_of_output_units 
					  * sizeof(float)) ;
      class_stat_wrong = (int *) malloc (file_header_info.no_of_output_units 
					  * sizeof(int)) ; 
      class_stat_right = (int *) malloc (file_header_info.no_of_output_units 
					  * sizeof(int)) ;
      class_stat_unknown = (int *) malloc (file_header_info.no_of_output_units 
					  * sizeof(int)) ;
      wrong   = 0   ;
      right   = 0   ;
      unknown = 0   ;
      error   = 0.0 ;
      
      for (i = 0; i < file_header_info.no_of_output_units; i++) 
      { 
         class_stat_wrong[i] = 0;   
         class_stat_right[i] = 0;   
         class_stat_unknown[i] = 0;   
      }

      for (pat_no = 0 ; pat_no < file_header_info.no_of_patterns ; pat_no++)
      {                
        if (file_header_info.input_pattern_included != 0)
	{ for (i = 1 ; i <= file_header_info.no_of_input_units ; i++)
	      fscanf (in_file, "%*f") ;
        }

        for (i = 0 ; i < file_header_info.no_of_output_units ; i++)
          fscanf (in_file, "%f", &teaching_output[i]) ;

        for (i = 0 ; i < file_header_info.no_of_output_units ; i++)
          fscanf (in_file, "%f", &output[i]) ;
        
        switch (function)
	{ 
          case R402040 :
            result = F_402040 (output, teaching_output, 
			       file_header_info.no_of_output_units, high, low);
            break ;
    	  case WTA    :
            result = F_WTA    (output, teaching_output, 
			       file_header_info.no_of_output_units, high, low);
            break ;
          case band :
            result = F_band   (output, teaching_output, 
			       file_header_info.no_of_output_units, high, low);
            break ;
        }

        class_no = ClassNo(teaching_output, 
                           file_header_info.no_of_output_units);

        switch (result)
	{
	  case WRONG   :  wrong++                        ; 
                          class_stat_wrong[class_no]++   ;
                          break                          ;  
          case RIGHT   :  right++                        ;
                          class_stat_right[class_no]++   ;
                          break                          ;
          case UNKNOWN :  unknown++                      ; 
                          class_stat_unknown[class_no]++ ;
                          break                          ;
	}

        if (statistics == OFF)
        {     
     	  if (output_text == OFF)
	  {
             if      ((sel_cond & WRONG)   == result) 
		fprintf(out_file, "%d\n", file_header_info.sub_pattern_present 
			? pat_no + 1 : pat_no + file_header_info.startpattern);
             else if ((sel_cond & RIGHT)   == result) 
		fprintf(out_file, "%d\n", file_header_info.sub_pattern_present
			? pat_no + 1 : pat_no + file_header_info.startpattern);
             else if ((sel_cond & UNKNOWN) == result)
		fprintf(out_file, "%d\n", file_header_info.sub_pattern_present 
			? pat_no + 1 : pat_no + file_header_info.startpattern);
	  }
          else
	  {
             if      ((sel_cond & WRONG)   == result) 
		fprintf(out_file, "wrong   : %d\n", 
			file_header_info.sub_pattern_present 
			? pat_no + 1 : pat_no + file_header_info.startpattern);
             else if ((sel_cond & RIGHT)   == result) 
		fprintf(out_file, "right   : %d\n", 
			file_header_info.sub_pattern_present 
			? pat_no + 1 : pat_no + file_header_info.startpattern);
             else if ((sel_cond & UNKNOWN) == result) 
		fprintf(out_file, "unknown : %d\n", 
			file_header_info.sub_pattern_present 
			? pat_no + 1 : pat_no + file_header_info.startpattern);
	  }
        }
        else error = error + Error (output, teaching_output, 
				    file_header_info.no_of_output_units) ;
	
        if (pat_no < file_header_info.no_of_patterns - 1) 
	    fscanf (in_file, "%*s") ;    
      }

      free (output)          ;
      free (teaching_output) ;

      if (statistics == ON)
      { 
         fprintf (out_file, "STATISTICS ( %d patterns )\n", 
		  file_header_info.no_of_patterns) ;
         fprintf (out_file, "wrong   : %5.2f %%  ( %d pattern(s) )\n", 
                  100.0*wrong  /file_header_info.no_of_patterns, wrong) ;
         fprintf (out_file, "right   : %5.2f %%  ( %d pattern(s) )\n",
                  100.0*right  /file_header_info.no_of_patterns, right) ;
         fprintf (out_file, "unknown : %5.2f %%  ( %d pattern(s) )\n",
                  100.0*unknown/file_header_info.no_of_patterns, unknown) ;
         fprintf (out_file, "error   : %f\n", error)  ;
      } 
      if (class_statistics == ON) 
      {
        fprintf (out_file,"\n\n");
        for (i = 0; i < file_header_info.no_of_output_units; i++)
	{
          fprintf (out_file,"STATISTICS FOR CLASS NO. : %d \n",i);
          fprintf (out_file,"wrong   : %5.2f %%  ( %d pattern(s) )\n",
                  100.0*class_stat_wrong[i]  /(class_stat_wrong[i]+
                  class_stat_right[i]+class_stat_unknown[i]), 
                  class_stat_wrong[i]) ;
          fprintf (out_file,"right   : %5.2f %%  ( %d pattern(s) )\n",
                  100.0*class_stat_right[i]  /(class_stat_wrong[i]+
                  class_stat_right[i]+class_stat_unknown[i]), 
                  class_stat_right[i]) ;
          fprintf (out_file,"unknown : %5.2f %%  ( %d pattern(s) )\n",
                  100.0*class_stat_unknown[i]/(class_stat_wrong[i]+
                  class_stat_right[i]+class_stat_unknown[i]), 
                  class_stat_unknown[i]) ;
          fprintf (out_file,"\n");
        }
      }
    }
    else
    {
      fprintf (stderr, "error:  invalid file header\n") ;
    }
  } 
  free (class_stat_wrong);
  free (class_stat_right);
  free (class_stat_unknown);

  if (in_file  != stdin)  fclose (in_file)  ;
  if (out_file != stdout) fclose (out_file) ;

}



/*****************************************************************************/
/* end of file                                                               */
/*****************************************************************************/
