update _aget_len_safe_embeddings

cc/fix_openai
Chester Curme 1 month ago
parent c1aa237bc2
commit 6a5263e4bd

@ -414,7 +414,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
_iter, tokens, indices = self._tokenize(texts, _chunk_size)
batched_embeddings: List[List[float]] = []
_chunk_size = chunk_size or self.chunk_size
for i in range(0, len(tokens), _chunk_size):
for i in _iter:
response = await self.async_client.create(
input=tokens[i : i + _chunk_size], **self._invocation_params
)
@ -426,6 +426,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
results: List[List[List[float]]] = [[] for _ in range(len(texts))]
num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
for i in range(len(indices)):
if self.skip_empty and len(batched_embeddings[i]) == 1:
continue
results[indices[i]].append(batched_embeddings[i])
num_tokens_in_batch[indices[i]].append(len(tokens[i]))

Loading…
Cancel
Save