import asyncio
import re
import sys
import time
import uuid
import warnings
from collections.abc import (
    AsyncIterator,
    Awaitable,
    Callable,
    Iterator,
    Sequence,
)
from functools import partial
from operator import itemgetter
from typing import Any, cast
from uuid import UUID

import pytest
from freezegun import freeze_time
from packaging import version
from pydantic import BaseModel, Field
from pytest_mock import MockerFixture
from syrupy.assertion import SnapshotAssertion
from typing_extensions import TypedDict, override

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.callbacks.manager import (
    AsyncCallbackManagerForRetrieverRun,
    CallbackManagerForRetrieverRun,
    atrace_as_chain_group,
    trace_as_chain_group,
)
from langchain_core.documents import Document
from langchain_core.language_models import (
    FakeListChatModel,
    FakeListLLM,
    FakeStreamingListLLM,
)
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.load import dumpd, dumps
from langchain_core.load.load import loads
from langchain_core.messages import AIMessageChunk, HumanMessage, SystemMessage
from langchain_core.messages.base import BaseMessage
from langchain_core.output_parsers import (
    BaseOutputParser,
    CommaSeparatedListOutputParser,
    StrOutputParser,
)
from langchain_core.outputs.chat_generation import ChatGeneration
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
from langchain_core.prompts import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    MessagesPlaceholder,
    PromptTemplate,
    SystemMessagePromptTemplate,
)
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import (
    AddableDict,
    ConfigurableField,
    ConfigurableFieldMultiOption,
    ConfigurableFieldSingleOption,
    RouterRunnable,
    Runnable,
    RunnableAssign,
    RunnableBinding,
    RunnableBranch,
    RunnableConfig,
    RunnableGenerator,
    RunnableLambda,
    RunnableParallel,
    RunnablePassthrough,
    RunnablePick,
    RunnableSequence,
    add,
    chain,
)
from langchain_core.runnables.base import RunnableMap, RunnableSerializable
from langchain_core.runnables.utils import Input, Output
from langchain_core.tools import BaseTool, tool
from langchain_core.tracers import (
    BaseTracer,
    ConsoleCallbackHandler,
    Run,
    RunLog,
    RunLogPatch,
)
from langchain_core.tracers._compat import pydantic_copy
from langchain_core.tracers.context import collect_runs
from langchain_core.utils.pydantic import PYDANTIC_VERSION
from tests.unit_tests.pydantic_utils import _normalize_schema, _schema
from tests.unit_tests.stubs import AnyStr, _any_id_ai_message, _any_id_ai_message_chunk

PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION
PYDANTIC_VERSION_AT_LEAST_210 = version.parse("2.10") <= PYDANTIC_VERSION


class FakeTracer(BaseTracer):
    """Fake tracer that records LangChain execution.

    It replaces run IDs with deterministic UUIDs for snapshotting.
    """

    def __init__(self) -> None:
        """Initialize the tracer."""
        super().__init__()
        self.runs: list[Run] = []
        self.uuids_map: dict[UUID, UUID] = {}
        self.uuids_generator = (
            UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)
        )

    def _replace_uuid(self, uuid: UUID) -> UUID:
        if uuid not in self.uuids_map:
            self.uuids_map[uuid] = next(self.uuids_generator)
        return self.uuids_map[uuid]

    def _replace_message_id(self, maybe_message: Any) -> Any:
        if isinstance(maybe_message, BaseMessage):
            maybe_message.id = str(next(self.uuids_generator))
        if isinstance(maybe_message, ChatGeneration):
            maybe_message.message.id = str(next(self.uuids_generator))
        if isinstance(maybe_message, LLMResult):
            for i, gen_list in enumerate(maybe_message.generations):
                for j, gen in enumerate(gen_list):
                    maybe_message.generations[i][j] = self._replace_message_id(gen)
        if isinstance(maybe_message, dict):
            for k, v in maybe_message.items():
                maybe_message[k] = self._replace_message_id(v)
        if isinstance(maybe_message, list):
            for i, v in enumerate(maybe_message):
                maybe_message[i] = self._replace_message_id(v)

        return maybe_message

    def _copy_run(self, run: Run) -> Run:
        if run.dotted_order:
            levels = run.dotted_order.split(".")
            processed_levels = []
            for level in levels:
                timestamp, run_id = level.split("Z")
                new_run_id = self._replace_uuid(UUID(run_id))
                processed_level = f"{timestamp}Z{new_run_id}"
                processed_levels.append(processed_level)
            new_dotted_order = ".".join(processed_levels)
        else:
            new_dotted_order = None
        update_dict = {
            "id": self._replace_uuid(run.id),
            "parent_run_id": (
                self.uuids_map[run.parent_run_id] if run.parent_run_id else None
            ),
            "child_runs": [self._copy_run(child) for child in run.child_runs],
            "trace_id": self._replace_uuid(run.trace_id) if run.trace_id else None,
            "dotted_order": new_dotted_order,
            "inputs": self._replace_message_id(run.inputs),
            "outputs": self._replace_message_id(run.outputs),
        }
        return pydantic_copy(run, update=update_dict)

    def _persist_run(self, run: Run) -> None:
        """Persist a run."""
        self.runs.append(self._copy_run(run))

    def flattened_runs(self) -> list[Run]:
        q = [*self.runs]
        result = []
        while q:
            parent = q.pop()
            result.append(parent)
            if parent.child_runs:
                q.extend(parent.child_runs)
        return result

    @property
    def run_ids(self) -> list[uuid.UUID | None]:
        runs = self.flattened_runs()
        uuids_map = {v: k for k, v in self.uuids_map.items()}
        return [uuids_map.get(r.id) for r in runs]


class FakeRunnable(Runnable[str, int]):
    @override
    def invoke(
        self,
        input: str,
        config: RunnableConfig | None = None,
        **kwargs: Any,
    ) -> int:
        return len(input)


class FakeRunnableSerializable(RunnableSerializable[str, int]):
    hello: str = ""

    @override
    def invoke(
        self,
        input: str,
        config: RunnableConfig | None = None,
        **kwargs: Any,
    ) -> int:
        return len(input)


class FakeRetriever(BaseRetriever):
    @override
    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> list[Document]:
        return [Document(page_content="foo"), Document(page_content="bar")]

    @override
    async def _aget_relevant_documents(
        self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
    ) -> list[Document]:
        return [Document(page_content="foo"), Document(page_content="bar")]


@pytest.mark.skipif(
    PYDANTIC_VERSION_AT_LEAST_210,
    reason=(
        "Only test with most recent version of pydantic. "
        "Pydantic introduced small fixes to generated JSONSchema on minor versions."
    ),
)
def test_schemas(snapshot: SnapshotAssertion) -> None:
    fake = FakeRunnable()  # str -> int

    assert fake.get_input_jsonschema() == {
        "title": "FakeRunnableInput",
        "type": "string",
    }
    assert fake.get_output_jsonschema() == {
        "title": "FakeRunnableOutput",
        "type": "integer",
    }
    assert fake.get_config_jsonschema(include=["tags", "metadata", "run_name"]) == {
        "properties": {
            "metadata": {
                "default": None,
                "title": "Metadata",
                "type": "object",
            },
            "run_name": {"default": None, "title": "Run Name", "type": "string"},
            "tags": {
                "default": None,
                "items": {"type": "string"},
                "title": "Tags",
                "type": "array",
            },
        },
        "title": "FakeRunnableConfig",
        "type": "object",
    }

    fake_bound = FakeRunnable().bind(a="b")  # str -> int

    assert fake_bound.get_input_jsonschema() == {
        "title": "FakeRunnableInput",
        "type": "string",
    }
    assert fake_bound.get_output_jsonschema() == {
        "title": "FakeRunnableOutput",
        "type": "integer",
    }

    fake_w_fallbacks = FakeRunnable().with_fallbacks((fake,))  # str -> int

    assert fake_w_fallbacks.get_input_jsonschema() == {
        "title": "FakeRunnableInput",
        "type": "string",
    }
    assert fake_w_fallbacks.get_output_jsonschema() == {
        "title": "FakeRunnableOutput",
        "type": "integer",
    }

    def typed_lambda_impl(x: str) -> int:
        return len(x)

    typed_lambda = RunnableLambda(typed_lambda_impl)  # str -> int

    assert typed_lambda.get_input_jsonschema() == {
        "title": "typed_lambda_impl_input",
        "type": "string",
    }
    assert typed_lambda.get_output_jsonschema() == {
        "title": "typed_lambda_impl_output",
        "type": "integer",
    }

    async def typed_async_lambda_impl(x: str) -> int:
        return len(x)

    typed_async_lambda = RunnableLambda(typed_async_lambda_impl)  # str -> int

    assert typed_async_lambda.get_input_jsonschema() == {
        "title": "typed_async_lambda_impl_input",
        "type": "string",
    }
    assert typed_async_lambda.get_output_jsonschema() == {
        "title": "typed_async_lambda_impl_output",
        "type": "integer",
    }

    fake_ret = FakeRetriever()  # str -> list[Document]

    assert fake_ret.get_input_jsonschema() == {
        "title": "FakeRetrieverInput",
        "type": "string",
    }
    assert _normalize_schema(fake_ret.get_output_jsonschema()) == {
        "$defs": {
            "Document": {
                "description": "Class for storing a piece of text and "
                "associated metadata.\n"
                "\n"
                "!!! note\n"
                "\n"
                "    `Document` is for **retrieval workflows**, not chat I/O. For "
                "sending text\n"
                "    to an LLM in a conversation, use message types from "
                "`langchain.messages`.\n"
                "\n"
                "Example:\n"
                "    ```python\n"
                "    from langchain_core.documents import Document\n"
                "\n"
                "    document = Document(\n"
                '        page_content="Hello, world!", '
                'metadata={"source": "https://example.com"}\n'
                "    )\n"
                "    ```",
                "properties": {
                    "id": {
                        "anyOf": [{"type": "string"}, {"type": "null"}],
                        "default": None,
                        "title": "Id",
                    },
                    "metadata": {"title": "Metadata", "type": "object"},
                    "page_content": {"title": "Page Content", "type": "string"},
                    "type": {
                        "const": "Document",
                        "default": "Document",
                        "title": "Type",
                    },
                },
                "required": ["page_content"],
                "title": "Document",
                "type": "object",
            }
        },
        "items": {"$ref": "#/$defs/Document"},
        "title": "FakeRetrieverOutput",
        "type": "array",
    }

    fake_llm = FakeListLLM(responses=["a"])  # str -> list[list[str]]

    assert _schema(fake_llm.input_schema) == snapshot(name="fake_llm_input_schema")
    assert _schema(fake_llm.output_schema) == {
        "title": "FakeListLLMOutput",
        "type": "string",
    }

    fake_chat = FakeListChatModel(responses=["a"])  # str -> list[list[str]]

    assert _schema(fake_chat.input_schema) == snapshot(name="fake_chat_input_schema")
    assert _schema(fake_chat.output_schema) == snapshot(name="fake_chat_output_schema")

    chat_prompt = ChatPromptTemplate.from_messages(
        [
            MessagesPlaceholder(variable_name="history"),
            ("human", "Hello, how are you?"),
        ]
    )

    assert _normalize_schema(chat_prompt.get_input_jsonschema()) == snapshot(
        name="chat_prompt_input_schema"
    )
    assert _normalize_schema(chat_prompt.get_output_jsonschema()) == snapshot(
        name="chat_prompt_output_schema"
    )

    prompt = PromptTemplate.from_template("Hello, {name}!")

    assert prompt.get_input_jsonschema() == {
        "title": "PromptInput",
        "type": "object",
        "properties": {"name": {"title": "Name", "type": "string"}},
        "required": ["name"],
    }
    assert _schema(prompt.output_schema) == snapshot(name="prompt_output_schema")

    prompt_mapper = PromptTemplate.from_template("Hello, {name}!").map()

    assert _normalize_schema(prompt_mapper.get_input_jsonschema()) == {
        "$defs": {
            "PromptInput": {
                "properties": {"name": {"title": "Name", "type": "string"}},
                "required": ["name"],
                "title": "PromptInput",
                "type": "object",
            }
        },
        "default": None,
        "items": {"$ref": "#/$defs/PromptInput"},
        "title": "RunnableEach<PromptTemplate>Input",
        "type": "array",
    }
    assert _schema(prompt_mapper.output_schema) == snapshot(
        name="prompt_mapper_output_schema"
    )

    list_parser = CommaSeparatedListOutputParser()

    assert _schema(list_parser.input_schema) == snapshot(
        name="list_parser_input_schema"
    )
    assert _schema(list_parser.output_schema) == {
        "title": "CommaSeparatedListOutputParserOutput",
        "type": "array",
        "items": {"type": "string"},
    }

    seq = prompt | fake_llm | list_parser

    assert seq.get_input_jsonschema() == {
        "title": "PromptInput",
        "type": "object",
        "properties": {"name": {"title": "Name", "type": "string"}},
        "required": ["name"],
    }
    assert seq.get_output_jsonschema() == {
        "type": "array",
        "items": {"type": "string"},
        "title": "CommaSeparatedListOutputParserOutput",
    }

    router: Runnable = RouterRunnable({})

    assert _schema(router.input_schema) == {
        "$ref": "#/definitions/RouterInput",
        "definitions": {
            "RouterInput": {
                "description": "Router input.",
                "properties": {
                    "input": {"title": "Input"},
                    "key": {"title": "Key", "type": "string"},
                },
                "required": ["key", "input"],
                "title": "RouterInput",
                "type": "object",
            }
        },
        "title": "RouterRunnableInput",
    }
    assert router.get_output_jsonschema() == {"title": "RouterRunnableOutput"}

    seq_w_map: Runnable = (
        prompt
        | fake_llm
        | {
            "original": RunnablePassthrough(input_type=str),
            "as_list": list_parser,
            "length": typed_lambda_impl,
        }
    )

    assert seq_w_map.get_input_jsonschema() == {
        "title": "PromptInput",
        "type": "object",
        "properties": {"name": {"title": "Name", "type": "string"}},
        "required": ["name"],
    }
    assert seq_w_map.get_output_jsonschema() == {
        "title": "RunnableParallel<original,as_list,length>Output",
        "type": "object",
        "properties": {
            "original": {"title": "Original", "type": "string"},
            "length": {"title": "Length", "type": "integer"},
            "as_list": {
                "title": "As List",
                "type": "array",
                "items": {"type": "string"},
            },
        },
        "required": ["original", "as_list", "length"],
    }

    # Add a test for schema of runnable assign
    def foo(x: int) -> int:
        return x

    foo_ = RunnableLambda(foo)

    assert foo_.assign(bar=lambda _: "foo").get_output_schema().model_json_schema() == {
        "properties": {"bar": {"title": "Bar"}, "root": {"title": "Root"}},
        "required": ["root", "bar"],
        "title": "RunnableAssignOutput",
        "type": "object",
    }


def test_passthrough_assign_schema() -> None:
    retriever = FakeRetriever()  # str -> list[Document]
    prompt = PromptTemplate.from_template("{context} {question}")
    fake_llm = FakeListLLM(responses=["a"])  # str -> list[list[str]]

    seq_w_assign = (
        RunnablePassthrough.assign(context=itemgetter("question") | retriever)
        | prompt
        | fake_llm
    )

    assert seq_w_assign.get_input_jsonschema() == {
        "properties": {"question": {"title": "Question", "type": "string"}},
        "title": "RunnableSequenceInput",
        "type": "object",
        "required": ["question"],
    }
    assert seq_w_assign.get_output_jsonschema() == {
        "title": "FakeListLLMOutput",
        "type": "string",
    }

    invalid_seq_w_assign = (
        RunnablePassthrough.assign(context=itemgetter("question") | retriever)
        | fake_llm
    )

    # fallback to RunnableAssign.input_schema if next runnable doesn't have
    # expected dict input_schema
    assert invalid_seq_w_assign.get_input_jsonschema() == {
        "properties": {"question": {"title": "Question"}},
        "title": "RunnableParallel<context>Input",
        "type": "object",
        "required": ["question"],
    }


def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
    first_lambda = lambda x: x["hello"]  # noqa: E731
    assert RunnableLambda(first_lambda).get_input_jsonschema() == {
        "title": "RunnableLambdaInput",
        "type": "object",
        "properties": {"hello": {"title": "Hello"}},
        "required": ["hello"],
    }

    second_lambda = lambda x, y: (x["hello"], x["bye"], y["bah"])  # noqa: E731
    assert RunnableLambda(second_lambda).get_input_jsonschema() == {
        "title": "RunnableLambdaInput",
        "type": "object",
        "properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}},
        "required": ["bye", "hello"],
    }

    def get_value(value):  # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
        return value["variable_name"]

    assert RunnableLambda(get_value).get_input_jsonschema() == {
        "title": "get_value_input",
        "type": "object",
        "properties": {"variable_name": {"title": "Variable Name"}},
        "required": ["variable_name"],
    }

    async def aget_value(value):  # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
        return (value["variable_name"], value.get("another"))

    assert RunnableLambda(aget_value).get_input_jsonschema() == {
        "title": "aget_value_input",
        "type": "object",
        "properties": {
            "another": {"title": "Another"},
            "variable_name": {"title": "Variable Name"},
        },
        "required": ["another", "variable_name"],
    }

    async def aget_values(value):  # type: ignore[no-untyped-def] # noqa: ANN001,ANN202
        return {
            "hello": value["variable_name"],
            "bye": value["variable_name"],
            "byebye": value["yo"],
        }

    assert RunnableLambda(aget_values).get_input_jsonschema() == {
        "title": "aget_values_input",
        "type": "object",
        "properties": {
            "variable_name": {"title": "Variable Name"},
            "yo": {"title": "Yo"},
        },
        "required": ["variable_name", "yo"],
    }

    class InputType(TypedDict):
        variable_name: str
        yo: int

    class OutputType(TypedDict):
        hello: str
        bye: str
        byebye: int

    async def aget_values_typed(value: InputType) -> OutputType:
        return {
            "hello": value["variable_name"],
            "bye": value["variable_name"],
            "byebye": value["yo"],
        }

    assert _normalize_schema(
        RunnableLambda(aget_values_typed).get_input_jsonschema()
    ) == _normalize_schema(
        {
            "$defs": {
                "InputType": {
                    "properties": {
                        "variable_name": {
                            "title": "Variable Name",
                            "type": "string",
                        },
                        "yo": {"title": "Yo", "type": "integer"},
                    },
                    "required": ["variable_name", "yo"],
                    "title": "InputType",
                    "type": "object",
                }
            },
            "allOf": [{"$ref": "#/$defs/InputType"}],
            "title": "aget_values_typed_input",
        }
    )

    if PYDANTIC_VERSION_AT_LEAST_29:
        assert _normalize_schema(
            RunnableLambda(aget_values_typed).get_output_jsonschema()
        ) == snapshot(name="schema8")


def test_with_types_with_type_generics() -> None:
    """Verify that with_types works if we use things like list[int]."""

    def foo(x: int) -> None:
        """Add one to the input."""
        raise NotImplementedError

    # Try specifying some
    RunnableLambda(foo).with_types(
        output_type=list[int],  # type: ignore[arg-type]
        input_type=list[int],  # type: ignore[arg-type]
    )
    RunnableLambda(foo).with_types(
        output_type=Sequence[int],  # type: ignore[arg-type]
        input_type=Sequence[int],  # type: ignore[arg-type]
    )


def test_schema_with_itemgetter() -> None:
    """Test runnable with itemgetter."""
    foo = RunnableLambda(itemgetter("hello"))
    assert _schema(foo.input_schema) == {
        "properties": {"hello": {"title": "Hello"}},
        "required": ["hello"],
        "title": "RunnableLambdaInput",
        "type": "object",
    }
    prompt = ChatPromptTemplate.from_template("what is {language}?")
    chain: Runnable = {"language": itemgetter("language")} | prompt
    assert _schema(chain.input_schema) == {
        "properties": {"language": {"title": "Language"}},
        "required": ["language"],
        "title": "RunnableParallel<language>Input",
        "type": "object",
    }


def test_schema_complex_seq() -> None:
    prompt1 = ChatPromptTemplate.from_template("what is the city {person} is from?")
    prompt2 = ChatPromptTemplate.from_template(
        "what country is the city {city} in? respond in {language}"
    )

    model = FakeListChatModel(responses=[""])

    chain1: Runnable = RunnableSequence(
        prompt1, model, StrOutputParser(), name="city_chain"
    )

    assert chain1.name == "city_chain"

    chain2: Runnable = (
        {"city": chain1, "language": itemgetter("language")}
        | prompt2
        | model
        | StrOutputParser()
    )

    assert chain2.get_input_jsonschema() == {
        "title": "RunnableParallel<city,language>Input",
        "type": "object",
        "properties": {
            "person": {"title": "Person", "type": "string"},
            "language": {"title": "Language"},
        },
        "required": ["person", "language"],
    }

    assert chain2.get_output_jsonschema() == {
        "title": "StrOutputParserOutput",
        "type": "string",
    }

    assert chain2.with_types(input_type=str).get_input_jsonschema() == {
        "title": "RunnableSequenceInput",
        "type": "string",
    }

    assert chain2.with_types(input_type=int).get_output_jsonschema() == {
        "title": "StrOutputParserOutput",
        "type": "string",
    }

    class InputType(BaseModel):
        person: str

    assert chain2.with_types(input_type=InputType).get_input_jsonschema() == {
        "title": "InputType",
        "type": "object",
        "properties": {"person": {"title": "Person", "type": "string"}},
        "required": ["person"],
    }


