From 5b35f077f913d5a28093a049c5e6c4eef84c8f0d Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Thu, 9 May 2024 00:45:42 +0300 Subject: [PATCH] [community][fix](DeepInfraEmbeddings): Implement chunking for large batches (#21189) **Description:** This PR introduces chunking logic to the `DeepInfraEmbeddings` class to handle large batch sizes without exceeding maximum batch size of the backend. This enhancement ensures that embedding generation processes large batches by breaking them down into smaller, manageable chunks, each conforming to the maximum batch size limit. **Issue:** Fixes #21189 **Dependencies:** No new dependencies introduced. --- .../embeddings/deepinfra.py | 19 +++++++++++++++++-- .../embeddings/test_deepinfra.py | 10 ++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/libs/community/langchain_community/embeddings/deepinfra.py b/libs/community/langchain_community/embeddings/deepinfra.py index 046e1e481f..4dbc0e4dad 100644 --- a/libs/community/langchain_community/embeddings/deepinfra.py +++ b/libs/community/langchain_community/embeddings/deepinfra.py @@ -6,6 +6,7 @@ from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator from langchain_core.utils import get_from_dict_or_env DEFAULT_MODEL_ID = "sentence-transformers/clip-ViT-B-32" +MAX_BATCH_SIZE = 1024 class DeepInfraEmbeddings(BaseModel, Embeddings): @@ -47,8 +48,11 @@ class DeepInfraEmbeddings(BaseModel, Embeddings): """Instruction used to embed the query.""" model_kwargs: Optional[dict] = None """Other model keyword args""" - deepinfra_api_token: Optional[str] = None + """API token for Deep Infra. If not provided, the token is + fetched from the environment variable 'DEEPINFRA_API_TOKEN'.""" + batch_size: int = MAX_BATCH_SIZE + """Batch size for embedding requests.""" class Config: """Configuration for this pydantic object.""" @@ -103,6 +107,8 @@ class DeepInfraEmbeddings(BaseModel, Embeddings): def embed_documents(self, texts: List[str]) -> List[List[float]]: """Embed documents using a Deep Infra deployed embedding model. + For larger batches, the input list of texts is chunked into smaller + batches to avoid exceeding the maximum request size. Args: texts: The list of texts to embed. @@ -110,8 +116,17 @@ class DeepInfraEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ + + embeddings = [] instruction_pairs = [f"{self.embed_instruction}{text}" for text in texts] - embeddings = self._embed(instruction_pairs) + + chunks = [ + instruction_pairs[i : i + self.batch_size] + for i in range(0, len(instruction_pairs), self.batch_size) + ] + for chunk in chunks: + embeddings += self._embed(chunk) + return embeddings def embed_query(self, text: str) -> List[float]: diff --git a/libs/community/tests/integration_tests/embeddings/test_deepinfra.py b/libs/community/tests/integration_tests/embeddings/test_deepinfra.py index f3a418ed23..8a72a5edcb 100644 --- a/libs/community/tests/integration_tests/embeddings/test_deepinfra.py +++ b/libs/community/tests/integration_tests/embeddings/test_deepinfra.py @@ -17,3 +17,13 @@ def test_deepinfra_call() -> None: assert len(r1[1]) == 768 r2 = deepinfra_emb.embed_query("What is the third letter of Greek alphabet") assert len(r2) == 768 + + +def test_deepinfra_call_with_large_batch_size() -> None: + deepinfra_emb = DeepInfraEmbeddings(model_id="BAAI/bge-base-en-v1.5") + texts = 2000 * [ + "Alpha is the first letter of Greek alphabet", + ] + r1 = deepinfra_emb.embed_documents(texts) + assert len(r1) == 2000 + assert len(r1[0]) == 768