MalteHB 2 weeks ago committed by GitHub
commit a587085f17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,89 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "719619d3",
"metadata": {},
"source": [
"# ONNX Embeddings\n",
"\n",
"[ONNX compatible-runtime with any model from the HuggingFace](https://huggingface.co/).\n",
"\n",
"\n",
"This notebook shows how to use `ONNXEmbeddings` with a BGE model created by the [Beijing Academy of Artificial Intelligence (BAAI)](https://www.baai.ac.cn/english.html) from `Hugging Face`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f7a54279",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"!pip install optimum transformers onnxruntime onnx"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9e1d5b6b",
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.embeddings import ONNXEmbeddings\n",
"from langchain_community.embeddings.huggingface import DEFAULT_QUERY_BGE_INSTRUCTION_EN\n",
"\n",
"model_name = \"BAAI/bge-large-en-v1.5\"\n",
"query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_EN\n",
"\n",
"model_kwargs = {\"device\": \"cpu\"}\n",
"\n",
"model = ONNXEmbeddings(model_name=model_name, model_kwargs=model_kwargs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e59d1a89",
"metadata": {},
"outputs": [],
"source": [
"embedding = model.embed_query(\n",
" \"How awesome is it, that we are now able to use ONNX-compatible runtimes in LangChain?\"\n",
")\n",
"len(embedding)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e596315f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -159,6 +159,9 @@ if TYPE_CHECKING:
from langchain_community.embeddings.ollama import (
OllamaEmbeddings,
)
from langchain_community.embeddings.onnx_embeddings import (
ONNXEmbeddings,
)
from langchain_community.embeddings.openai import (
OpenAIEmbeddings,
)
@ -290,6 +293,7 @@ __all__ = [
"TensorflowHubEmbeddings",
"VertexAIEmbeddings",
"VolcanoEmbeddings",
"ONNXEmbeddings",
"VoyageEmbeddings",
"XinferenceEmbeddings",
"YandexGPTEmbeddings",
@ -370,6 +374,7 @@ _module_lookup = {
"TitanTakeoffEmbed": "langchain_community.embeddings.titan_takeoff",
"PremAIEmbeddings": "langchain_community.embeddings.premai",
"YandexGPTEmbeddings": "langchain_community.embeddings.yandex",
"ONNXEmbeddings": "langchain_community.embeddings.onnx_embeddings",
}

@ -0,0 +1,90 @@
from typing import Any, Dict, List
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, Field
class ONNXEmbeddings(BaseModel, Embeddings):
"""ONNX Embedding models.
Example:
.. code-block:: python
from langchain_community.embeddings import ONNXEmbeddings
from langchain_community.embeddings.huggingface import (
DEFAULT_QUERY_BGE_INSTRUCTION_EN
)
model_name = "BAAI/bge-large-en"
model_kwargs = {'device': 'cpu'}
query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_EN
onnx_emb = ONNXEmbeddings(
model_name,
model_kwargs=model_kwargs,
query_instruction=query_instruction,
)
"""
client: Any #: :meta private:
tokenizer: Any #: :meta private:
model_name: str
"""The name of the HuggingFace model to transform to ONNX format."""
query_instruction: str = Field(default="query:")
"""Instruction to use for embedding query."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Key word arguments to pass to the model."""
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Key word arguments to pass when tokenizing the input text."""
def __init__(
self, model_name: str, client: Any = None, tokenizer: Any = None, **kwargs: Any
) -> None:
try:
from optimum.onnxruntime import ORTModelForFeatureExtraction
from transformers import AutoTokenizer
except ImportError as exc:
raise ImportError(
"Please ensure the required packages are installed with "
"`pip install optimum transformers onnxruntime onnx`."
) from exc
if not tokenizer:
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
if not client:
client = ORTModelForFeatureExtraction.from_pretrained(
self.model_name, export=True
)
super().__init__(
model_name=model_name, tokenizer=tokenizer, client=client, **kwargs
)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using an ONNX model.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
texts = [t.replace("\n", " ") for t in texts]
inputs = self.tokenizer(
texts, padding=True, truncation=True, return_tensors="pt"
)
embeddings = self.client(**inputs)
return embeddings.last_hidden_state.mean(dim=1).tolist()
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using an ONNX model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
text = text.replace("\n", " ")
text = self.query_instruction + text if self.query_instruction else text
inputs = self.tokenizer([text], return_tensors="pt", **self.encode_kwargs)
embedding = self.client(**inputs)
return embedding.last_hidden_state.mean(dim=1).squeeze().tolist()
Loading…
Cancel
Save