# Copyright © The Debusine Developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This file is part of Debusine. It is subject to the license terms
# in the LICENSE file found in the top-level directory of this
# distribution. No part of Debusine, including this file, may be copied,
# modified, propagated, or distributed except according to the terms
# contained in the LICENSE file.
"""Input artifacts for tasks."""

import abc
from collections.abc import Collection as ABCCollection
from typing import Any, TYPE_CHECKING, assert_never, override

from debusine.artifacts.models import ArtifactCategory, CollectionCategory
from debusine.client.models import RelationType
from debusine.tasks.executors import ExecutorImageCategory
from debusine.tasks.models import (
    ExtraDebusineRepository,
    ExtraExternalRepository,
    ExtraRepository,
    LookupMultiple,
    LookupSingle,
)
from debusine.tasks.server import (
    ArtifactInfo,
    CollectionInfo,
    MultipleArtifactInfo,
    TaskDatabaseInterface,
)

if TYPE_CHECKING:
    from debusine.tasks import BaseTask


class TaskInput[RD](abc.ABC):
    """Representation for one or more input artifacts for a task."""

    name: str

    @abc.abstractmethod
    def resolve(
        self, task: "BaseTask[Any, Any]", task_database: TaskDatabaseInterface
    ) -> RD:
        """Resolve the input from task data."""

    def __set_name__(
        self, owner: type["BaseTask[Any, Any]"], name: str
    ) -> None:
        """Set the field name at class construction."""
        self.name = name
        # Avoid sharing the inputs definition across all subclasses of BaseTask
        if "inputs" not in owner.__dict__:
            owner.inputs = owner.inputs.copy()
        owner.inputs[name] = self

    def __get__(
        self,
        obj: "BaseTask[Any, Any]",
        _: type["BaseTask[Any, Any]"] | None = None,
    ) -> RD:
        """Look up this input in the task object."""
        from debusine.tasks import TaskConfigError

        # If __get__ is called, it means that self.name is not in obj.__dict__,
        # which only happens if resolve_to_task has not yet been called
        raise TaskConfigError(f"Cannot access unresolved input {self.name!r}")

    def resolve_to_task(
        self, task: "BaseTask[Any, Any]", task_database: TaskDatabaseInterface
    ) -> None:
        """Resolve this input and set the result in the task dict."""
        task.__dict__[self.name] = self.resolve(task, task_database)

    @staticmethod
    def resolve_inputs(
        task: "BaseTask[Any, Any]", task_database: TaskDatabaseInterface
    ) -> None:
        """Resolve all input fields in the given task."""
        for inp in task.inputs.values():
            inp.resolve_to_task(task, task_database)


class DataFieldInput[RD](TaskInput[RD]):
    """Input field that resolves values based on a task data field."""

    field: str

    def __init__(self, *, field: str | None = None) -> None:
        """
        Declare an input defined by a task data field.

        :param field: dot-separated list of lookup names in task data. Defaults
          to the input member name
        """
        if field is not None:
            self.field = field

    def __set_name__(
        self, owner: type["BaseTask[Any, Any]"], name: str
    ) -> None:
        """Set the field name at class construction."""
        super().__set_name__(owner, name)
        if not hasattr(self, "field"):
            self.field = name

    def _resolve_field(self, task_data: Any) -> Any:
        """Look up the value corresponding to self.field in task_data."""
        # Prevent import loop
        from debusine.tasks import TaskConfigError

        try:
            lookup: Any = task_data
            for name in self.field.split("."):
                lookup = getattr(lookup, name)
        except AttributeError:
            raise TaskConfigError(f"Invalid input field: {self.field!r}")
        return lookup

    def _get_lookup_single(
        self,
        task_data: Any,
    ) -> LookupSingle | None:
        """Find the LookupSingle identified by self.field in task_data."""
        lookup = self._resolve_field(task_data)
        assert lookup is None or isinstance(lookup, LookupSingle)
        return lookup

    def _get_lookup_multiple(self, task_data: Any) -> LookupMultiple:
        """Find the LookupMultiple identified by self.field in task_data."""
        lookup = self._resolve_field(task_data)
        assert isinstance(lookup, LookupMultiple)
        return lookup

    def _get_lookup_list(
        self,
        task_data: Any,
    ) -> list[LookupSingle | LookupMultiple]:
        """Find the list of lookups identified by self.field in task_data."""
        lookups = self._resolve_field(task_data)
        if lookups is None:
            return []
        assert isinstance(lookups, list)
        for lookup in lookups:
            assert isinstance(lookup, LookupSingle | LookupMultiple)
        return lookups


