/*
  This file is part of CDO. CDO is a collection of Operators to
  manipulate and analyse Climate model Data.

  Copyright (C) 2003-2019 Uwe Schulzweida, <uwe.schulzweida AT mpimet.mpg.de>
  See COPYING file for copying and redistribution conditions.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; version 2 of the License.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
*/
#include <cdi.h>
#include "cdo_int.h"
#include "specspace.h"
#include "grid.h"

void geninx(long ntr, double *f, double *g);
void scaluv(double *fu, const double *rclat, int nlat, int lot);
void uv2dv(double *fu, double *fv, double *sd, double *sv, double *pol2, double *pol3, long klev, long nlat, long nt);
void dv2uv(const double *d, const double *o, double *u, double *v, double *f, double *g, long nt, long nsp, long nlev);

void after_legini_full(long ntr, long nlat, double *restrict poli, double *restrict pold, double *restrict pdev,
                       double *restrict pol2, double *restrict pol3, double *restrict coslat);
void after_legini(long ntr, long nlat, double *restrict poli, double *restrict pold, double *restrict coslat);

void
grid2spec(SPTRANS *sptrans, int gridIDin, double *arrayIn, int gridIDout, double *arrayOut)
{
  long nlev = 1;
  long ntr = gridInqTrunc(gridIDout);
  long nlon = gridInqXsize(gridIDin);
  long nlat = gridInqYsize(gridIDin);
  long waves = ntr + 1;
  long nfc = waves * 2;

  std::vector<double> fpwork(nlat * nfc * nlev);

  gp2fc(sptrans->trig, sptrans->ifax, arrayIn, fpwork.data(), nlat, nlon, nlev, nfc);
  fc2sp(fpwork.data(), arrayOut, sptrans->pold, nlev, nlat, nfc, ntr);
}

void
spec2grid(SPTRANS *sptrans, int gridIDin, double *arrayIn, int gridIDout, double *arrayOut)
{
  long nlev = 1;
  long ntr = gridInqTrunc(gridIDin);
  long nlon = gridInqXsize(gridIDout);
  long nlat = gridInqYsize(gridIDout);
  long waves = ntr + 1;
  long nfc = waves * 2;

  std::vector<double> fpwork(nlat * nfc * nlev);

  sp2fc(arrayIn, fpwork.data(), sptrans->poli, nlev, nlat, nfc, ntr);
  fc2gp(sptrans->trig, sptrans->ifax, fpwork.data(), arrayOut, nlat, nlon, nlev, nfc);
}

void
four2spec(SPTRANS *sptrans, int gridIDin, double *arrayIn, int gridIDout, double *arrayOut)
{
  (void) gridIDin;
  long nlev = 1;
  long ntr = gridInqTrunc(gridIDout);
  long nlat = sptrans->nlat;
  long waves = ntr + 1;
  long nfc = waves * 2;

  fc2sp(arrayIn, arrayOut, sptrans->pold, nlev, nlat, nfc, ntr);
}

void
spec2four(SPTRANS *sptrans, int gridIDin, double *arrayIn, int gridIDout, double *arrayOut)
{
  long nlev = 1;
  long ntr = gridInqTrunc(gridIDin);
  long nfc = gridInqSize(gridIDout);
  long nlat = nfc_to_nlat(nfc, ntr);
  long waves = ntr + 1;
  nfc = waves * 2;

  sp2fc(arrayIn, arrayOut, sptrans->poli, nlev, nlat, nfc, ntr);
}

void
four2grid(SPTRANS *sptrans, int gridIDin, double *arrayIn, int gridIDout, double *arrayOut)
{
  long nlev = 1;
  long ntr = gridInqTrunc(gridIDin);
  long nlon = gridInqXsize(gridIDout);
  long nlat = gridInqYsize(gridIDout);
  long waves = ntr + 1;
  long nfc = waves * 2;

  fc2gp(sptrans->trig, sptrans->ifax, arrayIn, arrayOut, nlat, nlon, nlev, nfc);
}

