From 3bb9bec3149b3f1d429c56546df4a84ebf98202d Mon Sep 17 00:00:00 2001 From: ccurme Date: Thu, 9 May 2024 11:37:03 -0400 Subject: [PATCH] bedrock: add unit test for retriever (#21485) This was implemented in https://github.com/langchain-ai/langchain/pull/21349 but dropped before merge. --- .../unit_tests/retrievers/test_bedrock.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 libs/community/tests/unit_tests/retrievers/test_bedrock.py diff --git a/libs/community/tests/unit_tests/retrievers/test_bedrock.py b/libs/community/tests/unit_tests/retrievers/test_bedrock.py new file mode 100644 index 0000000000..ad2da543d4 --- /dev/null +++ b/libs/community/tests/unit_tests/retrievers/test_bedrock.py @@ -0,0 +1,68 @@ +from typing import List +from unittest.mock import MagicMock + +import pytest +from langchain_core.documents import Document + +from langchain_community.retrievers import AmazonKnowledgeBasesRetriever + + +@pytest.fixture +def mock_client() -> MagicMock: + return MagicMock() + + +@pytest.fixture +def mock_retriever_config() -> dict: + return {"vectorSearchConfiguration": {"numberOfResults": 4}} + + +@pytest.fixture +def amazon_retriever( + mock_client: MagicMock, mock_retriever_config: dict +) -> AmazonKnowledgeBasesRetriever: + return AmazonKnowledgeBasesRetriever( + knowledge_base_id="test_kb_id", + retrieval_config=mock_retriever_config, + client=mock_client, + ) + + +def test_create_client(amazon_retriever: AmazonKnowledgeBasesRetriever) -> None: + with pytest.raises(ImportError): + amazon_retriever.create_client({}) + + +def test_get_relevant_documents( + amazon_retriever: AmazonKnowledgeBasesRetriever, mock_client: MagicMock +) -> None: + query: str = "test query" + mock_client.retrieve.return_value = { + "retrievalResults": [ + {"content": {"text": "result1"}, "metadata": {"key": "value1"}}, + { + "content": {"text": "result2"}, + "metadata": {"key": "value2"}, + "score": 1, + "location": "testLocation", + }, + {"content": {"text": "result3"}}, + ] + } + documents: List[Document] = amazon_retriever._get_relevant_documents( + query, + run_manager=None, # type: ignore + ) + + assert len(documents) == 3 + assert isinstance(documents[0], Document) + assert documents[0].page_content == "result1" + assert documents[0].metadata == {"score": 0, "source_metadata": {"key": "value1"}} + assert documents[1].page_content == "result2" + assert documents[1].metadata == { + "score": 1, + "source_metadata": {"key": "value2"}, + "location": "testLocation", + } + assert documents[2].page_content == "result3" + assert documents[2].metadata == {"score": 0}