class ArtifactInput[RD](DataFieldInput[RD]):
    """Look up artifacts enforcing their types."""

    def __init__(
        self,
        *,
        field: str | None = None,
        categories: ABCCollection[ArtifactCategory] | None = None,
    ) -> None:
        """
        Look up artifacts enforcing their type.

        :param field: dot-separated list of lookup names in task data
        :param categories: list of acceptable artifact categories
        """
        super().__init__(field=field)
        self.categories: tuple[ArtifactCategory, ...] | None = (
            tuple(categories) if categories is not None else None
        )

    def check_category(
        self, info: ArtifactInfo, field_name: str | None = None
    ) -> None:
        """Check the category of an ArtifactInfo."""
        # Prevent import loop
        from debusine.tasks import ensure_artifact_categories

        if self.categories is None:
            return

        ensure_artifact_categories(
            configuration_key=field_name or self.field,
            category=info.category,
            expected=self.categories,
        )


class SingleInput(ArtifactInput[ArtifactInfo]):
    """One single input artifact specified in a task data field."""

    @override
    def resolve(
        self, task: "BaseTask[Any, Any]", task_database: TaskDatabaseInterface
    ) -> ArtifactInfo:
        lookup = self._get_lookup_single(task.data)
        assert lookup is not None
        info = task_database.lookup_single_artifact(lookup)
        self.check_category(info)
        return info


class SingleInputList(ArtifactInput[list[ArtifactInfo]]):
    """A list of single input artifacts specified in a task data field."""

    @override
    def resolve(
        self, task: "BaseTask[Any, Any]", task_database: TaskDatabaseInterface
    ) -> list[ArtifactInfo]:
        infos: list[ArtifactInfo] = []
        for lookup in self._get_lookup_list(task.data):
            assert isinstance(lookup, LookupSingle)
            info = task_database.lookup_single_artifact(lookup)
            self.check_category(info)
            infos.append(info)
        return infos


class OptionalSingleInput(ArtifactInput[ArtifactInfo | None]):
    """One optional single input artifact specified in a task data field."""

    @override
    def resolve(
        self, task: "BaseTask[Any, Any]", task_database: TaskDatabaseInterface
    ) -> ArtifactInfo | None:
        lookup = self._get_lookup_single(task.data)
        info = task_database.lookup_single_artifact(lookup)
        if info is not None:
            self.check_category(info)
        return info


class MultiInput(ArtifactInput[MultipleArtifactInfo]):
    """A list of multiple artifacts lookups specified in a task data field."""

    @override
    def resolve(
        self, task: "BaseTask[Any, Any]", task_database: TaskDatabaseInterface
    ) -> MultipleArtifactInfo:
        lookup = self._get_lookup_multiple(task.data)
        assert lookup is not None
        info = task_database.lookup_multiple_artifacts(lookup)
        if self.categories is not None:
            for i, artifact in enumerate(info):
                self.check_category(artifact, f"{self.field}[{i}]")
        return info


