import os
import re
import sys
from contextlib import AbstractContextManager, nullcontext
from copy import deepcopy
from typing import TYPE_CHECKING, Any
from unittest.mock import patch

import pytest
from pydantic import BaseModel, Field, SecretStr
from pydantic.v1 import BaseModel as PydanticV1BaseModel
from pydantic.v1 import Field as PydanticV1Field

from langchain_core import utils
from langchain_core.outputs import GenerationChunk
from langchain_core.utils import (
    check_package_version,
    from_env,
    get_pydantic_field_names,
    guard_import,
)
from langchain_core.utils._merge import merge_dicts, merge_lists, merge_obj
from langchain_core.utils.utils import secret_from_env

if TYPE_CHECKING:
    from collections.abc import Callable


@pytest.mark.parametrize(
    ("package", "check_kwargs", "actual_version", "expected"),
    [
        ("stub", {"gt_version": "0.1"}, "0.1.2", None),
        ("stub", {"gt_version": "0.1.2"}, "0.1.12", None),
        ("stub", {"gt_version": "0.1.2"}, "0.1.2", (ValueError, "> 0.1.2")),
        ("stub", {"gte_version": "0.1"}, "0.1.2", None),
        ("stub", {"gte_version": "0.1.2"}, "0.1.2", None),
    ],
)
def test_check_package_version(
    package: str,
    check_kwargs: dict[str, str | None],
    actual_version: str,
    expected: tuple[type[Exception], str] | None,
) -> None:
    with patch("langchain_core.utils.utils.version", return_value=actual_version):
        if expected is None:
            check_package_version(package, **check_kwargs)
        else:
            with pytest.raises(expected[0], match=expected[1]):
                check_package_version(package, **check_kwargs)


@pytest.mark.parametrize(
    ("left", "right", "expected"),
    [
        # Merge `None` and `1`.
        ({"a": None}, {"a": 1}, {"a": 1}),
        # Merge `1` and `None`.
        ({"a": 1}, {"a": None}, {"a": 1}),
        # Merge `None` and a value.
        ({"a": None}, {"a": 0}, {"a": 0}),
        ({"a": None}, {"a": "txt"}, {"a": "txt"}),
        # Merge equal values.
        ({"a": 1}, {"a": 1}, {"a": 1}),
        ({"a": 1.5}, {"a": 1.5}, {"a": 1.5}),
        ({"a": True}, {"a": True}, {"a": True}),
        ({"a": False}, {"a": False}, {"a": False}),
        ({"a": "txt"}, {"a": "txt"}, {"a": "txttxt"}),
        ({"a": [1, 2]}, {"a": [1, 2]}, {"a": [1, 2, 1, 2]}),
        ({"a": {"b": "txt"}}, {"a": {"b": "txt"}}, {"a": {"b": "txttxt"}}),
        # Merge strings.
        ({"a": "one"}, {"a": "two"}, {"a": "onetwo"}),
        # Merge dicts.
        ({"a": {"b": 1}}, {"a": {"c": 2}}, {"a": {"b": 1, "c": 2}}),
        (
            {"function_call": {"arguments": None}},
            {"function_call": {"arguments": "{\n"}},
            {"function_call": {"arguments": "{\n"}},
        ),
        # Merge lists.
        ({"a": [1, 2]}, {"a": [3]}, {"a": [1, 2, 3]}),
        ({"a": 1, "b": 2}, {"a": 1}, {"a": 1, "b": 2}),
        ({"a": 1, "b": 2}, {"c": None}, {"a": 1, "b": 2, "c": None}),
        #
        # Invalid inputs.
        #
        (
            {"a": 1},
            {"a": "1"},
            pytest.raises(
                TypeError,
                match=re.escape(
                    'additional_kwargs["a"] already exists in this message, '
                    "but with a different type."
                ),
            ),
        ),
        (
            {"a": (1, 2)},
            {"a": (3,)},
            pytest.raises(
                TypeError,
                match=(
                    "Additional kwargs key a already exists in left dict and value "
                    r"has unsupported type .+tuple.+."
                ),
            ),
        ),
        # 'index' keyword has special handling
        (
            {"a": [{"index": 0, "b": "{"}]},
            {"a": [{"index": 0, "b": "f"}]},
            {"a": [{"index": 0, "b": "{f"}]},
        ),
        (
            {"a": [{"idx": 0, "b": "{"}]},
            {"a": [{"idx": 0, "b": "f"}]},
            {"a": [{"idx": 0, "b": "{"}, {"idx": 0, "b": "f"}]},
        ),
    ],
)
def test_merge_dicts(
    left: dict, right: dict, expected: dict | AbstractContextManager
) -> None:
    err = expected if isinstance(expected, AbstractContextManager) else nullcontext()

    left_copy = deepcopy(left)
    right_copy = deepcopy(right)
    with err:
        actual = merge_dicts(left, right)
        assert actual == expected
        # no mutation
        assert left == left_copy
        assert right == right_copy


