[community][fix](DeepInfraEmbeddings): Implement chunking for large batches (#21189)

**Description:**
This PR introduces chunking logic to the `DeepInfraEmbeddings` class to
handle large batch sizes without exceeding maximum batch size of the
backend. This enhancement ensures that embedding generation processes
large batches by breaking them down into smaller, manageable chunks,
each conforming to the maximum batch size limit.

**Issue:**
Fixes #21189

**Dependencies:**
No new dependencies introduced.
pull/21454/head
Oguz Vuruskaner 2 weeks ago committed by GitHub
parent f4ddf64faa
commit 5b35f077f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -6,6 +6,7 @@ from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
DEFAULT_MODEL_ID = "sentence-transformers/clip-ViT-B-32"
MAX_BATCH_SIZE = 1024
class DeepInfraEmbeddings(BaseModel, Embeddings):
@ -47,8 +48,11 @@ class DeepInfraEmbeddings(BaseModel, Embeddings):
"""Instruction used to embed the query."""
model_kwargs: Optional[dict] = None
"""Other model keyword args"""
deepinfra_api_token: Optional[str] = None
"""API token for Deep Infra. If not provided, the token is
fetched from the environment variable 'DEEPINFRA_API_TOKEN'."""
batch_size: int = MAX_BATCH_SIZE
"""Batch size for embedding requests."""
class Config:
"""Configuration for this pydantic object."""
@ -103,6 +107,8 @@ class DeepInfraEmbeddings(BaseModel, Embeddings):
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed documents using a Deep Infra deployed embedding model.
For larger batches, the input list of texts is chunked into smaller
batches to avoid exceeding the maximum request size.
Args:
texts: The list of texts to embed.
@ -110,8 +116,17 @@ class DeepInfraEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
embeddings = []
instruction_pairs = [f"{self.embed_instruction}{text}" for text in texts]
embeddings = self._embed(instruction_pairs)
chunks = [
instruction_pairs[i : i + self.batch_size]
for i in range(0, len(instruction_pairs), self.batch_size)
]
for chunk in chunks:
embeddings += self._embed(chunk)
return embeddings
def embed_query(self, text: str) -> List[float]:

@ -17,3 +17,13 @@ def test_deepinfra_call() -> None:
assert len(r1[1]) == 768
r2 = deepinfra_emb.embed_query("What is the third letter of Greek alphabet")
assert len(r2) == 768
def test_deepinfra_call_with_large_batch_size() -> None:
deepinfra_emb = DeepInfraEmbeddings(model_id="BAAI/bge-base-en-v1.5")
texts = 2000 * [
"Alpha is the first letter of Greek alphabet",
]
r1 = deepinfra_emb.embed_documents(texts)
assert len(r1) == 2000
assert len(r1[0]) == 768

Loading…
Cancel
Save