def test_configurable_fields(snapshot: SnapshotAssertion) -> None:
    fake_llm = FakeListLLM(responses=["a"])  # str -> list[list[str]]

    assert fake_llm.invoke("...") == "a"

    fake_llm_configurable = fake_llm.configurable_fields(
        responses=ConfigurableField(
            id="llm_responses",
            name="LLM Responses",
            description="A list of fake responses for this LLM",
        )
    )

    assert fake_llm_configurable.invoke("...") == "a"

    if PYDANTIC_VERSION_AT_LEAST_29:
        assert _normalize_schema(
            fake_llm_configurable.get_config_jsonschema()
        ) == snapshot(name="schema2")

    fake_llm_configured = fake_llm_configurable.with_config(
        configurable={"llm_responses": ["b"]}
    )

    assert fake_llm_configured.invoke("...") == "b"

    prompt = PromptTemplate.from_template("Hello, {name}!")

    assert prompt.invoke({"name": "John"}) == StringPromptValue(text="Hello, John!")

    prompt_configurable = prompt.configurable_fields(
        template=ConfigurableField(
            id="prompt_template",
            name="Prompt Template",
            description="The prompt template for this chain",
        )
    )

    assert prompt_configurable.invoke({"name": "John"}) == StringPromptValue(
        text="Hello, John!"
    )

    if PYDANTIC_VERSION_AT_LEAST_29:
        assert _normalize_schema(
            prompt_configurable.get_config_jsonschema()
        ) == snapshot(name="schema3")

    prompt_configured = prompt_configurable.with_config(
        configurable={"prompt_template": "Hello, {name}! {name}!"}
    )

    assert prompt_configured.invoke({"name": "John"}) == StringPromptValue(
        text="Hello, John! John!"
    )

    assert prompt_configurable.with_config(
        configurable={"prompt_template": "Hello {name} in {lang}"}
    ).get_input_jsonschema() == {
        "title": "PromptInput",
        "type": "object",
        "properties": {
            "lang": {"title": "Lang", "type": "string"},
            "name": {"title": "Name", "type": "string"},
        },
        "required": ["lang", "name"],
    }

    chain_configurable = prompt_configurable | fake_llm_configurable | StrOutputParser()

    assert chain_configurable.invoke({"name": "John"}) == "a"

    if PYDANTIC_VERSION_AT_LEAST_29:
        assert _normalize_schema(
            chain_configurable.get_config_jsonschema()
        ) == snapshot(name="schema4")

    assert (
        chain_configurable.with_config(
            configurable={
                "prompt_template": "A very good morning to you, {name} {lang}!",
                "llm_responses": ["c"],
            }
        ).invoke({"name": "John", "lang": "en"})
        == "c"
    )

    assert chain_configurable.with_config(
        configurable={
            "prompt_template": "A very good morning to you, {name} {lang}!",
            "llm_responses": ["c"],
        }
    ).get_input_jsonschema() == {
        "title": "PromptInput",
        "type": "object",
        "properties": {
            "lang": {"title": "Lang", "type": "string"},
            "name": {"title": "Name", "type": "string"},
        },
        "required": ["lang", "name"],
    }

    chain_with_map_configurable: Runnable = prompt_configurable | {
        "llm1": fake_llm_configurable | StrOutputParser(),
        "llm2": fake_llm_configurable | StrOutputParser(),
        "llm3": fake_llm.configurable_fields(
            responses=ConfigurableField("other_responses")
        )
        | StrOutputParser(),
    }

    assert chain_with_map_configurable.invoke({"name": "John"}) == {
        "llm1": "a",
        "llm2": "a",
        "llm3": "a",
    }

    if PYDANTIC_VERSION_AT_LEAST_29:
        assert _normalize_schema(
            chain_with_map_configurable.get_config_jsonschema()
        ) == snapshot(name="schema5")

    assert chain_with_map_configurable.with_config(
        configurable={
            "prompt_template": "A very good morning to you, {name}!",
            "llm_responses": ["c"],
            "other_responses": ["d"],
        }
    ).invoke({"name": "John"}) == {"llm1": "c", "llm2": "c", "llm3": "d"}


def test_configurable_alts_factory() -> None:
    fake_llm = FakeListLLM(responses=["a"]).configurable_alternatives(
        ConfigurableField(id="llm", name="LLM"),
        chat=partial(FakeListLLM, responses=["b"]),
    )

    assert fake_llm.invoke("...") == "a"

    assert fake_llm.with_config(configurable={"llm": "chat"}).invoke("...") == "b"


def test_configurable_fields_prefix_keys(snapshot: SnapshotAssertion) -> None:
    fake_chat = FakeListChatModel(responses=["b"]).configurable_fields(
        responses=ConfigurableFieldMultiOption(
            id="responses",
            name="Chat Responses",
            options={
                "hello": "A good morning to you!",
                "bye": "See you later!",
                "helpful": "How can I help you?",
            },
            default=["hello", "bye"],
        ),
        # (sleep is a configurable field in FakeListChatModel)
        sleep=ConfigurableField(
            id="chat_sleep",
            is_shared=True,
        ),
    )
    fake_llm = (
        FakeListLLM(responses=["a"])
        .configurable_fields(
            responses=ConfigurableField(
                id="responses",
                name="LLM Responses",
                description="A list of fake responses for this LLM",
            )
        )
        .configurable_alternatives(
            ConfigurableField(id="llm", name="LLM"),
            chat=fake_chat | StrOutputParser(),
            prefix_keys=True,
        )
    )
    prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields(
        template=ConfigurableFieldSingleOption(
            id="prompt_template",
            name="Prompt Template",
            description="The prompt template for this chain",
            options={
                "hello": "Hello, {name}!",
                "good_morning": "A very good morning to you, {name}!",
            },
            default="hello",
        )
    )

    chain = prompt | fake_llm

    if PYDANTIC_VERSION_AT_LEAST_29:
        assert _normalize_schema(_schema(chain.config_schema())) == snapshot(
            name="schema6"
        )


def test_configurable_fields_example(snapshot: SnapshotAssertion) -> None:
    fake_chat = FakeListChatModel(responses=["b"]).configurable_fields(
        responses=ConfigurableFieldMultiOption(
            id="chat_responses",
            name="Chat Responses",
            options={
                "hello": "A good morning to you!",
                "bye": "See you later!",
                "helpful": "How can I help you?",
            },
            default=["hello", "bye"],
        )
    )
    fake_llm = (
        FakeListLLM(responses=["a"])
        .configurable_fields(
            responses=ConfigurableField(
                id="llm_responses",
                name="LLM Responses",
                description="A list of fake responses for this LLM",
            )
        )
        .configurable_alternatives(
            ConfigurableField(id="llm", name="LLM"),
            chat=fake_chat | StrOutputParser(),
        )
    )

    prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields(
        template=ConfigurableFieldSingleOption(
            id="prompt_template",
            name="Prompt Template",
            description="The prompt template for this chain",
            options={
                "hello": "Hello, {name}!",
                "good_morning": "A very good morning to you, {name}!",
            },
            default="hello",
        )
    )

    # deduplication of configurable fields
    chain_configurable = prompt | fake_llm | (lambda x: {"name": x}) | prompt | fake_llm

    assert chain_configurable.invoke({"name": "John"}) == "a"

    if PYDANTIC_VERSION_AT_LEAST_29:
        assert _normalize_schema(
            chain_configurable.get_config_jsonschema()
        ) == snapshot(name="schema7")

    assert (
        chain_configurable.with_config(configurable={"llm": "chat"}).invoke(
            {"name": "John"}
        )
        == "A good morning to you!"
    )

    assert (
        chain_configurable.with_config(
            configurable={"llm": "chat", "chat_responses": ["helpful"]}
        ).invoke({"name": "John"})
        == "How can I help you?"
    )


def test_passthrough_tap(mocker: MockerFixture) -> None:
    fake = FakeRunnable()
    mock = mocker.Mock()

    seq = RunnablePassthrough[Any](mock) | fake | RunnablePassthrough[Any](mock)

    assert seq.invoke("hello", my_kwarg="value") == 5
    assert mock.call_args_list == [
        mocker.call("hello", my_kwarg="value"),
        mocker.call(5),
    ]
    mock.reset_mock()

    assert seq.batch(["hello", "byebye"], my_kwarg="value") == [5, 6]
    assert len(mock.call_args_list) == 4
    for call in [
        mocker.call("hello", my_kwarg="value"),
        mocker.call("byebye", my_kwarg="value"),
        mocker.call(5),
        mocker.call(6),
    ]:
        assert call in mock.call_args_list
    mock.reset_mock()

    assert seq.batch(["hello", "byebye"], my_kwarg="value", return_exceptions=True) == [
        5,
        6,
    ]
    assert len(mock.call_args_list) == 4
    for call in [
        mocker.call("hello", my_kwarg="value"),
        mocker.call("byebye", my_kwarg="value"),
        mocker.call(5),
        mocker.call(6),
    ]:
        assert call in mock.call_args_list
    mock.reset_mock()

    assert sorted(
        a
        for a in seq.batch_as_completed(
            ["hello", "byebye"], my_kwarg="value", return_exceptions=True
        )
    ) == [
        (0, 5),
        (1, 6),
    ]
    assert len(mock.call_args_list) == 4
    for call in [
        mocker.call("hello", my_kwarg="value"),
        mocker.call("byebye", my_kwarg="value"),
        mocker.call(5),
        mocker.call(6),
    ]:
        assert call in mock.call_args_list
    mock.reset_mock()

    assert list(
        seq.stream("hello", {"metadata": {"key": "value"}}, my_kwarg="value")
    ) == [5]
    assert mock.call_args_list == [
        mocker.call("hello", my_kwarg="value"),
        mocker.call(5),
    ]
    mock.reset_mock()


async def test_passthrough_tap_async(mocker: MockerFixture) -> None:
    fake = FakeRunnable()
    mock = mocker.Mock()

    seq = RunnablePassthrough[Any](mock) | fake | RunnablePassthrough[Any](mock)

    assert await seq.ainvoke("hello", my_kwarg="value") == 5
    assert mock.call_args_list == [
        mocker.call("hello", my_kwarg="value"),
        mocker.call(5),
    ]
    mock.reset_mock()

    assert await seq.abatch(["hello", "byebye"], my_kwarg="value") == [5, 6]
    assert len(mock.call_args_list) == 4
    for call in [
        mocker.call("hello", my_kwarg="value"),
        mocker.call("byebye", my_kwarg="value"),
        mocker.call(5),
        mocker.call(6),
    ]:
        assert call in mock.call_args_list
    mock.reset_mock()

    assert await seq.abatch(
        ["hello", "byebye"], my_kwarg="value", return_exceptions=True
    ) == [
        5,
        6,
    ]
    assert len(mock.call_args_list) == 4
    for call in [
        mocker.call("hello", my_kwarg="value"),
        mocker.call("byebye", my_kwarg="value"),
        mocker.call(5),
        mocker.call(6),
    ]:
        assert call in mock.call_args_list
    mock.reset_mock()

    assert sorted(
        [
            a
            async for a in seq.abatch_as_completed(
                ["hello", "byebye"], my_kwarg="value", return_exceptions=True
            )
        ]
    ) == [
        (0, 5),
        (1, 6),
    ]
    assert len(mock.call_args_list) == 4
    for call in [
        mocker.call("hello", my_kwarg="value"),
        mocker.call("byebye", my_kwarg="value"),
        mocker.call(5),
        mocker.call(6),
    ]:
        assert call in mock.call_args_list
    mock.reset_mock()

    assert [
        part
        async for part in seq.astream(
            "hello", {"metadata": {"key": "value"}}, my_kwarg="value"
        )
    ] == [5]
    assert mock.call_args_list == [
        mocker.call("hello", my_kwarg="value"),
        mocker.call(5),
    ]


async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None:
    fake = FakeRunnableSerializable()
    spy = mocker.spy(fake.__class__, "invoke")
    fakew = fake.configurable_fields(hello=ConfigurableField(id="hello", name="Hello"))

    assert (
        fakew.with_config(tags=["a-tag"]).invoke(
            "hello",
            {
                "configurable": {"hello": "there", "__secret_key": "nahnah"},
                "metadata": {"bye": "now"},
            },
        )
        == 5
    )
    assert spy.call_args_list[0].args[1:] == (
        "hello",
        {
            "tags": ["a-tag"],
            "callbacks": None,
            "recursion_limit": 25,
            "configurable": {"hello": "there", "__secret_key": "nahnah"},
            "metadata": {"hello": "there", "bye": "now"},
        },
    )
    spy.reset_mock()


def test_with_config(mocker: MockerFixture) -> None:
    fake = FakeRunnable()
    spy = mocker.spy(fake, "invoke")

    assert fake.with_config(tags=["a-tag"]).invoke("hello") == 5
    assert spy.call_args_list == [
        mocker.call(
            "hello",
            {"tags": ["a-tag"], "metadata": {}, "configurable": {}},
        ),
    ]
    spy.reset_mock()

    fake_1 = RunnablePassthrough[Any]()
    fake_2 = RunnablePassthrough[Any]()
    spy_seq_step = mocker.spy(fake_1.__class__, "invoke")

    sequence = fake_1.with_config(tags=["a-tag"]) | fake_2.with_config(
        tags=["b-tag"], max_concurrency=5
    )
    assert sequence.invoke("hello") == "hello"
    assert len(spy_seq_step.call_args_list) == 2
    for i, call in enumerate(spy_seq_step.call_args_list):
        assert call.args[1] == "hello"
        if i == 0:
            assert call.args[2].get("tags") == ["a-tag"]
            assert call.args[2].get("max_concurrency") is None
        else:
            assert call.args[2].get("tags") == ["b-tag"]
            assert call.args[2].get("max_concurrency") == 5
    mocker.stop(spy_seq_step)

    assert [
        *fake.with_config(tags=["a-tag"]).stream(
            "hello", {"metadata": {"key": "value"}}
        )
    ] == [5]
    assert spy.call_args_list == [
        mocker.call(
            "hello",
            {"tags": ["a-tag"], "metadata": {"key": "value"}, "configurable": {}},
        ),
    ]
    spy.reset_mock()

    assert fake.with_config(recursion_limit=5).batch(
        ["hello", "wooorld"], [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}]
    ) == [5, 7]

    assert len(spy.call_args_list) == 2
    for i, call in enumerate(
        sorted(spy.call_args_list, key=lambda x: 0 if x.args[0] == "hello" else 1)
    ):
        assert call.args[0] == ("hello" if i == 0 else "wooorld")
        if i == 0:
            assert call.args[1].get("recursion_limit") == 5
            assert call.args[1].get("tags") == ["a-tag"]
            assert call.args[1].get("metadata") == {}
        else:
            assert call.args[1].get("recursion_limit") == 5
            assert call.args[1].get("tags") == []
            assert call.args[1].get("metadata") == {"key": "value"}

    spy.reset_mock()

    assert sorted(
        c
        for c in fake.with_config(recursion_limit=5).batch_as_completed(
            ["hello", "wooorld"],
            [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}],
        )
    ) == [(0, 5), (1, 7)]

    assert len(spy.call_args_list) == 2
    for i, call in enumerate(
        sorted(spy.call_args_list, key=lambda x: 0 if x.args[0] == "hello" else 1)
    ):
        assert call.args[0] == ("hello" if i == 0 else "wooorld")
        if i == 0:
            assert call.args[1].get("recursion_limit") == 5
            assert call.args[1].get("tags") == ["a-tag"]
            assert call.args[1].get("metadata") == {}
        else:
            assert call.args[1].get("recursion_limit") == 5
            assert call.args[1].get("tags") == []
            assert call.args[1].get("metadata") == {"key": "value"}

    spy.reset_mock()

    assert fake.with_config(metadata={"a": "b"}).batch(
        ["hello", "wooorld"], {"tags": ["a-tag"]}
    ) == [5, 7]
    assert len(spy.call_args_list) == 2
    for i, call in enumerate(spy.call_args_list):
        assert call.args[0] == ("hello" if i == 0 else "wooorld")
        assert call.args[1].get("tags") == ["a-tag"]
        assert call.args[1].get("metadata") == {"a": "b"}
    spy.reset_mock()

    assert sorted(
        c for c in fake.batch_as_completed(["hello", "wooorld"], {"tags": ["a-tag"]})
    ) == [(0, 5), (1, 7)]
    assert len(spy.call_args_list) == 2
    for i, call in enumerate(spy.call_args_list):
        assert call.args[0] == ("hello" if i == 0 else "wooorld")
        assert call.args[1].get("tags") == ["a-tag"]


async def test_with_config_async(mocker: MockerFixture) -> None:
    fake = FakeRunnable()
    spy = mocker.spy(fake, "invoke")

    handler = ConsoleCallbackHandler()
    assert (
        await fake.with_config(metadata={"a": "b"}).ainvoke(
            "hello", config={"callbacks": [handler]}
        )
        == 5
    )
    assert spy.call_args_list == [
        mocker.call(
            "hello",
            {
                "callbacks": [handler],
                "metadata": {"a": "b"},
                "configurable": {},
                "tags": [],
            },
        ),
    ]
    spy.reset_mock()

    assert [
        part async for part in fake.with_config(metadata={"a": "b"}).astream("hello")
    ] == [5]
    assert spy.call_args_list == [
        mocker.call("hello", {"metadata": {"a": "b"}, "tags": [], "configurable": {}}),
    ]
    spy.reset_mock()

    assert await fake.with_config(recursion_limit=5, tags=["c"]).abatch(
        ["hello", "wooorld"], {"metadata": {"key": "value"}}
    ) == [
        5,
        7,
    ]
    assert sorted(spy.call_args_list) == [
        mocker.call(
            "hello",
            {
                "metadata": {"key": "value"},
                "tags": ["c"],
                "callbacks": None,
                "recursion_limit": 5,
                "configurable": {},
            },
        ),
        mocker.call(
            "wooorld",
            {
                "metadata": {"key": "value"},
                "tags": ["c"],
                "callbacks": None,
                "recursion_limit": 5,
                "configurable": {},
            },
        ),
    ]
    spy.reset_mock()

    assert sorted(
        [
            c
            async for c in fake.with_config(
                recursion_limit=5, tags=["c"]
            ).abatch_as_completed(["hello", "wooorld"], {"metadata": {"key": "value"}})
        ]
    ) == [
        (0, 5),
        (1, 7),
    ]
    assert len(spy.call_args_list) == 2
    first_call = next(call for call in spy.call_args_list if call.args[0] == "hello")
    assert first_call == mocker.call(
        "hello",
        {
            "metadata": {"key": "value"},
            "tags": ["c"],
            "callbacks": None,
            "recursion_limit": 5,
            "configurable": {},
        },
    )
    second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld")
    assert second_call == mocker.call(
        "wooorld",
        {
            "metadata": {"key": "value"},
            "tags": ["c"],
            "callbacks": None,
            "recursion_limit": 5,
            "configurable": {},
        },
    )


def test_default_method_implementations(mocker: MockerFixture) -> None:
    fake = FakeRunnable()
    spy = mocker.spy(fake, "invoke")

    assert fake.invoke("hello", {"tags": ["a-tag"]}) == 5
    assert spy.call_args_list == [
        mocker.call("hello", {"tags": ["a-tag"]}),
    ]
    spy.reset_mock()

    assert [*fake.stream("hello", {"metadata": {"key": "value"}})] == [5]
    assert spy.call_args_list == [
        mocker.call("hello", {"metadata": {"key": "value"}}),
    ]
    spy.reset_mock()

    assert fake.batch(
        ["hello", "wooorld"], [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}]
    ) == [5, 7]

    assert len(spy.call_args_list) == 2
    for call in spy.call_args_list:
        call_arg = call.args[0]

        if call_arg == "hello":
            assert call_arg == "hello"
            assert call.args[1].get("tags") == ["a-tag"]
            assert call.args[1].get("metadata") == {}
        else:
            assert call_arg == "wooorld"
            assert call.args[1].get("tags") == []
            assert call.args[1].get("metadata") == {"key": "value"}

    spy.reset_mock()

    assert fake.batch(["hello", "wooorld"], {"tags": ["a-tag"]}) == [5, 7]
    assert len(spy.call_args_list) == 2
    assert {call.args[0] for call in spy.call_args_list} == {"hello", "wooorld"}
    for call in spy.call_args_list:
        assert call.args[1].get("tags") == ["a-tag"]
        assert call.args[1].get("metadata") == {}


async def test_default_method_implementations_async(mocker: MockerFixture) -> None:
    fake = FakeRunnable()
    spy = mocker.spy(fake, "invoke")

    assert await fake.ainvoke("hello", config={"callbacks": []}) == 5
    assert spy.call_args_list == [
        mocker.call("hello", {"callbacks": []}),
    ]
    spy.reset_mock()

    assert [part async for part in fake.astream("hello")] == [5]
    assert spy.call_args_list == [
        mocker.call("hello", None),
    ]
    spy.reset_mock()

    assert await fake.abatch(["hello", "wooorld"], {"metadata": {"key": "value"}}) == [
        5,
        7,
    ]
    assert {call.args[0] for call in spy.call_args_list} == {"hello", "wooorld"}
    for call in spy.call_args_list:
        assert call.args[1] == {
            "metadata": {"key": "value"},
            "tags": [],
            "callbacks": None,
            "recursion_limit": 25,
            "configurable": {},
        }


