package de.tadris.flang_lib.script

import java.io.File
import kotlin.math.abs
import kotlin.math.min

/**
 * Configuration for hyperparameter optimization.
 */
object OptimizerConfig {
    const val STEP_SIZE = 0.01  // step size
    const val MAX_ITERATIONS = 50  // Safety limit per parameter
    const val DATASET_FILE = "doc/eval_dataset_4.txt"
    const val CFLANG_PATH = "./cflang/cflang"
    const val CFLANG_DIR = "cflang"

    // Metric selector function - currently uses median error, but configurable
    val METRIC_SELECTOR: (EvalMetrics) -> Double = { it.medianError }
}

/**
 * Tracks the optimization history for a single parameter.
 */
data class ParameterHistory(
    val parameter: Parameter,
    val originalValue: Double,
    var bestValue: Double,
    var bestMetric: Double,
    val history: MutableList<Pair<Double, Double>> = mutableListOf(),
    var skipped: Boolean = false,
    var skipReason: String? = null
)

/**
 * Main hyperparameter optimization script.
 * Optimizes all evaluation function parameters by hill-climbing on the error metric.
 */
fun main() {
    println("=== Hyperparameter Optimization System ===\n")

    // Define all parameters to optimize
    val parameters = listOf(
        // Header file parameters (fast_evaluation.h)
        Parameter.DefineParameter(
            name = "EVAL_WEIGHT_MATRIX",
            filePath = "app/src/main/cpp/include/fast_evaluation.h",
            defineName = "EVAL_WEIGHT_MATRIX"
        ),
        Parameter.DefineParameter(
            name = "EVAL_WEIGHT_MOVEMENT",
            filePath = "app/src/main/cpp/include/fast_evaluation.h",
            defineName = "EVAL_WEIGHT_MOVEMENT"
        ),
        Parameter.DefineParameter(
            name = "EVAL_WEIGHT_PIECE_VALUE",
            filePath = "app/src/main/cpp/include/fast_evaluation.h",
            defineName = "EVAL_WEIGHT_PIECE_VALUE"
        ),
        Parameter.DefineParameter(
            name = "EVAL_WEIGHT_KINGS_EVAL",
            filePath = "app/src/main/cpp/include/fast_evaluation.h",
            defineName = "EVAL_WEIGHT_KINGS_EVAL"
        ),
        Parameter.DefineParameter(
            name = "EVAL_KINGS_WEIGHT_MULTIPLIER",
            filePath = "app/src/main/cpp/include/fast_evaluation.h",
            defineName = "EVAL_KINGS_WEIGHT_MULTIPLIER"
        ),

        // Piece value parameters (fast_evaluation.h)
        Parameter.DefineParameter(
            name = "PIECE_VALUE_PAWN",
            filePath = "app/src/main/cpp/include/fast_evaluation.h",
            defineName = "PIECE_VALUE_PAWN"
        ),
        Parameter.DefineParameter(
            name = "PIECE_VALUE_HORSE",
            filePath = "app/src/main/cpp/include/fast_evaluation.h",
            defineName = "PIECE_VALUE_HORSE"
        ),
        Parameter.DefineParameter(
            name = "PIECE_VALUE_ROOK",
            filePath = "app/src/main/cpp/include/fast_evaluation.h",
            defineName = "PIECE_VALUE_ROOK"
        ),
        Parameter.DefineParameter(
            name = "PIECE_VALUE_FLANGER",
            filePath = "app/src/main/cpp/include/fast_evaluation.h",
            defineName = "PIECE_VALUE_FLANGER"
        ),
        Parameter.DefineParameter(
            name = "PIECE_VALUE_UNI",
            filePath = "app/src/main/cpp/include/fast_evaluation.h",
            defineName = "PIECE_VALUE_UNI"
        ),
        Parameter.DefineParameter(
            name = "PIECE_VALUE_KING",
            filePath = "app/src/main/cpp/include/fast_evaluation.h",
            defineName = "PIECE_VALUE_KING"
        )
    )

    // Get baseline metrics
    println("Evaluating baseline...")
    println("Dataset: ${OptimizerConfig.DATASET_FILE}")
    println("Cflang: ${OptimizerConfig.CFLANG_PATH}")

    val baselineMetrics = try {
        calculateMetrics(
            OptimizerConfig.DATASET_FILE,
            OptimizerConfig.CFLANG_PATH
        )
    } catch (e: Exception) {
        println("ERROR: Exception during baseline evaluation: ${e.message}")
        e.printStackTrace()
        return
    }

    if (baselineMetrics == null) {
        println("ERROR: Failed to get baseline metrics (returned null)")
        return
    }

    val baselineMetric = OptimizerConfig.METRIC_SELECTOR(baselineMetrics)
    println("Baseline metric (median error): %.2f\n".format(baselineMetric))

    // Optimize each parameter
    val results = mutableListOf<ParameterHistory>()
    parameters.forEachIndexed { index, param ->
        println("[${index + 1}/${parameters.size}] Optimizing ${param.name}...")
        val history = optimizeParameter(param)
        results.add(history)
        println()
    }

    // Print final summary
    printOptimizationSummary(results, baselineMetric, baselineMetrics)
}

