pull/20398/head
Eugene Yurtsev 2 months ago
parent 30ee3c4d9c
commit 5d43392ee3

@ -0,0 +1,227 @@
{
"cells": [
{
"cell_type": "raw",
"id": "a6c3a6e0-a94f-4d40-9022-2c7ac2380f6d",
"metadata": {},
"source": [
"---\n",
"sidebar_position: 0\n",
"---"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "c160026f-aadb-4e9f-8642-b4a9e8479d77",
"metadata": {},
"source": [
"# Custom Embeddings\n",
"\n",
"We'll explore how to create a custom embedding model using LangChain's Embeddings interface. Embeddings are critical in natural language processing applications as they convert text into a numerical form that algorithms can understand, thereby enabling a wide range of applications such as similarity search, text classification, and clustering.\n",
"\n",
"Implementing embeddings using the standard `Embeddings` interface will allow your embeddings to be utilized in existing `LangChain` abstractions (e.g., as the embeddings for a particular `Vectorstore` or cached using `CacheBackedEmbeddings`).\n",
"\n",
"## Interface\n",
"\n",
"The current `Embeddings` abstraction in LangChain is designed to operate on text data. In this implementation, the inputs are either single strings or lists of strings, and the outputs are lists of numerical arrays (vectors), where each vector represents\n",
"an embedding of the input text in some n-dimensional space.\n",
"\n",
"Your custom embedding must implement the following methods:\n",
"\n",
"| Method/Property | Description | Required/Optional |\n",
"|---------------------------------|----------------------------------------------------------------------------|-------------------|\n",
"| `embed_documents(texts)` | Generates embeddings for a list of documents. | Required |\n",
"| `embed_query(text)` | Generates an embedding for a single text query. | Required |\n",
"| `aembed_documents(texts)` | Asynchronously generates embeddings for a list of documents. | Optional |\n",
"| `aembed_query(text)` | Asynchronously generates an embedding for a single text query. | Optional |\n",
"\n",
"These methods ensure that your embedding model can be integrated seamlessly into the LangChain framework, providing both synchronous and asynchronous capabilities for scalability and performance optimization.\n",
"\n",
":::{.callout-note}\n",
"`embed_documents` takes in a list of plain text, not a list of LangChain `Document` objects. The name of this method\n",
"may change in future versions of LangChain.\n",
":::\n",
"\n",
"\n",
":::{.callout-important}\n",
"`Embeddings` do not currently implement the `Runnable` interface and are also **not** instances of pydantic `BaseModel`.\n",
":::"
]
},
{
"cell_type": "markdown",
"id": "2162547f-4577-47e8-b12f-e9aa3c243797",
"metadata": {},
"source": [
"## Implementation\n",
"\n",
"As an example, we'll implement a simple embeddings model that will count the characters in the text and generate a fixed size vector containing the character counts. The model will be case insensitive, and either count the characters from a-z or only the vowels (a, e, i, o, u). This model is for illustrative purposes only."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "6b838062-552c-43f8-94f8-d17e4ae4c221",
"metadata": {},
"outputs": [],
"source": [
"from collections import Counter\n",
"from typing import List\n",
"\n",
"from langchain_core.embeddings import Embeddings\n",
"\n",
"\n",
"class CharCountEmbeddings(Embeddings):\n",
" \"\"\"Embedding model that counts occurrences of characters in text.\n",
"\n",
" When contributing an implementation to LangChain, carefully document\n",
" the embedding model including the initialization parameters, include\n",
" an example of how to initialize the model and include any relevant\n",
" links to the underlying models documentation or API.\n",
"\n",
" Example:\n",
"\n",
" .. code-block:: python\n",
"\n",
" from langchain_community.embeddings import CharCountEmbeddings\n",
"\n",
" embeddings = ChatCountEmbeddings(only_vowels=True)\n",
" print(embeddings.embed_documents([\"Hello world\", \"Test\"]))\n",
" print(embeddings.embed_query(\"Quick Brown Fox\"))\n",
" \"\"\"\n",
"\n",
" def __init__(self, *, only_vowels: bool = False) -> None:\n",
" \"\"\"Initialize the embedding model.\n",
"\n",
" Args:\n",
" only_vowels: If True, the embedding will count only the\n",
" vowels (a, e, i, o, u) and produce a 5-dimensional vector.\n",
" If False, counts all lowercase alphabetic characters,\n",
" producing a 26-dimensional vector.\n",
" \"\"\"\n",
"\n",
" self.only_vowels = only_vowels\n",
"\n",
" def embed_documents(self, texts: List[str]) -> List[List[float]]:\n",
" \"\"\"Embed multiple documents by counting specific character sets.\"\"\"\n",
" return [self._embed_text(text) for text in texts]\n",
"\n",
" def embed_query(self, text: str) -> List[float]:\n",
" \"\"\"Embed a single query by counting specific character sets.\"\"\"\n",
" return self._embed_text(text)\n",
"\n",
" def _embed_text(self, text: str) -> List[float]:\n",
" \"\"\"Helper function to create a character count vector from text.\"\"\"\n",
" text = text.lower() # Normalize text to lowercase for case insensitivity.\n",
" count = Counter(text)\n",
" if self.only_vowels:\n",
" # Embed only vowels\n",
" vowels = \"aeiou\"\n",
" return [count.get(vowel, 0) for vowel in vowels]\n",
" else:\n",
" # Embed all letters from 'a' to 'z'\n",
" return [count.get(chr(i), 0) for i in range(ord(\"a\"), ord(\"z\") + 1)]\n",
"\n",
" # The async methods are optional.\n",
" # Delete them if you do not have an actual async imlementation.\n",
" async def aembed_documents(self, texts: List[str]) -> List[List[float]]:\n",
" \"\"\"Asynchronous embed search docs.\"\"\"\n",
" # This implementation is only for illustrative purposes.\n",
" # If you're connecting to an API, you should provide\n",
" # an actual async implementation (e.g., using httpx AsyncClient\n",
" # https://www.python-httpx.org/async/).\n",
" # If you do not have an actual async implementation, please\n",
" # DELETE this method as LangChain already provides a first pass\n",
" # optimization which involves delegating to the sync method.\n",
" # If you do not have a native async implementation, just delete this\n",
" # method. LangChain basically does this\n",
" return [self._embed_text(text) for text in texts]\n",
"\n",
" async def aembed_query(self, text: str) -> List[float]:\n",
" \"\"\"Asynchronous embed query text.\"\"\"\n",
" # See comment above for the aembed_documents regarding\n",
" # native async implementation\n",
" return self._embed_text(text)"
]
},
{
"cell_type": "markdown",
"id": "47a19044-5c3f-40da-889a-1a1cfffc137c",
"metadata": {},
"source": [
"### Let's test it 🧪"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "21c218fe-8f91-437f-b523-c2b6e5cf749e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[1, 1, 0, 0, 0], [0, 3, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 1, 0]]\n",
"[0, 4, 0, 0, 0]\n"
]
}
],
"source": [
"embeddings = CharCountEmbeddings(only_vowels=True)\n",
"print(embeddings.embed_documents([\"abce\", \"eee\", \"hello\", \"fox\"]))\n",
"print(embeddings.embed_query(\"eeee\"))"
]
},
{
"cell_type": "markdown",
"id": "de50f690-178e-4561-af98-14967b3c8501",
"metadata": {},
"source": [
"## Contributing\n",
"\n",
"We welcome contributions of Embedding models to the LangChain code base!\n",
"\n",
"Here's a checklist to help make sure your contribution gets added to LangChain:\n",
"\n",
"Documentation:\n",
"\n",
"* The model contains doc-strings for all initialization arguments, as these will be surfaced in the [API Reference](https://api.python.langchain.com/en/stable/langchain_api_reference.html).\n",
"* The class doc-string for the model contains a link to the model API if the model is powered by a service.\n",
"\n",
"Tests:\n",
"\n",
"* [ ] Add an integration tests to test the integration with the API or model.\n",
"\n",
"Optimizations:\n",
"\n",
"If your implementation is an integration with an `API` consider providing async native support (e.g., via httpx AsyncClient).\n",
" \n",
"* [ ] Provided a native async of `aembed_documents`\n",
"* [ ] Provided a native async of `aembed_query`"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading…
Cancel
Save