def test_prompt() -> None:
    prompt = ChatPromptTemplate.from_messages(
        messages=[
            SystemMessage(content="You are a nice assistant."),
            HumanMessagePromptTemplate.from_template("{question}"),
        ]
    )
    expected = ChatPromptValue(
        messages=[
            SystemMessage(content="You are a nice assistant."),
            HumanMessage(content="What is your name?"),
        ]
    )

    assert prompt.invoke({"question": "What is your name?"}) == expected

    assert prompt.batch(
        [
            {"question": "What is your name?"},
            {"question": "What is your favorite color?"},
        ]
    ) == [
        expected,
        ChatPromptValue(
            messages=[
                SystemMessage(content="You are a nice assistant."),
                HumanMessage(content="What is your favorite color?"),
            ]
        ),
    ]

    assert [*prompt.stream({"question": "What is your name?"})] == [expected]


async def test_prompt_async() -> None:
    prompt = ChatPromptTemplate.from_messages(
        messages=[
            SystemMessage(content="You are a nice assistant."),
            HumanMessagePromptTemplate.from_template("{question}"),
        ]
    )
    expected = ChatPromptValue(
        messages=[
            SystemMessage(content="You are a nice assistant."),
            HumanMessage(content="What is your name?"),
        ]
    )

    assert await prompt.ainvoke({"question": "What is your name?"}) == expected

    assert await prompt.abatch(
        [
            {"question": "What is your name?"},
            {"question": "What is your favorite color?"},
        ]
    ) == [
        expected,
        ChatPromptValue(
            messages=[
                SystemMessage(content="You are a nice assistant."),
                HumanMessage(content="What is your favorite color?"),
            ]
        ),
    ]

    assert [
        part async for part in prompt.astream({"question": "What is your name?"})
    ] == [expected]

    stream_log = [
        part async for part in prompt.astream_log({"question": "What is your name?"})
    ]

    assert len(stream_log[0].ops) == 1
    assert stream_log[0].ops[0]["op"] == "replace"
    assert stream_log[0].ops[0]["path"] == ""
    assert stream_log[0].ops[0]["value"]["logs"] == {}
    assert stream_log[0].ops[0]["value"]["final_output"] is None
    assert stream_log[0].ops[0]["value"]["streamed_output"] == []
    assert isinstance(stream_log[0].ops[0]["value"]["id"], str)

    assert stream_log[1:] == [
        RunLogPatch(
            {"op": "add", "path": "/streamed_output/-", "value": expected},
            {
                "op": "replace",
                "path": "/final_output",
                "value": ChatPromptValue(
                    messages=[
                        SystemMessage(content="You are a nice assistant."),
                        HumanMessage(content="What is your name?"),
                    ]
                ),
            },
        ),
    ]

    stream_log_state = [
        part
        async for part in prompt.astream_log(
            {"question": "What is your name?"}, diff=False
        )
    ]

    # remove random id
    stream_log[0].ops[0]["value"]["id"] = "00000000-0000-0000-0000-000000000000"
    stream_log_state[-1].ops[0]["value"]["id"] = "00000000-0000-0000-0000-000000000000"
    stream_log_state[-1].state["id"] = "00000000-0000-0000-0000-000000000000"

    # assert output with diff=False matches output with diff=True
    assert stream_log_state[-1].ops == [op for chunk in stream_log for op in chunk.ops]
    assert stream_log_state[-1] == RunLog(
        *[op for chunk in stream_log for op in chunk.ops],
        state={
            "final_output": ChatPromptValue(
                messages=[
                    SystemMessage(content="You are a nice assistant."),
                    HumanMessage(content="What is your name?"),
                ]
            ),
            "id": "00000000-0000-0000-0000-000000000000",
            "logs": {},
            "streamed_output": [
                ChatPromptValue(
                    messages=[
                        SystemMessage(content="You are a nice assistant."),
                        HumanMessage(content="What is your name?"),
                    ]
                )
            ],
            "type": "prompt",
            "name": "ChatPromptTemplate",
        },
    )

    # nested inside trace_with_chain_group

    async with atrace_as_chain_group("a_group") as manager:
        stream_log_nested = [
            part
            async for part in prompt.astream_log(
                {"question": "What is your name?"}, config={"callbacks": manager}
            )
        ]

    assert len(stream_log_nested[0].ops) == 1
    assert stream_log_nested[0].ops[0]["op"] == "replace"
    assert stream_log_nested[0].ops[0]["path"] == ""
    assert stream_log_nested[0].ops[0]["value"]["logs"] == {}
    assert stream_log_nested[0].ops[0]["value"]["final_output"] is None
    assert stream_log_nested[0].ops[0]["value"]["streamed_output"] == []
    assert isinstance(stream_log_nested[0].ops[0]["value"]["id"], str)

    assert stream_log_nested[1:] == [
        RunLogPatch(
            {"op": "add", "path": "/streamed_output/-", "value": expected},
            {
                "op": "replace",
                "path": "/final_output",
                "value": ChatPromptValue(
                    messages=[
                        SystemMessage(content="You are a nice assistant."),
                        HumanMessage(content="What is your name?"),
                    ]
                ),
            },
        ),
    ]


def test_prompt_template_params() -> None:
    prompt = ChatPromptTemplate.from_template(
        "Respond to the following question: {question}"
    )
    result = prompt.invoke(
        {
            "question": "test",
            "topic": "test",
        }
    )
    assert result == ChatPromptValue(
        messages=[HumanMessage(content="Respond to the following question: test")]
    )

    with pytest.raises(KeyError):
        prompt.invoke({})


def test_with_listeners(mocker: MockerFixture) -> None:
    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )
    chat = FakeListChatModel(responses=["foo"])

    chain = prompt | chat

    mock_start = mocker.Mock()
    mock_end = mocker.Mock()

    chain.with_listeners(on_start=mock_start, on_end=mock_end).invoke(
        {"question": "Who are you?"}
    )

    assert mock_start.call_count == 1
    assert mock_start.call_args[0][0].name == "RunnableSequence"
    assert mock_end.call_count == 1

    mock_start.reset_mock()
    mock_end.reset_mock()

    with trace_as_chain_group("hello") as manager:
        chain.with_listeners(on_start=mock_start, on_end=mock_end).invoke(
            {"question": "Who are you?"}, {"callbacks": manager}
        )

    assert mock_start.call_count == 1
    assert mock_start.call_args[0][0].name == "RunnableSequence"
    assert mock_end.call_count == 1


async def test_with_listeners_async(mocker: MockerFixture) -> None:
    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )
    chat = FakeListChatModel(responses=["foo"])

    chain = prompt | chat

    mock_start = mocker.Mock()
    mock_end = mocker.Mock()

    await chain.with_listeners(on_start=mock_start, on_end=mock_end).ainvoke(
        {"question": "Who are you?"}
    )

    assert mock_start.call_count == 1
    assert mock_start.call_args[0][0].name == "RunnableSequence"
    assert mock_end.call_count == 1

    mock_start.reset_mock()
    mock_end.reset_mock()

    async with atrace_as_chain_group("hello") as manager:
        await chain.with_listeners(on_start=mock_start, on_end=mock_end).ainvoke(
            {"question": "Who are you?"}, {"callbacks": manager}
        )

    assert mock_start.call_count == 1
    assert mock_start.call_args[0][0].name == "RunnableSequence"
    assert mock_end.call_count == 1


def test_with_listener_propagation(mocker: MockerFixture) -> None:
    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )
    chat = FakeListChatModel(responses=["foo"])
    chain: Runnable = prompt | chat
    mock_start = mocker.Mock()
    mock_end = mocker.Mock()
    chain_with_listeners = chain.with_listeners(on_start=mock_start, on_end=mock_end)

    chain_with_listeners.with_retry().invoke({"question": "Who are you?"})

    assert mock_start.call_count == 1
    assert mock_start.call_args[0][0].name == "RunnableSequence"
    assert mock_end.call_count == 1

    mock_start.reset_mock()
    mock_end.reset_mock()

    chain_with_listeners.with_types(output_type=str).invoke(
        {"question": "Who are you?"}
    )

    assert mock_start.call_count == 1
    assert mock_start.call_args[0][0].name == "RunnableSequence"
    assert mock_end.call_count == 1

    mock_start.reset_mock()
    mock_end.reset_mock()

    chain_with_listeners.with_config({"tags": ["foo"]}).invoke(
        {"question": "Who are you?"}
    )

    assert mock_start.call_count == 1
    assert mock_start.call_args[0][0].name == "RunnableSequence"
    assert mock_end.call_count == 1

    mock_start.reset_mock()
    mock_end.reset_mock()

    chain_with_listeners.bind(stop=["foo"]).invoke({"question": "Who are you?"})

    assert mock_start.call_count == 1
    assert mock_start.call_args[0][0].name == "RunnableSequence"
    assert mock_end.call_count == 1

    mock_start.reset_mock()
    mock_end.reset_mock()

    mock_start_inner = mocker.Mock()
    mock_end_inner = mocker.Mock()

    chain_with_listeners.with_listeners(
        on_start=mock_start_inner, on_end=mock_end_inner
    ).invoke({"question": "Who are you?"})

    assert mock_start.call_count == 1
    assert mock_start.call_args[0][0].name == "RunnableSequence"
    assert mock_end.call_count == 1
    assert mock_start_inner.call_count == 1
    assert mock_start_inner.call_args[0][0].name == "RunnableSequence"
    assert mock_end_inner.call_count == 1


@freeze_time("2023-01-01")
@pytest.mark.usefixtures("deterministic_uuids")
def test_prompt_with_chat_model(
    mocker: MockerFixture,
    snapshot: SnapshotAssertion,
) -> None:
    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )
    chat = FakeListChatModel(responses=["foo"])

    chain = prompt | chat

    assert repr(chain) == snapshot
    assert isinstance(chain, RunnableSequence)
    assert chain.first == prompt
    assert chain.middle == []
    assert chain.last == chat
    assert dumps(chain, pretty=True) == snapshot

    # Test invoke
    prompt_spy = mocker.spy(prompt.__class__, "invoke")
    chat_spy = mocker.spy(chat.__class__, "invoke")
    tracer = FakeTracer()
    assert chain.invoke(
        {"question": "What is your name?"}, {"callbacks": [tracer]}
    ) == _any_id_ai_message(content="foo")
    assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
    assert chat_spy.call_args.args[1] == ChatPromptValue(
        messages=[
            SystemMessage(content="You are a nice assistant."),
            HumanMessage(content="What is your name?"),
        ]
    )

    assert tracer.runs == snapshot

    mocker.stop(prompt_spy)
    mocker.stop(chat_spy)

    # Test batch
    prompt_spy = mocker.spy(prompt.__class__, "batch")
    chat_spy = mocker.spy(chat.__class__, "batch")
    tracer = FakeTracer()
    assert chain.batch(
        [
            {"question": "What is your name?"},
            {"question": "What is your favorite color?"},
        ],
        {"callbacks": [tracer]},
    ) == [
        _any_id_ai_message(content="foo"),
        _any_id_ai_message(content="foo"),
    ]
    assert prompt_spy.call_args.args[1] == [
        {"question": "What is your name?"},
        {"question": "What is your favorite color?"},
    ]
    assert chat_spy.call_args.args[1] == [
        ChatPromptValue(
            messages=[
                SystemMessage(content="You are a nice assistant."),
                HumanMessage(content="What is your name?"),
            ]
        ),
        ChatPromptValue(
            messages=[
                SystemMessage(content="You are a nice assistant."),
                HumanMessage(content="What is your favorite color?"),
            ]
        ),
    ]
    assert (
        len(
            [
                r
                for r in tracer.runs
                if r.parent_run_id is None and len(r.child_runs) == 2
            ]
        )
        == 2
    ), "Each of 2 outer runs contains exactly two inner runs (1 prompt, 1 chat)"
    mocker.stop(prompt_spy)
    mocker.stop(chat_spy)

    # Test stream
    prompt_spy = mocker.spy(prompt.__class__, "invoke")
    chat_spy = mocker.spy(chat.__class__, "stream")
    tracer = FakeTracer()
    assert [
        *chain.stream({"question": "What is your name?"}, {"callbacks": [tracer]})
    ] == [
        _any_id_ai_message_chunk(content="f"),
        _any_id_ai_message_chunk(content="o"),
        _any_id_ai_message_chunk(content="o", chunk_position="last"),
    ]
    assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
    assert chat_spy.call_args.args[1] == ChatPromptValue(
        messages=[
            SystemMessage(content="You are a nice assistant."),
            HumanMessage(content="What is your name?"),
        ]
    )


@freeze_time("2023-01-01")
@pytest.mark.usefixtures("deterministic_uuids")
async def test_prompt_with_chat_model_async(
    mocker: MockerFixture,
    snapshot: SnapshotAssertion,
) -> None:
    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )
    chat = FakeListChatModel(responses=["foo"])

    chain = prompt | chat

    assert repr(chain) == snapshot
    assert isinstance(chain, RunnableSequence)
    assert chain.first == prompt
    assert chain.middle == []
    assert chain.last == chat
    assert dumps(chain, pretty=True) == snapshot

    # Test invoke
    prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
    chat_spy = mocker.spy(chat.__class__, "ainvoke")
    tracer = FakeTracer()
    assert await chain.ainvoke(
        {"question": "What is your name?"}, {"callbacks": [tracer]}
    ) == _any_id_ai_message(content="foo")
    assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
    assert chat_spy.call_args.args[1] == ChatPromptValue(
        messages=[
            SystemMessage(content="You are a nice assistant."),
            HumanMessage(content="What is your name?"),
        ]
    )

    assert tracer.runs == snapshot

    mocker.stop(prompt_spy)
    mocker.stop(chat_spy)

    # Test batch
    prompt_spy = mocker.spy(prompt.__class__, "abatch")
    chat_spy = mocker.spy(chat.__class__, "abatch")
    tracer = FakeTracer()
    assert await chain.abatch(
        [
            {"question": "What is your name?"},
            {"question": "What is your favorite color?"},
        ],
        {"callbacks": [tracer]},
    ) == [
        _any_id_ai_message(content="foo"),
        _any_id_ai_message(content="foo"),
    ]
    assert prompt_spy.call_args.args[1] == [
        {"question": "What is your name?"},
        {"question": "What is your favorite color?"},
    ]
    assert chat_spy.call_args.args[1] == [
        ChatPromptValue(
            messages=[
                SystemMessage(content="You are a nice assistant."),
                HumanMessage(content="What is your name?"),
            ]
        ),
        ChatPromptValue(
            messages=[
                SystemMessage(content="You are a nice assistant."),
                HumanMessage(content="What is your favorite color?"),
            ]
        ),
    ]
    assert (
        len(
            [
                r
                for r in tracer.runs
                if r.parent_run_id is None and len(r.child_runs) == 2
            ]
        )
        == 2
    ), "Each of 2 outer runs contains exactly two inner runs (1 prompt, 1 chat)"
    mocker.stop(prompt_spy)
    mocker.stop(chat_spy)

    # Test stream
    prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
    chat_spy = mocker.spy(chat.__class__, "astream")
    tracer = FakeTracer()
    assert [
        a
        async for a in chain.astream(
            {"question": "What is your name?"}, {"callbacks": [tracer]}
        )
    ] == [
        _any_id_ai_message_chunk(content="f"),
        _any_id_ai_message_chunk(content="o"),
        _any_id_ai_message_chunk(content="o", chunk_position="last"),
    ]
    assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
    assert chat_spy.call_args.args[1] == ChatPromptValue(
        messages=[
            SystemMessage(content="You are a nice assistant."),
            HumanMessage(content="What is your name?"),
        ]
    )


@pytest.mark.skipif(
    condition=sys.version_info[1] == 13,
    reason=(
        "temporary, py3.13 exposes some invalid assumptions about order of batch async "
        "executions."
    ),
)
@freeze_time("2023-01-01")
async def test_prompt_with_llm(
    mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )
    llm = FakeListLLM(responses=["foo", "bar"])

    chain = prompt | llm

    assert isinstance(chain, RunnableSequence)
    assert chain.first == prompt
    assert chain.middle == []
    assert chain.last == llm
    assert dumps(chain, pretty=True) == snapshot

    # Test invoke
    prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
    llm_spy = mocker.spy(llm.__class__, "ainvoke")
    tracer = FakeTracer()
    assert (
        await chain.ainvoke({"question": "What is your name?"}, {"callbacks": [tracer]})
        == "foo"
    )
    assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
    assert llm_spy.call_args.args[1] == ChatPromptValue(
        messages=[
            SystemMessage(content="You are a nice assistant."),
            HumanMessage(content="What is your name?"),
        ]
    )
    assert tracer.runs == snapshot
    mocker.stop(prompt_spy)
    mocker.stop(llm_spy)

    # Test batch
    prompt_spy = mocker.spy(prompt.__class__, "abatch")
    llm_spy = mocker.spy(llm.__class__, "abatch")
    tracer = FakeTracer()
    assert await chain.abatch(
        [
            {"question": "What is your name?"},
            {"question": "What is your favorite color?"},
        ],
        {"callbacks": [tracer]},
    ) == ["bar", "foo"]
    assert prompt_spy.call_args.args[1] == [
        {"question": "What is your name?"},
        {"question": "What is your favorite color?"},
    ]
    assert llm_spy.call_args.args[1] == [
        ChatPromptValue(
            messages=[
                SystemMessage(content="You are a nice assistant."),
                HumanMessage(content="What is your name?"),
            ]
        ),
        ChatPromptValue(
            messages=[
                SystemMessage(content="You are a nice assistant."),
                HumanMessage(content="What is your favorite color?"),
            ]
        ),
    ]
    assert tracer.runs == snapshot
    mocker.stop(prompt_spy)
    mocker.stop(llm_spy)

    # Test stream
    prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
    llm_spy = mocker.spy(llm.__class__, "astream")
    tracer = FakeTracer()
    assert [
        token
        async for token in chain.astream(
            {"question": "What is your name?"}, {"callbacks": [tracer]}
        )
    ] == ["bar"]
    assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
    assert llm_spy.call_args.args[1] == ChatPromptValue(
        messages=[
            SystemMessage(content="You are a nice assistant."),
            HumanMessage(content="What is your name?"),
        ]
    )

    prompt_spy.reset_mock()
    llm_spy.reset_mock()
    stream_log = [
        part async for part in chain.astream_log({"question": "What is your name?"})
    ]

    # Remove IDs from logs
    for part in stream_log:
        for op in part.ops:
            if (
                isinstance(op["value"], dict)
                and "id" in op["value"]
                and not isinstance(op["value"]["id"], list)  # serialized lc id
            ):
                del op["value"]["id"]

    expected = [
        RunLogPatch(
            {
                "op": "replace",
                "path": "",
                "value": {
                    "logs": {},
                    "final_output": None,
                    "streamed_output": [],
                    "name": "RunnableSequence",
                    "type": "chain",
                },
            }
        ),
        RunLogPatch(
            {
                "op": "add",
                "path": "/logs/ChatPromptTemplate",
                "value": {
                    "end_time": None,
                    "final_output": None,
                    "metadata": {},
                    "name": "ChatPromptTemplate",
                    "start_time": "2023-01-01T00:00:00.000+00:00",
                    "streamed_output": [],
                    "streamed_output_str": [],
                    "tags": ["seq:step:1"],
                    "type": "prompt",
                },
            }
        ),
        RunLogPatch(
            {
                "op": "add",
                "path": "/logs/ChatPromptTemplate/final_output",
                "value": ChatPromptValue(
                    messages=[
                        SystemMessage(content="You are a nice assistant."),
                        HumanMessage(content="What is your name?"),
                    ]
                ),
            },
            {
                "op": "add",
                "path": "/logs/ChatPromptTemplate/end_time",
                "value": "2023-01-01T00:00:00.000+00:00",
            },
        ),
        RunLogPatch(
            {
                "op": "add",
                "path": "/logs/FakeListLLM",
                "value": {
                    "end_time": None,
                    "final_output": None,
                    "metadata": {"ls_model_type": "llm", "ls_provider": "fakelist"},
                    "name": "FakeListLLM",
                    "start_time": "2023-01-01T00:00:00.000+00:00",
                    "streamed_output": [],
                    "streamed_output_str": [],
                    "tags": ["seq:step:2"],
                    "type": "llm",
                },
            }
        ),
        RunLogPatch(
            {
                "op": "add",
                "path": "/logs/FakeListLLM/final_output",
                "value": {
                    "generations": [
                        [{"generation_info": None, "text": "foo", "type": "Generation"}]
                    ],
                    "llm_output": None,
                    "run": None,
                    "type": "LLMResult",
                },
            },
            {
                "op": "add",
                "path": "/logs/FakeListLLM/end_time",
                "value": "2023-01-01T00:00:00.000+00:00",
            },
        ),
        RunLogPatch(
            {"op": "add", "path": "/streamed_output/-", "value": "foo"},
            {"op": "replace", "path": "/final_output", "value": "foo"},
        ),
    ]
    assert stream_log == expected


