package io.ente.mobile_ocr

import android.graphics.Bitmap
import android.graphics.PointF
import ai.onnxruntime.*
import java.nio.FloatBuffer
import java.util.ArrayDeque
import kotlin.math.*

data class DetectionCandidate(
    val box: TextBox,
    val score: Float
)

data class DetectionStageSummary(
    val examinedDetections: Int,
    val maxDetectionScore: Float?,
    val candidates: List<DetectionCandidate>
)

class TextDetector(
    private val session: OrtSession,
    private val ortEnv: OrtEnvironment
) {
    companion object {
        private const val LIMIT_SIDE_LEN = 960
        private const val THRESH = 0.3f
        private const val BOX_THRESH = 0.6f
        private const val UNCLIP_RATIO = 1.5f
        private const val MIN_SIZE = 3
        private const val MAX_CANDIDATES = 1000
        private const val EPSILON = 1e-6f
    }

    fun detect(bitmap: Bitmap): List<TextBox> {
        val boxes = mutableListOf<TextBox>()
        runDetection(bitmap) { box, _ ->
            boxes.add(box)
            false
        }

        if (boxes.isEmpty()) {
            return emptyList()
        }

        return sortBoxes(boxes)
    }

    fun collectHighConfidenceDetections(
        bitmap: Bitmap,
        minimumDetectionConfidence: Float,
        maxCandidates: Int
    ): DetectionStageSummary {
        var examined = 0
        var maxScore = Float.NEGATIVE_INFINITY
        val candidates = mutableListOf<DetectionCandidate>()

        runDetection(bitmap) { box, score ->
            examined++
            if (score > maxScore) {
                maxScore = score
            }
            val meetsThreshold = score >= minimumDetectionConfidence
            if (meetsThreshold) {
                candidates.add(DetectionCandidate(box, score))
                candidates.size >= maxCandidates
            } else {
                false
            }
        }

        val bestScore = if (examined == 0) null else maxScore
        return DetectionStageSummary(
            examinedDetections = examined,
            maxDetectionScore = bestScore,
            candidates = candidates
        )
    }

    private fun runDetection(
        bitmap: Bitmap,
        handler: (TextBox, Float) -> Boolean
    ) {
        val originalWidth = bitmap.width
        val originalHeight = bitmap.height

        val (inputTensor, resizedWidth, resizedHeight) = preprocessImage(bitmap)

        var output: OnnxTensor? = null
        try {
            val inputs = mapOf("x" to inputTensor)
            output = session.run(inputs)[0] as OnnxTensor

            postprocessDetection(
                output = output,
                originalWidth = originalWidth,
                originalHeight = originalHeight,
                resizedWidth = resizedWidth,
                resizedHeight = resizedHeight,
                handler = handler
            )
        } finally {
            output?.close()
            inputTensor.close()
        }
    }


    private fun preprocessImage(bitmap: Bitmap): Triple<OnnxTensor, Int, Int> {
        val originalWidth = bitmap.width
        val originalHeight = bitmap.height

        // Calculate resize dimensions
        val (resizedWidth, resizedHeight) = calculateResizeDimensions(originalWidth, originalHeight)

        // Resize bitmap; Android may return the original if sizes match
        val resizedBitmap = Bitmap.createScaledBitmap(bitmap, resizedWidth, resizedHeight, true)

        // Convert to float array with normalization
        val inputArray = FloatArray(1 * 3 * resizedHeight * resizedWidth)
        val pixels = IntArray(resizedWidth * resizedHeight)
        resizedBitmap.getPixels(pixels, 0, resizedWidth, 0, 0, resizedWidth, resizedHeight)

        // Normalization parameters from OnnxOCR
        val mean = floatArrayOf(0.485f, 0.456f, 0.406f)
        val std = floatArrayOf(0.229f, 0.224f, 0.225f)
        val scale = 1.0f / 255.0f

        var pixelIndex = 0
        for (y in 0 until resizedHeight) {
            for (x in 0 until resizedWidth) {
                val pixel = pixels[y * resizedWidth + x]
                val b = (pixel and 0xFF) * scale
                val g = ((pixel shr 8) and 0xFF) * scale
                val r = ((pixel shr 16) and 0xFF) * scale

                // CHW format, BGR order to match PaddleOCR training data
                inputArray[pixelIndex] = (b - mean[0]) / std[0]
                inputArray[pixelIndex + resizedHeight * resizedWidth] = (g - mean[1]) / std[1]
                inputArray[pixelIndex + 2 * resizedHeight * resizedWidth] = (r - mean[2]) / std[2]
                pixelIndex++
            }
        }

        if (resizedBitmap !== bitmap && !resizedBitmap.isRecycled) {
            resizedBitmap.recycle()
        }

        val shape = longArrayOf(1, 3, resizedHeight.toLong(), resizedWidth.toLong())
        val inputTensor = OnnxTensor.createTensor(ortEnv, FloatBuffer.wrap(inputArray), shape)

        return Triple(inputTensor, resizedWidth, resizedHeight)
    }

    private fun calculateResizeDimensions(width: Int, height: Int): Pair<Int, Int> {
        val maxSide = max(width, height)
        val ratio = if (maxSide > LIMIT_SIDE_LEN) {
            LIMIT_SIDE_LEN.toFloat() / maxSide
        } else {
            1.0f
        }

        var resizedWidth = max(1, (width * ratio).roundToInt())
        var resizedHeight = max(1, (height * ratio).roundToInt())

        // Make dimensions multiple of 32 (minimum 32)
        resizedWidth = max(((resizedWidth + 31) / 32) * 32, 32)
        resizedHeight = max(((resizedHeight + 31) / 32) * 32, 32)

        return Pair(resizedWidth, resizedHeight)
    }

    private fun postprocessDetection(
        output: OnnxTensor,
        originalWidth: Int,
        originalHeight: Int,
        resizedWidth: Int,
        resizedHeight: Int,
        handler: (TextBox, Float) -> Boolean
    ) {
        val outputArray = ImageUtils.toFloatArray(output.floatBuffer)
        val probMap = Array(resizedHeight) { FloatArray(resizedWidth) }

        // Extract probability map (first channel)
        for (y in 0 until resizedHeight) {
            for (x in 0 until resizedWidth) {
                probMap[y][x] = outputArray[y * resizedWidth + x]
            }
        }

        // Apply threshold to get binary map
        val binaryMap = Array(resizedHeight) { BooleanArray(resizedWidth) }
        for (y in 0 until resizedHeight) {
            for (x in 0 until resizedWidth) {
                binaryMap[y][x] = probMap[y][x] > THRESH
            }
        }

        val components = extractConnectedComponents(binaryMap)
            .sortedByDescending { it.size }
            .take(MAX_CANDIDATES)
        val scaleX = originalWidth.toFloat() / resizedWidth
        val scaleY = originalHeight.toFloat() / resizedHeight

        componentLoop@ for (component in components) {
            if (component.size < 4) continue

            val hull = convexHull(component)
            if (hull.size < 3) continue

            val rect = minimumAreaRectangle(hull, pointsAreConvex = true)
            if (rect.isEmpty()) continue

            val score = calculateBoxScore(probMap, rect)
            if (score < BOX_THRESH) continue

            val unclippedPolygon = unclipBox(rect, UNCLIP_RATIO)
            if (unclippedPolygon.isEmpty()) continue

            val expandedRect = minimumAreaRectangle(unclippedPolygon, pointsAreConvex = false)
            if (expandedRect.isEmpty()) continue

            val minSide = getMinSide(expandedRect)
            if (minSide < MIN_SIZE) continue

            val clippedRect = ImageUtils.clipBoxToImageBounds(expandedRect, resizedWidth, resizedHeight)

            val scaledPoints = clippedRect.map { point ->
                PointF(point.x * scaleX, point.y * scaleY)
            }

            val orderedPoints = ImageUtils.orderPointsClockwise(scaledPoints)
            val shouldBreak = handler(TextBox(orderedPoints), score)
            if (shouldBreak) {
                break@componentLoop
            }
        }
    }

    private fun extractConnectedComponents(binaryMap: Array<BooleanArray>): List<List<PointF>> {
        val height = binaryMap.size
        val width = if (height > 0) binaryMap[0].size else 0
        val visited = Array(height) { BooleanArray(width) }
        val components = mutableListOf<List<PointF>>()
        val stack = ArrayDeque<Pair<Int, Int>>()

        for (y in 0 until height) {
            for (x in 0 until width) {
                if (!binaryMap[y][x] || visited[y][x]) continue

                val points = mutableListOf<PointF>()
                stack.clear()
                stack.add(Pair(x, y))
                visited[y][x] = true

                while (stack.isNotEmpty()) {
                    val (cx, cy) = stack.removeLast()
                    points.add(PointF(cx.toFloat(), cy.toFloat()))

                    for (dy in -1..1) {
                        for (dx in -1..1) {
                            if (dx == 0 && dy == 0) continue
                            val nx = cx + dx
                            val ny = cy + dy
                            if (nx in 0 until width && ny in 0 until height &&
                                binaryMap[ny][nx] && !visited[ny][nx]
                            ) {
                                visited[ny][nx] = true
                                stack.add(Pair(nx, ny))
                            }
                        }
                    }
                }

                components.add(points)
            }
        }

        return components
    }

    private fun calculateBoxScore(probMap: Array<FloatArray>, polygon: List<PointF>): Float {
        if (polygon.isEmpty()) return 0f

        var minX = floor(polygon.minOf { it.x.toDouble() }).toInt()
        var maxX = ceil(polygon.maxOf { it.x.toDouble() }).toInt()
        var minY = floor(polygon.minOf { it.y.toDouble() }).toInt()
        var maxY = ceil(polygon.maxOf { it.y.toDouble() }).toInt()

        minX = min(max(minX, 0), probMap[0].size - 1)
        maxX = min(max(maxX, 0), probMap[0].size - 1)
        minY = min(max(minY, 0), probMap.size - 1)
        maxY = min(max(maxY, 0), probMap.size - 1)

        if (maxX < minX || maxY < minY) return 0f

        var sum = 0f
        var count = 0

        for (y in minY..maxY) {
            for (x in minX..maxX) {
                if (isPointInsideQuad(x + 0.5f, y + 0.5f, polygon)) {
                    sum += probMap[y][x]
                    count++
                }
            }
        }

        return if (count > 0) sum / count else 0f
    }

    private fun minimumAreaRectangle(points: List<PointF>, pointsAreConvex: Boolean = false): List<PointF> {
        val hull = if (pointsAreConvex) points else convexHull(points)
        if (hull.size < 3) return emptyList()

        var bestRect: List<PointF> = emptyList()
        var minArea = Float.MAX_VALUE

        for (i in hull.indices) {
            val p1 = hull[i]
            val p2 = hull[(i + 1) % hull.size]
            val edgeVec = normalizeVector(p1, p2) ?: continue
            val normal = PointF(-edgeVec.y, edgeVec.x)

            var minProj = Float.MAX_VALUE
            var maxProj = -Float.MAX_VALUE
            var minOrth = Float.MAX_VALUE
            var maxOrth = -Float.MAX_VALUE

            for (pt in hull) {
                val relX = pt.x - p1.x
                val relY = pt.y - p1.y
                val projection = relX * edgeVec.x + relY * edgeVec.y
                val orthProjection = relX * normal.x + relY * normal.y

                if (projection < minProj) minProj = projection
                if (projection > maxProj) maxProj = projection
                if (orthProjection < minOrth) minOrth = orthProjection
                if (orthProjection > maxOrth) maxOrth = orthProjection
            }

            val width = maxProj - minProj
            val height = maxOrth - minOrth
            val area = width * height

            if (area < minArea && width > 1e-3f && height > 1e-3f) {
                minArea = area

                val corner0 = PointF(
                    p1.x + edgeVec.x * minProj + normal.x * minOrth,
                    p1.y + edgeVec.y * minProj + normal.y * minOrth
                )
                val corner1 = PointF(
                    p1.x + edgeVec.x * maxProj + normal.x * minOrth,
                    p1.y + edgeVec.y * maxProj + normal.y * minOrth
                )
                val corner2 = PointF(
                    p1.x + edgeVec.x * maxProj + normal.x * maxOrth,
                    p1.y + edgeVec.y * maxProj + normal.y * maxOrth
                )
                val corner3 = PointF(
                    p1.x + edgeVec.x * minProj + normal.x * maxOrth,
                    p1.y + edgeVec.y * minProj + normal.y * maxOrth
                )

                bestRect = listOf(corner0, corner1, corner2, corner3)
            }
        }

        return if (bestRect.isEmpty()) axisAlignedBoundingBox(hull) else bestRect
    }

    private fun normalizeVector(from: PointF, to: PointF): PointF? {
        val dx = to.x - from.x
        val dy = to.y - from.y
        val length = sqrt(dx * dx + dy * dy)
        if (length < 1e-6f) return null
        return PointF(dx / length, dy / length)
    }

    private fun axisAlignedBoundingBox(points: List<PointF>): List<PointF> {
        if (points.isEmpty()) return emptyList()

        var minX = Float.MAX_VALUE
        var maxX = -Float.MAX_VALUE
        var minY = Float.MAX_VALUE
        var maxY = -Float.MAX_VALUE

        for (point in points) {
            if (point.x < minX) minX = point.x
            if (point.x > maxX) maxX = point.x
            if (point.y < minY) minY = point.y
            if (point.y > maxY) maxY = point.y
        }

        return listOf(
            PointF(minX, minY),
            PointF(maxX, minY),
            PointF(maxX, maxY),
            PointF(minX, maxY)
        )
    }

    private fun isPointInsideQuad(x: Float, y: Float, quad: List<PointF>): Boolean {
        if (quad.size < 3) return false

        var hasPositive = false
        var hasNegative = false

        for (i in quad.indices) {
            val p1 = quad[i]
            val p2 = quad[(i + 1) % quad.size]
            val cross = (p2.x - p1.x) * (y - p1.y) - (p2.y - p1.y) * (x - p1.x)
            if (cross > 0) hasPositive = true else if (cross < 0) hasNegative = true
            if (hasPositive && hasNegative) return false
        }

        return true
    }

    private fun convexHull(points: List<PointF>): List<PointF> {
        if (points.size < 3) return points

        val sorted = points.sortedWith(compareBy({ it.x }, { it.y }))
        val lower = mutableListOf<PointF>()
        val upper = mutableListOf<PointF>()

        for (point in sorted) {
            while (lower.size >= 2 && crossProduct(lower[lower.size - 2], lower[lower.size - 1], point) <= 0) {
                lower.removeAt(lower.lastIndex)
            }
            lower.add(point)
        }

        for (point in sorted.reversed()) {
            while (upper.size >= 2 && crossProduct(upper[upper.size - 2], upper[upper.size - 1], point) <= 0) {
                upper.removeAt(upper.lastIndex)
            }
            upper.add(point)
        }

        lower.removeAt(lower.lastIndex)
        upper.removeAt(upper.lastIndex)
        return lower + upper
    }

    private fun crossProduct(o: PointF, a: PointF, b: PointF): Float {
        return (a.x - o.x) * (b.y - o.y) - (a.y - o.y) * (b.x - o.x)
    }

    private fun distance(p1: PointF, p2: PointF): Float {
        val dx = p2.x - p1.x
        val dy = p2.y - p1.y
        return sqrt(dx * dx + dy * dy)
    }

    private fun sortBoxes(boxes: List<TextBox>): List<TextBox> {
        if (boxes.isEmpty()) return emptyList()

        val sortedByTop = boxes.sortedBy { box ->
            box.points.minOf { it.y }
        }

        val ordered = mutableListOf<TextBox>()
        var index = 0
        while (index < sortedByTop.size) {
            val current = sortedByTop[index]
            val referenceY = current.points.minOf { it.y }
            val group = mutableListOf<TextBox>()

            var j = index
            while (j < sortedByTop.size) {
                val candidate = sortedByTop[j]
                val candidateY = candidate.points.minOf { it.y }
                if (abs(candidateY - referenceY) <= 10f) {
                    group.add(candidate)
                    j++
                } else {
                    break
                }
            }

            group.sortBy { box -> box.points.minOf { it.x } }
            ordered.addAll(group)
            index = j
        }

        return ordered
    }

    private fun unclipBox(box: List<PointF>, unclipRatio: Float): List<PointF> {
        if (box.size < 3) return emptyList()

        val area = polygonSignedArea(box)
        val perimeter = polygonPerimeter(box)
        if (perimeter <= EPSILON) return emptyList()

        val offset = kotlin.math.abs(area) * unclipRatio / perimeter
        if (offset <= EPSILON) return box

        val expanded = offsetPolygon(box, offset)
        return if (expanded.size >= 3) expanded else emptyList()
    }

    private fun getMinSide(box: List<PointF>): Float {
        if (box.size < 2) return 0f
        var minSide = Float.MAX_VALUE
        for (i in box.indices) {
            val next = (i + 1) % box.size
            val length = distance(box[i], box[next])
            if (length < minSide) {
                minSide = length
            }
        }
        return if (minSide == Float.MAX_VALUE) 0f else minSide
    }

    private fun polygonSignedArea(points: List<PointF>): Float {
        var area = 0f
        for (i in points.indices) {
            val j = (i + 1) % points.size
            area += points[i].x * points[j].y - points[j].x * points[i].y
        }
        return area / 2f
    }

    private fun polygonPerimeter(points: List<PointF>): Float {
        var perimeter = 0f
        for (i in points.indices) {
            val j = (i + 1) % points.size
            perimeter += distance(points[i], points[j])
        }
        return perimeter
    }

    private fun offsetPolygon(points: List<PointF>, offset: Float): List<PointF> {
        val count = points.size
        if (count < 3) return emptyList()

        val isCounterClockwise = polygonSignedArea(points) > 0f
        val result = ArrayList<PointF>(count)

        for (i in 0 until count) {
            val prev = points[(i - 1 + count) % count]
            val curr = points[i]
            val next = points[(i + 1) % count]

            val edge1 = PointF(curr.x - prev.x, curr.y - prev.y)
            val edge2 = PointF(next.x - curr.x, next.y - curr.y)

            val dir1 = normalize(edge1) ?: continue
            val dir2 = normalize(edge2) ?: continue

            val normal1 = if (isCounterClockwise) PointF(dir1.y, -dir1.x) else PointF(-dir1.y, dir1.x)
            val normal2 = if (isCounterClockwise) PointF(dir2.y, -dir2.x) else PointF(-dir2.y, dir2.x)

            val offsetPoint1 = PointF(curr.x + normal1.x * offset, curr.y + normal1.y * offset)
            val offsetPoint2 = PointF(curr.x + normal2.x * offset, curr.y + normal2.y * offset)

            val intersection = intersectLines(offsetPoint1, dir1, offsetPoint2, dir2)
            result.add(intersection ?: PointF(curr.x, curr.y))
        }

        return result
    }

    private fun normalize(vector: PointF): PointF? {
        val length = sqrt(vector.x * vector.x + vector.y * vector.y)
        if (length < EPSILON) return null
        return PointF(vector.x / length, vector.y / length)
    }

    private fun intersectLines(point: PointF, direction: PointF, otherPoint: PointF, otherDirection: PointF): PointF? {
        val cross = direction.x * otherDirection.y - direction.y * otherDirection.x
        if (abs(cross) < EPSILON) {
            return null
        }

        val diffX = otherPoint.x - point.x
        val diffY = otherPoint.y - point.y
        val t = (diffX * otherDirection.y - diffY * otherDirection.x) / cross

        return PointF(
            point.x + direction.x * t,
            point.y + direction.y * t
        )
    }
}
