/*
 * Copyright 2020 Michael Moessner
 *
 * This file is part of Tuner.
 *
 * Tuner is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Tuner is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Tuner.  If not, see <http://www.gnu.org/licenses/>.
 */

package de.moekadu.tuner.notedetection

import de.moekadu.tuner.misc.MemoryPool
import kotlin.math.pow

/** Class to compute auto correlation.
 * @param size Number of input values, which is a time series.
 * @param windowType Type of windowing to apply for FFT (first step of computing auto correlation).
 *   WindowingFunction.Tophat disables the windowing.
 */
class Correlation(val size : Int, val windowType : WindowingFunction = WindowingFunction.Tophat) {

  private val fft = FFT(2 * size)
  private val inputBitReversed = FloatArray(2 * size + 2)
  private val window = FloatArray(size)

  init {
    getWindow(windowType, size).copyInto(window)
  }

  /** Auto correlation of input.
   * @param input Input data which should be correlated (required size: size)
   * @param output Output array where we store the autocorrelation, (required size: size)
   * @param disableWindow if true, we disable windowing, even when it is defined in the constructor.
   * @param spectrum If a non-null array is given, we will store here to spectrum of zero-padded input (input is
   *   zero-padded to become twice the size, before we start correlating). If it is null, we will use the internal class
   *   storage. If it is non-null, the size of the spectrum must be 2*size+2.
   */
  fun correlate(input : FloatArray, output : FloatArray, disableWindow : Boolean = false, spectrum : FloatArray? = null) {
    require(input.size == size) {"input size must be equal to the size of the correlation size"}
    require(output.size == size) {"output  size must be correlation size"}
    if(spectrum != null) {
      require(spectrum.size == 2 * size + 2) { "output spectrum size must be 2*size+2" }
    }
    val spectrumStorage = spectrum ?: inputBitReversed

    if(windowType == WindowingFunction.Tophat || disableWindow) {
      for (i in 0 until size) {
        val ir2 = 2 * fft.bitReverseTable[i]
        val i2 = 2 * i
        if (i2 >= ir2) {
          spectrumStorage[i2] = if (ir2 < size) input[ir2] else 0f
          spectrumStorage[i2 + 1] = if (ir2 + 1 < size) input[ir2 + 1] else 0f
          spectrumStorage[ir2] = if (i2 < size) input[i2] else 0f
          spectrumStorage[ir2 + 1] = if (i2 + 1 < size) input[i2 + 1] else 0f
        }
      }
    }
    else {
      for (i in 0 until size) {
        val ir2 = 2 * fft.bitReverseTable[i]
        val i2 = 2 * i
        if (i2 >= ir2) {
          spectrumStorage[i2] = if (ir2 < size) window[ir2] * input[ir2] else 0f
          spectrumStorage[i2 + 1] = if (ir2 + 1 < size) window[ir2 + 1] * input[ir2 + 1] else 0f
          spectrumStorage[ir2] = if (i2 < size) window[i2] * input[i2] else 0f
          spectrumStorage[ir2 + 1] = if (i2 + 1 < size) window[i2 + 1] * input[i2 + 1] else 0f
        }
      }
    }
    fft.fftBitReversed(spectrumStorage)
    fft.combineFFTResultForRealFFT(spectrumStorage)

    for (i in 0 .. size) {
      inputBitReversed[2*i] = spectrumStorage[2*i].pow(2) + spectrumStorage[2*i+1].pow(2)
      inputBitReversed[2*i+1] = 0f
    }

    // it is allowed to pass the same variable here to do the ifft inplace
    fft.ifftReal(inputBitReversed, inputBitReversed)

    for(i in 0 until size)
      output[i] = inputBitReversed[i]
  }
}


class MemoryPoolCorrelation {
  private val pool = MemoryPool<Correlation>()

  fun get(size: Int, windowType: WindowingFunction) = pool.get(
    factory = { Correlation(size, windowType) },
    checker = { it.size == size && it.windowType == windowType }
  )
}
