#include "tldevel.h"
#ifdef HAVE_OPENMP
#include <omp.h>
#endif

#ifdef HAVE_AVX2
#include <xmmintrin.h>
#include <mm_malloc.h>
#endif


#include "tlrng.h"

#include "bisectingKmeans.h"

#include "msa_struct.h"
/* #include "global.h" */
#include "task.h"
#include "sequence_distance.h"
#include "euclidean_dist.h"

/* #include "alignment.h" */
#include "pick_anchor.h"
#include "esl_stopwatch.h"

struct node{
        struct node* left;
        struct node* right;
        int id;
};

struct kmeans_result{
        int* sl;
        int* sr;
        int nl;
        int nr;
        float score;
};

static struct kmeans_result* alloc_kmeans_result(int num_samples);
static void free_kmeans_results(struct kmeans_result* k);

struct node* upgma(float **dm,int* samples, int numseq);
static struct node* alloc_node(void);

static int label_internal(struct node *n, int label);
static void create_tasks(struct node*n, struct aln_tasks* t);
/* static void create_tasks(struct node*n, struct aln_tasks* t); */


/* static int bisecting_kmeans_serial(struct msa *msa, struct node **ret_n, float **dm, int *samples, int num_samples); */
static int bisecting_kmeans(struct msa* msa, struct node** ret_n,
                            const float * const * dm,
                            int* samples, int num_samples);
/* static int bisecting_kmeans_parallel(struct msa* msa, struct node** ret_n, float** dm,int* samples, int num_samples); */

static int split(const float * const * dm, int *samples, int num_anchors, int num_samples,
                 int seed_pick, struct kmeans_result **ret);
static int split2(const float * const * dm,const int* samples, const int num_anchors,const int num_samples,const int seed_pick,struct kmeans_result** ret);

static inline int cmp_floats(const float a, const float b);

inline int cmp_floats(const float a, const float b)
{
        const float epsilon = 1e-6; // Set a small epsilon value for tolerance
        if (fabsf(a - b) < epsilon) {
                return 0; // Numbers are equal
        } else if (a > b) {
                return 1; // First number is bigger
        } else {
                return -1; // Second number is bigger
        }
}


int build_tree_kmeans_noisy(struct msa* msa, struct aln_tasks** tasks,
                            uint64_t seed, float noise_sigma)
{
        struct aln_tasks* t = NULL;
        struct node* root = NULL;
        float** dm = NULL;
        int* samples = NULL;
        int* anchors = NULL;
        int num_anchors;
        int numseq;
        int i;

        ASSERT(msa != NULL, "No alignment.");

        t = *tasks;
        if(!t){
                RUN(alloc_tasks(&t, msa->numseq));
        }
        numseq = msa->numseq;

        DECLARE_TIMER(timer);
        if(!msa->quiet){
                LOG_MSG("Calculating pairwise distances (noisy, seed=%lu)", (unsigned long)seed);
        }
        START_TIMER(timer);
        RUNP(anchors = pick_anchor(msa, &num_anchors));
        RUNP(dm = d_estimation(msa, anchors, num_anchors, 0));

        /* Add multiplicative Gaussian noise to distance matrix */
        if(seed != 0 && noise_sigma > 0.0f){
                struct rng_state* rng = NULL;
                rng = init_rng(seed);
                for(i = 0; i < numseq; i++){
                        for(int j = 0; j < num_anchors; j++){
                                double noise = tl_random_gaussian(rng, 1.0, (double)noise_sigma);
                                if(noise < 0.1) noise = 0.1;
                                dm[i][j] *= (float)noise;
                        }
                }
                free_rng(rng);
        }

        STOP_TIMER(timer);
        if(!msa->quiet){
                GET_TIMING(timer);
        }
        MFREE(anchors);

        MMALLOC(samples, sizeof(int) * numseq);
        for(i = 0; i < numseq; i++){
                samples[i] = i;
        }

        START_TIMER(timer);
        if(!msa->quiet){
                LOG_MSG("Building guide tree.");
        }

#ifdef HAVE_OPENMP
#pragma omp parallel
#pragma omp single nowait
#endif
        bisecting_kmeans(msa, &root, (const float * const *)dm, samples, numseq);

        STOP_TIMER(timer);
        if(!msa->quiet){
                GET_TIMING(timer);
        }
        label_internal(root, numseq);
        create_tasks(root, t);

        /* Compute per-sequence normalized mean distance (noisy variant). */
        if(msa->seq_distances == NULL){
                MMALLOC(msa->seq_distances, sizeof(float) * numseq);
        }
        for(i = 0; i < numseq; i++){
                float sum = 0.0f;
                int j;
                for(j = 0; j < num_anchors; j++){
                        sum += dm[i][j];
                }
                float mean_dist = sum / (float)num_anchors;
                float seq_len = (float)msa->sequences[i]->len;
                msa->seq_distances[i] = (seq_len > 0.0f) ? mean_dist / seq_len : 0.0f;
        }

        MFREE(root);
        for(i = 0; i < numseq; i++){
#ifdef HAVE_AVX2
                _mm_free(dm[i]);
#else
                MFREE(dm[i]);
#endif
        }
        MFREE(dm);
        DESTROY_TIMER(timer);
        return OK;
ERROR:
        return FAIL;
}