@pytest.mark.parametrize(
    ("left", "right", "expected"),
    [
        # 'type' special key handling
        ({"type": "foo"}, {"type": "foo"}, {"type": "foo"}),
        (
            {"type": "foo"},
            {"type": "bar"},
            pytest.raises(ValueError, match="Unable to merge"),
        ),
    ],
)
@pytest.mark.xfail(reason="Refactors to make in 0.3")
def test_merge_dicts_0_3(
    left: dict, right: dict, expected: dict | AbstractContextManager
) -> None:
    err = expected if isinstance(expected, AbstractContextManager) else nullcontext()

    left_copy = deepcopy(left)
    right_copy = deepcopy(right)
    with err:
        actual = merge_dicts(left, right)
        assert actual == expected
        # no mutation
        assert left == left_copy
        assert right == right_copy


@pytest.mark.parametrize(
    ("module_name", "pip_name", "package", "expected"),
    [
        ("langchain_core.utils", None, None, utils),
        ("langchain_core.utils", "langchain-core", None, utils),
        ("langchain_core.utils", None, "langchain-core", utils),
        ("langchain_core.utils", "langchain-core", "langchain-core", utils),
    ],
)
def test_guard_import(
    module_name: str, pip_name: str | None, package: str | None, expected: Any
) -> None:
    if package is None and pip_name is None:
        ret = guard_import(module_name)
    elif package is None and pip_name is not None:
        ret = guard_import(module_name, pip_name=pip_name)
    elif package is not None and pip_name is None:
        ret = guard_import(module_name, package=package)
    elif package is not None and pip_name is not None:
        ret = guard_import(module_name, pip_name=pip_name, package=package)
    else:
        msg = "Invalid test case"
        raise ValueError(msg)
    assert ret == expected


@pytest.mark.parametrize(
    ("module_name", "pip_name", "package", "expected_pip_name"),
    [
        ("langchain_core.utilsW", None, None, "langchain-core"),
        ("langchain_core.utilsW", "langchain-core-2", None, "langchain-core-2"),
        ("langchain_core.utilsW", None, "langchain-coreWX", "langchain-core"),
        (
            "langchain_core.utilsW",
            "langchain-core-2",
            "langchain-coreWX",
            "langchain-core-2",
        ),
        ("langchain_coreW", None, None, "langchain-coreW"),  # ModuleNotFoundError
    ],
)
def test_guard_import_failure(
    module_name: str,
    pip_name: str | None,
    package: str | None,
    expected_pip_name: str,
) -> None:
    with pytest.raises(
        ImportError,
        match=f"Could not import {module_name} python package. "
        f"Please install it with `pip install {expected_pip_name}`.",
    ):
        guard_import(module_name, pip_name=pip_name, package=package)


@pytest.mark.skipif(
    sys.version_info >= (3, 14),
    reason="pydantic.v1 namespace not supported with Python 3.14+",
)
def test_get_pydantic_field_names_v1_in_2() -> None:
    class PydanticV1Model(PydanticV1BaseModel):
        field1: str
        field2: int
        alias_field: int = PydanticV1Field(alias="aliased_field")

    result = get_pydantic_field_names(PydanticV1Model)
    expected = {"field1", "field2", "aliased_field", "alias_field"}
    assert result == expected


