/*
    Theseus - maximum likelihood superpositioning of macromolecular structures

    Copyright (C) 2004-2009 Douglas L. Theobald

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the:

    Free Software Foundation, Inc.,
    59 Temple Place, Suite 330,
    Boston, MA  02111-1307  USA

    -/_|:|_|_\-
*/

#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include "CovMat.h"
#include "HierarchVars.h"
#include "pdbMalloc.h"
#include "pdbStats.h"
#include "pdbIO.h"
#include "pdbUtils.h"
#include "ProcLAPACKSVD.h"
#include "ProcLAPACKSVDOcc.h"
#include "CovMat.h"
#include "MultiPose.h"
#include "MultiPose2MSA.h"
#include "DLTmath.h"
#include "theseuslib.h"


int
test_charmm(CoordsArray *cdsA)
{
    int         nensem, natomx, ls, cnum, vlen;
    double     *xbuf = NULL, *ybuf = NULL, *zbuf = NULL;
    int         i;

    nensem = cnum = cdsA->cnum;
    natomx = vlen = cdsA->vlen;

    xbuf = malloc(nensem*natomx*sizeof(double));
    ybuf = malloc(nensem*natomx*sizeof(double));
    zbuf = malloc(nensem*natomx*sizeof(double));

    //cdsA = CoordsArrayInit();

    for (i = 0; i < cnum; ++i)
    {
        memcpy(&xbuf[i*vlen], cdsA->coords[i]->x, vlen*sizeof(double));
        memcpy(&ybuf[i*vlen], cdsA->coords[i]->y, vlen*sizeof(double));
        memcpy(&zbuf[i*vlen], cdsA->coords[i]->z, vlen*sizeof(double));
    }

    theseus_(xbuf, ybuf, zbuf, &natomx, &nensem, &ls);

    for (i = 0; i < cnum; ++i)
    {
        memcpy(cdsA->coords[i]->x, &xbuf[i*vlen], vlen*sizeof(double));
        memcpy(cdsA->coords[i]->y, &ybuf[i*vlen], vlen*sizeof(double));
        memcpy(cdsA->coords[i]->z, &zbuf[i*vlen], vlen*sizeof(double));
    }

    free(xbuf);
    free(ybuf);
    free(zbuf);

    return(1);
}


/* Expects xbuf, etc. to be allocated vectors of length natomx*nensem */
int
theseus_(double *xbuf, double *ybuf, double *zbuf, int *natomx, int *nensem, int *ls)
{
    CoordsArray    *cdsA = NULL;
    int         cnum = *nensem, vlen = *natomx;
    int         i, j;

    cdsA = CoordsArrayInit();

    if (*ls == 1)
    {
        cdsA->algo->leastsquares = 1;
        cdsA->algo->varweight = 0;
        cdsA->algo->hierarch = 0;
        cdsA->algo->LedoitWolf = 0;
    }

    CoordsArrayAlloc(cdsA, cnum, vlen);

    /* DLT debug - shouldn't need to do this, I think prob is in SuperPose() */
    for (j = 0; j < cnum; ++j)
        for (i = 0; i < vlen; ++i)
            cdsA->coords[j]->o[i] = 1.0;

    for (i = 0; i < cnum; ++i)
    {
        memcpy(cdsA->coords[i]->x, &xbuf[i*vlen], vlen*sizeof(double));
        memcpy(cdsA->coords[i]->y, &ybuf[i*vlen], vlen*sizeof(double));
        memcpy(cdsA->coords[i]->z, &zbuf[i*vlen], vlen*sizeof(double));
    }

    MultiPoseLib(cdsA);

    for (i = 0; i < cnum; ++i)
    {
        memcpy(&xbuf[i*vlen], cdsA->coords[i]->x, vlen*sizeof(double));
        memcpy(&ybuf[i*vlen], cdsA->coords[i]->y, vlen*sizeof(double));
        memcpy(&zbuf[i*vlen], cdsA->coords[i]->z, vlen*sizeof(double));
    }

    CoordsArrayDestroy(&cdsA);

    return(1);
}