int build_tree_kmeans(struct msa* msa, struct aln_tasks** tasks)
{
        struct aln_tasks* t = NULL;
        struct node* root = NULL;
        float** dm = NULL;
        int* samples = NULL;
        int* anchors = NULL;
        int num_anchors;
        int numseq;

        int i;

        ASSERT(msa != NULL, "No alignment.");

        t = *tasks;
        if(!t){
                RUN(alloc_tasks(&t, msa->numseq));
        }
        numseq = msa->numseq;

        DECLARE_TIMER(timer);
        /* pick anchors . */
        if(!msa->quiet){
                LOG_MSG("Calculating pairwise distances");
        }
        START_TIMER(timer);
        RUNP(anchors = pick_anchor(msa, &num_anchors));

        RUNP(dm = d_estimation(msa, anchors, num_anchors,0));//les,int pair)


        STOP_TIMER(timer);
        if(!msa->quiet){
                GET_TIMING(timer);
        }
        MFREE(anchors);

        MMALLOC(samples, sizeof(int)* numseq);
        for(i = 0; i < numseq;i++){
                samples[i] = i;
        }

        START_TIMER(timer);
        if(!msa->quiet){
                LOG_MSG("Building guide tree.");
        }

        /* if(n_threads == 1){ */
        /*         RUN(bisecting_kmeans_serial(msa,&root, dm, samples, numseq)); */
        /* }else{ */
#ifdef HAVE_OPENMP
#pragma omp parallel
#pragma omp single nowait
#endif
        bisecting_kmeans(msa,&root, (const float * const *)dm, samples, numseq);
        /* } */

        STOP_TIMER(timer);
        if(!msa->quiet){
                GET_TIMING(timer);
        }
        label_internal(root, numseq);

        create_tasks(root, t);

        /* Compute per-sequence normalized mean distance to anchors.
           Used for distance-dependent gap penalty scaling. */
        if(msa->seq_distances == NULL){
                MMALLOC(msa->seq_distances, sizeof(float) * numseq);
        }
        for(i = 0; i < numseq; i++){
                float sum = 0.0f;
                int j;
                for(j = 0; j < num_anchors; j++){
                        sum += dm[i][j];
                }
                float mean_dist = sum / (float)num_anchors;
                float seq_len = (float)msa->sequences[i]->len;
                msa->seq_distances[i] = (seq_len > 0.0f) ? mean_dist / seq_len : 0.0f;
        }

        MFREE(root);
        for(i = 0; i < numseq; i++){
#ifdef HAVE_AVX2
                _mm_free(dm[i]);
#else
                MFREE(dm[i]);
#endif
        }
        MFREE(dm);
        DESTROY_TIMER(timer);
        return OK;
ERROR:
        return FAIL;
}

