"""`LLMResult` class."""

from __future__ import annotations

from copy import deepcopy
from typing import Literal

from pydantic import BaseModel

from langchain_core.outputs.chat_generation import ChatGeneration, ChatGenerationChunk
from langchain_core.outputs.generation import Generation, GenerationChunk
from langchain_core.outputs.run_info import RunInfo


class LLMResult(BaseModel):
    """A container for results of an LLM call.

    Both chat models and LLMs generate an `LLMResult` object. This object contains the
    generated outputs and any additional information that the model provider wants to
    return.
    """

    generations: list[
        list[Generation | ChatGeneration | GenerationChunk | ChatGenerationChunk]
    ]
    """Generated outputs.

    The first dimension of the list represents completions for different input prompts.

    The second dimension of the list represents different candidate generations for a
    given prompt.

    - When returned from **an LLM**, the type is `list[list[Generation]]`.
    - When returned from a **chat model**, the type is `list[list[ChatGeneration]]`.

    `ChatGeneration` is a subclass of `Generation` that has a field for a structured
    chat message.
    """

    llm_output: dict | None = None
    """For arbitrary LLM provider specific output.

    This dictionary is a free-form dictionary that can contain any information that the
    provider wants to return. It is not standardized and is provider-specific.

    Users should generally avoid relying on this field and instead rely on accessing
    relevant information from standardized fields present in AIMessage.
    """

    run: list[RunInfo] | None = None
    """List of metadata info for model call for each input.

    See `langchain_core.outputs.run_info.RunInfo` for details.
    """

    type: Literal["LLMResult"] = "LLMResult"
    """Type is used exclusively for serialization purposes."""

    def flatten(self) -> list[LLMResult]:
        """Flatten generations into a single list.

        Unpack `list[list[Generation]] -> list[LLMResult]` where each returned
        `LLMResult` contains only a single `Generation`. If token usage information is
        available, it is kept only for the `LLMResult` corresponding to the top-choice
        `Generation`, to avoid over-counting of token usage downstream.

        Returns:
            List of `LLMResult` objects where each returned `LLMResult` contains a
                single `Generation`.
        """
        llm_results = []
        for i, gen_list in enumerate(self.generations):
            # Avoid double counting tokens in OpenAICallback
            if i == 0:
                llm_results.append(
                    LLMResult(
                        generations=[gen_list],
                        llm_output=self.llm_output,
                    )
                )
            else:
                if self.llm_output is not None:
                    llm_output = deepcopy(self.llm_output)
                    llm_output["token_usage"] = {}
                else:
                    llm_output = None
                llm_results.append(
                    LLMResult(
                        generations=[gen_list],
                        llm_output=llm_output,
                    )
                )
        return llm_results

    def __eq__(self, other: object) -> bool:
        """Check for `LLMResult` equality by ignoring any metadata related to runs.

        Args:
            other: Another `LLMResult` object to compare against.

        Returns:
            `True` if the generations and `llm_output` are equal, `False` otherwise.
        """
        if not isinstance(other, LLMResult):
            return NotImplemented
        return (
            self.generations == other.generations
            and self.llm_output == other.llm_output
        )

    __hash__ = None  # type: ignore[assignment]