/* Expects xbuf, etc. to be allocated vectors of length natomx*nensem */
int
theseus2_(double *xbuf, double *ybuf, double *zbuf, int *natomx, int *nensem, int *ls)
{
    CoordsArray    *cdsA = NULL;
    int         cnum = *nensem, vlen = *natomx;
    int         i, j;

    cdsA = CoordsArrayInit();

    if (*ls == 1)
    {
        cdsA->algo->leastsquares = 1;
        cdsA->algo->varweight = 0;
        cdsA->algo->hierarch = 0;
        cdsA->algo->LedoitWolf = 0;
    }

    CoordsArrayAlloc(cdsA, cnum, vlen);

    /* DLT debug - shouldn't need to do this, I think prob is in SuperPose() */
    for (j = 0; j < cnum; ++j)
        for (i = 0; i < vlen; ++i)
            cdsA->coords[j]->o[i] = 1.0;

    for (i = 0; i < cnum; ++i)
    {
        free(cdsA->coords[i]->x);
        free(cdsA->coords[i]->y);
        free(cdsA->coords[i]->z);
    
        cdsA->coords[i]->x = &xbuf[i*vlen];
        cdsA->coords[i]->y = &ybuf[i*vlen];
        cdsA->coords[i]->z = &zbuf[i*vlen];
    }

    MultiPoseLib(cdsA);

    for (i = 0; i < cnum; ++i)
    {
        cdsA->coords[i]->x = NULL;
        cdsA->coords[i]->y = NULL;
        cdsA->coords[i]->z = NULL;
    }

    CoordsArrayDestroy(&cdsA);

    return(1);
}


void
CalcS2(CoordsArray *cdsA, const int nsell, double *bxij, double *byij, double *bzij, 
       double *rij, double *s2, const int whoiam)
{
    int         i, j, k, m;
    int         cnum = cdsA->cnum;
    double     *x = NULL, *y = NULL, *z = NULL;
    double      sx2, sy2, sz2, sxy, sxz, syz, xij, yij, zij, rijk;

    for (k = 0; k < nsell; ++k)
    {
        sx2 = sy2 = sz2 = sxy = sxz = syz = 0.0;
        rij[k] = 0.0;
        for (m = 0; m < cnum; ++m)
        {
            x = cdsA->coords[m]->x;
            y = cdsA->coords[m]->y;
            z = cdsA->coords[m]->z;

            i = k;
            j = nsell + k;
            xij = x[i] - x[j];
            yij = y[i] - y[j];
            zij = z[i] - z[j];

            if (m == whoiam)
            {
                bxij[k] = xij;
                byij[k] = yij;
                bzij[k] = zij;
            }

            rij[k] += sqrt(xij*xij + yij*yij + zij*zij);

            sx2 += xij*xij;
            sy2 += yij*yij;
            sz2 += zij*zij;
            sxy += xij*yij;
            sxz += xij*zij;
            syz += yij*zij;
        }

        sx2 /= cnum;
        sy2 /= cnum;
        sz2 /= cnum;
        sxy /= cnum;
        sxz /= cnum;
        syz /= cnum;

        rij[k] /= cnum;
        rijk = rij[k];

        //        order parameter:
        s2[k] = (1.5/(rijk*rijk*rijk*rijk)) * (sx2*sx2 + sy2*sy2 + sz2*sz2 + 
              2.0*(sxy*sxy + sxz*sxz + syz*syz)) - 0.5;

        //        components of force:
        xij = bxij[k];
        yij = byij[k];
        zij = bzij[k];
        bxij[k] = sx2*xij + sxy*yij + sxz*zij;
        byij[k] = sy2*yij + sxy*xij + syz*zij;
        bzij[k] = sz2*zij + sxz*xij + syz*yij;

        if (isnan(s2[k]))
        {
            printf("ERRORTH1> %3d: %8.3e %8.3e %8.3e %8.3e %8.3e %8.3e %8.3e\n",
                   k, rijk, sx2, sy2, sz2, sxy, sxz, syz);
            printf("ERRORTH2> %3d: %8.3e %8.3e %8.3e \n", k, xij, yij, zij);
        }
    }
}