int bisecting_kmeans(struct msa* msa, struct node** ret_n, const float * const * dm,int* samples, int num_samples)
{
        struct kmeans_result* res_tmp = NULL;
        struct kmeans_result* best = NULL;
        /* struct kmeans_result** res = NULL; */
        struct node* n = NULL;
        int num_anchors = 0;

        int i,j;
        int tries = 40;
        /* int t_iter; */
        /* int r; */
        int* sl = NULL;
        int* sr = NULL;
        int num_l,num_r;

        /* LOG_MSG("num_samples: %d", num_samples); */
        num_anchors = MACRO_MIN(32, msa->numseq);

        if(num_samples < KALIGN_KMEANS_UPGMA_THRESHOLD){
                float** dm = NULL;
                RUNP(dm = d_estimation(msa, samples, num_samples,1));// anchors, num_anchors,1));
                n = upgma(dm,samples, num_samples);
                *ret_n = n;
                gfree(dm);
                MFREE(samples);
                return OK;
                //return n;
        }

        /* else if(num_samples < 1000){ */
        /*         RUN(bisecting_kmeans_serial(msa, &n, dm, samples, num_samples)); */
        /*         *ret_n = n; */
        /*         return OK; */
        /* } */


        best = NULL;
        res_tmp = NULL;
        struct kmeans_result* res[4];

        /* MMALLOC(res, sizeof(struct kmeans_result*) * 4); */
        for(i = 0; i < 4;i++){
                res[i] = NULL;
        }
        tries = MACRO_MIN(tries, num_samples);
        int step = num_samples / tries;
        int change = 0;
        for(i = 0;i < tries;i += 4){
                change = 0;

#ifdef HAVE_OPENMP
#pragma omp task shared(dm,samples,num_anchors, num_samples,i,step,res)
#endif
                split2(dm,samples,num_anchors, num_samples, (i)*step, &res[0]);
#ifdef HAVE_OPENMP
#pragma omp task shared(dm,samples,num_anchors, num_samples,i,step,res)
#endif
                split2(dm,samples,num_anchors, num_samples, (i+ 1)*step, &res[1]);
#ifdef HAVE_OPENMP
#pragma omp task shared(dm,samples,num_anchors, num_samples,i,step,res)
#endif
                split2(dm,samples,num_anchors, num_samples, (i+ 2)*step, &res[2]);
#ifdef HAVE_OPENMP
#pragma omp task shared(dm,samples,num_anchors, num_samples,i,step,res)
#endif
                split2(dm,samples,num_anchors, num_samples, (i+ 3)*step, &res[3]);
#ifdef HAVE_OPENMP
#pragma omp taskwait
#endif

                for(j = 0; j < 4;j++){
                        if(!best){
                                change++;
                                best = res[j];
                                res[j] = NULL;
                        }else{
                                if(best->score > res[j]->score){
                                        res_tmp = best;
                                        best = res[j];
                                        res[j] = res_tmp;
                                        /* LOG_MSG("Better!!! %f %f", res_tmp->score,best->score); */

                                        change++;
                                }
                        }
                }
                if(!change){
                        break;
                }
        }

        sl = best->sl;
        sr = best->sr;

        num_l = best->nl;
        num_r = best->nr;

        /* free_kmeans_results(res[0]); */
        /* free_kmeans_results(res[1]); */
        /* free_kmeans_results(res[2]); */
        /* free_kmeans_results(res[3]); */

        for(i = 0; i < 4;i++){
                free_kmeans_results(res[i]);
        }
        /* MFREE(res); */
        MFREE(best);

        MFREE(samples);
        n = alloc_node();

/* #ifdef HAVE_OPENMP */
/* #pragma omp parallel //num_threads(2) */
/* #pragma omp single nowait */
#ifdef HAVE_OPENMP
#pragma omp task shared(msa,n,dm)
#endif
        bisecting_kmeans(msa,&n->left, dm, sl, num_l);

#ifdef HAVE_OPENMP
#pragma omp task shared(msa,n,dm,num_anchors)
#endif
        bisecting_kmeans(msa,&n->right, dm, sr, num_r);

#ifdef HAVE_OPENMP
#pragma omp taskwait
#endif

        *ret_n =n;
        return OK;
ERROR:
        return FAIL;
}

/* int bisecting_kmeans_serial(struct msa* msa, struct node** ret_n, float** dm,int* samples, int num_samples) */
/* { */
/*         struct kmeans_result* res_tmp = NULL; */
/*         struct kmeans_result* best = NULL; */
/*         struct kmeans_result** res_ptr = NULL; */
/*         int num_anchors = 0; */
/*         struct node* n = NULL; */
/*         int i,j; */
/*         int tries = 40; */
/*         /\* int t_iter; *\/ */
/*         /\* int r; *\/ */
/*         int* sl = NULL; */
/*         int* sr = NULL; */
/*         int num_l,num_r; */