/**
 * Optimizes a single parameter using hill-climbing.
 */
fun optimizeParameter(param: Parameter): ParameterHistory {
    val originalValue = try {
        val value = param.getCurrentValue()
        println("  Original value: $value")
        value
    } catch (e: Exception) {
        println("  ERROR: Failed to read parameter value: ${e.message}")
        e.printStackTrace()
        return ParameterHistory(
            parameter = param,
            originalValue = 0.0,
            bestValue = 0.0,
            bestMetric = Double.MAX_VALUE,
            skipped = true,
            skipReason = "Failed to read value: ${e.message}"
        )
    }

    val history = ParameterHistory(
        parameter = param,
        originalValue = originalValue,
        bestValue = originalValue,
        bestMetric = Double.MAX_VALUE
    )

    // Create backup before making any changes
    val backups = mutableListOf<Pair<String, File>>()
    try {
        backups.add(param.filePath to backupFile(param.filePath))

        // Get baseline metric with current parameter value
        println("  Getting baseline metric with current value...")
        val baselineMetricValue = testParameterValue(param, originalValue)
        if (baselineMetricValue == null) {
            println("  ERROR: Failed to get baseline metric")
            history.skipped = true
            history.skipReason = "Failed to get baseline metric"
            return history
        }
        println("  Baseline metric: %.2f".format(baselineMetricValue))
        history.bestMetric = baselineMetricValue

        // Test both directions
        val upValue = originalValue * (1.0 + OptimizerConfig.STEP_SIZE)
        val downValue = originalValue * (1.0 - OptimizerConfig.STEP_SIZE)

        println("  Testing (%.3f) and (%.3f)...".format(upValue, downValue))

        val upMetric = testParameterValue(param, upValue)
        val downMetric = testParameterValue(param, downValue)

        // Determine direction
        when {
            upMetric == null || downMetric == null -> {
                println("  Compilation failed, skipping parameter")
                history.skipped = true
                history.skipReason = "Compilation failed"
                // Restore original value
                param.setValue(originalValue)
                runMake()
            }

            upMetric >= baselineMetricValue && downMetric >= baselineMetricValue -> {
                println("  Both directions worsen metric (baseline: %.2f, up: %.2f, down: %.2f)".format(baselineMetricValue, upMetric, downMetric))
                println("  Already at local optimum")
                history.skipped = true
                history.skipReason = "Already optimal"
                // Restore original value before moving to next parameter
                param.setValue(originalValue)
                runMake()
            }

            else -> {
                // Determine which direction is better
                val (direction, startValue, startMetric) = if (upMetric < downMetric) {
                    println("  Direction: UP (metric: %.2f vs %.2f)".format(upMetric, downMetric))
                    Triple(OptimizerConfig.STEP_SIZE, upValue, upMetric)
                } else {
                    println("  Direction: DOWN (metric: %.2f vs %.2f)".format(downMetric, upMetric))
                    Triple(-OptimizerConfig.STEP_SIZE, downValue, downMetric)
                }

                // Hill climb
                var currentValue = startValue
                var currentMetric = startMetric
                history.history.add(currentValue to currentMetric)

                var iteration = 1
                while (iteration < OptimizerConfig.MAX_ITERATIONS) {
                    val nextValue = currentValue * (1.0 + direction)

                    // Safety bounds
                    if (nextValue <= 0.0) {
                        println("  Stopping: value would become non-positive")
                        break
                    }

                    val nextMetric = testParameterValue(param, nextValue)

                    if (nextMetric == null) {
                        println("  Compilation failed at iteration $iteration")
                        break
                    }

                    if (nextMetric >= currentMetric) {
                        println("  Converged at iteration $iteration (metric: %.2f)".format(currentMetric))
                        break
                    }

                    currentValue = nextValue
                    currentMetric = nextMetric
                    history.history.add(currentValue to currentMetric)

                    println("  Iteration $iteration: value=%.3f, metric=%.2f".format(currentValue, currentMetric))
                    iteration++
                }

                if (iteration >= OptimizerConfig.MAX_ITERATIONS) {
                    println("  Stopped: reached maximum iterations")
                }

                // Set to best value
                history.bestValue = currentValue
                history.bestMetric = currentMetric
                param.setValue(currentValue)

                if (!runMake()) {
                    println("  ERROR: Failed to build with best value, restoring original")
                    param.setValue(originalValue)
                    runMake()
                    history.skipped = true
                    history.skipReason = "Failed to build with best value"
                }

                val improvement = originalValue - currentMetric
                val pctChange = ((currentValue - originalValue) / originalValue) * 100.0
                println("  Best value: %.3f (%.1f%% change, metric: %.2f, improvement: %.2f)"
                    .format(currentValue, pctChange, currentMetric, improvement))
            }
        }

    } catch (e: Exception) {
        e.printStackTrace()
        println("  ERROR: ${e.message}")
        history.skipped = true
        history.skipReason = "Exception: ${e.message}"

        // Restore from backup
        backups.forEach { (path, backup) ->
            restoreFromBackup(path, backup)
        }
        runMake()

    } finally {
        // Clean up backups if everything succeeded
        backups.forEach { (_, backup) ->
            backup.delete()
        }
    }

    return history
}

