diff --git a/libs/community/langchain_community/embeddings/deepinfra.py b/libs/community/langchain_community/embeddings/deepinfra.py index 046e1e481f..4dbc0e4dad 100644 --- a/libs/community/langchain_community/embeddings/deepinfra.py +++ b/libs/community/langchain_community/embeddings/deepinfra.py @@ -6,6 +6,7 @@ from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator from langchain_core.utils import get_from_dict_or_env DEFAULT_MODEL_ID = "sentence-transformers/clip-ViT-B-32" +MAX_BATCH_SIZE = 1024 class DeepInfraEmbeddings(BaseModel, Embeddings): @@ -47,8 +48,11 @@ class DeepInfraEmbeddings(BaseModel, Embeddings): """Instruction used to embed the query.""" model_kwargs: Optional[dict] = None """Other model keyword args""" - deepinfra_api_token: Optional[str] = None + """API token for Deep Infra. If not provided, the token is + fetched from the environment variable 'DEEPINFRA_API_TOKEN'.""" + batch_size: int = MAX_BATCH_SIZE + """Batch size for embedding requests.""" class Config: """Configuration for this pydantic object.""" @@ -103,6 +107,8 @@ class DeepInfraEmbeddings(BaseModel, Embeddings): def embed_documents(self, texts: List[str]) -> List[List[float]]: """Embed documents using a Deep Infra deployed embedding model. + For larger batches, the input list of texts is chunked into smaller + batches to avoid exceeding the maximum request size. Args: texts: The list of texts to embed. @@ -110,8 +116,17 @@ class DeepInfraEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ + + embeddings = [] instruction_pairs = [f"{self.embed_instruction}{text}" for text in texts] - embeddings = self._embed(instruction_pairs) + + chunks = [ + instruction_pairs[i : i + self.batch_size] + for i in range(0, len(instruction_pairs), self.batch_size) + ] + for chunk in chunks: + embeddings += self._embed(chunk) + return embeddings def embed_query(self, text: str) -> List[float]: diff --git a/libs/community/tests/integration_tests/embeddings/test_deepinfra.py b/libs/community/tests/integration_tests/embeddings/test_deepinfra.py index f3a418ed23..8a72a5edcb 100644 --- a/libs/community/tests/integration_tests/embeddings/test_deepinfra.py +++ b/libs/community/tests/integration_tests/embeddings/test_deepinfra.py @@ -17,3 +17,13 @@ def test_deepinfra_call() -> None: assert len(r1[1]) == 768 r2 = deepinfra_emb.embed_query("What is the third letter of Greek alphabet") assert len(r2) == 768 + + +def test_deepinfra_call_with_large_batch_size() -> None: + deepinfra_emb = DeepInfraEmbeddings(model_id="BAAI/bge-base-en-v1.5") + texts = 2000 * [ + "Alpha is the first letter of Greek alphabet", + ] + r1 = deepinfra_emb.embed_documents(texts) + assert len(r1) == 2000 + assert len(r1[0]) == 768