/*         num_anchors = MACRO_MIN(32, msa->numseq); */

/*         if(num_samples < KALIGN_KMEANS_UPGMA_THRESHOLD){ */
/*                 float** dm = NULL; */
/*                 RUNP(dm = d_estimation(msa, samples, num_samples,1));// anchors, num_anchors,1)); */
/*                 n = upgma(dm,samples, num_samples); */
/*                 *ret_n = n; */
/*                 gfree(dm); */
/*                 MFREE(samples); */
/*                 return OK; */
/*         } */

/*         best = NULL; */
/*         res_tmp = NULL; */

/*         MMALLOC(res_ptr, sizeof(struct kmeans_result*) * 4); */
/*         for(i = 0; i < 4;i++){ */
/*                 res_ptr[i] = NULL; */
/*         } */
/*         tries = MACRO_MIN(tries, num_samples); */
/*         int step = num_samples / tries; */
/*         int change = 0; */
/*         for(i = 0;i < tries;i += 4){ */
/*                 change = 0; */
/*                 for(j = 0; j < 4;j++){ */
/*                         split(dm,samples,num_anchors, num_samples, (i+ j)*step, &res_ptr[j]); */
/*                 } */

/*                 for(j = 0; j < 4;j++){ */
/*                         if(!best){ */
/*                                 change++; */
/*                                 best = res_ptr[j]; */
/*                                 res_ptr[j] = NULL; */
/*                         }else{ */
/*                                 if(best->score > res_ptr[j]->score){ */
/*                                         res_tmp = best; */
/*                                         best = res_ptr[j]; */
/*                                         res_ptr[j] = res_tmp; */
/*                                         /\* LOG_MSG("Better!!! %f %f", res_tmp->score,best->score); *\/ */

/*                                         change++; */
/*                                 } */
/*                         } */
/*                 } */
/*                 if(!change){ */
/*                         break; */
/*                 } */

/*         } */
/*         sl = best->sl; */
/*         sr = best->sr; */

/*         num_l = best->nl; */
/*         num_r = best->nr; */

/*         for(i = 0; i < 4;i++){ */
/*                 free_kmeans_results(res_ptr[i]); */
/*         } */
/*         MFREE(res_ptr); */
/*         MFREE(best); */

/*         MFREE(samples); */
/*         n = alloc_node(); */

/*         bisecting_kmeans_serial(msa,&n->left , dm, sl, num_l); */
/*         bisecting_kmeans_serial(msa,&n->right, dm, sr, num_r); */
/*         *ret_n = n; */
/*         return OK; */
/* ERROR: */
/*         return FAIL; */
/* } */