class MultiInputList(ArtifactInput[list[MultipleArtifactInfo]]):
    """A multiple artifacts lookup specified in a task data field."""

    @override
    def resolve(
        self, task: "BaseTask[Any, Any]", task_database: TaskDatabaseInterface
    ) -> list[MultipleArtifactInfo]:
        infos: list[MultipleArtifactInfo] = []
        for lookup_idx, lookup in enumerate(self._get_lookup_list(task.data)):
            assert isinstance(lookup, LookupMultiple)
            info = task_database.lookup_multiple_artifacts(lookup)
            if self.categories is not None:
                for result_idx, artifact in enumerate(info):
                    self.check_category(
                        artifact, f"{self.field}[{lookup_idx}][{result_idx}]"
                    )
            infos.append(info)
        return infos


class UploadArtifactsInput(ArtifactInput[MultipleArtifactInfo]):
    """A multiple artifact lookup finding all related binaries in an upload."""

    @override
    def __init__(
        self,
        *,
        field: str | None = None,
        categories: ABCCollection[ArtifactCategory] | None = None,
    ) -> None:
        categories_with_upload = [ArtifactCategory.UPLOAD]
        categories_with_upload.extend(categories or ())
        super().__init__(field=field, categories=categories_with_upload)

    @override
    def resolve(
        self, task: "BaseTask[Any, Any]", task_database: TaskDatabaseInterface
    ) -> MultipleArtifactInfo:
        lookup = self._get_lookup_multiple(task.data)
        assert lookup is not None

        # Incrementally build the result set
        result: list[ArtifactInfo] = []

        # Lookup the initial set of packages
        infos = task_database.lookup_multiple_artifacts(lookup)
        upload_ids: list[int] = []
        for i, info in enumerate(infos):
            if info.category == ArtifactCategory.UPLOAD:
                # File uploads for later expansion
                upload_ids.append(info.id)
            else:
                self.check_category(info, f"{self.field}[{i}]")
                result.append(info)

        # Expand with binaries related to uploads
        if upload_ids:
            result.extend(
                task_database.find_related_artifacts(
                    upload_ids,
                    ArtifactCategory.BINARY_PACKAGE,
                    relation_type=RelationType.EXTENDS,
                )
            )

        return MultipleArtifactInfo(result)


class EnvironmentInput(ArtifactInput[ArtifactInfo]):
    """One environment artifact specified in a task data field."""

    def __init__(
        self,
        *,
        field: str | None = None,
        categories: ABCCollection[ArtifactCategory] | None = None,
        image_category: ExecutorImageCategory | None = None,
        set_backend: bool = True,
        try_variant: bool = True,
    ) -> None:
        """
        Look up artifacts enforcing their type.

        :param field: dot-separated list of lookup names in task data
        :param categories: list of acceptable artifact categories
        :param image_category: try to use an environment with this image
          category; defaults to the image category needed by the executor
          for `self.backend`
        :param set_backend: if True (default), try to use an environment
          matching `self.backend`
        :param try_variant: if True (default), try to use an environment
          whose variant is `self.name`, but fall back to looking up an
          environment without a variant if the first lookup fails
        """
        super().__init__(field=field, categories=categories)
        self.image_category = image_category
        self.set_backend = set_backend
        self.try_variant = try_variant

    @override
    def resolve(
        self, task: "BaseTask[Any,Any]", task_database: TaskDatabaseInterface
    ) -> ArtifactInfo:
        # Prevent import loop
        from debusine.tasks import BaseTaskWithExecutor

        assert isinstance(task, BaseTaskWithExecutor)
        lookup = self._get_lookup_single(task.data)
        assert lookup is not None
        info = task_database.get_environment(
            lookup,
            architecture=task.build_architecture(),
            backend=task.backend if self.set_backend else None,
            default_category=CollectionCategory.ENVIRONMENTS,
            image_category=self.image_category,
            try_variant=task.name if self.try_variant else None,
        )
        self.check_category(info)
        return info


