"""Test in memory indexer."""

import pytest
from langchain_tests.integration_tests.indexer import (
    AsyncDocumentIndexTestSuite,
    DocumentIndexerTestSuite,
)
from typing_extensions import override

from langchain_core.documents import Document
from langchain_core.indexing.in_memory import (
    InMemoryDocumentIndex,
)


class TestDocumentIndexerTestSuite(DocumentIndexerTestSuite):
    @pytest.fixture
    @override
    def index(self) -> InMemoryDocumentIndex:
        return InMemoryDocumentIndex()


class TestAsyncDocumentIndexerTestSuite(AsyncDocumentIndexTestSuite):
    # Something funky is going on with mypy and async pytest fixture
    @pytest.fixture
    @override
    async def index(self) -> InMemoryDocumentIndex:
        return InMemoryDocumentIndex()


def test_sync_retriever() -> None:
    index = InMemoryDocumentIndex()
    documents = [
        Document(id="1", page_content="hello world"),
        Document(id="2", page_content="goodbye cat"),
    ]
    index.upsert(documents)
    assert index.invoke("hello") == [documents[0], documents[1]]
    assert index.invoke("cat") == [documents[1], documents[0]]


async def test_async_retriever() -> None:
    index = InMemoryDocumentIndex()
    documents = [
        Document(id="1", page_content="hello world"),
        Document(id="2", page_content="goodbye cat"),
    ]
    await index.aupsert(documents)
    assert (await index.ainvoke("hello")) == [documents[0], documents[1]]
    assert (await index.ainvoke("cat")) == [documents[1], documents[0]]