int split(const float * const * dm,int* samples, int num_anchors,int num_samples,int seed_pick,struct kmeans_result** ret)
{
        struct kmeans_result* res = NULL;
        int* sl = NULL;
        int* sr = NULL;
        int num_l,num_r;
        float* w = NULL;
        float* wl = NULL;
        float* wr = NULL;
        float* cl = NULL;
        float* cr = NULL;
        float dl = 0.0F;
        float dr = 0.0F;
        float score;
        int num_var;
        int i;
        int s;
        int j;
        int stop = 0;

        num_var = num_anchors / 8;
        if( num_anchors%8){
                num_var++;
        }
        num_var = num_var << 3;




#ifdef HAVE_AVX2
        wr = _mm_malloc(sizeof(float) * num_var,32);
        wl = _mm_malloc(sizeof(float) * num_var,32);
        cr = _mm_malloc(sizeof(float) * num_var,32);
        cl = _mm_malloc(sizeof(float) * num_var,32);
        w = _mm_malloc(sizeof(float) * num_var,32);
#else
        MMALLOC(wr,sizeof(float) * num_var);
        MMALLOC(wl,sizeof(float) * num_var);
        MMALLOC(cr,sizeof(float) * num_var);
        MMALLOC(cl,sizeof(float) * num_var);
        MMALLOC(w,sizeof(float) * num_var);
#endif

        if(*ret){
                res = *ret;
        }else{
                RUNP(res = alloc_kmeans_result(num_samples));
        }

        res->score = FLT_MAX;

        sl = res->sl;
        sr = res->sr;


        for(i = 0; i < num_var;i++){
                w[i] = 0.0F;
                wr[i] = 0.0F;
                wl[i] = 0.0F;
                cr[i] = 0.0F;
                cl[i] = 0.0F;
        }
        for(i = 0; i < num_samples;i++){
                s = samples[i];
                for(j = 0; j < num_anchors;j++){
                        w[j] += dm[s][j];
                }
        }

        for(j = 0; j < num_anchors;j++){
                w[j] /= (float)num_samples;
        }
        //r = tl_random_int(rng  , num_samples);
        //r = sel[t_iter];

        s = samples[seed_pick];
        /* LOG_MSG("Selected %d\n",s); */
        for(j = 0; j < num_anchors;j++){
                cl[j] = dm[s][j];
        }

        for(j = 0; j < num_anchors;j++){
                cr[j] = w[j] - (cl[j] - w[j]);
                fprintf(stdout,"BEGIN:   %e %e diff::: %f  %f\n", cl[j],cr[j], cl[j]-cr[j],w[j]);

        }

#ifdef HAVE_AVX2
        _mm_free(w);
#else
        MFREE(w);
#endif

        /* check if cr == cl - we have identical sequences  */
        s = 0;
        for(j = 0; j < num_anchors;j++){
                int res = cmp_floats(cl[j],cr[j]);
                /* if(fabsf(cl[j]-cr[j]) >  1.0E-6){ */
                /*         s = 1; */
                /*         break; */
                /* } */
                if(res != 0){
                        s++;
                }
        }

        fprintf(stdout,"S: %d\n",s);
        s = 0;
        for(j = 0; j < num_anchors;j++){
                /* int res = cmp_floats(dr,dl); */
                if(fabsf(cl[j]-cr[j]) >  1.0E-6){
                        s++;
                        //break;
                }
                /* if(res != 0){ */
                /*         s++; */
                /* } */
        }

        fprintf(stdout,"S: %d\n",s);
#ifdef HAVE_AVX2
        edist_256(cl,cr, num_anchors, &dr);
#else
        edist_serial(cl, cr, num_anchors, &dr);
#endif



        fprintf(stdout,"R/L Dist: %e  %e %d\n",dr,1e-6, dr < 1e-6);
        cmp_floats(cl[j],cr[j]);
        if(!s){
                score = 0.0F;
                num_l = 0;
                num_r = 0;
                /* The code below caused sequence sets of size 1 to be passed to clustering...  */
                /* sl[num_l] = samples[0]; */
                /* num_l++; */

                /* for(i =1 ; i <num_samples;i++){ */
                /*         sr[num_r] = samples[i]; */
                /*         num_r++; */
                /* } */
                for(i = 0; i < num_samples/2;i++){
                        sl[num_l] = samples[i];
                        num_l++;
                }
                for(i = num_samples/2; i < num_samples;i++){
                        sr[num_r] = samples[i];
                        num_r++;
                }
        }else{
                w = NULL;
                while(1){
                        stop++;
                        if(stop == 10000){
                                ERROR_MSG("Failed.");
                        }
                        num_l = 0;
                        num_r = 0;

                        for(i = 0; i < num_anchors;i++){
                                wr[i] = 0.0F;
                                wl[i] = 0.0F;
                        }
                        score = 0.0f;
                        for(i = 0; i < num_samples;i++){
                                s = samples[i];
#ifdef HAVE_AVX2
                                edist_256(dm[s], cl, num_anchors, &dl);
                                edist_256(dm[s], cr, num_anchors, &dr);
#else
                                edist_serial(dm[s], cl, num_anchors, &dl);
                                edist_serial(dm[s], cr, num_anchors, &dr);
#endif
                                score += MACRO_MIN(dl,dr);


                                int res = cmp_floats(dr,dl);
                                if(res == -1){
                                        w = wr;
                                        sr[num_r] = s;
                                        num_r++;
                                }else if (res == 1){
                                        w = wl;
                                        sl[num_l] = s;
                                        num_l++;
                                }else{
                                        /* Assign sequence to smaller group  */
                                        /* if(num_l < num_r){ */
                                        /*         w = wl; */
                                        /*         sl[num_l] = s; */
                                        /*         num_l++; */
                                        /* }else{ */
                                        /*         w = wr; */
                                        /*         sr[num_r] = s; */
                                        /*         num_r++; */
                                        /* } */
                                        if(i & 1){
                                                w = wr;
                                                sr[num_r] = s;
                                                num_r++;
                                        }else{
                                                w = wl;
                                                sl[num_l] = s;
                                                num_l++;
                                        }
                                }
                                fprintf(stdout,"%e %e nl: %d nr:%d\n", dl,dr, num_l, num_r);
                                for(j = 0; j < num_anchors;j++){
                                        w[j] += dm[s][j];
                                }
                        }

                        for(j = 0; j < num_anchors;j++){
                                wl[j] /= (float)num_l;
                                wr[j] /= (float)num_r;
                                fprintf(stdout,"ANCH: %f %f %d %d \n",wl[j],wr[j], num_l, num_r);
                                if(isnan(wl[j]) || isnan(wr[j])){
                                        exit(0);
                                }
                        }

                        s = 0;

                        for(j = 0; j < num_anchors;j++){
                                if(wl[j] != cl[j]){
                                        s = 1;
                                         break;
                                }
                                if(wr[j] != cr[j]){
                                        s = 1;
                                        break;
                                }
                        }


                        if(s){
                                w = cl;
                                cl = wl;
                                wl = w;

                                w = cr;
                                cr = wr;
                                wr = w;
                        }else{
                                break;
                        }
                }
        }

#ifdef HAVE_AVX2
        _mm_free(wr);
        _mm_free(wl);
        _mm_free(cr);
        _mm_free(cl);
#else
        MFREE(wr);
        MFREE(wl);
        MFREE(cr);
        MFREE(cl);
#endif

        res->nl =  num_l;
        res->nr =  num_r;
        res->score = score;
        *ret = res;
        return OK;
ERROR:
        return FAIL;
}