def test_get_pydantic_field_names_v2_in_2() -> None:
    class PydanticModel(BaseModel):
        field1: str
        field2: int
        alias_field: int = Field(alias="aliased_field")

    result = get_pydantic_field_names(PydanticModel)
    expected = {"field1", "field2", "aliased_field", "alias_field"}
    assert result == expected


def test_from_env_with_env_variable() -> None:
    key = "TEST_KEY"
    value = "test_value"
    with patch.dict(os.environ, {key: value}):
        get_value = from_env(key)
        assert get_value() == value


def test_from_env_with_default_value() -> None:
    key = "TEST_KEY"
    default_value = "default_value"
    with patch.dict(os.environ, {}, clear=True):
        get_value = from_env(key, default=default_value)
        assert get_value() == default_value


def test_from_env_with_error_message() -> None:
    key = "TEST_KEY"
    error_message = "Custom error message"
    with patch.dict(os.environ, {}, clear=True):
        get_value = from_env(key, error_message=error_message)
        with pytest.raises(ValueError, match=error_message):
            get_value()


def test_from_env_with_default_error_message() -> None:
    key = "TEST_KEY"
    with patch.dict(os.environ, {}, clear=True):
        get_value = from_env(key)
        with pytest.raises(ValueError, match=f"Did not find {key}"):
            get_value()


def test_secret_from_env_with_env_variable(monkeypatch: pytest.MonkeyPatch) -> None:
    # Set the environment variable
    monkeypatch.setenv("TEST_KEY", "secret_value")

    # Get the function
    get_secret: Callable[[], SecretStr | None] = secret_from_env("TEST_KEY")

    # Assert that it returns the correct value
    assert get_secret() == SecretStr("secret_value")


def test_secret_from_env_with_default_value(monkeypatch: pytest.MonkeyPatch) -> None:
    # Unset the environment variable
    monkeypatch.delenv("TEST_KEY", raising=False)

    # Get the function with a default value
    get_secret: Callable[[], SecretStr] = secret_from_env(
        "TEST_KEY", default="default_value"
    )

    # Assert that it returns the default value
    assert get_secret() == SecretStr("default_value")


def test_secret_from_env_with_none_default(monkeypatch: pytest.MonkeyPatch) -> None:
    # Unset the environment variable
    monkeypatch.delenv("TEST_KEY", raising=False)

    # Get the function with a default value of None
    get_secret: Callable[[], SecretStr | None] = secret_from_env(
        "TEST_KEY", default=None
    )

    # Assert that it returns None
    assert get_secret() is None