@freeze_time("2023-01-01")
async def test_prompt_with_llm_parser(
    mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )
    llm = FakeStreamingListLLM(responses=["bear, dog, cat", "tomato, lettuce, onion"])
    parser = CommaSeparatedListOutputParser()

    chain = prompt | llm | parser

    assert isinstance(chain, RunnableSequence)
    assert chain.first == prompt
    assert chain.middle == [llm]
    assert chain.last == parser
    assert dumps(chain, pretty=True) == snapshot

    # Test invoke
    prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
    llm_spy = mocker.spy(llm.__class__, "ainvoke")
    parser_spy = mocker.spy(parser.__class__, "ainvoke")
    tracer = FakeTracer()
    assert await chain.ainvoke(
        {"question": "What is your name?"}, {"callbacks": [tracer]}
    ) == ["bear", "dog", "cat"]
    assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
    assert llm_spy.call_args.args[1] == ChatPromptValue(
        messages=[
            SystemMessage(content="You are a nice assistant."),
            HumanMessage(content="What is your name?"),
        ]
    )
    assert parser_spy.call_args.args[1] == "bear, dog, cat"
    assert tracer.runs == snapshot
    mocker.stop(prompt_spy)
    mocker.stop(llm_spy)
    mocker.stop(parser_spy)

    # Test batch
    prompt_spy = mocker.spy(prompt.__class__, "abatch")
    llm_spy = mocker.spy(llm.__class__, "abatch")
    parser_spy = mocker.spy(parser.__class__, "abatch")
    tracer = FakeTracer()
    assert await chain.abatch(
        [
            {"question": "What is your name?"},
            {"question": "What is your favorite color?"},
        ],
        {"callbacks": [tracer]},
    ) == [["tomato", "lettuce", "onion"], ["bear", "dog", "cat"]]
    assert prompt_spy.call_args.args[1] == [
        {"question": "What is your name?"},
        {"question": "What is your favorite color?"},
    ]
    assert llm_spy.call_args.args[1] == [
        ChatPromptValue(
            messages=[
                SystemMessage(content="You are a nice assistant."),
                HumanMessage(content="What is your name?"),
            ]
        ),
        ChatPromptValue(
            messages=[
                SystemMessage(content="You are a nice assistant."),
                HumanMessage(content="What is your favorite color?"),
            ]
        ),
    ]
    assert parser_spy.call_args.args[1] == [
        "tomato, lettuce, onion",
        "bear, dog, cat",
    ]
    assert len(tracer.runs) == 2
    assert all(
        run.name == "RunnableSequence"
        and run.run_type == "chain"
        and len(run.child_runs) == 3
        for run in tracer.runs
    )
    mocker.stop(prompt_spy)
    mocker.stop(llm_spy)
    mocker.stop(parser_spy)

    # Test stream
    prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
    llm_spy = mocker.spy(llm.__class__, "astream")
    tracer = FakeTracer()
    assert [
        token
        async for token in chain.astream(
            {"question": "What is your name?"}, {"callbacks": [tracer]}
        )
    ] == [["tomato"], ["lettuce"], ["onion"]]
    assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
    assert llm_spy.call_args.args[1] == ChatPromptValue(
        messages=[
            SystemMessage(content="You are a nice assistant."),
            HumanMessage(content="What is your name?"),
        ]
    )

    prompt_spy.reset_mock()
    llm_spy.reset_mock()
    stream_log = [
        part async for part in chain.astream_log({"question": "What is your name?"})
    ]

    # Remove IDs from logs
    for part in stream_log:
        for op in part.ops:
            if (
                isinstance(op["value"], dict)
                and "id" in op["value"]
                and not isinstance(op["value"]["id"], list)  # serialized lc id
            ):
                del op["value"]["id"]

    expected = [
        RunLogPatch(
            {
                "op": "replace",
                "path": "",
                "value": {
                    "logs": {},
                    "final_output": None,
                    "streamed_output": [],
                    "name": "RunnableSequence",
                    "type": "chain",
                },
            }
        ),
        RunLogPatch(
            {
                "op": "add",
                "path": "/logs/ChatPromptTemplate",
                "value": {
                    "end_time": None,
                    "final_output": None,
                    "metadata": {},
                    "name": "ChatPromptTemplate",
                    "start_time": "2023-01-01T00:00:00.000+00:00",
                    "streamed_output": [],
                    "streamed_output_str": [],
                    "tags": ["seq:step:1"],
                    "type": "prompt",
                },
            }
        ),
        RunLogPatch(
            {
                "op": "add",
                "path": "/logs/ChatPromptTemplate/final_output",
                "value": ChatPromptValue(
                    messages=[
                        SystemMessage(content="You are a nice assistant."),
                        HumanMessage(content="What is your name?"),
                    ]
                ),
            },
            {
                "op": "add",
                "path": "/logs/ChatPromptTemplate/end_time",
                "value": "2023-01-01T00:00:00.000+00:00",
            },
        ),
        RunLogPatch(
            {
                "op": "add",
                "path": "/logs/FakeStreamingListLLM",
                "value": {
                    "end_time": None,
                    "final_output": None,
                    "metadata": {
                        "ls_model_type": "llm",
                        "ls_provider": "fakestreaminglist",
                    },
                    "name": "FakeStreamingListLLM",
                    "start_time": "2023-01-01T00:00:00.000+00:00",
                    "streamed_output": [],
                    "streamed_output_str": [],
                    "tags": ["seq:step:2"],
                    "type": "llm",
                },
            }
        ),
        RunLogPatch(
            {
                "op": "add",
                "path": "/logs/FakeStreamingListLLM/final_output",
                "value": {
                    "generations": [
                        [
                            {
                                "generation_info": None,
                                "text": "bear, dog, cat",
                                "type": "Generation",
                            }
                        ]
                    ],
                    "llm_output": None,
                    "run": None,
                    "type": "LLMResult",
                },
            },
            {
                "op": "add",
                "path": "/logs/FakeStreamingListLLM/end_time",
                "value": "2023-01-01T00:00:00.000+00:00",
            },
        ),
        RunLogPatch(
            {
                "op": "add",
                "path": "/logs/CommaSeparatedListOutputParser",
                "value": {
                    "end_time": None,
                    "final_output": None,
                    "metadata": {},
                    "name": "CommaSeparatedListOutputParser",
                    "start_time": "2023-01-01T00:00:00.000+00:00",
                    "streamed_output": [],
                    "streamed_output_str": [],
                    "tags": ["seq:step:3"],
                    "type": "parser",
                },
            }
        ),
        RunLogPatch(
            {
                "op": "add",
                "path": "/logs/CommaSeparatedListOutputParser/streamed_output/-",
                "value": ["bear"],
            }
        ),
        RunLogPatch(
            {"op": "add", "path": "/streamed_output/-", "value": ["bear"]},
            {"op": "replace", "path": "/final_output", "value": ["bear"]},
        ),
        RunLogPatch(
            {
                "op": "add",
                "path": "/logs/CommaSeparatedListOutputParser/streamed_output/-",
                "value": ["dog"],
            }
        ),
        RunLogPatch(
            {"op": "add", "path": "/streamed_output/-", "value": ["dog"]},
            {"op": "add", "path": "/final_output/1", "value": "dog"},
        ),
        RunLogPatch(
            {
                "op": "add",
                "path": "/logs/CommaSeparatedListOutputParser/streamed_output/-",
                "value": ["cat"],
            }
        ),
        RunLogPatch(
            {"op": "add", "path": "/streamed_output/-", "value": ["cat"]},
            {"op": "add", "path": "/final_output/2", "value": "cat"},
        ),
        RunLogPatch(
            {
                "op": "add",
                "path": "/logs/CommaSeparatedListOutputParser/final_output",
                "value": {"output": ["bear", "dog", "cat"]},
            },
            {
                "op": "add",
                "path": "/logs/CommaSeparatedListOutputParser/end_time",
                "value": "2023-01-01T00:00:00.000+00:00",
            },
        ),
    ]
    assert stream_log == expected


@freeze_time("2023-01-01")
async def test_stream_log_retriever() -> None:
    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{documents}"
        + "{question}"
    )
    llm = FakeListLLM(responses=["foo", "bar"])

    chain: Runnable = (
        {"documents": FakeRetriever(), "question": itemgetter("question")}
        | prompt
        | {"one": llm, "two": llm}
    )

    stream_log = [
        part async for part in chain.astream_log({"question": "What is your name?"})
    ]

    # Remove IDs from logs
    for part in stream_log:
        for op in part.ops:
            if (
                isinstance(op["value"], dict)
                and "id" in op["value"]
                and not isinstance(op["value"]["id"], list)  # serialized lc id
            ):
                del op["value"]["id"]

    assert sorted(cast("RunLog", add(stream_log)).state["logs"]) == [
        "ChatPromptTemplate",
        "FakeListLLM",
        "FakeListLLM:2",
        "FakeRetriever",
        "RunnableLambda",
        "RunnableParallel<documents,question>",
        "RunnableParallel<one,two>",
    ]


@freeze_time("2023-01-01")
async def test_stream_log_lists() -> None:
    async def list_producer(_: AsyncIterator[Any]) -> AsyncIterator[AddableDict]:
        for i in range(4):
            yield AddableDict(alist=[str(i)])

    chain = RunnableGenerator(list_producer)

    stream_log = [
        part async for part in chain.astream_log({"question": "What is your name?"})
    ]

    # Remove IDs from logs
    for part in stream_log:
        for op in part.ops:
            if (
                isinstance(op["value"], dict)
                and "id" in op["value"]
                and not isinstance(op["value"]["id"], list)  # serialized lc id
            ):
                del op["value"]["id"]

    assert stream_log == [
        RunLogPatch(
            {
                "op": "replace",
                "path": "",
                "value": {
                    "final_output": None,
                    "logs": {},
                    "streamed_output": [],
                    "name": "list_producer",
                    "type": "chain",
                },
            }
        ),
        RunLogPatch(
            {"op": "add", "path": "/streamed_output/-", "value": {"alist": ["0"]}},
            {"op": "replace", "path": "/final_output", "value": {"alist": ["0"]}},
        ),
        RunLogPatch(
            {"op": "add", "path": "/streamed_output/-", "value": {"alist": ["1"]}},
            {"op": "add", "path": "/final_output/alist/1", "value": "1"},
        ),
        RunLogPatch(
            {"op": "add", "path": "/streamed_output/-", "value": {"alist": ["2"]}},
            {"op": "add", "path": "/final_output/alist/2", "value": "2"},
        ),
        RunLogPatch(
            {"op": "add", "path": "/streamed_output/-", "value": {"alist": ["3"]}},
            {"op": "add", "path": "/final_output/alist/3", "value": "3"},
        ),
    ]

    state = add(stream_log)

    assert isinstance(state, RunLog)

    assert state.state == {
        "final_output": {"alist": ["0", "1", "2", "3"]},
        "logs": {},
        "name": "list_producer",
        "streamed_output": [
            {"alist": ["0"]},
            {"alist": ["1"]},
            {"alist": ["2"]},
            {"alist": ["3"]},
        ],
        "type": "chain",
    }


@freeze_time("2023-01-01")
async def test_prompt_with_llm_and_async_lambda(
    mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )
    llm = FakeListLLM(responses=["foo", "bar"])

    async def passthrough(value: Any) -> Any:
        return value

    chain = prompt | llm | passthrough

    assert isinstance(chain, RunnableSequence)
    assert chain.first == prompt
    assert chain.middle == [llm]
    assert chain.last == RunnableLambda(func=passthrough)
    assert dumps(chain, pretty=True) == snapshot

    # Test invoke
    prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
    llm_spy = mocker.spy(llm.__class__, "ainvoke")
    tracer = FakeTracer()
    assert (
        await chain.ainvoke({"question": "What is your name?"}, {"callbacks": [tracer]})
        == "foo"
    )
    assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
    assert llm_spy.call_args.args[1] == ChatPromptValue(
        messages=[
            SystemMessage(content="You are a nice assistant."),
            HumanMessage(content="What is your name?"),
        ]
    )
    assert tracer.runs == snapshot
    mocker.stop(prompt_spy)
    mocker.stop(llm_spy)


@freeze_time("2023-01-01")
@pytest.mark.usefixtures("deterministic_uuids")
def test_prompt_with_chat_model_and_parser(
    mocker: MockerFixture,
    snapshot: SnapshotAssertion,
) -> None:
    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )
    chat = FakeListChatModel(responses=["foo, bar"])
    parser = CommaSeparatedListOutputParser()

    chain = prompt | chat | parser

    assert isinstance(chain, RunnableSequence)
    assert chain.first == prompt
    assert chain.middle == [chat]
    assert chain.last == parser
    assert dumps(chain, pretty=True) == snapshot

    # Test invoke
    prompt_spy = mocker.spy(prompt.__class__, "invoke")
    chat_spy = mocker.spy(chat.__class__, "invoke")
    parser_spy = mocker.spy(parser.__class__, "invoke")
    tracer = FakeTracer()
    assert chain.invoke(
        {"question": "What is your name?"}, {"callbacks": [tracer]}
    ) == ["foo", "bar"]
    assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
    assert chat_spy.call_args.args[1] == ChatPromptValue(
        messages=[
            SystemMessage(content="You are a nice assistant."),
            HumanMessage(content="What is your name?"),
        ]
    )
    assert parser_spy.call_args.args[1] == _any_id_ai_message(content="foo, bar")

    assert tracer.runs == snapshot


@freeze_time("2023-01-01")
@pytest.mark.usefixtures("deterministic_uuids")
def test_combining_sequences(
    snapshot: SnapshotAssertion,
) -> None:
    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )
    chat = FakeListChatModel(responses=["foo, bar"])
    parser = CommaSeparatedListOutputParser()

    chain = prompt | chat | parser

    assert isinstance(chain, RunnableSequence)
    assert chain.first == prompt
    assert chain.middle == [chat]
    assert chain.last == parser
    assert dumps(chain, pretty=True) == snapshot

    prompt2 = (
        SystemMessagePromptTemplate.from_template("You are a nicer assistant.")
        + "{question}"
    )
    chat2 = FakeListChatModel(responses=["baz, qux"])
    parser2 = CommaSeparatedListOutputParser()
    input_formatter = RunnableLambda[list[str], dict[str, Any]](
        lambda x: {"question": x[0] + x[1]}
    )

    chain2 = input_formatter | prompt2 | chat2 | parser2

    assert isinstance(chain2, RunnableSequence)
    assert chain2.first == input_formatter
    assert chain2.middle == [prompt2, chat2]
    assert chain2.last == parser2
    assert dumps(chain2, pretty=True) == snapshot

    combined_chain = chain | chain2

    assert isinstance(combined_chain, RunnableSequence)
    assert combined_chain.first == prompt
    assert combined_chain.middle == [
        chat,
        parser,
        input_formatter,
        prompt2,
        chat2,
    ]
    assert combined_chain.last == parser2
    assert dumps(combined_chain, pretty=True) == snapshot

    # Test invoke
    tracer = FakeTracer()
    assert combined_chain.invoke(
        {"question": "What is your name?"}, {"callbacks": [tracer]}
    ) == ["baz", "qux"]

    assert tracer.runs == snapshot


@freeze_time("2023-01-01")
def test_seq_dict_prompt_llm(
    mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
    passthrough = mocker.Mock(side_effect=lambda x: x)

    retriever = FakeRetriever()

    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + """Context:
{documents}

Question:
{question}"""
    )

    chat = FakeListChatModel(responses=["foo, bar"])

    parser = CommaSeparatedListOutputParser()

    chain: Runnable = (
        {
            "question": RunnablePassthrough[str]() | passthrough,
            "documents": passthrough | retriever,
            "just_to_test_lambda": passthrough,
        }
        | prompt
        | chat
        | parser
    )

    assert repr(chain) == snapshot
    assert isinstance(chain, RunnableSequence)
    assert isinstance(chain.first, RunnableParallel)
    assert chain.middle == [prompt, chat]
    assert chain.last == parser
    assert dumps(chain, pretty=True) == snapshot

    # Test invoke
    prompt_spy = mocker.spy(prompt.__class__, "invoke")
    chat_spy = mocker.spy(chat.__class__, "invoke")
    parser_spy = mocker.spy(parser.__class__, "invoke")
    tracer = FakeTracer()
    assert chain.invoke("What is your name?", {"callbacks": [tracer]}) == [
        "foo",
        "bar",
    ]
    assert prompt_spy.call_args.args[1] == {
        "documents": [Document(page_content="foo"), Document(page_content="bar")],
        "question": "What is your name?",
        "just_to_test_lambda": "What is your name?",
    }
    assert chat_spy.call_args.args[1] == ChatPromptValue(
        messages=[
            SystemMessage(
                content="You are a nice assistant.",
                additional_kwargs={},
                response_metadata={},
            ),
            HumanMessage(
                content="Context:\n"
                "[Document(metadata={}, page_content='foo'), "
                "Document(metadata={}, page_content='bar')]\n"
                "\n"
                "Question:\n"
                "What is your name?",
                additional_kwargs={},
                response_metadata={},
            ),
        ]
    )
    assert parser_spy.call_args.args[1] == _any_id_ai_message(content="foo, bar")
    assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
    parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
    assert len(parent_run.child_runs) == 4
    map_run = parent_run.child_runs[0]
    assert map_run.name == "RunnableParallel<question,documents,just_to_test_lambda>"
    assert len(map_run.child_runs) == 3


@freeze_time("2023-01-01")
def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
    passthrough = mocker.Mock(side_effect=lambda x: x)

    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )

    chat = FakeListChatModel(responses=["i'm a chatbot"])

    llm = FakeListLLM(responses=["i'm a textbot"])

    chain = (
        prompt
        | passthrough
        | {
            "chat": chat,
            "llm": llm,
        }
    )

    assert repr(chain) == snapshot
    assert isinstance(chain, RunnableSequence)
    assert chain.first == prompt
    assert chain.middle == [RunnableLambda(passthrough)]
    assert isinstance(chain.last, RunnableParallel)
    assert dumps(chain, pretty=True) == snapshot

    # Test invoke
    prompt_spy = mocker.spy(prompt.__class__, "invoke")
    chat_spy = mocker.spy(chat.__class__, "invoke")
    llm_spy = mocker.spy(llm.__class__, "invoke")
    tracer = FakeTracer()
    assert chain.invoke(
        {"question": "What is your name?"}, {"callbacks": [tracer]}
    ) == {
        "chat": _any_id_ai_message(content="i'm a chatbot"),
        "llm": "i'm a textbot",
    }
    assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
    assert chat_spy.call_args.args[1] == ChatPromptValue(
        messages=[
            SystemMessage(content="You are a nice assistant."),
            HumanMessage(content="What is your name?"),
        ]
    )
    assert llm_spy.call_args.args[1] == ChatPromptValue(
        messages=[
            SystemMessage(content="You are a nice assistant."),
            HumanMessage(content="What is your name?"),
        ]
    )
    assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
    parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
    assert len(parent_run.child_runs) == 3
    map_run = parent_run.child_runs[2]
    assert map_run.name == "RunnableParallel<chat,llm>"
    assert len(map_run.child_runs) == 2


