///
/// This file is part of Rheolef.
///
/// Copyright (C) 2000-2009 Pierre Saramito <Pierre.Saramito@imag.fr>
///
/// Rheolef is free software; you can redistribute it and/or modify
/// it under the terms of the GNU General Public License as published by
/// the Free Software Foundation; either version 2 of the License, or
/// (at your option) any later version.
///
/// Rheolef is distributed in the hope that it will be useful,
/// but WITHOUT ANY WARRANTY; without even the implied warranty of
/// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
/// GNU General Public License for more details.
///
/// You should have received a copy of the GNU General Public License
/// along with Rheolef; if not, write to the Free Software
/// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
///
/// =========================================================================
//
// This code presents "mat-mat" expressions
//
// also print matlab code for non-regression test purpose, i.e.
//
// usage: 
//   gzip -d < in-animal-s2.hb.gz | blas3_tst | octave -q
//
//  machine-precision independent computation:
//   gzip -d < in-animal-s2.hb.gz | blas3_tst -ndigit 12 | octave -q
//
# include "rheolef/skit.h"
using namespace rheolef;
using namespace std;

bool do_trunc = false;

void check (const csr<Float>& b, const char* expr, int sub_prec = 1)
{
    static int i = 0;
    cout << "e1=" << expr    << ";\n";
    cout << "e2=" << ml << b << ";\n";

    if (!do_trunc) {
        cout << "error" << ++i << "=norm(e1-e2)\n\n";
    } else {
	if (sub_prec == 1) {
            cout << "error" << ++i << "=eps1*round(norm(e1-e2)/eps1)\n\n";
	} else {
	    cout << "eps2=eps1^(1.0/" << sub_prec << ");\n";
            cout << "error" << ++i << "=eps2*round(norm(e1-e2)/eps2)\n\n";
	}
    }
}
int main (int argc, char* argv[])
{	
    cout.setf(ios::scientific, ios::floatfield);
    int digits10 = numeric_limits<Float>::digits10;
    
    // avoid machine-dependent output digits in non-regression mode:
    if (argc == 3) {
	do_trunc = true;
        digits10 = atoi(argv[2]);
	cout << "eps1=sqrt(10^(-" << digits10 << "));\n";
    }
    cout << setprecision(digits10);

    csr<Float> a;
    cin >> a;
    cout << "a=" << ml << a << ";\n";

    vec<Float> dv(a.ncol());
    vec<Float>::iterator iter = dv.begin();
    vec<Float>::iterator last = dv.end();
    int i = 0;
    while (iter != last) {
	*iter = i; ++i; ++iter;
    }
    basic_diag<Float> d = basic_diag<Float>(dv);
    cout << "dv=" << ml << dv << ";\n";
    cout << "d=diag(dv);\n";

  // mult with diagonal  
    csr<Float> b;
    b = a*d;
    check (b, "b=a*d");

  // addition, substraction
    cout << "b=" << ml << b << ";\n";

    csr<Float> c;
    c = a+b;
    check (c, "c=a+b");

    c = a-b;
    check (c, "c=a-b");

    // check auto-reference
    a = a-b;
    check (a, "a=a-b");
    
    a = b-a;
    check (a, "a=b-a");
  
  // transpose
   
     csr<Float> at = trans(a);
     check (at, "at=a'");

  // multiplication

     csr<Float> c1;
     c1 = at*a;
     check (c1, "c1=a'*a", 2);

     csr<Float> c2 = a*at;
     check (c2, "c2=a*a'", 2);

    // check auto-reference
     c2 = c2*(a*at);
     check (c2, "c2=c2*(a*a')", 4);

     // note: here, need a copy
     c2 = (a*at)*c2;
     check (c2, "c2=(a*a')*c2", 8);

#ifdef TODO
    b.left_mult(d);
    check (b, "b=d*b");

    b = d*a;
    check (b, "b=d*a");

    b *= d;
    check (b, "b=b*d");
#endif // TODO
    return 0;
}