void
grid2four(SPTRANS *sptrans, int gridIDin, double *arrayIn, int gridIDout, double *arrayOut)
{
  long nlev = 1;
  long ntr = gridInqTrunc(gridIDout);
  long nlon = gridInqXsize(gridIDin);
  long nlat = gridInqYsize(gridIDin);
  long waves = ntr + 1;
  long nfc = waves * 2;

  gp2fc(sptrans->trig, sptrans->ifax, arrayIn, arrayOut, nlat, nlon, nlev, nfc);
}

void
spec2spec(int gridIDin, double *arrayIn, int gridIDout, double *arrayOut)
{
  long ntrIn = gridInqTrunc(gridIDin);
  long ntrOut = gridInqTrunc(gridIDout);

  sp2sp(arrayIn, ntrIn, arrayOut, ntrOut);
}

void
speccut(int gridIDin, double *arrayIn, double *arrayOut, int *waves)
{
  long ntr = gridInqTrunc(gridIDin);

  spcut(arrayIn, arrayOut, ntr, waves);
}

SPTRANS *
sptrans_new(long nlon, long nlat, long ntr, int flag)
{
  SPTRANS *sptrans = (SPTRANS *) Malloc(sizeof(SPTRANS));

  sptrans->nlon = nlon;
  sptrans->nlat = nlat;
  sptrans->ntr = ntr;

  long nsp = (ntr + 1) * (ntr + 2);
  sptrans->poldim = nsp / 2 * nlat;

  sptrans->trig = (double *) Malloc(nlon * sizeof(double));
  fft_set(sptrans->trig, sptrans->ifax, nlon);

  sptrans->poli = (double *) Malloc(sptrans->poldim * sizeof(double));
  sptrans->pold = (double *) Malloc(sptrans->poldim * sizeof(double));

  if (flag)
    {
      sptrans->pol2 = (double *) Malloc(sptrans->poldim * sizeof(double));
      sptrans->pol3 = (double *) Malloc(sptrans->poldim * sizeof(double));
    }
  else
    {
      sptrans->pol2 = nullptr;
      sptrans->pol3 = nullptr;
    }

  sptrans->coslat = (double *) Malloc(nlat * sizeof(double));
  sptrans->rcoslat = (double *) Malloc(nlat * sizeof(double));

  if (flag)
    after_legini_full(ntr, nlat, sptrans->poli, sptrans->pold, nullptr, sptrans->pol2, sptrans->pol3, sptrans->coslat);
  else
    after_legini(ntr, nlat, sptrans->poli, sptrans->pold, sptrans->coslat);

  for (long jgl = 0; jgl < nlat; ++jgl) sptrans->rcoslat[jgl] = 1.0 / sptrans->coslat[jgl];

  return sptrans;
}

void
sptrans_delete(SPTRANS *sptrans)
{
  if (sptrans)
    {
      if (sptrans->trig)
        {
          Free(sptrans->trig);
          sptrans->trig = nullptr;
        }
      if (sptrans->poli)
        {
          Free(sptrans->poli);
          sptrans->poli = nullptr;
        }
      if (sptrans->pold)
        {
          Free(sptrans->pold);
          sptrans->pold = nullptr;
        }
      if (sptrans->pol2)
        {
          Free(sptrans->pol2);
          sptrans->pol2 = nullptr;
        }
      if (sptrans->pol3)
        {
          Free(sptrans->pol3);
          sptrans->pol3 = nullptr;
        }
      if (sptrans->coslat)
        {
          Free(sptrans->coslat);
          sptrans->coslat = nullptr;
        }
      if (sptrans->rcoslat)
        {
          Free(sptrans->rcoslat);
          sptrans->rcoslat = nullptr;
        }

      Free(sptrans);
      sptrans = nullptr;
    }
}