@freeze_time("2023-01-01")
def test_router_runnable(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
    chain1 = ChatPromptTemplate.from_template(
        "You are a math genius. Answer the question: {question}"
    ) | FakeListLLM(responses=["4"])
    chain2 = ChatPromptTemplate.from_template(
        "You are an english major. Answer the question: {question}"
    ) | FakeListLLM(responses=["2"])
    router = RouterRunnable({"math": chain1, "english": chain2})
    chain: Runnable = {
        "key": lambda x: x["key"],
        "input": {"question": lambda x: x["question"]},
    } | router
    assert dumps(chain, pretty=True) == snapshot

    result = chain.invoke({"key": "math", "question": "2 + 2"})
    assert result == "4"

    result2 = chain.batch(
        [
            {"key": "math", "question": "2 + 2"},
            {"key": "english", "question": "2 + 2"},
        ]
    )
    assert result2 == ["4", "2"]

    # Test invoke
    router_spy = mocker.spy(router.__class__, "invoke")
    tracer = FakeTracer()
    assert (
        chain.invoke({"key": "math", "question": "2 + 2"}, {"callbacks": [tracer]})
        == "4"
    )
    assert router_spy.call_args.args[1] == {
        "key": "math",
        "input": {"question": "2 + 2"},
    }
    assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
    parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
    assert len(parent_run.child_runs) == 2
    router_run = parent_run.child_runs[1]
    assert router_run.name == "RunnableSequence"  # TODO: should be RunnableRouter
    assert len(router_run.child_runs) == 2


async def test_router_runnable_async() -> None:
    chain1 = ChatPromptTemplate.from_template(
        "You are a math genius. Answer the question: {question}"
    ) | FakeListLLM(responses=["4"])
    chain2 = ChatPromptTemplate.from_template(
        "You are an english major. Answer the question: {question}"
    ) | FakeListLLM(responses=["2"])
    router = RouterRunnable({"math": chain1, "english": chain2})
    chain: Runnable = {
        "key": lambda x: x["key"],
        "input": {"question": lambda x: x["question"]},
    } | router

    result = await chain.ainvoke({"key": "math", "question": "2 + 2"})
    assert result == "4"

    result2 = await chain.abatch(
        [
            {"key": "math", "question": "2 + 2"},
            {"key": "english", "question": "2 + 2"},
        ]
    )
    assert result2 == ["4", "2"]


@freeze_time("2023-01-01")
def test_higher_order_lambda_runnable(
    mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
    math_chain = ChatPromptTemplate.from_template(
        "You are a math genius. Answer the question: {question}"
    ) | FakeListLLM(responses=["4"])
    english_chain = ChatPromptTemplate.from_template(
        "You are an english major. Answer the question: {question}"
    ) | FakeListLLM(responses=["2"])
    input_map = RunnableParallel(
        key=lambda x: x["key"],
        input={"question": lambda x: x["question"]},
    )

    def router(params: dict[str, Any]) -> Runnable:
        if params["key"] == "math":
            return itemgetter("input") | math_chain
        if params["key"] == "english":
            return itemgetter("input") | english_chain
        msg = f"Unknown key: {params['key']}"
        raise ValueError(msg)

    chain: Runnable = input_map | router
    assert dumps(chain, pretty=True) == snapshot

    result = chain.invoke({"key": "math", "question": "2 + 2"})
    assert result == "4"

    result2 = chain.batch(
        [
            {"key": "math", "question": "2 + 2"},
            {"key": "english", "question": "2 + 2"},
        ]
    )
    assert result2 == ["4", "2"]

    # Test invoke
    math_spy = mocker.spy(math_chain.__class__, "invoke")
    tracer = FakeTracer()
    assert (
        chain.invoke({"key": "math", "question": "2 + 2"}, {"callbacks": [tracer]})
        == "4"
    )
    assert math_spy.call_args.args[1] == {
        "key": "math",
        "input": {"question": "2 + 2"},
    }
    assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
    parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
    assert len(parent_run.child_runs) == 2
    router_run = parent_run.child_runs[1]
    assert router_run.name == "router"
    assert len(router_run.child_runs) == 1
    math_run = router_run.child_runs[0]
    assert math_run.name == "RunnableSequence"
    assert len(math_run.child_runs) == 3


async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None:
    math_chain = ChatPromptTemplate.from_template(
        "You are a math genius. Answer the question: {question}"
    ) | FakeListLLM(responses=["4"])
    english_chain = ChatPromptTemplate.from_template(
        "You are an english major. Answer the question: {question}"
    ) | FakeListLLM(responses=["2"])
    input_map = RunnableParallel(
        key=lambda x: x["key"],
        input={"question": lambda x: x["question"]},
    )

    def router(value: dict[str, Any]) -> Runnable:
        if value["key"] == "math":
            return itemgetter("input") | math_chain
        if value["key"] == "english":
            return itemgetter("input") | english_chain
        msg = f"Unknown key: {value['key']}"
        raise ValueError(msg)

    chain: Runnable = input_map | router

    result = await chain.ainvoke({"key": "math", "question": "2 + 2"})
    assert result == "4"

    result2 = await chain.abatch(
        [
            {"key": "math", "question": "2 + 2"},
            {"key": "english", "question": "2 + 2"},
        ]
    )
    assert result2 == ["4", "2"]

    # Test ainvoke
    async def arouter(params: dict[str, Any]) -> Runnable:
        if params["key"] == "math":
            return itemgetter("input") | math_chain
        if params["key"] == "english":
            return itemgetter("input") | english_chain
        msg = f"Unknown key: {params['key']}"
        raise ValueError(msg)

    achain: Runnable = input_map | arouter
    math_spy = mocker.spy(math_chain.__class__, "ainvoke")
    tracer = FakeTracer()
    assert (
        await achain.ainvoke(
            {"key": "math", "question": "2 + 2"}, {"callbacks": [tracer]}
        )
        == "4"
    )
    assert math_spy.call_args.args[1] == {
        "key": "math",
        "input": {"question": "2 + 2"},
    }
    assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
    parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
    assert len(parent_run.child_runs) == 2
    router_run = parent_run.child_runs[1]
    assert router_run.name == "arouter"
    assert len(router_run.child_runs) == 1
    math_run = router_run.child_runs[0]
    assert math_run.name == "RunnableSequence"
    assert len(math_run.child_runs) == 3


@freeze_time("2023-01-01")
def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
    passthrough = mocker.Mock(side_effect=lambda x: x)

    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )

    chat = FakeListChatModel(responses=["i'm a chatbot"])

    llm = FakeListLLM(responses=["i'm a textbot"])

    chain = (
        prompt
        | passthrough
        | {
            "chat": chat.bind(stop=["Thought:"]),
            "llm": llm,
            "passthrough": passthrough,
        }
    )

    assert isinstance(chain, RunnableSequence)
    assert chain.first == prompt
    assert chain.middle == [RunnableLambda(passthrough)]
    assert isinstance(chain.last, RunnableParallel)

    if PYDANTIC_VERSION_AT_LEAST_210:
        assert dumps(chain, pretty=True) == snapshot

    # Test invoke
    prompt_spy = mocker.spy(prompt.__class__, "invoke")
    chat_spy = mocker.spy(chat.__class__, "invoke")
    llm_spy = mocker.spy(llm.__class__, "invoke")
    tracer = FakeTracer()
    assert chain.invoke(
        {"question": "What is your name?"}, {"callbacks": [tracer]}
    ) == {
        "chat": _any_id_ai_message(content="i'm a chatbot"),
        "llm": "i'm a textbot",
        "passthrough": ChatPromptValue(
            messages=[
                SystemMessage(content="You are a nice assistant."),
                HumanMessage(content="What is your name?"),
            ]
        ),
    }
    assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
    assert chat_spy.call_args.args[1] == ChatPromptValue(
        messages=[
            SystemMessage(content="You are a nice assistant."),
            HumanMessage(content="What is your name?"),
        ]
    )
    assert llm_spy.call_args.args[1] == ChatPromptValue(
        messages=[
            SystemMessage(content="You are a nice assistant."),
            HumanMessage(content="What is your name?"),
        ]
    )
    assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
    parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
    assert len(parent_run.child_runs) == 3
    map_run = parent_run.child_runs[2]
    assert map_run.name == "RunnableParallel<chat,llm,passthrough>"
    assert len(map_run.child_runs) == 3


def test_map_stream() -> None:
    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )

    chat_res = "i'm a chatbot"
    # sleep to better simulate a real stream
    chat = FakeListChatModel(responses=[chat_res], sleep=0.01)

    llm_res = "i'm a textbot"
    # sleep to better simulate a real stream
    llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)

    chain: Runnable = prompt | {
        "chat": chat.bind(stop=["Thought:"]),
        "llm": llm,
        "passthrough": RunnablePassthrough(),
    }

    stream = chain.stream({"question": "What is your name?"})

    final_value = None
    streamed_chunks = []
    for chunk in stream:
        streamed_chunks.append(chunk)
        if final_value is None:
            final_value = chunk
        else:
            final_value += chunk

    assert streamed_chunks[0] in [
        {"passthrough": prompt.invoke({"question": "What is your name?"})},
        {"llm": "i"},
        {"chat": _any_id_ai_message_chunk(content="i")},
    ]
    assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
    assert all(len(c.keys()) == 1 for c in streamed_chunks)
    assert final_value is not None
    assert final_value.get("chat").content == "i'm a chatbot"
    assert final_value.get("llm") == "i'm a textbot"
    assert final_value.get("passthrough") == prompt.invoke(
        {"question": "What is your name?"}
    )

    chain_pick_one = chain.pick("llm")

    assert chain_pick_one.get_output_jsonschema() == {
        "title": "RunnableSequenceOutput",
        "type": "string",
    }

    stream = chain_pick_one.stream({"question": "What is your name?"})

    final_value = None
    streamed_chunks = []
    for chunk in stream:
        streamed_chunks.append(chunk)
        if final_value is None:
            final_value = chunk
        else:
            final_value += chunk

    assert streamed_chunks[0] == "i"
    assert len(streamed_chunks) == len(llm_res)

    chain_pick_two = chain.assign(hello=RunnablePick("llm").pipe(llm)).pick(
        [
            "llm",
            "hello",
        ]
    )

    assert chain_pick_two.get_output_jsonschema() == {
        "title": "RunnableSequenceOutput",
        "type": "object",
        "properties": {
            "hello": {"title": "Hello", "type": "string"},
            "llm": {"title": "Llm", "type": "string"},
        },
        "required": ["llm", "hello"],
    }

    stream = chain_pick_two.stream({"question": "What is your name?"})

    final_value = None
    streamed_chunks = []
    for chunk in stream:
        streamed_chunks.append(chunk)
        if final_value is None:
            final_value = chunk
        else:
            final_value += chunk

    assert streamed_chunks[0] in [
        {"llm": "i"},
        {"chat": _any_id_ai_message_chunk(content="i")},
    ]
    if not (
        # TODO: Rewrite properly the statement above
        streamed_chunks[0] == {"llm": "i"}
        or {"chat": _any_id_ai_message_chunk(content="i")}
    ):
        msg = f"Got an unexpected chunk: {streamed_chunks[0]}"
        raise AssertionError(msg)

    assert len(streamed_chunks) == len(llm_res) + len(chat_res)


def test_map_stream_iterator_input() -> None:
    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )

    chat_res = "i'm a chatbot"
    # sleep to better simulate a real stream
    chat = FakeListChatModel(responses=[chat_res], sleep=0.01)

    llm_res = "i'm a textbot"
    # sleep to better simulate a real stream
    llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)

    chain: Runnable = (
        prompt
        | llm
        | {
            "chat": chat.bind(stop=["Thought:"]),
            "llm": llm,
            "passthrough": RunnablePassthrough(),
        }
    )

    stream = chain.stream({"question": "What is your name?"})

    final_value = None
    streamed_chunks = []
    for chunk in stream:
        streamed_chunks.append(chunk)
        if final_value is None:
            final_value = chunk
        else:
            final_value += chunk

    assert streamed_chunks[0] in [
        {"passthrough": "i"},
        {"llm": "i"},
        {"chat": _any_id_ai_message_chunk(content="i")},
    ]
    assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res)
    assert all(len(c.keys()) == 1 for c in streamed_chunks)
    assert final_value is not None
    assert final_value.get("chat").content == "i'm a chatbot"
    assert final_value.get("llm") == "i'm a textbot"
    assert final_value.get("passthrough") == "i'm a textbot"


async def test_map_astream() -> None:
    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )

    chat_res = "i'm a chatbot"
    # sleep to better simulate a real stream
    chat = FakeListChatModel(responses=[chat_res], sleep=0.01)

    llm_res = "i'm a textbot"
    # sleep to better simulate a real stream
    llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)

    chain: Runnable = prompt | {
        "chat": chat.bind(stop=["Thought:"]),
        "llm": llm,
        "passthrough": RunnablePassthrough(),
    }

    stream = chain.astream({"question": "What is your name?"})

    final_value = None
    streamed_chunks = []
    async for chunk in stream:
        streamed_chunks.append(chunk)
        if final_value is None:
            final_value = chunk
        else:
            final_value += chunk

    assert streamed_chunks[0] in [
        {"passthrough": prompt.invoke({"question": "What is your name?"})},
        {"llm": "i"},
        {"chat": _any_id_ai_message_chunk(content="i")},
    ]
    assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
    assert all(len(c.keys()) == 1 for c in streamed_chunks)
    assert final_value is not None
    assert final_value.get("chat").content == "i'm a chatbot"
    final_value["chat"].id = AnyStr()
    assert final_value.get("llm") == "i'm a textbot"
    assert final_value.get("passthrough") == prompt.invoke(
        {"question": "What is your name?"}
    )

    # Test astream_log state accumulation

    final_state = None
    streamed_ops = []
    async for chunk in chain.astream_log({"question": "What is your name?"}):
        streamed_ops.extend(chunk.ops)
        if final_state is None:
            final_state = chunk
        else:
            final_state += chunk
    final_state = cast("RunLog", final_state)

    assert final_state.state["final_output"] == final_value
    assert len(final_state.state["streamed_output"]) == len(streamed_chunks)
    assert isinstance(final_state.state["id"], str)
    assert len(final_state.ops) == len(streamed_ops)
    assert len(final_state.state["logs"]) == 5
    assert (
        final_state.state["logs"]["ChatPromptTemplate"]["name"] == "ChatPromptTemplate"
    )
    assert final_state.state["logs"]["ChatPromptTemplate"][
        "final_output"
    ] == prompt.invoke({"question": "What is your name?"})
    assert (
        final_state.state["logs"]["RunnableParallel<chat,llm,passthrough>"]["name"]
        == "RunnableParallel<chat,llm,passthrough>"
    )
    assert sorted(final_state.state["logs"]) == [
        "ChatPromptTemplate",
        "FakeListChatModel",
        "FakeStreamingListLLM",
        "RunnableParallel<chat,llm,passthrough>",
        "RunnablePassthrough",
    ]

    # Test astream_log with include filters
    final_state = None
    async for chunk in chain.astream_log(
        {"question": "What is your name?"}, include_names=["FakeListChatModel"]
    ):
        if final_state is None:
            final_state = chunk
        else:
            final_state += chunk
    final_state = cast("RunLog", final_state)

    assert final_state.state["final_output"] == final_value
    assert len(final_state.state["streamed_output"]) == len(streamed_chunks)
    assert len(final_state.state["logs"]) == 1
    assert final_state.state["logs"]["FakeListChatModel"]["name"] == "FakeListChatModel"

    # Test astream_log with exclude filters
    final_state = None
    async for chunk in chain.astream_log(
        {"question": "What is your name?"}, exclude_names=["FakeListChatModel"]
    ):
        if final_state is None:
            final_state = chunk
        else:
            final_state += chunk
    final_state = cast("RunLog", final_state)

    assert final_state.state["final_output"] == final_value
    assert len(final_state.state["streamed_output"]) == len(streamed_chunks)
    assert len(final_state.state["logs"]) == 4
    assert (
        final_state.state["logs"]["ChatPromptTemplate"]["name"] == "ChatPromptTemplate"
    )
    assert final_state.state["logs"]["ChatPromptTemplate"]["final_output"] == (
        prompt.invoke({"question": "What is your name?"})
    )
    assert (
        final_state.state["logs"]["RunnableParallel<chat,llm,passthrough>"]["name"]
        == "RunnableParallel<chat,llm,passthrough>"
    )
    assert sorted(final_state.state["logs"]) == [
        "ChatPromptTemplate",
        "FakeStreamingListLLM",
        "RunnableParallel<chat,llm,passthrough>",
        "RunnablePassthrough",
    ]


async def test_map_astream_iterator_input() -> None:
    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )

    chat_res = "i'm a chatbot"
    # sleep to better simulate a real stream
    chat = FakeListChatModel(responses=[chat_res], sleep=0.01)

    llm_res = "i'm a textbot"
    # sleep to better simulate a real stream
    llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)

    chain: Runnable = (
        prompt
        | llm
        | {
            "chat": chat.bind(stop=["Thought:"]),
            "llm": llm,
            "passthrough": RunnablePassthrough(),
        }
    )

    stream = chain.astream({"question": "What is your name?"})

    final_value = None
    streamed_chunks = []
    async for chunk in stream:
        streamed_chunks.append(chunk)
        if final_value is None:
            final_value = chunk
        else:
            final_value += chunk

    assert streamed_chunks[0] in [
        {"passthrough": "i"},
        {"llm": "i"},
        {"chat": AIMessageChunk(content="i")},
    ]
    assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res)
    assert all(len(c.keys()) == 1 for c in streamed_chunks)
    assert final_value is not None
    assert final_value.get("chat").content == "i'm a chatbot"
    assert final_value.get("llm") == "i'm a textbot"
    assert final_value.get("passthrough") == llm_res

    simple_map = RunnableMap(passthrough=RunnablePassthrough())
    assert loads(dumps(simple_map)) == simple_map


def test_with_config_with_config() -> None:
    llm = FakeListLLM(responses=["i'm a textbot"])

    assert dumpd(
        llm.with_config({"metadata": {"a": "b"}}).with_config(tags=["a-tag"])
    ) == dumpd(llm.with_config({"metadata": {"a": "b"}, "tags": ["a-tag"]}))


def test_metadata_is_merged() -> None:
    """Test metadata and tags defined in with_config and at are merged/concatend."""
    foo = RunnableLambda(lambda x: x).with_config({"metadata": {"my_key": "my_value"}})
    expected_metadata = {
        "my_key": "my_value",
        "my_other_key": "my_other_value",
    }
    with collect_runs() as cb:
        foo.invoke("hi", {"metadata": {"my_other_key": "my_other_value"}})
        run = cb.traced_runs[0]
    assert run.extra is not None
    assert run.extra["metadata"] == expected_metadata


def test_tags_are_appended() -> None:
    """Test tags from with_config are concatenated with those in invocation."""
    foo = RunnableLambda(lambda x: x).with_config({"tags": ["my_key"]})
    with collect_runs() as cb:
        foo.invoke("hi", {"tags": ["invoked_key"]})
        run = cb.traced_runs[0]
    assert isinstance(run.tags, list)
    assert sorted(run.tags) == sorted(["my_key", "invoked_key"])


def test_bind_bind() -> None:
    llm = FakeListLLM(responses=["i'm a textbot"])

    assert dumpd(
        llm.bind(stop=["Thought:"], one="two").bind(
            stop=["Observation:"], hello="world"
        )
    ) == dumpd(llm.bind(stop=["Observation:"], one="two", hello="world"))


def test_bind_with_lambda() -> None:
    def my_function(_: Any, **kwargs: Any) -> int:
        return 3 + int(kwargs.get("n", 0))

    runnable = RunnableLambda(my_function).bind(n=1)
    assert runnable.invoke({}) == 4
    chunks = list(runnable.stream({}))
    assert chunks == [4]


async def test_bind_with_lambda_async() -> None:
    def my_function(_: Any, **kwargs: Any) -> int:
        return 3 + int(kwargs.get("n", 0))

    runnable = RunnableLambda(my_function).bind(n=1)
    assert await runnable.ainvoke({}) == 4
    chunks = [item async for item in runnable.astream({})]
    assert chunks == [4]


def test_deep_stream() -> None:
    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )
    llm = FakeStreamingListLLM(responses=["foo-lish"])

    chain = prompt | llm | StrOutputParser()

    stream = chain.stream({"question": "What up"})

    chunks = list(stream)

    assert len(chunks) == len("foo-lish")
    assert "".join(chunks) == "foo-lish"

    chunks = []
    for chunk in (chain | RunnablePassthrough()).stream({"question": "What up"}):
        chunks.append(chunk)

    assert len(chunks) == len("foo-lish")
    assert "".join(chunks) == "foo-lish"


def test_deep_stream_assign() -> None:
    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )
    llm = FakeStreamingListLLM(responses=["foo-lish"])

    chain: Runnable = prompt | llm | {"str": StrOutputParser()}

    stream = chain.stream({"question": "What up"})

    chunks = list(stream)

    assert len(chunks) == len("foo-lish")
    assert add(chunks) == {"str": "foo-lish"}

    chain_with_assign = chain.assign(hello=itemgetter("str") | llm)

    assert chain_with_assign.get_input_jsonschema() == {
        "title": "PromptInput",
        "type": "object",
        "properties": {"question": {"title": "Question", "type": "string"}},
        "required": ["question"],
    }
    assert chain_with_assign.get_output_jsonschema() == {
        "title": "RunnableSequenceOutput",
        "type": "object",
        "properties": {
            "str": {"title": "Str", "type": "string"},
            "hello": {"title": "Hello", "type": "string"},
        },
        "required": ["str", "hello"],
    }

    chunks = []
    for chunk in chain_with_assign.stream({"question": "What up"}):
        chunks.append(chunk)

    assert len(chunks) == len("foo-lish") * 2
    assert chunks == [
        # first stream passthrough input chunks
        {"str": "f"},
        {"str": "o"},
        {"str": "o"},
        {"str": "-"},
        {"str": "l"},
        {"str": "i"},
        {"str": "s"},
        {"str": "h"},
        # then stream assign output chunks
        {"hello": "f"},
        {"hello": "o"},
        {"hello": "o"},
        {"hello": "-"},
        {"hello": "l"},
        {"hello": "i"},
        {"hello": "s"},
        {"hello": "h"},
    ]
    assert add(chunks) == {"str": "foo-lish", "hello": "foo-lish"}
    assert chain_with_assign.invoke({"question": "What up"}) == {
        "str": "foo-lish",
        "hello": "foo-lish",
    }

    chain_with_assign_shadow = chain.assign(
        str=lambda _: "shadow",
        hello=itemgetter("str") | llm,
    )

    assert chain_with_assign_shadow.get_input_jsonschema() == {
        "title": "PromptInput",
        "type": "object",
        "properties": {"question": {"title": "Question", "type": "string"}},
        "required": ["question"],
    }
    assert chain_with_assign_shadow.get_output_jsonschema() == {
        "title": "RunnableSequenceOutput",
        "type": "object",
        "properties": {
            "str": {"title": "Str"},
            "hello": {"title": "Hello", "type": "string"},
        },
        "required": ["str", "hello"],
    }

    chunks = []
    for chunk in chain_with_assign_shadow.stream({"question": "What up"}):
        chunks.append(chunk)

    assert len(chunks) == len("foo-lish") + 1
    assert add(chunks) == {"str": "shadow", "hello": "foo-lish"}
    assert chain_with_assign_shadow.invoke({"question": "What up"}) == {
        "str": "shadow",
        "hello": "foo-lish",
    }


async def test_deep_astream() -> None:
    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )
    llm = FakeStreamingListLLM(responses=["foo-lish"])

    chain = prompt | llm | StrOutputParser()

    stream = chain.astream({"question": "What up"})

    chunks = [chunk async for chunk in stream]

    assert len(chunks) == len("foo-lish")
    assert "".join(chunks) == "foo-lish"

    chunks = []
    async for chunk in (chain | RunnablePassthrough()).astream({"question": "What up"}):
        chunks.append(chunk)

    assert len(chunks) == len("foo-lish")
    assert "".join(chunks) == "foo-lish"