/**
 * Tests a parameter value by setting it, compiling, and evaluating.
 * Returns the metric value or null on failure.
 */
fun testParameterValue(param: Parameter, value: Double): Double? {
    // Set parameter
    try {
        param.setValue(value)
        // Verify the value was set correctly
        val actualValue = param.getCurrentValue()
        if (abs(actualValue - value) > 0.001) {
            println("    ERROR: Value verification failed. Expected $value, got $actualValue")
            return null
        }
    } catch (e: Exception) {
        println("    ERROR setting value: ${e.message}")
        e.printStackTrace()
        return null
    }

    // Compile
    if (!runMake()) {
        return null
    }

    // Evaluate
    val metrics = calculateMetrics(
        OptimizerConfig.DATASET_FILE,
        OptimizerConfig.CFLANG_PATH
    )

    if (metrics == null) {
        println("    ERROR: Evaluation failed")
        return null
    }

    return OptimizerConfig.METRIC_SELECTOR(metrics)
}

/**
 * Runs make clean && make in the cflang directory.
 * Returns true on success, false on failure.
 */
fun runMake(): Boolean {
    val workDir = File(OptimizerConfig.CFLANG_DIR)

    // Run make clean (suppress output)
    val cleanResult = ProcessBuilder("make", "clean")
        .directory(workDir)
        .redirectErrorStream(true)
        .start()

    cleanResult.inputStream.bufferedReader().readText()  // Consume output
    cleanResult.waitFor()

    // Run make
    val makeResult = ProcessBuilder("make")
        .directory(workDir)
        .redirectErrorStream(true)
        .start()

    val makeOutput = makeResult.inputStream.bufferedReader().readText()
    val exitCode = makeResult.waitFor()

    if (exitCode != 0) {
        println("    ERROR: Compilation failed (exit code $exitCode):")
        // Show all output on failure, not just error lines
        makeOutput.lines().take(20).forEach { line ->
            if (line.isNotBlank()) {
                println("    $line")
            }
        }
        if (makeOutput.lines().size > 20) {
            println("    ... (${makeOutput.lines().size - 20} more lines)")
        }
        return false
    }

    return true
}

/**
 * Prints a comprehensive summary of the optimization results.
 */