class OptionalEnvironmentInput(ArtifactInput[ArtifactInfo | None]):
    """One optional environment artifact specified in a task data field."""

    @override
    def resolve(
        self, task: "BaseTask[Any,Any]", task_database: TaskDatabaseInterface
    ) -> ArtifactInfo | None:
        # Prevent import loop
        from debusine.tasks import BaseTaskWithExecutor

        assert isinstance(task, BaseTaskWithExecutor)
        lookup = self._get_lookup_single(task.data)
        if lookup is None:
            return None
        info = task_database.get_environment(
            lookup,
            architecture=task.build_architecture(),
            backend=task.backend,
            default_category=CollectionCategory.ENVIRONMENTS,
            image_category=None,
            try_variant=task.name,
        )
        self.check_category(info)
        return info


class SuiteArchiveInput(
    DataFieldInput[tuple[CollectionInfo, CollectionInfo | None]]
):
    """Look up a suite and an archive based on a task data field."""

    @override
    def resolve(
        self, task: "BaseTask[Any, Any]", task_database: TaskDatabaseInterface
    ) -> tuple[CollectionInfo, CollectionInfo | None]:
        # Prevent import loop
        from debusine.tasks import ensure_collection_category

        lookup = self._get_lookup_single(task.data)
        assert lookup is not None

        suite = task_database.lookup_single_collection(lookup)
        ensure_collection_category(
            configuration_key=self.field,
            category=suite.category,
            expected=CollectionCategory.SUITE,
        )

        archive = task_database.lookup_singleton_collection(
            CollectionCategory.ARCHIVE
        )
        return suite, archive


class ExtraRepositoriesInput(
    DataFieldInput[list[ExtraExternalRepository] | None]
):
    """Look up extra repositories."""

    def _resolve_extra_repository(
        self,
        task_database: TaskDatabaseInterface,
        extra_repository: ExtraRepository,
    ) -> ExtraExternalRepository:
        """Resolve a single extra repository into the external form."""
        # Prevent import loop
        from debusine.tasks import TaskConfigError

        match extra_repository:
            case ExtraDebusineRepository():
                debusine_fqdn = task_database.get_server_setting(
                    "DEBUSINE_DEBIAN_ARCHIVE_PRIMARY_FQDN"
                )
                suite = task_database.lookup_single_collection(
                    extra_repository.suite
                )
                components = extra_repository.components
                if components is None:
                    if "components" not in suite.data:
                        raise TaskConfigError(
                            f"'components' not set for {extra_repository.suite}"
                        )
                    components = suite.data["components"]
                return ExtraExternalRepository(
                    # Ideally we'd use HTTPS, but that might require
                    # certificate setup within environments, and we know that
                    # our own repositories are signed, so HTTP should be good
                    # enough to use from workers for now.
                    url=(
                        f"http://{debusine_fqdn}"
                        f"/{suite.scope_name}/{suite.workspace_name}"
                    ),
                    suite=suite.name,
                    components=components,
                    signing_key=task_database.export_suite_signing_keys(
                        suite.id
                    ),
                )
            case ExtraExternalRepository():
                return extra_repository
            case _ as unreachable:
                assert_never(unreachable)

    @override
    def resolve(
        self, task: "BaseTask[Any, Any]", task_database: TaskDatabaseInterface
    ) -> list[ExtraExternalRepository] | None:
        # Prevent import loop
        from debusine.tasks import ExtraRepositoryMixin

        assert isinstance(task, ExtraRepositoryMixin)

        extra_repositories = self._resolve_field(task.data)
        if extra_repositories is None:
            return None
        return [
            self._resolve_extra_repository(task_database, extra_repository)
            for extra_repository in extra_repositories
        ]


class DebusineFQDNInput(TaskInput[str]):
    """Resolve to the Debusine FQDN."""

    @override
    def resolve(
        self, task: "BaseTask[Any,Any]", task_database: TaskDatabaseInterface
    ) -> str:
        return task_database.get_server_setting("DEBUSINE_FQDN")