async def test_deep_astream_assign() -> None:
    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )
    llm = FakeStreamingListLLM(responses=["foo-lish"])

    chain: Runnable = prompt | llm | {"str": StrOutputParser()}

    stream = chain.astream({"question": "What up"})

    chunks = [chunk async for chunk in stream]

    assert len(chunks) == len("foo-lish")
    assert add(chunks) == {"str": "foo-lish"}

    chain_with_assign = chain.assign(
        hello=itemgetter("str") | llm,
    )

    assert chain_with_assign.get_input_jsonschema() == {
        "title": "PromptInput",
        "type": "object",
        "properties": {"question": {"title": "Question", "type": "string"}},
        "required": ["question"],
    }
    assert chain_with_assign.get_output_jsonschema() == {
        "title": "RunnableSequenceOutput",
        "type": "object",
        "properties": {
            "str": {"title": "Str", "type": "string"},
            "hello": {"title": "Hello", "type": "string"},
        },
        "required": ["str", "hello"],
    }

    chunks = []
    async for chunk in chain_with_assign.astream({"question": "What up"}):
        chunks.append(chunk)

    assert len(chunks) == len("foo-lish") * 2
    assert chunks == [
        # first stream passthrough input chunks
        {"str": "f"},
        {"str": "o"},
        {"str": "o"},
        {"str": "-"},
        {"str": "l"},
        {"str": "i"},
        {"str": "s"},
        {"str": "h"},
        # then stream assign output chunks
        {"hello": "f"},
        {"hello": "o"},
        {"hello": "o"},
        {"hello": "-"},
        {"hello": "l"},
        {"hello": "i"},
        {"hello": "s"},
        {"hello": "h"},
    ]
    assert add(chunks) == {"str": "foo-lish", "hello": "foo-lish"}
    assert await chain_with_assign.ainvoke({"question": "What up"}) == {
        "str": "foo-lish",
        "hello": "foo-lish",
    }

    chain_with_assign_shadow = chain | RunnablePassthrough.assign(
        str=lambda _: "shadow",
        hello=itemgetter("str") | llm,
    )

    assert chain_with_assign_shadow.get_input_jsonschema() == {
        "title": "PromptInput",
        "type": "object",
        "properties": {"question": {"title": "Question", "type": "string"}},
        "required": ["question"],
    }
    assert chain_with_assign_shadow.get_output_jsonschema() == {
        "title": "RunnableSequenceOutput",
        "type": "object",
        "properties": {
            "str": {"title": "Str"},
            "hello": {"title": "Hello", "type": "string"},
        },
        "required": ["str", "hello"],
    }

    chunks = []
    async for chunk in chain_with_assign_shadow.astream({"question": "What up"}):
        chunks.append(chunk)

    assert len(chunks) == len("foo-lish") + 1
    assert add(chunks) == {"str": "shadow", "hello": "foo-lish"}
    assert await chain_with_assign_shadow.ainvoke({"question": "What up"}) == {
        "str": "shadow",
        "hello": "foo-lish",
    }


def test_runnable_sequence_transform() -> None:
    llm = FakeStreamingListLLM(responses=["foo-lish"])

    chain = llm | StrOutputParser()

    stream = chain.transform(llm.stream("Hi there!"))

    chunks = list(stream)

    assert len(chunks) == len("foo-lish")
    assert "".join(chunks) == "foo-lish"


async def test_runnable_sequence_atransform() -> None:
    llm = FakeStreamingListLLM(responses=["foo-lish"])

    chain = llm | StrOutputParser()

    stream = chain.atransform(llm.astream("Hi there!"))

    chunks = [chunk async for chunk in stream]

    assert len(chunks) == len("foo-lish")
    assert "".join(chunks) == "foo-lish"


class FakeSplitIntoListParser(BaseOutputParser[list[str]]):
    """Parse the output of an LLM call to a comma-separated list."""

    @classmethod
    def is_lc_serializable(cls) -> bool:
        """Return whether or not the class is serializable."""
        return True

    @override
    def get_format_instructions(self) -> str:
        return (
            "Your response should be a list of comma separated values, "
            "eg: `foo, bar, baz`"
        )

    @override
    def parse(self, text: str) -> list[str]:
        """Parse the output of an LLM call."""
        return text.strip().split(", ")


def test_each_simple() -> None:
    """Test that each() works with a simple runnable."""
    parser = FakeSplitIntoListParser()
    assert parser.invoke("first item, second item") == ["first item", "second item"]
    assert parser.map().invoke(["a, b", "c"]) == [["a", "b"], ["c"]]
    assert parser.map().map().invoke([["a, b", "c"], ["c, e"]]) == [
        [["a", "b"], ["c"]],
        [["c", "e"]],
    ]


def test_each(snapshot: SnapshotAssertion) -> None:
    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )
    first_llm = FakeStreamingListLLM(responses=["first item, second item, third item"])
    parser = FakeSplitIntoListParser()
    second_llm = FakeStreamingListLLM(responses=["this", "is", "a", "test"])

    chain = prompt | first_llm | parser | second_llm.map()

    assert dumps(chain, pretty=True) == snapshot
    output = chain.invoke({"question": "What up"})
    assert output == ["this", "is", "a"]

    assert (parser | second_llm.map()).invoke("first item, second item") == [
        "test",
        "this",
    ]


def test_recursive_lambda() -> None:
    def _simple_recursion(x: int) -> int | Runnable:
        if x < 10:
            return RunnableLambda(lambda *_: _simple_recursion(x + 1))
        return x

    runnable = RunnableLambda(_simple_recursion)
    assert runnable.invoke(5) == 10

    with pytest.raises(RecursionError):
        runnable.invoke(0, {"recursion_limit": 9})


def test_retrying(mocker: MockerFixture) -> None:
    def _lambda(x: int) -> int:
        if x == 1:
            msg = "x is 1"
            raise ValueError(msg)
        if x == 2:
            msg = "x is 2"
            raise RuntimeError(msg)
        return x

    lambda_mock = mocker.Mock(side_effect=_lambda)
    runnable = RunnableLambda(lambda_mock)

    with pytest.raises(ValueError, match="x is 1"):
        runnable.invoke(1)

    assert lambda_mock.call_count == 1
    lambda_mock.reset_mock()

    with pytest.raises(ValueError, match="x is 1"):
        runnable.with_retry(
            stop_after_attempt=2,
            retry_if_exception_type=(ValueError,),
            exponential_jitter_params={"initial": 0.1},
        ).invoke(1)

    assert lambda_mock.call_count == 2  # retried
    lambda_mock.reset_mock()

    with pytest.raises(RuntimeError):
        runnable.with_retry(
            stop_after_attempt=2,
            wait_exponential_jitter=False,
            retry_if_exception_type=(ValueError,),
        ).invoke(2)

    assert lambda_mock.call_count == 1  # did not retry
    lambda_mock.reset_mock()

    with pytest.raises(ValueError, match="x is 1"):
        runnable.with_retry(
            stop_after_attempt=2,
            wait_exponential_jitter=False,
            retry_if_exception_type=(ValueError,),
        ).batch([1, 2, 0])

    # 3rd input isn't retried because it succeeded
    assert lambda_mock.call_count == 3 + 2
    lambda_mock.reset_mock()

    output = runnable.with_retry(
        stop_after_attempt=2,
        wait_exponential_jitter=False,
        retry_if_exception_type=(ValueError,),
    ).batch([1, 2, 0], return_exceptions=True)

    # 3rd input isn't retried because it succeeded
    assert lambda_mock.call_count == 3 + 2
    assert len(output) == 3
    assert isinstance(output[0], ValueError)
    assert isinstance(output[1], RuntimeError)
    assert output[2] == 0
    lambda_mock.reset_mock()


def test_retry_batch_preserves_order() -> None:
    """Regression test: batch with retry should preserve input order.

    The previous implementation stored successful results in a map keyed by the
    index within the *pending* (filtered) list rather than the original input
    index, causing collisions after retries. This produced duplicated outputs
    and dropped earlier successes (e.g. [0,1,2] -> [1,1,2]).
    """
    # Fail only the middle element on the first attempt to trigger the bug.
    first_fail: set[int] = {1}

    def sometimes_fail(x: int) -> int:  # pragma: no cover - trivial
        if x in first_fail:
            first_fail.remove(x)
            msg = "fail once"
            raise ValueError(msg)
        return x

    runnable = RunnableLambda(sometimes_fail)

    results = runnable.with_retry(
        stop_after_attempt=2,
        wait_exponential_jitter=False,
        retry_if_exception_type=(ValueError,),
    ).batch([0, 1, 2])

    # Expect exact ordering preserved.
    assert results == [0, 1, 2]


async def test_async_retry_batch_preserves_order() -> None:
    """Async variant of order preservation regression test."""
    first_fail: set[int] = {1}

    def sometimes_fail(x: int) -> int:  # pragma: no cover - trivial
        if x in first_fail:
            first_fail.remove(x)
            msg = "fail once"
            raise ValueError(msg)
        return x

    runnable = RunnableLambda(sometimes_fail)

    results = await runnable.with_retry(
        stop_after_attempt=2,
        wait_exponential_jitter=False,
        retry_if_exception_type=(ValueError,),
    ).abatch([0, 1, 2])

    assert results == [0, 1, 2]


async def test_async_retrying(mocker: MockerFixture) -> None:
    def _lambda(x: int) -> int:
        if x == 1:
            msg = "x is 1"
            raise ValueError(msg)
        if x == 2:
            msg = "x is 2"
            raise RuntimeError(msg)
        return x

    lambda_mock = mocker.Mock(side_effect=_lambda)
    runnable = RunnableLambda(lambda_mock)

    with pytest.raises(ValueError, match="x is 1"):
        await runnable.ainvoke(1)

    assert lambda_mock.call_count == 1
    lambda_mock.reset_mock()

    with pytest.raises(ValueError, match="x is 1"):
        await runnable.with_retry(
            stop_after_attempt=2,
            wait_exponential_jitter=False,
            retry_if_exception_type=(ValueError, KeyError),
        ).ainvoke(1)

    assert lambda_mock.call_count == 2  # retried
    lambda_mock.reset_mock()

    with pytest.raises(RuntimeError):
        await runnable.with_retry(
            stop_after_attempt=2,
            wait_exponential_jitter=False,
            retry_if_exception_type=(ValueError,),
        ).ainvoke(2)

    assert lambda_mock.call_count == 1  # did not retry
    lambda_mock.reset_mock()

    with pytest.raises(ValueError, match="x is 1"):
        await runnable.with_retry(
            stop_after_attempt=2,
            wait_exponential_jitter=False,
            retry_if_exception_type=(ValueError,),
        ).abatch([1, 2, 0])

    # 3rd input isn't retried because it succeeded
    assert lambda_mock.call_count == 3 + 2
    lambda_mock.reset_mock()

    output = await runnable.with_retry(
        stop_after_attempt=2,
        wait_exponential_jitter=False,
        retry_if_exception_type=(ValueError,),
    ).abatch([1, 2, 0], return_exceptions=True)

    # 3rd input isn't retried because it succeeded
    assert lambda_mock.call_count == 3 + 2
    assert len(output) == 3
    assert isinstance(output[0], ValueError)
    assert isinstance(output[1], RuntimeError)
    assert output[2] == 0
    lambda_mock.reset_mock()


def test_runnable_lambda_stream() -> None:
    """Test that stream works for both normal functions & those returning Runnable."""
    # Normal output should work
    output: list[Any] = list(RunnableLambda(range).stream(5))
    assert output == [range(5)]

    # Runnable output should also work
    llm_res = "i'm a textbot"
    # sleep to better simulate a real stream
    llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)

    output = list(RunnableLambda[str, str](lambda _: llm).stream(""))
    assert output == list(llm_res)


def test_runnable_lambda_stream_with_callbacks() -> None:
    """Test that stream works for RunnableLambda when using callbacks."""
    tracer = FakeTracer()

    llm_res = "i'm a textbot"
    # sleep to better simulate a real stream
    llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
    config: RunnableConfig = {"callbacks": [tracer]}

    assert list(
        RunnableLambda[str, str](lambda _: llm).stream("", config=config)
    ) == list(llm_res)

    assert len(tracer.runs) == 1
    assert tracer.runs[0].error is None
    assert tracer.runs[0].outputs == {"output": llm_res}

    def raise_value_error(_: int) -> int:
        """Raise a value error."""
        msg = "x is too large"
        raise ValueError(msg)

    # Check that the chain on error is invoked
    with pytest.raises(ValueError, match="x is too large"):
        _ = list(RunnableLambda(raise_value_error).stream(1000, config=config))

    assert len(tracer.runs) == 2
    assert "ValueError('x is too large')" in str(tracer.runs[1].error)
    assert not tracer.runs[1].outputs


async def test_runnable_lambda_astream() -> None:
    """Test that astream works for both normal functions & those returning Runnable."""

    # Wrapper to make a normal function async
    def awrapper(func: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
        async def afunc(*args: Any, **kwargs: Any) -> Any:
            return func(*args, **kwargs)

        return afunc

    # Normal output should work
    output: list[Any] = [
        chunk
        async for chunk in RunnableLambda(
            func=id,
            afunc=awrapper(range),  # id func is just dummy
        ).astream(5)
    ]
    assert output == [range(5)]

    # Normal output using func should also work
    output = [_ async for _ in RunnableLambda(range).astream(5)]
    assert output == [range(5)]

    # Runnable output should also work
    llm_res = "i'm a textbot"
    # sleep to better simulate a real stream
    llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)

    output = [
        _
        async for _ in RunnableLambda(
            func=id,
            afunc=awrapper(lambda _: llm),
        ).astream("")
    ]
    assert output == list(llm_res)

    output = [
        chunk async for chunk in RunnableLambda[str, str](lambda _: llm).astream("")
    ]
    assert output == list(llm_res)


async def test_runnable_lambda_astream_with_callbacks() -> None:
    """Test that astream works for RunnableLambda when using callbacks."""
    tracer = FakeTracer()

    llm_res = "i'm a textbot"
    # sleep to better simulate a real stream
    llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
    config: RunnableConfig = {"callbacks": [tracer]}

    assert [
        _
        async for _ in RunnableLambda[str, str](lambda _: llm).astream(
            "", config=config
        )
    ] == list(llm_res)

    assert len(tracer.runs) == 1
    assert tracer.runs[0].error is None
    assert tracer.runs[0].outputs == {"output": llm_res}

    def raise_value_error(_: int) -> int:
        """Raise a value error."""
        msg = "x is too large"
        raise ValueError(msg)

    # Check that the chain on error is invoked
    with pytest.raises(ValueError, match="x is too large"):
        _ = [
            _
            async for _ in RunnableLambda(raise_value_error).astream(
                1000, config=config
            )
        ]

    assert len(tracer.runs) == 2
    assert "ValueError('x is too large')" in str(tracer.runs[1].error)
    assert not tracer.runs[1].outputs


@freeze_time("2023-01-01")
def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
    class ControlledExceptionRunnable(Runnable[str, str]):
        def __init__(self, fail_starts_with: str) -> None:
            self.fail_starts_with = fail_starts_with

        @override
        def invoke(
            self, input: Any, config: RunnableConfig | None = None, **kwargs: Any
        ) -> Any:
            raise NotImplementedError

        def _batch(
            self,
            inputs: list[str],
        ) -> list[str | Exception]:
            outputs: list[str | Exception] = []
            for value in inputs:
                if value.startswith(self.fail_starts_with):
                    outputs.append(
                        ValueError(
                            f"ControlledExceptionRunnable({self.fail_starts_with}) "
                            f"fail for {value}"
                        )
                    )
                else:
                    outputs.append(value + "a")
            return outputs

        def batch(
            self,
            inputs: list[str],
            config: RunnableConfig | list[RunnableConfig] | None = None,
            *,
            return_exceptions: bool = False,
            **kwargs: Any,
        ) -> list[str]:
            return self._batch_with_config(
                self._batch,
                inputs,
                config,
                return_exceptions=return_exceptions,
                **kwargs,
            )

    chain = (
        ControlledExceptionRunnable("bux")
        | ControlledExceptionRunnable("bar")
        | ControlledExceptionRunnable("baz")
        | ControlledExceptionRunnable("foo")
    )

    assert isinstance(chain, RunnableSequence)

    # Test batch
    with pytest.raises(
        ValueError, match=re.escape("ControlledExceptionRunnable(bar) fail for bara")
    ):
        chain.batch(["foo", "bar", "baz", "qux"])

    spy = mocker.spy(ControlledExceptionRunnable, "batch")
    tracer = FakeTracer()
    inputs = ["foo", "bar", "baz", "qux"]
    outputs = chain.batch(inputs, {"callbacks": [tracer]}, return_exceptions=True)
    assert len(outputs) == 4
    assert isinstance(outputs[0], ValueError)
    assert isinstance(outputs[1], ValueError)
    assert isinstance(outputs[2], ValueError)
    assert outputs[3] == "quxaaaa"
    assert spy.call_count == 4
    inputs_to_batch = [c[0][1] for c in spy.call_args_list]
    assert inputs_to_batch == [
        # inputs to sequence step 0
        # same as inputs to sequence.batch()
        ["foo", "bar", "baz", "qux"],
        # inputs to sequence step 1
        # == outputs of sequence step 0 as no exceptions were raised
        ["fooa", "bara", "baza", "quxa"],
        # inputs to sequence step 2
        # 'bar' was dropped as it raised an exception in step 1
        ["fooaa", "bazaa", "quxaa"],
        # inputs to sequence step 3
        # 'baz' was dropped as it raised an exception in step 2
        ["fooaaa", "quxaaa"],
    ]
    parent_runs = sorted(
        (r for r in tracer.runs if r.parent_run_id is None),
        key=lambda run: inputs.index(run.inputs["input"]),
    )
    assert len(parent_runs) == 4

    parent_run_foo = parent_runs[0]
    assert parent_run_foo.inputs["input"] == "foo"
    assert repr(ValueError("ControlledExceptionRunnable(foo) fail for fooaaa")) in str(
        parent_run_foo.error
    )
    assert len(parent_run_foo.child_runs) == 4
    assert [r.error for r in parent_run_foo.child_runs[:-1]] == [
        None,
        None,
        None,
    ]
    assert repr(ValueError("ControlledExceptionRunnable(foo) fail for fooaaa")) in str(
        parent_run_foo.child_runs[-1].error
    )

    parent_run_bar = parent_runs[1]
    assert parent_run_bar.inputs["input"] == "bar"
    assert repr(ValueError("ControlledExceptionRunnable(bar) fail for bara")) in str(
        parent_run_bar.error
    )
    assert len(parent_run_bar.child_runs) == 2
    assert parent_run_bar.child_runs[0].error is None
    assert repr(ValueError("ControlledExceptionRunnable(bar) fail for bara")) in str(
        parent_run_bar.child_runs[1].error
    )

    parent_run_baz = parent_runs[2]
    assert parent_run_baz.inputs["input"] == "baz"
    assert repr(ValueError("ControlledExceptionRunnable(baz) fail for bazaa")) in str(
        parent_run_baz.error
    )
    assert len(parent_run_baz.child_runs) == 3

    assert [r.error for r in parent_run_baz.child_runs[:-1]] == [
        None,
        None,
    ]
    assert repr(ValueError("ControlledExceptionRunnable(baz) fail for bazaa")) in str(
        parent_run_baz.child_runs[-1].error
    )

    parent_run_qux = parent_runs[3]
    assert parent_run_qux.inputs["input"] == "qux"
    assert parent_run_qux.error is None
    assert parent_run_qux.outputs is not None
    assert parent_run_qux.outputs["output"] == "quxaaaa"
    assert len(parent_run_qux.child_runs) == 4
    assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None]


