mirror of https://github.com/hwchase17/langchain
Merge 1d9b77e08f
into 242eeb537f
commit
a587085f17
@ -0,0 +1,89 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "719619d3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# ONNX Embeddings\n",
|
||||
"\n",
|
||||
"[ONNX compatible-runtime with any model from the HuggingFace](https://huggingface.co/).\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"This notebook shows how to use `ONNXEmbeddings` with a BGE model created by the [Beijing Academy of Artificial Intelligence (BAAI)](https://www.baai.ac.cn/english.html) from `Hugging Face`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f7a54279",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install optimum transformers onnxruntime onnx"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9e1d5b6b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.embeddings import ONNXEmbeddings\n",
|
||||
"from langchain_community.embeddings.huggingface import DEFAULT_QUERY_BGE_INSTRUCTION_EN\n",
|
||||
"\n",
|
||||
"model_name = \"BAAI/bge-large-en-v1.5\"\n",
|
||||
"query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_EN\n",
|
||||
"\n",
|
||||
"model_kwargs = {\"device\": \"cpu\"}\n",
|
||||
"\n",
|
||||
"model = ONNXEmbeddings(model_name=model_name, model_kwargs=model_kwargs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e59d1a89",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embedding = model.embed_query(\n",
|
||||
" \"How awesome is it, that we are now able to use ONNX-compatible runtimes in LangChain?\"\n",
|
||||
")\n",
|
||||
"len(embedding)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e596315f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,90 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field
|
||||
|
||||
|
||||
class ONNXEmbeddings(BaseModel, Embeddings):
|
||||
"""ONNX Embedding models.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.embeddings import ONNXEmbeddings
|
||||
from langchain_community.embeddings.huggingface import (
|
||||
DEFAULT_QUERY_BGE_INSTRUCTION_EN
|
||||
)
|
||||
|
||||
model_name = "BAAI/bge-large-en"
|
||||
model_kwargs = {'device': 'cpu'}
|
||||
query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_EN
|
||||
onnx_emb = ONNXEmbeddings(
|
||||
model_name,
|
||||
model_kwargs=model_kwargs,
|
||||
query_instruction=query_instruction,
|
||||
)
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
tokenizer: Any #: :meta private:
|
||||
model_name: str
|
||||
"""The name of the HuggingFace model to transform to ONNX format."""
|
||||
query_instruction: str = Field(default="query:")
|
||||
"""Instruction to use for embedding query."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Key word arguments to pass to the model."""
|
||||
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Key word arguments to pass when tokenizing the input text."""
|
||||
|
||||
def __init__(
|
||||
self, model_name: str, client: Any = None, tokenizer: Any = None, **kwargs: Any
|
||||
) -> None:
|
||||
try:
|
||||
from optimum.onnxruntime import ORTModelForFeatureExtraction
|
||||
from transformers import AutoTokenizer
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please ensure the required packages are installed with "
|
||||
"`pip install optimum transformers onnxruntime onnx`."
|
||||
) from exc
|
||||
|
||||
if not tokenizer:
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
if not client:
|
||||
client = ORTModelForFeatureExtraction.from_pretrained(
|
||||
self.model_name, export=True
|
||||
)
|
||||
super().__init__(
|
||||
model_name=model_name, tokenizer=tokenizer, client=client, **kwargs
|
||||
)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using an ONNX model.
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
texts = [t.replace("\n", " ") for t in texts]
|
||||
inputs = self.tokenizer(
|
||||
texts, padding=True, truncation=True, return_tensors="pt"
|
||||
)
|
||||
embeddings = self.client(**inputs)
|
||||
return embeddings.last_hidden_state.mean(dim=1).tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using an ONNX model.
|
||||
Args:
|
||||
text: The text to embed.
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
text = self.query_instruction + text if self.query_instruction else text
|
||||
inputs = self.tokenizer([text], return_tensors="pt", **self.encode_kwargs)
|
||||
embedding = self.client(**inputs)
|
||||
return embedding.last_hidden_state.mean(dim=1).squeeze().tolist()
|
Loading…
Reference in New Issue