from __future__ import annotations

import inspect
import sys
from collections.abc import Callable
from functools import partial, wraps
from typing import Any, ClassVar

import typer
from click import ClickException
from typer.main import _typer_developer_exception_attr_name
from typer.models import CommandFunctionType, Default, DeveloperExceptionConfig

from clive.__private.core._async import asyncio_run
from clive.__private.core.constants.cli import HELP_FLAGS

type ExitCode = int
type ErrorHandlingCallback[T: Exception] = Callable[[T], ExitCode | None]


class CliveTyper(typer.Typer):
    """
    A modified version of Typer that allows for registering error handlers and has a different defaults set.

    Such a handlers could be only registered for the main Typer instance, but not for sub-commands. That's because
    Typer.__call__ is not called for each sub-commands, but only for the main Typer instance.

    Args:
        name: The name of the command.
        help: The help text for the command.
        chain: Decides whether command will be chained.

    Example:
        Error handlers can be registered like this:

        ```python
        @typer_instance.error_handler(TypeError)
        def may_handle_some_error(error: TypeError) -> int | None:
            if "Some error" in str(error):
                typer.echo("Some error occurred")
                return 1
            return None

        @typer_instance.error_handler(Exception)
        def handle_any_error(error: Exception) -> None:
            raise CLIError(str(error), 1)
        ```

        Then when `TypeError("Other error")` is raised:

        - `may_handle_some_error` will ignore the error, because of the `if` condition.
        - Instead `handle_any_error` will handle it, and since it raises `CLIPrettyError` - it will be pretty printed.
    """

    __clive_error_handlers__: ClassVar[dict[type[Exception], ErrorHandlingCallback[Any]]] = {}
    """ClassVar since error handlers could be registered only for the main Typer instance, but not for sub-commands."""

    def __init__(
        self,
        *,
        name: str | None = Default(None),
        help: str | None = Default(None),  # noqa: A002
        chain: bool = Default(value=False),
    ) -> None:
        super().__init__(
            name=name,
            help=help,
            chain=chain,
            rich_markup_mode="rich",
            context_settings={"help_option_names": HELP_FLAGS},
            no_args_is_help=True,
        )

    def __call__(self, *args: Any, **kwargs: Any) -> Any:  # noqa: ANN401
        try:
            return super().__call__(*args, **kwargs)
        except Exception as error:  # noqa: BLE001
            self.__handle_error(error)

    def error_handler[T: Exception](
        self, error: type[T]
    ) -> Callable[[ErrorHandlingCallback[T]], ErrorHandlingCallback[T]]:
        def decorator(f: ErrorHandlingCallback[T]) -> ErrorHandlingCallback[T]:
            self.__clive_error_handlers__[error] = f
            return f

        return decorator

    def callback(self, *args: Any, **kwargs: Any) -> Callable[[CommandFunctionType], CommandFunctionType]:
        return partial(self._maybe_run_async, super().callback(*args, **kwargs))

    def command(self, *args: Any, **kwargs: Any) -> Callable[[CommandFunctionType], CommandFunctionType]:
        return partial(self._maybe_run_async, super().command(*args, **kwargs))

    @staticmethod
    def _maybe_run_async(decorator: Any, func: Any) -> Any:  # noqa: ANN401
        @wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Any:  # noqa: ANN401
            return asyncio_run(func(*args, **kwargs))

        if inspect.iscoroutinefunction(func):
            return decorator(wrapper)
        return decorator(func)

    def __handle_error(self, error: Exception) -> None:
        handler = self.__get_error_handler(error)

        try:
            exit_code = handler(error)
            if exit_code is None:
                # means that error was not handled by that callback, try to handle with the next one
                self.__handle_error(error)

            sys.exit(exit_code)
        except ClickException as click_exception:
            from typer import rich_utils  # noqa: PLC0415

            # See: `typer/core.py` -> `_main` -> `except click.ClickException as e:`
            # If ClickException was raised in the registered error handler, we need to format it like Typer does.
            rich_utils.rich_format_error(click_exception)
            sys.exit(click_exception.exit_code)
        except Exception as exception:
            # See: `typer/mian.py` -> `Typer.__call__` -> `except Exception as e:`
            # If any other exception was raised in the registered error handler, we need to format it like Typer does.
            setattr(
                exception,
                _typer_developer_exception_attr_name,
                DeveloperExceptionConfig(
                    pretty_exceptions_enable=self.pretty_exceptions_enable,
                    pretty_exceptions_show_locals=self.pretty_exceptions_show_locals,
                    pretty_exceptions_short=self.pretty_exceptions_short,
                ),
            )
            raise exception from None

    def __get_error_handler[T: Exception](self, error: T) -> ErrorHandlingCallback[T]:
        for type_ in type(error).mro():
            if type_ in self.__clive_error_handlers__:
                return self.__clive_error_handlers__.pop(type_)

        raise error  # reraise if no handler is available