/* Expects xbuf, etc. to be allocated vectors of length natomx*nensem */
/* Expects bxij, byij, bzij, il, jl, rij, s2, etc. to be length nsell */
int
theseuss2_(const double *xbuf, const double *ybuf, const double *zbuf,
           const int *natomx, const int *nensem, const int *nsell,
           double *bxij, double *byij, double *bzij,
           const int *il, const int *jl, double *rij, double *s2,
           const int *ls, const int *whoiam)
{
    CoordsArray    *cdsA = NULL;
    int         cnum = *nensem, len = *natomx, vlen = *nsell*2;
    int         i, j, ilj, jlj;

    cdsA = CoordsArrayInit();

    if (*ls == 1)
    {
        cdsA->algo->leastsquares = 1;
        cdsA->algo->varweight = 0;
        cdsA->algo->hierarch = 0;
        cdsA->algo->LedoitWolf = 0;
    }
//printf(" ENSS2ML>: cnum=%d len=%d vlen=%d nsell=%d\n", cnum, len, vlen, *nsell);
//fflush(NULL);
    CoordsArrayAlloc(cdsA, cnum, vlen);

    /* DLT debug - shouldn't need to do this, I think prob is in SuperPose() */
    for (j = 0; j < cnum; ++j)
        for (i = 0; i < vlen; ++i)
            cdsA->coords[j]->o[i] = 1.0;

    for (i = 0; i < cnum; ++i)
    {
        for (j = 0; j < *nsell; ++j)
        {
            ilj = il[j] - 1;
            cdsA->coords[i]->x[j] = xbuf[i*len + ilj];
            cdsA->coords[i]->y[j] = ybuf[i*len + ilj];
            cdsA->coords[i]->z[j] = zbuf[i*len + ilj];

            jlj = jl[j] - 1;
            cdsA->coords[i]->x[*nsell + j] = xbuf[i*len + jlj];
            cdsA->coords[i]->y[*nsell + j] = ybuf[i*len + jlj];
            cdsA->coords[i]->z[*nsell + j] = zbuf[i*len + jlj];
        }
    }

    MultiPoseLib(cdsA);

    CalcS2(cdsA, vlen, bxij, byij, bzij, rij, s2, *whoiam);
//VecPrint(bxij, *nsell);
//VecPrint(rij, *nsell);
//fflush(NULL);
//VecPrint(s2, *nsell);
//fflush(NULL);
//WriteCoordsFile(cdsA->coords[0], "charmm.pdb");
    CoordsArrayDestroy(&cdsA);

    return(1);
}


/* A version of MultiPose for a general library, very pared down */
int
MultiPoseLib(CoordsArray *cdsA)
{
    int             i, round, innerround, slxn; /* index of random coord to select as first */
    const int       cnum = cdsA->cnum;
    Algorithm      *algo = NULL;
    Statistics     *stats = NULL;
    Coords        **coords = NULL;
    Coords         *avecoords = NULL;

    /* setup cdsA */
    CoordsArraySetup(cdsA);

    /* setup local aliases based on cdsA */
    algo = cdsA->algo;
    stats = cdsA->stats;
    coords = cdsA->coords;
    avecoords = cdsA->avecoords;

    if (algo->covweight == 1)
        SetupCovWeighting(cdsA); /* DLT debug */

    stats->hierarch_p1 = 1.0;
    stats->hierarch_p2 = 1.0;
    algo->constant = 0.001;

    /* randomly select a structure to use as the initial mean structure */
    slxn = (int) (genrand_real2() * cnum);
    CoordsCopyAll(avecoords, cdsA->coords[slxn]);

    if (algo->bfact > 0)
    {
        for (i = 0; i < cnum; ++i)
            Bfacts2PrVars(cdsA, i);
    }

    if (algo->alignment == 1)
        CalcDf(cdsA);

//    WriteCoordsFile(cdsA->coords[0], "charmm_inp0.pdb");
//    WriteCoordsFile(cdsA->coords[1], "charmm_inp1.pdb");

    /* The EM algorithm */
    /* The outer loop:
       (1) First calculates the translations
       (2) Does inner loop -- calc rotations and average till convergence
       (3) Holding the superposition constant, calculates the covariance
           matrices and corresponding weight matrices, looping till 
           convergence when using a dimensional/axial covariance matrix */
    round = 0;
    while(1)
    {
        ++round;
        algo->rounds = round;

        /* Estimate Translations: Find weighted center and translate all coords */
        CalcTranslationsIp(cdsA, algo);

        for (i = 0; i < cnum; ++i)
            ApplyCenterIp(coords[i]);

        /* save the translation vector for each coord in the array */
        for (i = 0; i < cnum; ++i)
            memcpy(coords[i]->translation, coords[i]->center, 3 * sizeof(double));

        /* Inner loop:
           (1) Calc rotations given weights/weight matrices
           (2) Rotate coords with new rotations
           (3) Recalculate average

           Loop till convergence, holding constant the variances, covariances, and translations */
        innerround = 0;
        do
        {
            ++innerround;

            /* find the optimal rotation matrices */
            if (algo->alignment == 1)
                CalcRotationsOcc(cdsA);
            else
                CalcRotations(cdsA);

            if ((innerround == 1) && (CheckConvergenceOuter(cdsA, round, algo->precision) == 1))
                goto outsidetheloops;

            /* rotate the scratch coords with new rotation matrix */
            for (i = 0; i < cnum; ++i)
                RotateCoordsIp(coords[i], (const double **) coords[i]->matrix);

            /* find global rmsd and average coords (both held in structure) */
            if (algo->alignment == 1)
            {
                AveCoordsOcc(cdsA);
                EM_MissingCoords(cdsA);
            }
            else
            {
                AveCoords(cdsA);
            }
        }
        while((CheckConvergenceInner(cdsA, algo->precision) == 0) && (innerround < 160));

        /* Holding the superposition constant, calculate the covariance
           matrix and corresponding weight matrix, looping till convergence. */
        CalcCovariances(cdsA);

//        VecPrint(cdsA->var, cdsA->vlen);

        /* calculate the weights/weight matrices */
        CalcWts(cdsA);
    }

    outsidetheloops:

//    WriteCoordsFile(cdsA->coords[0], "charmm_out.pdb");

//    CalcStats(cdsA);
    printf(" ENSS2ML>: THESEUS rounds: %d\n", round);

    return(round);
}