int split2(const float * const * dm,const int* samples, const int num_anchors,const int num_samples,const int seed_pick,struct kmeans_result** ret)
{
        struct kmeans_result* res = NULL;
        int* sl = NULL;
        int* sr = NULL;
        int num_l,num_r;
        float* w = NULL;
        float* wl = NULL;
        float* wr = NULL;
        float* cl = NULL;
        float* cr = NULL;
        float dl = 0.0F;
        float dr = 0.0F;
        float score;
        int num_var;
        int i;
        int s;
        int j;

        num_var = num_anchors / 8;
        if( num_anchors%8){
                num_var++;
        }
        num_var = num_var << 3;




#ifdef HAVE_AVX2
        wr = _mm_malloc(sizeof(float) * num_var,32);
        wl = _mm_malloc(sizeof(float) * num_var,32);
        cr = _mm_malloc(sizeof(float) * num_var,32);
        cl = _mm_malloc(sizeof(float) * num_var,32);
        w = _mm_malloc(sizeof(float) * num_var,32);
#else
        MMALLOC(wr,sizeof(float) * num_var);
        MMALLOC(wl,sizeof(float) * num_var);
        MMALLOC(cr,sizeof(float) * num_var);
        MMALLOC(cl,sizeof(float) * num_var);
        MMALLOC(w,sizeof(float) * num_var);
#endif

        if(*ret){
                res = *ret;
        }else{
                RUNP(res = alloc_kmeans_result(num_samples));
        }

        res->score = FLT_MAX;

        sl = res->sl;
        sr = res->sr;


        for(i = 0; i < num_var;i++){
                w[i] = 0.0F;
                wr[i] = 0.0F;
                wl[i] = 0.0F;
                cr[i] = 0.0F;
                cl[i] = 0.0F;
        }
        for(i = 0; i < num_samples;i++){
                s = samples[i];
                for(j = 0; j < num_anchors;j++){
                        w[j] += dm[s][j];
                }
        }

        for(j = 0; j < num_anchors;j++){
                w[j] /= (float)num_samples;
        }
        //r = tl_random_int(rng  , num_samples);
        //r = sel[t_iter];

        s = samples[seed_pick];
        /* LOG_MSG("Selected %d\n",s); */
        for(j = 0; j < num_anchors;j++){
                cl[j] = dm[s][j];
        }

        for(j = 0; j < num_anchors;j++){
                cr[j] = w[j] - (cl[j] - w[j]);

        }

#ifdef HAVE_AVX2
        _mm_free(w);
#else
        MFREE(w);
#endif

        w = NULL;
        for(int stop = 0; stop < 500; stop++){
                num_l = 0;
                num_r = 0;

                for(i = 0; i < num_anchors;i++){
                        wr[i] = 0.0F;
                        wl[i] = 0.0F;
                }
                score = 0.0f;
                for(i = 0; i < num_samples;i++){
                        s = samples[i];
#ifdef HAVE_AVX2
                        edist_256(dm[s], cl, num_anchors, &dl);
                        edist_256(dm[s], cr, num_anchors, &dr);
#else
                        edist_serial(dm[s], cl, num_anchors, &dl);
                        edist_serial(dm[s], cr, num_anchors, &dr);
#endif
                        score += MACRO_MIN(dl,dr);


                        int res = cmp_floats(dr,dl);
                        if(res == -1){
                                w = wr;
                                sr[num_r] = s;
                                num_r++;
                        }else if (res == 1){
                                w = wl;
                                sl[num_l] = s;
                                num_l++;
                        }else{
                                if(i & 1){
                                        w = wr;
                                        sr[num_r] = s;
                                        num_r++;
                                }else{
                                        w = wl;
                                        sl[num_l] = s;
                                        num_l++;
                                }
                        }
                        for(j = 0; j < num_anchors;j++){
                                w[j] += dm[s][j];
                        }
                }
                if(num_l == 0 || num_r == 0){
                        score = 0.0F;
                        num_l = 0;
                        num_r = 0;

                        for(i = 0; i < num_samples/2;i++){
                                sl[num_l] = samples[i];
                                num_l++;
                        }
                        for(i = num_samples/2; i < num_samples;i++){
                                sr[num_r] = samples[i];
                                num_r++;
                        }
                        break;
                }

                for(j = 0; j < num_anchors;j++){
                        wl[j] /= (float)num_l;
                        wr[j] /= (float)num_r;
                }

                s = 0;

                for(j = 0; j < num_anchors;j++){
                        int res = cmp_floats(wl[j],cl[j]);
                        if(res != 0){
                                s = 1;
                                break;
                        }
                        res = cmp_floats(wr[j],cr[j]);
                        if(res != 0){
                                s = 1;
                                break;
                        }
                }

                if(s){
                        w = cl;
                        cl = wl;
                        wl = w;

                        w = cr;
                        cr = wr;
                        wr = w;
                }else{
                        break;
                }
        }

#ifdef HAVE_AVX2
        _mm_free(wr);
        _mm_free(wl);
        _mm_free(cr);
        _mm_free(cl);
#else
        MFREE(wr);
        MFREE(wl);
        MFREE(cr);
        MFREE(cl);
#endif

        res->nl =  num_l;
        res->nr =  num_r;
        res->score = score;
        *ret = res;
        return OK;
ERROR:
        return FAIL;
}