def test_secret_from_env_without_default_raises_error(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    # Unset the environment variable
    monkeypatch.delenv("TEST_KEY", raising=False)

    # Get the function without a default value
    get_secret: Callable[[], SecretStr] = secret_from_env("TEST_KEY")

    # Assert that it raises a ValueError with the correct message
    with pytest.raises(ValueError, match="Did not find TEST_KEY"):
        get_secret()


def test_secret_from_env_with_custom_error_message(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    # Unset the environment variable
    monkeypatch.delenv("TEST_KEY", raising=False)

    # Get the function without a default value but with a custom error message
    get_secret: Callable[[], SecretStr] = secret_from_env(
        "TEST_KEY", error_message="Custom error message"
    )

    # Assert that it raises a ValueError with the custom message
    with pytest.raises(ValueError, match="Custom error message"):
        get_secret()


def test_using_secret_from_env_as_default_factory(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    class Foo(BaseModel):
        secret: SecretStr = Field(default_factory=secret_from_env("TEST_KEY"))

    # Pass the secret as a parameter
    foo = Foo(secret="super_secret")
    assert foo.secret.get_secret_value() == "super_secret"

    # Set the environment variable
    monkeypatch.setenv("TEST_KEY", "secret_value")
    assert Foo().secret.get_secret_value() == "secret_value"

    class Bar(BaseModel):
        secret: SecretStr | None = Field(
            default_factory=secret_from_env("TEST_KEY_2", default=None)
        )

    assert Bar().secret is None

    class Buzz(BaseModel):
        secret: SecretStr | None = Field(
            default_factory=secret_from_env("TEST_KEY_2", default="hello")
        )

    # We know it will be SecretStr rather than SecretStr | None
    assert Buzz().secret.get_secret_value() == "hello"  # type: ignore[union-attr]

    class OhMy(BaseModel):
        secret: SecretStr | None = Field(
            default_factory=secret_from_env("FOOFOOFOOBAR")
        )

    with pytest.raises(ValueError, match="Did not find FOOFOOFOOBAR"):
        OhMy()


def test_generation_chunk_addition_type_error() -> None:
    chunk1 = GenerationChunk(text="", generation_info={"len": 0})
    chunk2 = GenerationChunk(text="Non-empty text", generation_info={"len": 14})
    result = chunk1 + chunk2
    assert result == GenerationChunk(text="Non-empty text", generation_info={"len": 14})


@pytest.mark.parametrize(
    ("left", "right", "expected"),
    [
        # Both None
        (None, None, None),
        # Left None
        (None, [1, 2], [1, 2]),
        # Right None
        ([1, 2], None, [1, 2]),
        # Simple merge
        ([1, 2], [3, 4], [1, 2, 3, 4]),
        # Empty lists
        ([], [], []),
        ([], [1], [1]),
        ([1], [], [1]),
        # Merge with index handling
        (
            [{"index": 0, "text": "hello"}],
            [{"index": 0, "text": " world"}],
            [{"index": 0, "text": "hello world"}],
        ),
        # Multiple elements with different indexes
        (
            [{"index": 0, "a": "x"}],
            [{"index": 1, "b": "y"}],
            [{"index": 0, "a": "x"}, {"index": 1, "b": "y"}],
        ),
        # Elements without index key
        (
            [{"no_index": "a"}],
            [{"no_index": "b"}],
            [{"no_index": "a"}, {"no_index": "b"}],
        ),
    ],
)
def test_merge_lists(
    left: list | None, right: list | None, expected: list | None
) -> None:
    left_copy = deepcopy(left)
    right_copy = deepcopy(right)
    actual = merge_lists(left, right)
    assert actual == expected
    # Verify no mutation
    assert left == left_copy
    assert right == right_copy


def test_merge_lists_multiple_others() -> None:
    """Test `merge_lists` with multiple lists."""
    result = merge_lists([1], [2], [3])
    assert result == [1, 2, 3]


def test_merge_lists_all_none() -> None:
    """Test `merge_lists` with all `None` arguments."""
    result = merge_lists(None, None, None)
    assert result is None


@pytest.mark.parametrize(
    ("left", "right", "expected"),
    [
        # Both None
        (None, None, None),
        # Left None
        (None, "hello", "hello"),
        # Right None
        ("hello", None, "hello"),
        # String merge
        ("hello", " world", "hello world"),
        # Dict merge
        ({"a": 1}, {"b": 2}, {"a": 1, "b": 2}),
        # List merge
        ([1, 2], [3], [1, 2, 3]),
        # Equal values
        (42, 42, 42),
        (3.14, 3.14, 3.14),
        (True, True, True),
    ],
)
def test_merge_obj(left: Any, right: Any, expected: Any) -> None:
    actual = merge_obj(left, right)
    assert actual == expected


def test_merge_obj_type_mismatch() -> None:
    """Test `merge_obj` raises `TypeError` on type mismatch."""
    with pytest.raises(TypeError, match="left and right are of different types"):
        merge_obj("string", 123)


def test_merge_obj_unmergeable_values() -> None:
    """Test `merge_obj` raises `ValueError` on unmergeable values."""
    with pytest.raises(ValueError, match="Unable to merge"):
        merge_obj(1, 2)  # Different integers


def test_merge_obj_tuple_raises() -> None:
    """Test `merge_obj` raises `ValueError` for tuples."""
    with pytest.raises(ValueError, match="Unable to merge"):
        merge_obj((1, 2), (3, 4))