DVTRANS *
dvtrans_new(long ntr)
{
  DVTRANS *dvtrans = (DVTRANS *) Malloc(sizeof(DVTRANS));

  dvtrans->ntr = ntr;

  long dimsp = (ntr + 1) * (ntr + 2);
  dvtrans->fdim = dimsp / 2;

  dvtrans->f1 = (double *) Malloc(dvtrans->fdim * sizeof(double));
  dvtrans->f2 = (double *) Malloc(dvtrans->fdim * sizeof(double));

  geninx(ntr, dvtrans->f1, dvtrans->f2);

  return dvtrans;
}

void
dvtrans_delete(DVTRANS *dvtrans)
{
  if (dvtrans)
    {
      if (dvtrans->f1)
        {
          Free(dvtrans->f1);
          dvtrans->f1 = nullptr;
        }
      if (dvtrans->f2)
        {
          Free(dvtrans->f2);
          dvtrans->f2 = nullptr;
        }

      Free(dvtrans);
      dvtrans = nullptr;
    }
}

void
trans_uv2dv(SPTRANS *sptrans, long nlev, int gridID1, double *gu, double *gv, int gridID2, double *sd, double *svo)
{
  if (gridInqType(gridID1) != GRID_GAUSSIAN)
    cdoAbort("unexpected grid1 type: %s instead of Gaussian", gridNamePtr(gridInqType(gridID1)));

  if (gridInqType(gridID2) != GRID_SPECTRAL)
    cdoAbort("unexpected grid2 type: %s instead of spectral", gridNamePtr(gridInqType(gridID2)));

  long ntr = gridInqTrunc(gridID2);
  long nlon = gridInqXsize(gridID1);
  long nlat = gridInqYsize(gridID1);
  long waves = ntr + 1;
  long nfc = waves * 2;

  std::vector<double> fpwork1(nlat * nfc * nlev);
  std::vector<double> fpwork2(nlat * nfc * nlev);

  gp2fc(sptrans->trig, sptrans->ifax, gu, fpwork1.data(), nlat, nlon, nlev, nfc);
  gp2fc(sptrans->trig, sptrans->ifax, gv, fpwork2.data(), nlat, nlon, nlev, nfc);

  scaluv(fpwork1.data(), sptrans->coslat, nlat, nfc * nlev);
  scaluv(fpwork2.data(), sptrans->coslat, nlat, nfc * nlev);

  uv2dv(fpwork1.data(), fpwork2.data(), sd, svo, sptrans->pol2, sptrans->pol3, nlev, nlat, ntr);
}

void
trans_dv2uv(SPTRANS *sptrans, DVTRANS *dvtrans, long nlev, int gridID1, double *sd, double *svo, int gridID2, double *gu, double *gv)
{
  if (gridInqType(gridID1) != GRID_SPECTRAL) Warning("unexpected grid1 type: %s", gridNamePtr(gridInqType(gridID1)));
  if (gridInqType(gridID2) != GRID_GAUSSIAN) Warning("unexpected grid2 type: %s", gridNamePtr(gridInqType(gridID2)));

  const long ntr = gridInqTrunc(gridID1);
  const long nlon = gridInqXsize(gridID2);
  const long nlat = gridInqYsize(gridID2);
  const long waves = ntr + 1;
  const long nfc = waves * 2;
  const long dimsp = (ntr + 1) * (ntr + 2);

  double *su = gu;
  double *sv = gv;

  dv2uv(sd, svo, su, sv, dvtrans->f1, dvtrans->f2, ntr, dimsp, nlev);

  std::vector<double> fpwork(nlat * nfc * nlev);

  sp2fc(su, fpwork.data(), sptrans->poli, nlev, nlat, nfc, ntr);
  scaluv(fpwork.data(), sptrans->rcoslat, nlat, nfc * nlev);
  fc2gp(sptrans->trig, sptrans->ifax, fpwork.data(), gu, nlat, nlon, nlev, nfc);

  sp2fc(sv, fpwork.data(), sptrans->poli, nlev, nlat, nfc, ntr);
  scaluv(fpwork.data(), sptrans->rcoslat, nlat, nfc * nlev);
  fc2gp(sptrans->trig, sptrans->ifax, fpwork.data(), gv, nlat, nlon, nlev, nfc);
}