struct node* upgma(float **dm,int* samples, int numseq)
{
        struct node** tree = NULL;
        struct node* tmp = NULL;

        int i,j;
        int *as = NULL;

        float max;
        int node_a = 0;
        int node_b = 0;
        int cnode = numseq;
        int numprofiles;


        numprofiles = (numseq << 1) - 1;

        MMALLOC(as,sizeof(int)*numseq);
        for (i = numseq; i--;){
                as[i] = i+1;
        }


        MMALLOC(tree,sizeof(struct node*)*numseq);
        for (i = 0;i < numseq;i++){
                tree[i] = NULL;
                tree[i] = alloc_node();
                tree[i]->id = samples[i];
        }

        while (cnode != numprofiles){
                max = FLT_MAX;
                for (i = 0;i < numseq-1; i++){
                        if (as[i]){
                                for ( j = i + 1;j < numseq;j++){
                                        if (as[j]){
                                                if (dm[i][j] < max){
                                                        max = dm[i][j];
                                                        node_a = i;
                                                        node_b = j;
                                                }
                                        }
                                }
                        }
                }
                tmp = NULL;
                tmp = alloc_node();
                tmp->left = tree[node_a];
                tmp->right = tree[node_b];


                tree[node_a] = tmp;
                tree[node_b] = NULL;

                /*deactivate  sequences to be joined*/
                as[node_a] = cnode+1;
                as[node_b] = 0;


                cnode++;

                /*calculate new distances*/
                for (j = numseq;j--;){
                        if (j != node_b){
                                dm[node_a][j] = (dm[node_a][j] + dm[node_b][j])*0.5F + 0.001F;
                        }
                        //fprintf(stdout,"\n");
                }
                dm[node_a][node_a] = 0.0F;
                for (j = numseq;j--;){
                        dm[j][node_a] = dm[node_a][j];
                }
        }
        tmp = tree[node_a];
        MFREE(tree);
        MFREE(as);
        return tmp;
ERROR:
        return NULL;
}