int
MultiPoseLibSimp(CoordsArray *cdsA)
{
    int             i, round, innerround, slxn; /* index of random coord to select as first */
    const int       cnum = cdsA->cnum;
    Algorithm      *algo = NULL;
    Statistics     *stats = NULL;
    Coords        **coords = NULL;
    Coords         *avecoords = NULL;

    /* setup cdsA */
    CoordsArraySetup(cdsA);

    /* setup local aliases based on cdsA */
    algo = cdsA->algo;
    stats = cdsA->stats;
    coords = cdsA->coords;
    avecoords = cdsA->avecoords;

    stats->hierarch_p1 = 1.0;
    stats->hierarch_p2 = 1.0;
    algo->constant = 0.001;

    /* randomly select a structure to use as the initial mean structure */
    slxn = (int) (genrand_real2() * cnum);
    CoordsCopyAll(avecoords, cdsA->coords[slxn]);

    /* The CEM algorithm */
    /* The outer loop:
       (1) First calculates the translations
       (2) Does inner loop -- calc rotations and average till convergence
       (3) Holding the superposition constant, calculates the covariance
           matrices and corresponding weight matrices */
    round = 0;
    while(1)
    {
        ++round;
        algo->rounds = round;

        /* Estimate Translations: Find weighted center and translate all coords */
        CalcTranslationsIp(cdsA, algo);

        for (i = 0; i < cnum; ++i)
            ApplyCenterIp(coords[i]);

        /* save the translation vector for each coord in the array */
        for (i = 0; i < cnum; ++i)
            memcpy(coords[i]->translation, coords[i]->center, 3 * sizeof(double));

        /* Inner loop:
           (1) Calc rotations given weights/weight matrices
           (2) Rotate coords with new rotations
           (3) Recalculate average

           Loop till convergence, holding constant the variances, covariances, and translations */
        innerround = 0;
        do
        {
            ++innerround;

            /* find the optimal rotation matrices */
            CalcRotations(cdsA);

            if ((innerround == 1) && (CheckConvergenceOuter(cdsA, round, algo->precision) == 1))
                return(round);

            /* rotate the scratch coords with new rotation matrix */
            for (i = 0; i < cnum; ++i)
                RotateCoordsIp(coords[i], (const double **) coords[i]->matrix);

            /* find global rmsd and average coords (both held in structure) */
            AveCoords(cdsA);
        }
        while((CheckConvergenceInner(cdsA, algo->precision) == 0) && (innerround < 160));

        /* Holding the superposition constant, calculate the covariance
           matrix and corresponding weight matrix, looping till convergence. */
        CalcCovariances(cdsA);

        /* calculate the weights/weight matrices */
        CalcWts(cdsA);
    }

    return(round);
}