fun printOptimizationSummary(
    results: List<ParameterHistory>,
    baselineMetric: Double,
    baselineMetrics: EvalMetrics
) {
    println("\n" + "=".repeat(80))
    println("OPTIMIZATION SUMMARY")
    println("=".repeat(80))

    println("\nBaseline Metrics:")
    println("  Median error: %.2f".format(baselineMetrics.medianError))
    println("  MAE: %.2f".format(baselineMetrics.mae))
    println("  RMSE: %.2f".format(baselineMetrics.rmse))
    println("  Sign accuracy: %.1f%%%%".format(baselineMetrics.signAccuracy))
    println("  Correlation: %.3f".format(baselineMetrics.correlation))

    val changedParams = results.filter {
        !it.skipped && abs(it.bestValue - it.originalValue) > 0.001
    }
    val skippedParams = results.filter { it.skipped }

    println("\nParameters changed: ${changedParams.size}/${results.size}")
    println("Parameters skipped: ${skippedParams.size}/${results.size}")

    if (changedParams.isNotEmpty()) {
        println("\n--- Changed Parameters ---")
        changedParams.forEach { history ->
            val pctChange = ((history.bestValue - history.originalValue) /
                            history.originalValue) * 100.0

            println("\n${history.parameter.name}:")
            println("  Original:  %.3f".format(history.originalValue))
            println("  Optimized: %.3f (%+.1f%%%%)".format(history.bestValue, pctChange))
            println("  Metric:    %.2f".format(history.bestMetric))
            println("  Steps:     ${history.history.size}")
        }
    }

    if (skippedParams.isNotEmpty()) {
        println("\n--- Skipped Parameters ---")
        skippedParams.forEach { history ->
            println("${history.parameter.name}: ${history.skipReason}")
        }
    }

    // Get final metrics
    println("\nEvaluating final metrics...")
    val finalMetrics = calculateMetrics(
        OptimizerConfig.DATASET_FILE,
        OptimizerConfig.CFLANG_PATH
    )

    if (finalMetrics != null) {
        val finalMetric = OptimizerConfig.METRIC_SELECTOR(finalMetrics)
        val totalImprovement = baselineMetric - finalMetric
        val improvementPct = (totalImprovement / baselineMetric) * 100.0

        println("\nFinal Metrics:")
        println("  Median error: %.2f".format(finalMetrics.medianError))
        println("  MAE: %.2f".format(finalMetrics.mae))
        println("  RMSE: %.2f".format(finalMetrics.rmse))
        println("  Sign accuracy: %.1f%%%%".format(finalMetrics.signAccuracy))
        println("  Correlation: %.3f".format(finalMetrics.correlation))

        println("\n" + "-".repeat(80))
        println("Overall Improvement:")
        println("  Baseline metric:  %.2f".format(baselineMetric))
        println("  Final metric:     %.2f".format(finalMetric))
        println("  Improvement:      %+.2f (%+.1f%%%%)".format(totalImprovement, improvementPct))
        println("=".repeat(80))

        // Save detailed results to file
        saveResultsToFile(results, baselineMetric, finalMetric, "doc/optimization_results.txt")
        println("\nDetailed results saved to doc/optimization_results.txt")
    } else {
        println("\nERROR: Failed to get final metrics")
    }
}

/**
 * Saves optimization results to a file.
 */
fun saveResultsToFile(
    results: List<ParameterHistory>,
    baselineMetric: Double,
    finalMetric: Double,
    filePath: String
) {
    val file = File(filePath)
    file.bufferedWriter().use { writer ->
        writer.write("=== Hyperparameter Optimization Results ===\n\n")
        writer.write("Baseline metric: %.2f\n".format(baselineMetric))
        writer.write("Final metric: %.2f\n".format(finalMetric))
        writer.write("Improvement: %+.2f (%+.1f%%%%)\n\n".format(
            baselineMetric - finalMetric,
            ((baselineMetric - finalMetric) / baselineMetric) * 100.0
        ))

        writer.write("--- Parameter Changes ---\n\n")
        results.forEach { history ->
            writer.write("${history.parameter.name}:\n")
            writer.write("  Original: %.3f\n".format(history.originalValue))

            if (history.skipped) {
                writer.write("  Status: SKIPPED (${history.skipReason})\n")
            } else {
                val pctChange = ((history.bestValue - history.originalValue) /
                                history.originalValue) * 100.0
                writer.write("  Optimized: %.3f (%+.1f%%%%)\n".format(history.bestValue, pctChange))
                writer.write("  Best metric: %.2f\n".format(history.bestMetric))
                writer.write("  Iterations: ${history.history.size}\n")

                if (history.history.isNotEmpty()) {
                    writer.write("  History:\n")
                    history.history.forEach { (value, metric) ->
                        writer.write("    %.3f -> %.2f\n".format(value, metric))
                    }
                }
            }
            writer.write("\n")
        }
    }
}