struct node* alloc_node(void)
{
        struct node* n = NULL;
        MMALLOC(n, sizeof(struct node));
        n->left = NULL;
        n->right = NULL;
        n->id = -1;
        return n;
ERROR:
        return NULL;
}

int label_internal(struct node*n, int label)
{
        //n->d = d;
        if(n->left){
                label = label_internal(n->left, label);
        }
        if(n->right){
                label = label_internal(n->right, label);
        }
        if(n->id == -1){
                n->id = label;
                label++;
        }
        return label;

}

void create_tasks(struct node*n, struct aln_tasks* t)
{


        if(n->left && n->right){
                struct task* task;

                task = t->list[t->n_tasks];
                task->a = n->left->id;
                task->b = n->right->id;
                task->c = n->id;
                /* task->p = depth; */
                /* task->p = n->d; */
                /* task->n = n->n; */
                /* fprintf(stdout,"Node %d   depends on %d %d\n", n->id , n->left->id, n->right->id); */

                t->n_tasks++;
        }
        if(n->left){
                create_tasks(n->left,t);
        }
        if(n->right){
                create_tasks(n->right,t);
        }
        if(n->left){
                if(n->right){
                        MFREE(n->left);
                        MFREE(n->right);
                }
        }
}


struct kmeans_result* alloc_kmeans_result(int num_samples)
{
        struct kmeans_result* k = NULL;
        ASSERT(num_samples != 0, "No samples???");

        MMALLOC(k, sizeof(struct kmeans_result));

        k->nl = 0;
        k->nr = 0;
        k->sl = NULL;
        k->sr = NULL;
        MMALLOC(k->sl, sizeof(int) * num_samples);
        MMALLOC(k->sr, sizeof(int) * num_samples);
        k->score = FLT_MAX;
        return k;
ERROR:
        free_kmeans_results(k);
        return NULL;
}

void free_kmeans_results(struct kmeans_result* k)
{
        if(k){
                if(k->sl){
                        MFREE(k->sl);
                }
                if(k->sr){
                        MFREE(k->sr);
                }
                MFREE(k);
        }
}

int build_tree_from_pairwise(struct msa* msa, struct aln_tasks** tasks, float** dm)
{
        struct aln_tasks* t = NULL;
        struct node* root = NULL;
        int* samples = NULL;
        int numseq;
        int i, j;

        ASSERT(msa != NULL, "No alignment.");
        ASSERT(dm != NULL, "No distance matrix.");

        t = *tasks;
        if(!t){
                RUN(alloc_tasks(&t, msa->numseq));
        }
        numseq = msa->numseq;

        /* Compute per-sequence mean pairwise distance (for gap/VSM scaling).
           Must be done BEFORE upgma() which modifies dm in place. */
        if(msa->seq_distances == NULL){
                MMALLOC(msa->seq_distances, sizeof(float) * numseq);
        }
        for(i = 0; i < numseq; i++){
                float sum = 0.0f;
                for(j = 0; j < numseq; j++){
                        if(j != i) sum += dm[i][j];
                }
                msa->seq_distances[i] = (numseq > 1) ? sum / (float)(numseq - 1) : 0.0f;
        }

        MMALLOC(samples, sizeof(int) * numseq);
        for(i = 0; i < numseq; i++){
                samples[i] = i;
        }

        /* Build UPGMA tree from pairwise distances.
           Note: upgma() modifies dm in-place. */
        root = upgma(dm, samples, numseq);
        ASSERT(root != NULL, "UPGMA tree construction failed.");

        label_internal(root, numseq);
        create_tasks(root, t);

        MFREE(samples);
        MFREE(root);
        *tasks = t;
        return OK;
ERROR:
        if(samples) MFREE(samples);
        return FAIL;
}
