diff --git a/libs/partners/chroma/langchain_chroma/vectorstores.py b/libs/partners/chroma/langchain_chroma/vectorstores.py index 5e5394507f..d5cb86d534 100644 --- a/libs/partners/chroma/langchain_chroma/vectorstores.py +++ b/libs/partners/chroma/langchain_chroma/vectorstores.py @@ -60,7 +60,8 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: Y = np.array(Y) if X.shape[1] != Y.shape[1]: raise ValueError( - f"Number of columns in X and Y must be the same. X has shape {X.shape} " + "Number of columns in X and Y must be the same. X has shape" + f"{X.shape} " f"and Y has shape {Y.shape}." ) @@ -133,6 +134,7 @@ class Chroma(VectorStore): collection_metadata: Optional[Dict] = None, client: Optional[chromadb.ClientAPI] = None, relevance_score_fn: Optional[Callable[[float], float]] = None, + create_collection_if_not_exists: Optional[bool] = True, ) -> None: """Initialize with a Chroma client.""" @@ -161,11 +163,14 @@ class Chroma(VectorStore): ) self._embedding_function = embedding_function - self._collection = self._client.get_or_create_collection( - name=collection_name, - embedding_function=None, - metadata=collection_metadata, - ) + if create_collection_if_not_exists: + self._collection = self._client.get_or_create_collection( + name=collection_name, + embedding_function=None, + metadata=collection_metadata, + ) + else: + self._collection = self._client.get_collection(name=collection_name) self.override_relevance_score_fn = relevance_score_fn @property @@ -650,7 +655,8 @@ class Chroma(VectorStore): """ return self.update_documents([document_id], [document]) - def update_documents(self, ids: List[str], documents: List[Document]) -> None: # type: ignore + # type: ignore + def update_documents(self, ids: List[str], documents: List[Document]) -> None: """Update a document in the collection. Args: diff --git a/libs/partners/chroma/tests/integration_tests/test_vectorstores.py b/libs/partners/chroma/tests/integration_tests/test_vectorstores.py index 156a825621..97018f8733 100644 --- a/libs/partners/chroma/tests/integration_tests/test_vectorstores.py +++ b/libs/partners/chroma/tests/integration_tests/test_vectorstores.py @@ -1,10 +1,12 @@ """Test Chroma functionality.""" import uuid +from typing import Generator import chromadb import pytest import requests +from chromadb.api.client import SharedSystemClient from langchain_core.documents import Document from langchain_core.embeddings.fake import FakeEmbeddings as Fak @@ -15,6 +17,13 @@ from tests.integration_tests.fake_embeddings import ( ) +@pytest.fixture() +def client() -> Generator[chromadb.ClientAPI, None, None]: + SharedSystemClient.clear_system_cache() + client = chromadb.Client(chromadb.config.Settings()) + yield client + + def test_chroma() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -271,10 +280,7 @@ def test_chroma_with_relevance_score_custom_normalization_fn() -> None: ] -def test_init_from_client() -> None: - import chromadb - - client = chromadb.Client(chromadb.config.Settings()) +def test_init_from_client(client: chromadb.ClientAPI) -> None: Chroma(client=client) @@ -414,3 +420,72 @@ def test_chroma_legacy_batching() -> None: ) db.delete_collection() + + +def test_create_collection_if_not_exist_default() -> None: + """Tests existing behaviour without the new create_collection_if_not_exists flag.""" + texts = ["foo", "bar", "baz"] + docsearch = Chroma.from_texts( + collection_name="test_collection", texts=texts, embedding=FakeEmbeddings() + ) + assert docsearch._client.get_collection("test_collection") is not None + docsearch.delete_collection() + + +def test_create_collection_if_not_exist_true_existing( + client: chromadb.ClientAPI, +) -> None: + """Tests create_collection_if_not_exists=True and collection already existing.""" + client.create_collection("test_collection") + vectorstore = Chroma( + client=client, + collection_name="test_collection", + embedding_function=FakeEmbeddings(), + create_collection_if_not_exists=True, + ) + assert vectorstore._client.get_collection("test_collection") is not None + vectorstore.delete_collection() + + +def test_create_collection_if_not_exist_false_existing( + client: chromadb.ClientAPI, +) -> None: + """Tests create_collection_if_not_exists=False and collection already existing.""" + client.create_collection("test_collection") + vectorstore = Chroma( + client=client, + collection_name="test_collection", + embedding_function=FakeEmbeddings(), + create_collection_if_not_exists=False, + ) + assert vectorstore._client.get_collection("test_collection") is not None + vectorstore.delete_collection() + + +def test_create_collection_if_not_exist_false_non_existing( + client: chromadb.ClientAPI, +) -> None: + """Tests create_collection_if_not_exists=False and collection not-existing, + should raise.""" + with pytest.raises(Exception, match="does not exist"): + Chroma( + client=client, + collection_name="test_collection", + embedding_function=FakeEmbeddings(), + create_collection_if_not_exists=False, + ) + + +def test_create_collection_if_not_exist_true_non_existing( + client: chromadb.ClientAPI, +) -> None: + """Tests create_collection_if_not_exists=True and collection non-existing. .""" + vectorstore = Chroma( + client=client, + collection_name="test_collection", + embedding_function=FakeEmbeddings(), + create_collection_if_not_exists=True, + ) + + assert vectorstore._client.get_collection("test_collection") is not None + vectorstore.delete_collection()