#!/usr/bin/env python3
"""
Generate an image from the speed report
"""
import argparse
import colorsys
import json
import math
from enum import Enum
from typing import Any, NamedTuple, Self

from PIL import Image, ImageDraw

IMAGE_WIDTH = 2400
IMAGE_HEIGHT = 1600

MARGIN_LEFT = 80
MARGIN_RIGHT = 160
MARGIN_TOP = 80
MARGIN_BOTTOM = 80

DIRECTION_COLOR_BY_CATEGORY = {
    0: (25, 25, 25),
    1: (128, 128, 128),
    2: (255, 50, 50),
}

FIELD_POINT_RADIUS = 2


class Direction(Enum):
    LEFT = 1
    RIGHT = 2


KEY_INDICATOR_LENGTH = 14

ARROW_COLORS = {
    Direction.LEFT: (255, 0, 0),
    Direction.RIGHT: (0, 255, 0),
}

ARROW_SIZE = 6


class Point(NamedTuple):
    x: float
    y: float

    @staticmethod
    def from_polar(*, angle: float, radius: float) -> Self:
        return Point(
            radius * math.cos(angle),
            radius * math.sin(angle),
        )

    def __add__(self, other: Self) -> Self:
        return Point(
            self.x + other.x,
            self.y + other.y,
        )

    def __sub__(self, other: Self) -> Self:
        return Point(
            self.x - other.x,
            self.y - other.y,
        )

    def angle(self) -> float:
        return math.atan2(self.y, self.x)


def color_for_value(
    value: float, min_value: float, max_value: float
) -> tuple[int, int, int]:
    """Convert a value to a color using HSV color space"""
    if max_value == min_value:
        normalized = 0
    else:
        normalized = (value - min_value) / (max_value - min_value)

    # Use hue from blue (240°) to red (0°)
    hue = (1 - normalized) * 0.67  # 0.67 is blue in HSV
    saturation = 1.0
    value = 0.9

    r, g, b = colorsys.hsv_to_rgb(hue, saturation, value)
    return (int(r * 255), int(g * 255), int(b * 255))


def parse_arguments() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.description = __doc__
    parser.add_argument(
        "-f",
        "--field",
        default="speed",
        help="Field to use for coloring (default: speed)",
    )
    parser.add_argument(
        "--max-value",
        type=float,
        default=None,
        help="Maximum value for the selected field (default: auto-detect)",
    )
    parser.add_argument(
        "-o",
        "--output",
        help="Output file name (default: SPEED_FILE with .png extension)",
    )
    parser.add_argument("file", help="Input JSONL file", metavar="SPEED_FILE")
    return parser.parse_args()


def load_json_data(filename: str) -> list[dict[str, Any]]:
    """
    Load data from a JSONL file and return a list of dictionaries.
    """
    print(f"Loading data from {filename}...")
    data_points = []

    with open(filename, "r") as f:
        for idx, line in enumerate(f, start=1):
            line = line.strip()
            if not line:
                continue

            try:
                data = json.loads(line)
                data_points.append(data)
            except json.JSONDecodeError as e:
                print(f"Error line {idx}: {e}")
                continue

    print(f"Loaded {len(data_points)} data points.")
    return data_points


def draw_legend(
    draw: ImageDraw.Draw,
    min_value: float,
    max_value: float,
) -> None:
    legend_width = 80
    legend_height = IMAGE_HEIGHT - MARGIN_TOP - MARGIN_BOTTOM
    legend_x = IMAGE_WIDTH - MARGIN_RIGHT + 20
    legend_y = MARGIN_TOP

    # Draw legend background
    draw.rectangle(
        (legend_x, legend_y, legend_x + legend_width, legend_y + legend_height),
        fill=(30, 30, 30),
        outline=(200, 200, 200),
    )

    # Draw color gradient
    for y in range(legend_height):
        gradient_value = max_value - (y / legend_height) * (max_value - min_value)
        color = color_for_value(gradient_value, min_value, max_value)
        y_pos = legend_y + y
        draw.line(
            [(legend_x + 5, y_pos), (legend_x + legend_width - 5, y_pos)], fill=color
        )

    # Draw legend labels
    draw.text((legend_x, legend_y - 15), f"Max: {max_value:.1f}", fill=(255, 255, 255))
    draw.text(
        (legend_x, legend_y + legend_height + 5),
        f"Min: {min_value:.1f}",
        fill=(255, 255, 255),
    )


