"""A fake callback handler for testing purposes."""

from __future__ import annotations

from itertools import chain
from typing import Any
from uuid import UUID

from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.messages import BaseMessage
from pydantic import BaseModel


class BaseFakeCallbackHandler(BaseModel):
    """Base fake callback handler for testing."""

    starts: int = 0
    ends: int = 0
    errors: int = 0
    errors_args: list[Any] = []
    text: int = 0
    ignore_llm_: bool = False
    ignore_chain_: bool = False
    ignore_agent_: bool = False
    ignore_retriever_: bool = False
    ignore_chat_model_: bool = False

    # to allow for similar callback handlers that are not technically equal
    fake_id: str | None = None

    # add finer-grained counters for easier debugging of failing tests
    chain_starts: int = 0
    chain_ends: int = 0
    llm_starts: int = 0
    llm_ends: int = 0
    llm_streams: int = 0
    tool_starts: int = 0
    tool_ends: int = 0
    agent_actions: int = 0
    agent_ends: int = 0
    chat_model_starts: int = 0
    retriever_starts: int = 0
    retriever_ends: int = 0
    retriever_errors: int = 0
    retries: int = 0


class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
    """Base fake callback handler mixin for testing."""

    def on_llm_start_common(self) -> None:
        self.llm_starts += 1
        self.starts += 1

    def on_llm_end_common(self) -> None:
        self.llm_ends += 1
        self.ends += 1

    def on_llm_error_common(self, *args: Any, **kwargs: Any) -> None:
        self.errors += 1
        self.errors_args.append({"args": args, "kwargs": kwargs})

    def on_llm_new_token_common(self) -> None:
        self.llm_streams += 1

    def on_retry_common(self) -> None:
        self.retries += 1

    def on_chain_start_common(self) -> None:
        self.chain_starts += 1
        self.starts += 1

    def on_chain_end_common(self) -> None:
        self.chain_ends += 1
        self.ends += 1

    def on_chain_error_common(self) -> None:
        self.errors += 1

    def on_tool_start_common(self) -> None:
        self.tool_starts += 1
        self.starts += 1

    def on_tool_end_common(self) -> None:
        self.tool_ends += 1
        self.ends += 1

    def on_tool_error_common(self) -> None:
        self.errors += 1

    def on_agent_action_common(self) -> None:
        self.agent_actions += 1
        self.starts += 1

    def on_agent_finish_common(self) -> None:
        self.agent_ends += 1
        self.ends += 1

    def on_chat_model_start_common(self) -> None:
        self.chat_model_starts += 1
        self.starts += 1

    def on_text_common(self) -> None:
        self.text += 1

    def on_retriever_start_common(self) -> None:
        self.starts += 1
        self.retriever_starts += 1

    def on_retriever_end_common(self) -> None:
        self.ends += 1
        self.retriever_ends += 1

    def on_retriever_error_common(self) -> None:
        self.errors += 1
        self.retriever_errors += 1


class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
    """Fake callback handler for testing."""

    @property
    def ignore_llm(self) -> bool:
        """Whether to ignore LLM callbacks."""
        return self.ignore_llm_

    @property
    def ignore_chain(self) -> bool:
        """Whether to ignore chain callbacks."""
        return self.ignore_chain_

    @property
    def ignore_agent(self) -> bool:
        """Whether to ignore agent callbacks."""
        return self.ignore_agent_

    @property
    def ignore_retriever(self) -> bool:
        """Whether to ignore retriever callbacks."""
        return self.ignore_retriever_

    def on_llm_start(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> Any:
        self.on_llm_start_common()

    def on_llm_new_token(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> Any:
        self.on_llm_new_token_common()

    def on_llm_end(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> Any:
        self.on_llm_end_common()

    def on_llm_error(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> Any:
        self.on_llm_error_common(*args, **kwargs)

    def on_retry(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> Any:
        self.on_retry_common()

    def on_chain_start(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> Any:
        self.on_chain_start_common()

    def on_chain_end(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> Any:
        self.on_chain_end_common()

    def on_chain_error(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> Any:
        self.on_chain_error_common()

    def on_tool_start(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> Any:
        self.on_tool_start_common()

    def on_tool_end(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> Any:
        self.on_tool_end_common()

    def on_tool_error(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> Any:
        self.on_tool_error_common()

    def on_agent_action(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> Any:
        self.on_agent_action_common()

    def on_agent_finish(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> Any:
        self.on_agent_finish_common()

    def on_text(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> Any:
        self.on_text_common()

    def on_retriever_start(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> Any:
        self.on_retriever_start_common()

    def on_retriever_end(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> Any:
        self.on_retriever_end_common()

    def on_retriever_error(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> Any:
        self.on_retriever_error_common()

    # Overriding since BaseModel has __deepcopy__ method as well
    def __deepcopy__(self, memo: dict) -> FakeCallbackHandler:  # type: ignore[override]
        return self


class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
    def on_chat_model_start(
        self,
        serialized: dict[str, Any],
        messages: list[list[BaseMessage]],
        *,
        run_id: UUID,
        parent_run_id: UUID | None = None,
        **kwargs: Any,
    ) -> Any:
        assert all(isinstance(m, BaseMessage) for m in chain(*messages))
        self.on_chat_model_start_common()


class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixin):
    """Fake async callback handler for testing."""

    @property
    def ignore_llm(self) -> bool:
        """Whether to ignore LLM callbacks."""
        return self.ignore_llm_

    @property
    def ignore_chain(self) -> bool:
        """Whether to ignore chain callbacks."""
        return self.ignore_chain_

    @property
    def ignore_agent(self) -> bool:
        """Whether to ignore agent callbacks."""
        return self.ignore_agent_

    async def on_retry(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> Any:
        self.on_retry_common()

    async def on_llm_start(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        self.on_llm_start_common()

    async def on_llm_new_token(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        self.on_llm_new_token_common()

    async def on_llm_end(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        self.on_llm_end_common()

    async def on_llm_error(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        self.on_llm_error_common(*args, **kwargs)

    async def on_chain_start(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        self.on_chain_start_common()

    async def on_chain_end(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        self.on_chain_end_common()

    async def on_chain_error(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        self.on_chain_error_common()

    async def on_tool_start(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        self.on_tool_start_common()

    async def on_tool_end(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        self.on_tool_end_common()

    async def on_tool_error(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        self.on_tool_error_common()

    async def on_agent_action(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        self.on_agent_action_common()

    async def on_agent_finish(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        self.on_agent_finish_common()

    async def on_text(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        self.on_text_common()

    # Overriding since BaseModel has __deepcopy__ method as well
    def __deepcopy__(self, memo: dict) -> FakeAsyncCallbackHandler:  # type: ignore[override]
        return self