@freeze_time("2023-01-01")
async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
    class ControlledExceptionRunnable(Runnable[str, str]):
        def __init__(self, fail_starts_with: str) -> None:
            self.fail_starts_with = fail_starts_with

        @override
        def invoke(
            self, input: Any, config: RunnableConfig | None = None, **kwargs: Any
        ) -> Any:
            raise NotImplementedError

        async def _abatch(
            self,
            inputs: list[str],
        ) -> list[str | Exception]:
            outputs: list[str | Exception] = []
            for value in inputs:
                if value.startswith(self.fail_starts_with):
                    outputs.append(
                        ValueError(
                            f"ControlledExceptionRunnable({self.fail_starts_with}) "
                            f"fail for {value}"
                        )
                    )
                else:
                    outputs.append(value + "a")
            return outputs

        async def abatch(
            self,
            inputs: list[str],
            config: RunnableConfig | list[RunnableConfig] | None = None,
            *,
            return_exceptions: bool = False,
            **kwargs: Any,
        ) -> list[str]:
            return await self._abatch_with_config(
                self._abatch,
                inputs,
                config,
                return_exceptions=return_exceptions,
                **kwargs,
            )

    chain = (
        ControlledExceptionRunnable("bux")
        | ControlledExceptionRunnable("bar")
        | ControlledExceptionRunnable("baz")
        | ControlledExceptionRunnable("foo")
    )

    assert isinstance(chain, RunnableSequence)

    # Test abatch
    with pytest.raises(
        ValueError, match=re.escape("ControlledExceptionRunnable(bar) fail for bara")
    ):
        await chain.abatch(["foo", "bar", "baz", "qux"])

    spy = mocker.spy(ControlledExceptionRunnable, "abatch")
    tracer = FakeTracer()
    inputs = ["foo", "bar", "baz", "qux"]
    outputs = await chain.abatch(
        inputs, {"callbacks": [tracer]}, return_exceptions=True
    )
    assert len(outputs) == 4
    assert isinstance(outputs[0], ValueError)
    assert isinstance(outputs[1], ValueError)
    assert isinstance(outputs[2], ValueError)
    assert outputs[3] == "quxaaaa"
    assert spy.call_count == 4
    inputs_to_batch = [c[0][1] for c in spy.call_args_list]
    assert inputs_to_batch == [
        # inputs to sequence step 0
        # same as inputs to sequence.batch()
        ["foo", "bar", "baz", "qux"],
        # inputs to sequence step 1
        # == outputs of sequence step 0 as no exceptions were raised
        ["fooa", "bara", "baza", "quxa"],
        # inputs to sequence step 2
        # 'bar' was dropped as it raised an exception in step 1
        ["fooaa", "bazaa", "quxaa"],
        # inputs to sequence step 3
        # 'baz' was dropped as it raised an exception in step 2
        ["fooaaa", "quxaaa"],
    ]
    parent_runs = sorted(
        (r for r in tracer.runs if r.parent_run_id is None),
        key=lambda run: inputs.index(run.inputs["input"]),
    )
    assert len(parent_runs) == 4

    parent_run_foo = parent_runs[0]
    assert parent_run_foo.inputs["input"] == "foo"
    assert repr(ValueError("ControlledExceptionRunnable(foo) fail for fooaaa")) in str(
        parent_run_foo.error
    )
    assert len(parent_run_foo.child_runs) == 4
    assert [r.error for r in parent_run_foo.child_runs[:-1]] == [
        None,
        None,
        None,
    ]
    assert repr(ValueError("ControlledExceptionRunnable(foo) fail for fooaaa")) in str(
        parent_run_foo.child_runs[-1].error
    )

    parent_run_bar = parent_runs[1]
    assert parent_run_bar.inputs["input"] == "bar"
    assert repr(ValueError("ControlledExceptionRunnable(bar) fail for bara")) in str(
        parent_run_bar.error
    )
    assert len(parent_run_bar.child_runs) == 2
    assert parent_run_bar.child_runs[0].error is None
    assert repr(ValueError("ControlledExceptionRunnable(bar) fail for bara")) in str(
        parent_run_bar.child_runs[1].error
    )

    parent_run_baz = parent_runs[2]
    assert parent_run_baz.inputs["input"] == "baz"
    assert repr(ValueError("ControlledExceptionRunnable(baz) fail for bazaa")) in str(
        parent_run_baz.error
    )
    assert len(parent_run_baz.child_runs) == 3
    assert [r.error for r in parent_run_baz.child_runs[:-1]] == [
        None,
        None,
    ]
    assert repr(ValueError("ControlledExceptionRunnable(baz) fail for bazaa")) in str(
        parent_run_baz.child_runs[-1].error
    )

    parent_run_qux = parent_runs[3]
    assert parent_run_qux.inputs["input"] == "qux"
    assert parent_run_qux.error is None
    assert parent_run_qux.outputs is not None
    assert parent_run_qux.outputs["output"] == "quxaaaa"
    assert len(parent_run_qux.child_runs) == 4
    assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None]


def test_runnable_branch_init() -> None:
    """Verify that runnable branch gets initialized properly."""
    add = RunnableLambda[int, int](lambda x: x + 1)
    condition = RunnableLambda[int, bool](lambda x: x > 0)

    # Test failure with less than 2 branches
    with pytest.raises(
        ValueError, match="RunnableBranch requires at least two branches"
    ):
        RunnableBranch((condition, add))

    # Test failure with less than 2 branches
    with pytest.raises(
        ValueError, match="RunnableBranch requires at least two branches"
    ):
        RunnableBranch(condition)


@pytest.mark.parametrize(
    "branches",
    [
        [
            (RunnableLambda(lambda x: x > 0), RunnableLambda(lambda x: x + 1)),
            RunnableLambda(lambda x: x - 1),
        ],
        [
            (RunnableLambda(lambda x: x > 0), RunnableLambda(lambda x: x + 1)),
            (RunnableLambda(lambda x: x > 5), RunnableLambda(lambda x: x + 1)),
            RunnableLambda(lambda x: x - 1),
        ],
        [
            (lambda x: x > 0, lambda x: x + 1),
            (lambda x: x > 5, lambda x: x + 1),
            lambda x: x - 1,
        ],
    ],
)
def test_runnable_branch_init_coercion(branches: Sequence[Any]) -> None:
    """Verify that runnable branch gets initialized properly."""
    runnable = RunnableBranch[int, int](*branches)
    for branch in runnable.branches:
        condition, body = branch
        assert isinstance(condition, Runnable)
        assert isinstance(body, Runnable)

    assert isinstance(runnable.default, Runnable)
    assert _schema(runnable.input_schema) == {
        "title": "RunnableBranchInput",
        "type": "integer",
    }


def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None:
    """Verify that runnables are invoked only when necessary."""
    # Test with single branch
    add = RunnableLambda[int, int](lambda x: x + 1)
    sub = RunnableLambda[int, int](lambda x: x - 1)
    condition = RunnableLambda[int, bool](lambda x: x > 0)
    spy = mocker.spy(condition, "invoke")
    add_spy = mocker.spy(add, "invoke")

    branch = RunnableBranch[int, int]((condition, add), (condition, add), sub)
    assert spy.call_count == 0
    assert add_spy.call_count == 0

    assert branch.invoke(1) == 2
    assert add_spy.call_count == 1
    assert spy.call_count == 1

    assert branch.invoke(2) == 3
    assert spy.call_count == 2
    assert add_spy.call_count == 2

    assert branch.invoke(-3) == -4
    # Should fall through to default branch with condition being evaluated twice!
    assert spy.call_count == 4
    # Add should not be invoked
    assert add_spy.call_count == 2


def test_runnable_branch_invoke() -> None:
    # Test with single branch
    def raise_value_error(_: int) -> int:
        """Raise a value error."""
        msg = "x is too large"
        raise ValueError(msg)

    branch = RunnableBranch[int, int](
        (lambda x: x > 100, raise_value_error),
        # mypy cannot infer types from the lambda
        (lambda x: x > 0 and x < 5, lambda x: x + 1),
        (lambda x: x > 5, lambda x: x * 10),
        lambda x: x - 1,
    )

    assert branch.invoke(1) == 2
    assert branch.invoke(10) == 100
    assert branch.invoke(0) == -1
    # Should raise an exception
    with pytest.raises(ValueError, match="x is too large"):
        branch.invoke(1000)


def test_runnable_branch_batch() -> None:
    """Test batch variant."""
    # Test with single branch
    branch = RunnableBranch[int, int](
        (lambda x: x > 0 and x < 5, lambda x: x + 1),
        (lambda x: x > 5, lambda x: x * 10),
        lambda x: x - 1,
    )

    assert branch.batch([1, 10, 0]) == [2, 100, -1]


async def test_runnable_branch_ainvoke() -> None:
    """Test async variant of invoke."""
    branch = RunnableBranch[int, int](
        (lambda x: x > 0 and x < 5, lambda x: x + 1),
        (lambda x: x > 5, lambda x: x * 10),
        lambda x: x - 1,
    )

    assert await branch.ainvoke(1) == 2
    assert await branch.ainvoke(10) == 100
    assert await branch.ainvoke(0) == -1

    # Verify that the async variant is used if available
    async def condition(x: int) -> bool:
        return x > 0

    async def add(x: int) -> int:
        return x + 1

    async def sub(x: int) -> int:
        return x - 1

    branch = RunnableBranch[int, int]((condition, add), sub)

    assert await branch.ainvoke(1) == 2
    assert await branch.ainvoke(-10) == -11


def test_runnable_branch_invoke_callbacks() -> None:
    """Verify that callbacks are correctly used in invoke."""
    tracer = FakeTracer()

    def raise_value_error(_: int) -> int:
        """Raise a value error."""
        msg = "x is too large"
        raise ValueError(msg)

    branch = RunnableBranch[int, int](
        (lambda x: x > 100, raise_value_error),
        lambda x: x - 1,
    )

    assert branch.invoke(1, config={"callbacks": [tracer]}) == 0
    assert len(tracer.runs) == 1
    assert tracer.runs[0].error is None
    assert tracer.runs[0].outputs == {"output": 0}

    # Check that the chain on end is invoked
    with pytest.raises(ValueError, match="x is too large"):
        branch.invoke(1000, config={"callbacks": [tracer]})

    assert len(tracer.runs) == 2
    assert "ValueError('x is too large')" in str(tracer.runs[1].error)
    assert not tracer.runs[1].outputs


async def test_runnable_branch_ainvoke_callbacks() -> None:
    """Verify that callbacks are invoked correctly in ainvoke."""
    tracer = FakeTracer()

    async def raise_value_error(_: int) -> int:
        """Raise a value error."""
        msg = "x is too large"
        raise ValueError(msg)

    branch = RunnableBranch[int, int](
        (lambda x: x > 100, raise_value_error),
        lambda x: x - 1,
    )

    assert await branch.ainvoke(1, config={"callbacks": [tracer]}) == 0
    assert len(tracer.runs) == 1
    assert tracer.runs[0].error is None
    assert tracer.runs[0].outputs == {"output": 0}

    # Check that the chain on end is invoked
    with pytest.raises(ValueError, match="x is too large"):
        await branch.ainvoke(1000, config={"callbacks": [tracer]})

    assert len(tracer.runs) == 2
    assert "ValueError('x is too large')" in str(tracer.runs[1].error)
    assert not tracer.runs[1].outputs


async def test_runnable_branch_abatch() -> None:
    """Test async variant of invoke."""
    branch = RunnableBranch[int, int](
        (lambda x: x > 0 and x < 5, lambda x: x + 1),
        (lambda x: x > 5, lambda x: x * 10),
        lambda x: x - 1,
    )

    assert await branch.abatch([1, 10, 0]) == [2, 100, -1]


def test_runnable_branch_stream() -> None:
    """Verify that stream works for RunnableBranch."""
    llm_res = "i'm a textbot"
    # sleep to better simulate a real stream
    llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)

    branch = RunnableBranch[str, Any](
        (lambda x: x == "hello", llm),
        lambda x: x,
    )

    assert list(branch.stream("hello")) == list(llm_res)
    assert list(branch.stream("bye")) == ["bye"]


def test_runnable_branch_stream_with_callbacks() -> None:
    """Verify that stream works for RunnableBranch when using callbacks."""
    tracer = FakeTracer()

    def raise_value_error(x: str) -> Any:
        """Raise a value error."""
        msg = f"x is {x}"
        raise ValueError(msg)

    llm_res = "i'm a textbot"
    # sleep to better simulate a real stream
    llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)

    branch = RunnableBranch[str, Any](
        (lambda x: x == "error", raise_value_error),
        (lambda x: x == "hello", llm),
        lambda x: x,
    )
    config: RunnableConfig = {"callbacks": [tracer]}

    assert list(branch.stream("hello", config=config)) == list(llm_res)

    assert len(tracer.runs) == 1
    assert tracer.runs[0].error is None
    assert tracer.runs[0].outputs == {"output": llm_res}

    # Verify that the chain on error is invoked
    with pytest.raises(ValueError, match="x is error"):
        _ = list(branch.stream("error", config=config))

    assert len(tracer.runs) == 2
    assert "ValueError('x is error')" in str(tracer.runs[1].error)
    assert not tracer.runs[1].outputs

    assert list(branch.stream("bye", config=config)) == ["bye"]

    assert len(tracer.runs) == 3
    assert tracer.runs[2].error is None
    assert tracer.runs[2].outputs == {"output": "bye"}


async def test_runnable_branch_astream() -> None:
    """Verify that astream works for RunnableBranch."""
    llm_res = "i'm a textbot"
    # sleep to better simulate a real stream
    llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)

    branch = RunnableBranch[str, Any](
        (lambda x: x == "hello", llm),
        lambda x: x,
    )

    assert [_ async for _ in branch.astream("hello")] == list(llm_res)
    assert [_ async for _ in branch.astream("bye")] == ["bye"]

    # Verify that the async variant is used if available
    async def condition(x: str) -> bool:
        return x == "hello"

    async def repeat(x: str) -> str:
        return x + x

    async def reverse(x: str) -> str:
        return x[::-1]

    branch = RunnableBranch[str, Any]((condition, repeat), llm)

    assert [_ async for _ in branch.astream("hello")] == ["hello" * 2]
    assert [_ async for _ in branch.astream("bye")] == list(llm_res)

    branch = RunnableBranch[str, Any]((condition, llm), reverse)

    assert [_ async for _ in branch.astream("hello")] == list(llm_res)
    assert [_ async for _ in branch.astream("bye")] == ["eyb"]


async def test_runnable_branch_astream_with_callbacks() -> None:
    """Verify that astream works for RunnableBranch when using callbacks."""
    tracer = FakeTracer()

    def raise_value_error(x: str) -> Any:
        """Raise a value error."""
        msg = f"x is {x}"
        raise ValueError(msg)

    llm_res = "i'm a textbot"
    # sleep to better simulate a real stream
    llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)

    branch = RunnableBranch[str, Any](
        (lambda x: x == "error", raise_value_error),
        (lambda x: x == "hello", llm),
        lambda x: x,
    )
    config: RunnableConfig = {"callbacks": [tracer]}

    assert [_ async for _ in branch.astream("hello", config=config)] == list(llm_res)

    assert len(tracer.runs) == 1
    assert tracer.runs[0].error is None
    assert tracer.runs[0].outputs == {"output": llm_res}

    # Verify that the chain on error is invoked
    with pytest.raises(ValueError, match="x is error"):
        _ = [_ async for _ in branch.astream("error", config=config)]

    assert len(tracer.runs) == 2
    assert "ValueError('x is error')" in str(tracer.runs[1].error)
    assert not tracer.runs[1].outputs

    assert [_ async for _ in branch.astream("bye", config=config)] == ["bye"]

    assert len(tracer.runs) == 3
    assert tracer.runs[2].error is None
    assert tracer.runs[2].outputs == {"output": "bye"}


def test_representation_of_runnables() -> None:
    """Test representation of runnables."""
    runnable = RunnableLambda[int, int](lambda x: x * 2)
    assert repr(runnable) == "RunnableLambda(lambda x: x * 2)"

    def f(_: int) -> int:
        """Return 2."""
        return 2

    assert repr(RunnableLambda(func=f)) == "RunnableLambda(f)"

    async def af(_: int) -> int:
        """Return 2."""
        return 2

    assert repr(RunnableLambda(func=f, afunc=af)) == "RunnableLambda(f)"

    assert repr(
        RunnableLambda(lambda x: x * 2)
        | {
            "a": RunnableLambda(lambda x: x * 2),
            "b": RunnableLambda(lambda x: x * 3),
        }
    ) == (
        "RunnableLambda(lambda x: x * 2)\n"
        "| {\n"
        "    a: RunnableLambda(...),\n"
        "    b: RunnableLambda(...)\n"
        "  }"
    )


async def test_tool_from_runnable() -> None:
    prompt = (
        SystemMessagePromptTemplate.from_template("You are a nice assistant.")
        + "{question}"
    )
    llm = FakeStreamingListLLM(responses=["foo-lish"])

    chain = prompt | llm | StrOutputParser()

    chain_tool = tool("chain_tool", chain)

    assert isinstance(chain_tool, BaseTool)
    assert chain_tool.name == "chain_tool"
    assert chain_tool.run({"question": "What up"}) == chain.invoke(
        {"question": "What up"}
    )
    assert await chain_tool.arun({"question": "What up"}) == await chain.ainvoke(
        {"question": "What up"}
    )
    assert chain_tool.description.endswith(repr(chain))
    assert _schema(chain_tool.args_schema) == chain.get_input_jsonschema()
    assert _schema(chain_tool.args_schema) == {
        "properties": {"question": {"title": "Question", "type": "string"}},
        "title": "PromptInput",
        "type": "object",
        "required": ["question"],
    }


def test_runnable_gen() -> None:
    """Test that a generator can be used as a runnable."""

    def gen(_: Iterator[Any]) -> Iterator[int]:
        yield 1
        yield 2
        yield 3

    runnable = RunnableGenerator(gen)

    assert runnable.get_input_jsonschema() == {"title": "gen_input"}
    assert runnable.get_output_jsonschema() == {
        "title": "gen_output",
        "type": "integer",
    }

    assert runnable.invoke(None) == 6
    assert list(runnable.stream(None)) == [1, 2, 3]
    assert runnable.batch([None, None]) == [6, 6]


async def test_runnable_gen_async() -> None:
    """Test that a generator can be used as a runnable."""

    async def agen(_: AsyncIterator[Any]) -> AsyncIterator[int]:
        yield 1
        yield 2
        yield 3

    arunnable = RunnableGenerator(agen)

    assert await arunnable.ainvoke(None) == 6
    assert [p async for p in arunnable.astream(None)] == [1, 2, 3]
    assert await arunnable.abatch([None, None]) == [6, 6]

    class AsyncGen:
        async def __call__(self, _: AsyncIterator[Any]) -> AsyncIterator[int]:
            yield 1
            yield 2
            yield 3

    arunnablecallable = RunnableGenerator(AsyncGen())
    assert await arunnablecallable.ainvoke(None) == 6
    assert [p async for p in arunnablecallable.astream(None)] == [1, 2, 3]
    assert await arunnablecallable.abatch([None, None]) == [6, 6]
    with pytest.raises(NotImplementedError):
        await asyncio.to_thread(arunnablecallable.invoke, None)
    with pytest.raises(NotImplementedError):
        await asyncio.to_thread(arunnablecallable.stream, None)
    with pytest.raises(NotImplementedError):
        await asyncio.to_thread(arunnablecallable.batch, [None, None])