def main() -> None:
    args = parse_arguments()
    if args.output:
        output = args.output
    else:
        output = args.file.removesuffix(".jsonl") + ".png"

    color_field = args.field

    data_points = load_json_data(args.file)

    if not data_points:
        print("No valid data points found")
        return

    x_values = [point["x"] for point in data_points]
    y_values = [point["y"] for point in data_points]
    color_values = [point[color_field] for point in data_points]

    # Image bounds
    min_x, max_x = min(x_values), max(x_values)
    min_y, max_y = min(y_values), max(y_values)

    # Value bounds
    min_value = min(color_values)
    if args.max_value is not None:
        max_value = args.max_value
        print(f"Using user-defined maximum {color_field} value: {max_value}")
    else:
        max_value = max(color_values)

    print(f"X range: {min_x} to {max_x}")
    print(f"Y range: {min_y} to {max_y}")
    print(f"{color_field.capitalize()} range: {min_value} to {max_value}")
    assert min_x != max_x
    assert min_y != max_y

    # Area available for the graph
    graph_width = IMAGE_WIDTH - MARGIN_RIGHT - MARGIN_LEFT
    graph_height = IMAGE_HEIGHT - MARGIN_TOP - MARGIN_BOTTOM

    image = Image.new("RGB", size=(IMAGE_WIDTH, IMAGE_HEIGHT))
    draw = ImageDraw.Draw(image, mode="RGBA")

    def map_to_image(point: Point) -> Point:
        x_ratio = (point.x - min_x) / (max_x - min_x)
        y_ratio = (point.y - min_y) / (max_y - min_y)

        img_x = MARGIN_LEFT + (x_ratio * graph_width)
        img_y = IMAGE_HEIGHT - MARGIN_BOTTOM - (y_ratio * graph_height)

        return Point(img_x, img_y)

    # Draw direction lines and key indicators
    prev_point = None
    for data in data_points:
        point = map_to_image(Point(data["x"], data["y"]))
        if prev_point:
            angle = (point - prev_point).angle()

            # Key indicators
            if data.get("left"):
                delta = Point.from_polar(
                    angle=angle - math.pi / 2, radius=KEY_INDICATOR_LENGTH
                )
                draw.line(
                    [point, point + delta],
                    width=1,
                    fill=ARROW_COLORS[Direction.LEFT],
                )

            if data.get("right"):
                delta = Point.from_polar(
                    angle=angle + math.pi / 2, radius=KEY_INDICATOR_LENGTH
                )
                draw.line(
                    [point, point + delta],
                    width=1,
                    fill=ARROW_COLORS[Direction.RIGHT],
                )

            # Direction line
            category = data["category"]
            if category > -1:
                direction = data.get("direction", 0)
                length = int(direction * KEY_INDICATOR_LENGTH * 0.9)
                if length != 0:
                    delta = Point.from_polar(angle=angle - math.pi / 2, radius=length)
                    draw.line(
                        [point, point + delta],
                        width=3,
                        fill=DIRECTION_COLOR_BY_CATEGORY[category],
                    )

        prev_point = point

    # Draw connected points for selected field
    prev_point = None
    for data in data_points:
        point = map_to_image(Point(data["x"], data["y"]))
        value = data[color_field]

        color = color_for_value(value, min_value, max_value)

        draw.ellipse(
            (
                point.x - FIELD_POINT_RADIUS,
                point.y - FIELD_POINT_RADIUS,
                point.x + FIELD_POINT_RADIUS,
                point.y + FIELD_POINT_RADIUS,
            ),
            fill=color,
        )

        if prev_point:
            draw.line([prev_point, point], fill=color, width=1)

        prev_point = point

    draw_legend(draw, min_value, max_value)
    image.save(output)
    print(f"Visualization saved to {output}")


if __name__ == "__main__":
    main()
