package xyz.lepisma.harp.screens.metrics

import androidx.compose.foundation.Canvas
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.geometry.Offset
import androidx.compose.ui.geometry.Size
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.Path
import androidx.compose.ui.graphics.StrokeCap
import androidx.compose.ui.graphics.StrokeJoin
import androidx.compose.ui.graphics.drawscope.Stroke
import androidx.compose.ui.graphics.drawscope.withTransform
import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.text.TextStyle
import androidx.compose.ui.text.drawText
import androidx.compose.ui.text.rememberTextMeasurer
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.unit.Dp
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import kotlinx.datetime.LocalDateTime
import kotlinx.datetime.TimeZone
import kotlinx.datetime.format
import kotlinx.datetime.format.DayOfWeekNames
import kotlinx.datetime.format.MonthNames
import kotlinx.datetime.format.char
import kotlinx.datetime.toInstant
import xyz.lepisma.harp.data.Metric
import xyz.lepisma.harp.data.MetricValue
import xyz.lepisma.harp.data.to2d
import kotlin.time.ExperimentalTime

data class YTick (
    val value: Float,
    val label: String,
    val showLine: Boolean
)

data class XTick (
    val value: LocalDateTime,
    val label: String
)

fun getYTicks(mvs: List<MetricValue>, metric: Metric?): List<YTick> {
    // If metric metadata is not present, just give 4 ticks using min and max value
    // If range is present, show the full range, on plot along with min and max

    val minVal = mvs.minOf { it.value }
    val maxVal = mvs.maxOf { it.value }

    val yLow = metric?.range?.first ?: minVal
    val yHigh = metric?.range?.second ?: maxVal

    // There could be cases where the values are outside of range, that's not handled well
    // right now.

    return listOf(yLow, minVal, maxVal, yHigh).map { YTick(it, it.to2d(), true) }
}

fun getXTicks(mvs: List<MetricValue>): List<XTick> {
    return listOf(mvs.first(), mvs.last()).map {
        XTick(
            value = it.datetime,
            label = it.datetime.format(LocalDateTime.Format {
                monthName(MonthNames.ENGLISH_ABBREVIATED)
                char(' ')
                dayOfMonth()
                chars(", ")
                year()
            })
        )
    }
}

/**
 * Make a time series plot for given metricValues, use metric metadata if present
 */
@OptIn(ExperimentalTime::class)
@Composable
fun MetricsPlot(
    mvs: List<MetricValue>,
    metric: Metric?,
    modifier: Modifier = Modifier.Companion,
    padding: Dp = 24.dp,
    lineColor: Color = MaterialTheme.colorScheme.primary,
    dotColor: Color = MaterialTheme.colorScheme.primary,
    gridColor: Color = Color(0xFFD1D5DB),
    axisColor: Color = Color(0xFF4B5563)
) {
    Box(modifier = modifier) {
        val textMeasurer = rememberTextMeasurer()
        val textColor = MaterialTheme.colorScheme.outline

        if (mvs.isEmpty()) return

        val yTicks = getYTicks(mvs, metric)
        val xTicks = getXTicks(mvs)

        Canvas(modifier = Modifier.Companion.fillMaxSize()) {
            val left = padding.toPx()
            val right = size.width - padding.toPx()
            val top = padding.toPx()
            val bottom = size.height - padding.toPx()

            drawLine(
                axisColor, start = Offset(left, top), end = Offset(left, bottom),
                strokeWidth = 1.5f
            )

            val yLow = yTicks.first().value
            val yHigh = yTicks.last().value

            yTicks.forEach { yt ->
                val norm = (yt.value - yLow) / (yHigh - yLow)
                val y = bottom - norm * (bottom - top)

                val layout = textMeasurer.measure(
                    text = AnnotatedString(yt.label),
                    style = TextStyle(
                        fontSize = 10.sp,
                        textAlign = TextAlign.Companion.Right,
                        color = textColor
                    )
                )

                drawText(
                    textLayoutResult = layout,
                    topLeft = Offset(left - layout.size.width - 8f, y - layout.size.height / 2f)
                )
            }

            xTicks.forEach { xt ->
                val i = mvs.map { it.datetime }.indexOf(xt.value)

                val fraction = i.toFloat() / (mvs.lastIndex.coerceAtLeast(1))
                val x = left + fraction * (right - left)

                val layout = textMeasurer.measure(
                    text = AnnotatedString(xt.label),
                    style = TextStyle(
                        fontSize = 10.sp,
                        textAlign = TextAlign.Companion.Center,
                        color = textColor
                    )
                )

                drawText(
                    textLayoutResult = layout,
                    topLeft = Offset(x - layout.size.width / 2f, bottom + 20f)
                )
            }

            val spanV = (yHigh - yLow).takeIf { it > 0f } ?: 1f

            yTicks.filter { it.showLine }.forEach { yt ->
                val yNorm = (yt.value - yLow) / spanV
                val y = bottom - yNorm * (bottom - top)
                drawLine(gridColor, Offset(left, y), Offset(right, y), strokeWidth = 1f)
            }

            val times = mvs.map { mv ->
                mv.datetime.toInstant(TimeZone.Companion.currentSystemDefault())
                    .toEpochMilliseconds()
            }

            val minT = times.minOrNull() ?: 0L
            val maxT = times.maxOrNull() ?: 1L
            val spanT = (maxT - minT).takeIf { it > 0 } ?: 1L

            val path = Path()
            mvs.forEachIndexed { idx, mv ->
                val t = times[idx]
                val xNorm = (t - minT).toFloat() / spanT.toFloat()
                val x = left + xNorm * (right - left)

                val yNorm = (mv.value - yLow) / spanV
                val y = bottom - yNorm * (bottom - top)

                if (idx == 0) path.moveTo(x, y) else path.lineTo(x, y)

                drawCircle(
                    color = dotColor,
                    radius = 6f,
                    center = Offset(x, y)
                )
            }

            val healthyLow = metric?.healthyRange?.first
            val healthyHigh = metric?.healthyRange?.second

            if (healthyLow != null && healthyHigh != null) {
                drawRect(
                    color = Color.Green,
                    topLeft = Offset(
                        x = left,
                        y = bottom - ((healthyHigh - yLow) / spanV) * (bottom - top)
                    ),
                    alpha = 0.2f,
                    size = Size(
                        width = right - left,
                        height = ((healthyHigh - healthyLow) / spanV) * (bottom - top)
                    )
                )
            }

            drawPath(
                path = path,
                color = lineColor,
                style = Stroke(
                    width = 3f,
                    cap = StrokeCap.Companion.Round,
                    join = StrokeJoin.Companion.Round
                )
            )
        }
    }
}