def test_runnable_gen_context_config() -> None:
    """Test generator runnable config propagation.

    Test that a generator can call other runnables with config
    propagated from the context.
    """
    fake = RunnableLambda(len)

    def gen(_: Iterator[Any]) -> Iterator[int]:
        yield fake.invoke("a")
        yield fake.invoke("aa")
        yield fake.invoke("aaa")

    runnable = RunnableGenerator(gen)

    assert runnable.get_input_jsonschema() == {"title": "gen_input"}
    assert runnable.get_output_jsonschema() == {
        "title": "gen_output",
        "type": "integer",
    }

    tracer = FakeTracer()
    run_id = uuid.uuid4()
    assert runnable.invoke(None, {"callbacks": [tracer], "run_id": run_id}) == 6
    assert len(tracer.runs) == 1
    assert tracer.runs[0].outputs == {"output": 6}
    assert len(tracer.runs[0].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
    run_ids = tracer.run_ids
    assert run_id in run_ids
    assert len(run_ids) == len(set(run_ids))
    tracer.runs.clear()

    assert list(runnable.stream(None)) == [1, 2, 3]
    assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"

    tracer = FakeTracer()
    run_id = uuid.uuid4()
    assert list(runnable.stream(None, {"callbacks": [tracer], "run_id": run_id})) == [
        1,
        2,
        3,
    ]
    assert len(tracer.runs) == 1
    assert tracer.runs[0].outputs == {"output": 6}
    assert len(tracer.runs[0].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
    run_ids = tracer.run_ids
    assert run_id in run_ids
    assert len(run_ids) == len(set(run_ids))
    tracer.runs.clear()

    tracer = FakeTracer()
    run_id = uuid.uuid4()

    with pytest.warns(RuntimeWarning):
        assert runnable.batch(
            [None, None], {"callbacks": [tracer], "run_id": run_id}
        ) == [6, 6]
    assert len(tracer.runs) == 2
    assert tracer.runs[0].outputs == {"output": 6}
    assert tracer.runs[1].outputs == {"output": 6}
    assert len(tracer.runs[0].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
    assert len(tracer.runs[1].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]


@pytest.mark.skipif(
    sys.version_info < (3, 11),
    reason="Python 3.10 and below don't support running "
    "async tasks in a specific context",
)
async def test_runnable_gen_context_config_async() -> None:
    """Test generator runnable config propagation.

    Test that a generator can call other runnables with config
    propagated from the context.
    """
    fake = RunnableLambda(len)

    async def agen(_: AsyncIterator[Any]) -> AsyncIterator[int]:
        yield await fake.ainvoke("a")
        yield await fake.ainvoke("aa")
        yield await fake.ainvoke("aaa")

    arunnable = RunnableGenerator(agen)

    tracer = FakeTracer()

    run_id = uuid.uuid4()
    assert await arunnable.ainvoke(None, {"callbacks": [tracer], "run_id": run_id}) == 6
    assert len(tracer.runs) == 1
    assert tracer.runs[0].outputs == {"output": 6}
    assert len(tracer.runs[0].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
    run_ids = tracer.run_ids
    assert run_id in run_ids
    assert len(run_ids) == len(set(run_ids))
    tracer.runs.clear()

    assert [p async for p in arunnable.astream(None)] == [1, 2, 3]
    assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"

    tracer = FakeTracer()
    run_id = uuid.uuid4()
    assert [
        p
        async for p in arunnable.astream(
            None, {"callbacks": [tracer], "run_id": run_id}
        )
    ] == [
        1,
        2,
        3,
    ]
    assert len(tracer.runs) == 1
    assert tracer.runs[0].outputs == {"output": 6}
    assert len(tracer.runs[0].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
    run_ids = tracer.run_ids
    assert run_id in run_ids
    assert len(run_ids) == len(set(run_ids))

    tracer = FakeTracer()
    run_id = uuid.uuid4()
    with pytest.warns(RuntimeWarning):
        assert await arunnable.abatch(
            [None, None], {"callbacks": [tracer], "run_id": run_id}
        ) == [6, 6]
    assert len(tracer.runs) == 2
    assert tracer.runs[0].outputs == {"output": 6}
    assert tracer.runs[1].outputs == {"output": 6}
    assert len(tracer.runs[0].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
    assert len(tracer.runs[1].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]


def test_runnable_iter_context_config() -> None:
    """Test generator runnable config propagation.

    Test that a generator can call other runnables with config
    propagated from the context.
    """
    fake = RunnableLambda(len)

    @chain
    def gen(value: str) -> Iterator[int]:
        yield fake.invoke(value)
        yield fake.invoke(value * 2)
        yield fake.invoke(value * 3)

    assert gen.get_input_jsonschema() == {
        "title": "gen_input",
        "type": "string",
    }
    assert gen.get_output_jsonschema() == {
        "title": "gen_output",
        "type": "integer",
    }

    tracer = FakeTracer()
    assert gen.invoke("a", {"callbacks": [tracer]}) == 6
    assert len(tracer.runs) == 1
    assert tracer.runs[0].outputs == {"output": 6}
    assert len(tracer.runs[0].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
    tracer.runs.clear()

    assert list(gen.stream("a")) == [1, 2, 3]
    assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"

    tracer = FakeTracer()
    assert list(gen.stream("a", {"callbacks": [tracer]})) == [1, 2, 3]
    assert len(tracer.runs) == 1
    assert tracer.runs[0].outputs == {"output": 6}
    assert len(tracer.runs[0].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]

    tracer = FakeTracer()
    assert gen.batch(["a", "a"], {"callbacks": [tracer]}) == [6, 6]
    assert len(tracer.runs) == 2
    assert tracer.runs[0].outputs == {"output": 6}
    assert tracer.runs[1].outputs == {"output": 6}
    assert len(tracer.runs[0].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
    assert len(tracer.runs[1].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]


@pytest.mark.skipif(
    sys.version_info < (3, 11),
    reason="Python 3.10 and below don't support running "
    "async tasks in a specific context",
)
async def test_runnable_iter_context_config_async() -> None:
    """Test generator runnable config propagation.

    Test that a generator can call other runnables with config
    propagated from the context.
    """
    fake = RunnableLambda(len)

    @chain
    async def agen(value: str) -> AsyncIterator[int]:
        yield await fake.ainvoke(value)
        yield await fake.ainvoke(value * 2)
        yield await fake.ainvoke(value * 3)

    assert agen.get_input_jsonschema() == {
        "title": "agen_input",
        "type": "string",
    }
    assert agen.get_output_jsonschema() == {
        "title": "agen_output",
        "type": "integer",
    }

    tracer = FakeTracer()
    assert await agen.ainvoke("a", {"callbacks": [tracer]}) == 6
    assert len(tracer.runs) == 1
    assert tracer.runs[0].outputs == {"output": 6}
    assert len(tracer.runs[0].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
    tracer.runs.clear()

    assert [p async for p in agen.astream("a")] == [1, 2, 3]
    assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"

    tracer = FakeTracer()
    assert [p async for p in agen.astream("a", {"callbacks": [tracer]})] == [
        1,
        2,
        3,
    ]
    assert len(tracer.runs) == 1
    assert tracer.runs[0].outputs == {"output": 6}
    assert len(tracer.runs[0].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]

    tracer = FakeTracer()
    assert [p async for p in agen.astream_log("a", {"callbacks": [tracer]})]
    assert len(tracer.runs) == 1
    assert tracer.runs[0].outputs == {"output": 6}
    assert len(tracer.runs[0].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]

    tracer = FakeTracer()
    assert await agen.abatch(["a", "a"], {"callbacks": [tracer]}) == [6, 6]
    assert len(tracer.runs) == 2
    assert tracer.runs[0].outputs == {"output": 6}
    assert tracer.runs[1].outputs == {"output": 6}
    assert len(tracer.runs[0].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
    assert len(tracer.runs[1].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]


def test_runnable_lambda_context_config() -> None:
    """Test function runnable config propagation.

    Test that a function can call other runnables with config
    propagated from the context.
    """
    fake = RunnableLambda(len)

    @chain
    def fun(value: str) -> int:
        output = fake.invoke(value)
        output += fake.invoke(value * 2)
        output += fake.invoke(value * 3)
        return output

    assert fun.get_input_jsonschema() == {"title": "fun_input", "type": "string"}
    assert fun.get_output_jsonschema() == {
        "title": "fun_output",
        "type": "integer",
    }

    tracer = FakeTracer()
    assert fun.invoke("a", {"callbacks": [tracer]}) == 6
    assert len(tracer.runs) == 1
    assert tracer.runs[0].outputs == {"output": 6}
    assert len(tracer.runs[0].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
    tracer.runs.clear()

    assert list(fun.stream("a")) == [6]
    assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"

    tracer = FakeTracer()
    assert list(fun.stream("a", {"callbacks": [tracer]})) == [6]
    assert len(tracer.runs) == 1
    assert tracer.runs[0].outputs == {"output": 6}
    assert len(tracer.runs[0].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]

    tracer = FakeTracer()
    assert fun.batch(["a", "a"], {"callbacks": [tracer]}) == [6, 6]
    assert len(tracer.runs) == 2
    assert tracer.runs[0].outputs == {"output": 6}
    assert tracer.runs[1].outputs == {"output": 6}
    assert len(tracer.runs[0].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
    assert len(tracer.runs[1].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]


@pytest.mark.skipif(
    sys.version_info < (3, 11),
    reason="Python 3.10 and below don't support running "
    "async tasks in a specific context",
)
async def test_runnable_lambda_context_config_async() -> None:
    """Test function runnable config propagation.

    Test that a function can call other runnables with config
    propagated from the context.
    """
    fake = RunnableLambda(len)

    @chain
    async def afun(value: str) -> int:
        output = await fake.ainvoke(value)
        output += await fake.ainvoke(value * 2)
        output += await fake.ainvoke(value * 3)
        return output

    assert afun.get_input_jsonschema() == {"title": "afun_input", "type": "string"}
    assert afun.get_output_jsonschema() == {
        "title": "afun_output",
        "type": "integer",
    }

    tracer = FakeTracer()
    assert await afun.ainvoke("a", {"callbacks": [tracer]}) == 6
    assert len(tracer.runs) == 1
    assert tracer.runs[0].outputs == {"output": 6}
    assert len(tracer.runs[0].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
    tracer.runs.clear()

    assert [p async for p in afun.astream("a")] == [6]
    assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"

    tracer = FakeTracer()
    assert [p async for p in afun.astream("a", {"callbacks": [tracer]})] == [6]
    assert len(tracer.runs) == 1
    assert tracer.runs[0].outputs == {"output": 6}
    assert len(tracer.runs[0].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]

    tracer = FakeTracer()
    assert await afun.abatch(["a", "a"], {"callbacks": [tracer]}) == [6, 6]
    assert len(tracer.runs) == 2
    assert tracer.runs[0].outputs == {"output": 6}
    assert tracer.runs[1].outputs == {"output": 6}
    assert len(tracer.runs[0].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
    assert len(tracer.runs[1].child_runs) == 3
    assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"]
    assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3]


async def test_runnable_gen_transform() -> None:
    """Test that a generator can be used as a runnable."""

    def gen_indexes(length_iter: Iterator[int]) -> Iterator[int]:
        yield from range(next(length_iter))

    async def agen_indexes(length_iter: AsyncIterator[int]) -> AsyncIterator[int]:
        async for length in length_iter:
            for i in range(length):
                yield i

    def plus_one(ints: Iterator[int]) -> Iterator[int]:
        for i in ints:
            yield i + 1

    async def aplus_one(ints: AsyncIterator[int]) -> AsyncIterator[int]:
        async for i in ints:
            yield i + 1

    chain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | plus_one
    achain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one

    assert chain.get_input_jsonschema() == {
        "title": "gen_indexes_input",
        "type": "integer",
    }
    assert chain.get_output_jsonschema() == {
        "title": "plus_one_output",
        "type": "integer",
    }
    assert achain.get_input_jsonschema() == {
        "title": "gen_indexes_input",
        "type": "integer",
    }
    assert achain.get_output_jsonschema() == {
        "title": "aplus_one_output",
        "type": "integer",
    }

    assert list(chain.stream(3)) == [1, 2, 3]
    assert [p async for p in achain.astream(4)] == [1, 2, 3, 4]


def test_with_config_callbacks() -> None:
    result = RunnableLambda(lambda x: x).with_config({"callbacks": []})
    # Bugfix from version 0.0.325
    # ConfigError: field "callbacks" not yet prepared so type is still a ForwardRef,
    # you might need to call RunnableConfig.update_forward_refs().
    assert isinstance(result, RunnableBinding)


async def test_ainvoke_on_returned_runnable() -> None:
    """Test ainvoke on a returned runnable.

    Verify that a runnable returned by a sync runnable in the async path will
    be runthroughaasync path (issue #13407).
    """

    def idchain_sync(_input: dict[str, Any], /) -> bool:
        return False

    async def idchain_async(_input: dict[str, Any], /) -> bool:
        return True

    idchain = RunnableLambda(func=idchain_sync, afunc=idchain_async)

    def func(_input: dict[str, Any], /) -> Runnable[dict[str, Any], bool]:
        return idchain

    assert await RunnableLambda(func).ainvoke({})


def test_invoke_stream_passthrough_assign_trace() -> None:
    def idchain_sync(_input: dict, /) -> bool:
        return False

    chain = RunnablePassthrough.assign(urls=idchain_sync)

    tracer = FakeTracer()
    chain.invoke({"example": [1, 2, 3]}, {"callbacks": [tracer]})

    assert tracer.runs[0].name == "RunnableAssign<urls>"
    assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"

    tracer = FakeTracer()
    for _ in chain.stream({"example": [1, 2, 3]}, {"callbacks": [tracer]}):
        pass

    assert tracer.runs[0].name == "RunnableAssign<urls>"
    assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"


async def test_ainvoke_astream_passthrough_assign_trace() -> None:
    def idchain_sync(_input: dict, /) -> bool:
        return False

    chain = RunnablePassthrough.assign(urls=idchain_sync)

    tracer = FakeTracer()
    await chain.ainvoke({"example": [1, 2, 3]}, {"callbacks": [tracer]})

    assert tracer.runs[0].name == "RunnableAssign<urls>"
    assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"

    tracer = FakeTracer()
    async for _ in chain.astream({"example": [1, 2, 3]}, {"callbacks": [tracer]}):
        pass

    assert tracer.runs[0].name == "RunnableAssign<urls>"
    assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"


async def test_astream_log_deep_copies() -> None:
    """Verify that deep copies are used when using jsonpatch in astream log.

    jsonpatch re-uses objects in its API; e.g.,

    import jsonpatch
    obj1 = { "a": 1 }
    value = { "b": 2 }
    obj2 = { "a": 1, "value": value }

    ops = list(jsonpatch.JsonPatch.from_diff(obj1, obj2))
    assert id(ops[0]['value']) == id(value)

    This can create unexpected consequences for downstream code.
    """

    def _get_run_log(run_log_patches: Sequence[RunLogPatch]) -> RunLog:
        """Get run log."""
        run_log = RunLog(state=None)  # type: ignore[arg-type]
        for log_patch in run_log_patches:
            run_log += log_patch
        return run_log

    def add_one(x: int) -> int:
        """Add one."""
        return x + 1

    chain = RunnableLambda(add_one)
    chunks = []
    final_output: RunLogPatch | None = None
    async for chunk in chain.astream_log(1):
        chunks.append(chunk)
        final_output = chunk if final_output is None else final_output + chunk

    run_log = _get_run_log(chunks)
    state = run_log.state.copy()
    # Ignoring type here since we know that the state is a dict
    # so we can delete `id` for testing purposes
    state.pop("id")  # type: ignore[misc]
    assert state == {
        "final_output": 2,
        "logs": {},
        "streamed_output": [2],
        "name": "add_one",
        "type": "chain",
    }


def test_transform_of_runnable_lambda_with_dicts() -> None:
    """Test transform of runnable lamdbda."""
    runnable = RunnableLambda(lambda x: x)
    chunks = iter(
        [
            {"foo": "n"},
        ]
    )
    assert list(runnable.transform(chunks)) == [{"foo": "n"}]

    # Test as part of a sequence
    seq = runnable | runnable
    chunks = iter(
        [
            {"foo": "n"},
        ]
    )
    assert list(seq.transform(chunks)) == [{"foo": "n"}]
    # Test some other edge cases
    assert list(seq.stream({"foo": "n"})) == [{"foo": "n"}]


async def test_atransform_of_runnable_lambda_with_dicts() -> None:
    async def identity(x: dict[str, str]) -> dict[str, str]:
        """Return x."""
        return x

    runnable = RunnableLambda(identity)

    async def chunk_iterator() -> AsyncIterator[dict[str, str]]:
        yield {"foo": "a"}
        yield {"foo": "n"}

    chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]
    assert chunks == [{"foo": "n"}]

    seq = runnable | runnable
    chunks = [chunk async for chunk in seq.atransform(chunk_iterator())]
    assert chunks == [{"foo": "n"}]


def test_default_transform_with_dicts() -> None:
    """Test that default transform works with dicts."""

    class CustomRunnable(RunnableSerializable[Input, Output]):
        @override
        def invoke(
            self, input: Input, config: RunnableConfig | None = None, **kwargs: Any
        ) -> Output:
            return cast("Output", input)

    runnable = CustomRunnable[dict[str, str], dict[str, str]]()
    chunks = iter(
        [
            {"foo": "a"},
            {"foo": "n"},
        ]
    )

    assert list(runnable.transform(chunks)) == [{"foo": "n"}]
    assert list(runnable.stream({"foo": "n"})) == [{"foo": "n"}]


async def test_default_atransform_with_dicts() -> None:
    """Test that default transform works with dicts."""

    class CustomRunnable(RunnableSerializable[Input, Output]):
        @override
        def invoke(
            self, input: Input, config: RunnableConfig | None = None, **kwargs: Any
        ) -> Output:
            return cast("Output", input)

    runnable = CustomRunnable[dict[str, str], dict[str, str]]()

    async def chunk_iterator() -> AsyncIterator[dict[str, str]]:
        yield {"foo": "a"}
        yield {"foo": "n"}

    chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]

    assert chunks == [{"foo": "n"}]

    # Test with addable dict
    async def chunk_iterator_with_addable() -> AsyncIterator[dict[str, str]]:
        yield AddableDict({"foo": "a"})
        yield AddableDict({"foo": "n"})

    chunks = [
        chunk async for chunk in runnable.atransform(chunk_iterator_with_addable())
    ]

    assert chunks == [{"foo": "an"}]


def test_passthrough_transform_with_dicts() -> None:
    """Test that default transform works with dicts."""
    runnable = RunnablePassthrough(lambda x: x)
    chunks = list(runnable.transform(iter([{"foo": "a"}, {"foo": "n"}])))
    assert chunks == [{"foo": "a"}, {"foo": "n"}]


async def test_passthrough_atransform_with_dicts() -> None:
    """Test that default transform works with dicts."""
    runnable = RunnablePassthrough(lambda x: x)

    async def chunk_iterator() -> AsyncIterator[dict[str, str]]:
        yield {"foo": "a"}
        yield {"foo": "n"}

    chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]
    assert chunks == [{"foo": "a"}, {"foo": "n"}]


def test_listeners() -> None:
    def fake_chain(inputs: dict[str, str]) -> dict[str, str]:
        return {**inputs, "key": "extra"}

    shared_state = {}
    value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
    value2 = {"inputs": {"name": "two"}, "outputs": {"name": "two"}}

    def on_start(run: Run) -> None:
        shared_state[run.id] = {"inputs": run.inputs}

    def on_end(run: Run) -> None:
        shared_state[run.id]["outputs"] = run.inputs

    chain = (
        RunnableLambda(fake_chain)
        .with_listeners(on_end=on_end, on_start=on_start)
        .map()
    )

    data = [{"name": "one"}, {"name": "two"}]
    chain.invoke(data, config={"max_concurrency": 1})
    assert len(shared_state) == 2
    assert value1 in shared_state.values(), "Value not found in the dictionary."
    assert value2 in shared_state.values(), "Value not found in the dictionary."


async def test_listeners_async() -> None:
    def fake_chain(inputs: dict[str, str]) -> dict[str, str]:
        return {**inputs, "key": "extra"}

    shared_state = {}
    value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
    value2 = {"inputs": {"name": "two"}, "outputs": {"name": "two"}}

    def on_start(run: Run) -> None:
        shared_state[run.id] = {"inputs": run.inputs}

    def on_end(run: Run) -> None:
        shared_state[run.id]["outputs"] = run.inputs

    chain = (
        RunnableLambda(fake_chain)
        .with_listeners(on_end=on_end, on_start=on_start)
        .map()
    )

    data = [{"name": "one"}, {"name": "two"}]
    await chain.ainvoke(data, config={"max_concurrency": 1})

    assert len(shared_state) == 2
    assert value1 in shared_state.values(), "Value not found in the dictionary."
    assert value2 in shared_state.values(), "Value not found in the dictionary."


def test_closing_iterator_doesnt_raise_error() -> None:
    """Test that closing an iterator calls on_chain_end rather than on_chain_error."""
    on_chain_error_triggered = False
    on_chain_end_triggered = False

    class MyHandler(BaseCallbackHandler):
        @override
        def on_chain_error(
            self,
            error: BaseException,
            *,
            run_id: UUID,
            parent_run_id: UUID | None = None,
            tags: list[str] | None = None,
            **kwargs: Any,
        ) -> None:
            """Run when chain errors."""
            nonlocal on_chain_error_triggered
            on_chain_error_triggered = True

        @override
        def on_chain_end(
            self,
            outputs: dict[str, Any],
            *,
            run_id: UUID,
            parent_run_id: UUID | None = None,
            **kwargs: Any,
        ) -> None:
            nonlocal on_chain_end_triggered
            on_chain_end_triggered = True

    llm = GenericFakeChatModel(messages=iter(["hi there"]))
    chain = llm | StrOutputParser()
    chain_ = chain.with_config({"callbacks": [MyHandler()]})
    st = chain_.stream("hello")
    next(st)
    # This is a generator so close is defined on it.
    st.close()  # type: ignore[attr-defined]
    # Wait for a bit to make sure that the callback is called.
    time.sleep(0.05)
    assert on_chain_error_triggered is False
    assert on_chain_end_triggered is True


def test_pydantic_protected_namespaces() -> None:
    # Check that protected namespaces (e.g., `model_kwargs`) do not raise warnings
    with warnings.catch_warnings():
        warnings.simplefilter("error")

        class CustomChatModel(RunnableSerializable[str, str]):
            model_kwargs: dict[str, Any] = Field(default_factory=dict)


def test_schema_for_prompt_and_chat_model() -> None:
    """Test schema generation for prompt and chat model.

    Testing that schema is generated properly when using variable names
    that collide with pydantic attributes.
    """
    prompt = ChatPromptTemplate([("system", "{model_json_schema}, {_private}, {json}")])
    chat_res = "i'm a chatbot"
    # sleep to better simulate a real stream
    chat = FakeListChatModel(responses=[chat_res], sleep=0.01)
    chain = prompt | chat
    assert (
        chain.invoke(
            {
                "model_json_schema": "hello",
                "_private": "goodbye",
                "json": "json",
            }
        ).content
        == chat_res
    )

    assert chain.get_input_jsonschema() == {
        "properties": {
            "model_json_schema": {"title": "Model Json Schema", "type": "string"},
            "_private": {"title": "Private", "type": "string"},
            "json": {"title": "Json", "type": "string"},
        },
        "required": [
            "_private",
            "json",
            "model_json_schema",
        ],
        "title": "PromptInput",
        "type": "object",
    }


def test_runnable_assign() -> None:
    def add_ten(x: dict[str, int]) -> dict[str, int]:
        return {"added": x["input"] + 10}

    mapper = RunnableParallel({"add_step": RunnableLambda(add_ten)})
    runnable_assign = RunnableAssign(mapper)

    result = runnable_assign.invoke({"input": 5})
    assert result == {"input": 5, "add_step": {"added": 15}}


class _Foo(TypedDict):
    foo: str


class _InputData(_Foo):
    bar: str


def test_runnable_typed_dict_schema() -> None:
    """Testing that the schema is generated properly(not empty) when using TypedDict.

    subclasses to annotate the arguments of a RunnableParallel children.
    """

    def forward_foo(input_data: _InputData) -> str:
        return input_data["foo"]

    def transform_input(input_data: _InputData) -> dict[str, str]:
        foo = input_data["foo"]
        bar = input_data["bar"]

        return {"transformed": foo + bar}

    foo_runnable = RunnableLambda(forward_foo)
    other_runnable = RunnableLambda(transform_input)

    parallel = RunnableParallel(
        foo=foo_runnable,
        other=other_runnable,
    )
    assert (
        repr(parallel.input_schema.model_validate({"foo": "Y", "bar": "Z"}))
        == "RunnableParallel<foo,other>Input(root={'foo': 'Y', 'bar': 'Z'})"
    )
