Merge branch 'master' into bagatur/0.2

pull/19246/head
Bagatur 3 months ago
commit 0495ca0d10

1
.gitignore vendored

@ -116,6 +116,7 @@ celerybeat.pid
.env
.envrc
.venv*
venv*
env/
ENV/
env.bak/

@ -14,19 +14,20 @@ For the most part, new integrations should be added to the Community package. Pa
In the following sections, we'll walk through how to contribute to each of these packages from a fake company, `Parrot Link AI`.
## Community Package
## Community package
The `langchain-community` package is in `libs/community` and contains most integrations.
It is installed by users with `pip install langchain-community`, and exported members can be imported with code like
It can be installed with `pip install langchain-community`, and exported members can be imported with code like
```python
from langchain_community.chat_models import ParrotLinkLLM
from langchain_community.llms import ChatParrotLink
from langchain_community.chat_models import ChatParrotLink
from langchain_community.llms import ParrotLinkLLM
from langchain_community.vectorstores import ParrotLinkVectorStore
```
The community package relies on manually-installed dependent packages, so you will see errors if you try to import a package that is not installed. In our fake example, if you tried to import `ParrotLinkLLM` without installing `parrot-link-sdk`, you will see an `ImportError` telling you to install it when trying to use it.
The `community` package relies on manually-installed dependent packages, so you will see errors
if you try to import a package that is not installed. In our fake example, if you tried to import `ParrotLinkLLM` without installing `parrot-link-sdk`, you will see an `ImportError` telling you to install it when trying to use it.
Let's say we wanted to implement a chat model for Parrot Link AI. We would create a new file in `libs/community/langchain_community/chat_models/parrot_link.py` with the following code:
@ -39,7 +40,7 @@ class ChatParrotLink(BaseChatModel):
Example:
.. code-block:: python
from langchain_parrot_link import ChatParrotLink
from langchain_community.chat_models import ChatParrotLink
model = ChatParrotLink()
"""
@ -56,9 +57,16 @@ And add documentation to:
- `docs/docs/integrations/chat/parrot_link.ipynb`
## Partner Packages
## Partner package in LangChain repo
Partner packages are in `libs/partners/*` and are installed by users with `pip install langchain-{partner}`, and exported members can be imported with code like
Partner packages can be hosted in the `LangChain` monorepo or in an external repo.
Partner package in the `LangChain` repo is placed in `libs/partners/{partner}`
and the package source code is in `libs/partners/{partner}/langchain_{partner}`.
A package is
installed by users with `pip install langchain-{partner}`, and the package members
can be imported with code like:
```python
from langchain_{partner} import X
@ -123,13 +131,49 @@ By default, this will include stubs for a Chat Model, an LLM, and/or a Vector St
### Write Unit and Integration Tests
Some basic tests are generated in the tests/ directory. You should add more tests to cover your package's functionality.
Some basic tests are presented in the `tests/` directory. You should add more tests to cover your package's functionality.
For information on running and implementing tests, see the [Testing guide](./testing).
### Write documentation
Documentation is generated from Jupyter notebooks in the `docs/` directory. You should move the generated notebooks to the relevant `docs/docs/integrations` directory in the monorepo root.
Documentation is generated from Jupyter notebooks in the `docs/` directory. You should place the notebooks with examples
to the relevant `docs/docs/integrations` directory in the monorepo root.
### (If Necessary) Deprecate community integration
Note: this is only necessary if you're migrating an existing community integration into
a partner package. If the component you're integrating is net-new to LangChain (i.e.
not already in the `community` package), you can skip this step.
Let's pretend we migrated our `ChatParrotLink` chat model from the community package to
the partner package. We would need to deprecate the old model in the community package.
We would do that by adding a `@deprecated` decorator to the old model as follows, in
`libs/community/langchain_community/chat_models/parrot_link.py`.
Before our change, our chat model might look like this:
```python
class ChatParrotLink(BaseChatModel):
...
```
After our change, it would look like this:
```python
from langchain_core._api.deprecation import deprecated
@deprecated(
since="0.0.<next community version>",
removal="0.2.0",
alternative_import="langchain_parrot_link.ChatParrotLink"
)
class ChatParrotLink(BaseChatModel):
...
```
You should do this for *each* component that you're migrating to the partner package.
### Additional steps
@ -143,3 +187,15 @@ Maintainer steps (Contributors should **not** do these):
- [ ] set up pypi and test pypi projects
- [ ] add credential secrets to Github Actions
- [ ] add package to conda-forge
## Partner package in external repo
If you are creating a partner package in an external repo, you should follow the same steps as above,
but you will need to set up your own CI/CD and package management.
Name your package as `langchain-{partner}-{integration}`.
Still, you have to create the `libs/partners/{partner}-{integration}` folder in the `LangChain` monorepo
and add a `README.md` file with a link to the external repo.
See this [example](https://github.com/langchain-ai/langchain/tree/master/libs/partners/google-genai).
This allows keeping track of all the partner packages in the `LangChain` documentation.

@ -20,9 +20,11 @@
]
},
{
"cell_type": "raw",
"cell_type": "code",
"id": "0f316b5c",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade --quiet langchain langchain-openai"
]

@ -20,9 +20,11 @@
]
},
{
"cell_type": "raw",
"cell_type": "code",
"id": "b3121aa8",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade --quiet langchain langchain-openai"
]

@ -23,7 +23,7 @@ We also are working to share guides and cookbooks that demonstrate how to use th
## LangSmith Evaluation
LangSmith provides an integrated evaluation and tracing framework that allows you to check for regressions, compare systems, and easily identify and fix any sources of errors and performance issues. Check out the docs on [LangSmith Evaluation](https://docs.smith.langchain.com/category/testing--evaluation) and additional [cookbooks](https://docs.smith.langchain.com/category/langsmith-cookbook) for more detailed information on evaluating your applications.
LangSmith provides an integrated evaluation and tracing framework that allows you to check for regressions, compare systems, and easily identify and fix any sources of errors and performance issues. Check out the docs on [LangSmith Evaluation](https://docs.smith.langchain.com/evaluation) and additional [cookbooks](https://docs.smith.langchain.com/cookbook) for more detailed information on evaluating your applications.
## LangChain benchmarks

@ -129,7 +129,7 @@
"Who was famed for their Christian spirit?\n",
"Who assimilted the Roman language?\n",
"Who ruled the country of Normandy?\n",
"What principality did William the conquerer found?\n",
"What principality did William the conqueror found?\n",
"What is the original meaning of the word Norman?\n",
"When was the Latin version of the word Norman first recorded?\n",
"What name comes from the English words Normans/Normanz?\"\"\"\n",

@ -22,7 +22,7 @@
"outputs": [],
"source": [
"# You need the dgml-utils package to use the DocugamiLoader (run pip install directly without \"poetry run\" if you are not using poetry)\n",
"!poetry run pip install dgml-utils==0.3.0 --upgrade --quiet"
"!poetry run pip install docugami-langchain dgml-utils==0.3.0 --upgrade --quiet"
]
},
{
@ -56,7 +56,7 @@
"source": [
"import os\n",
"\n",
"from langchain_community.document_loaders import DocugamiLoader"
"from docugami_langchain.document_loaders import DocugamiLoader"
]
},
{
@ -470,7 +470,7 @@
"source": [
"from typing import Dict, List\n",
"\n",
"from langchain_community.document_loaders import DocugamiLoader\n",
"from docugami_langchain.document_loaders import DocugamiLoader\n",
"from langchain_core.documents import Document\n",
"\n",
"loader = DocugamiLoader(docset_id=\"zo954yqy53wp\")\n",
@ -655,7 +655,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.1"
"version": "3.9.18"
}
},
"nbformat": 4,

@ -56,7 +56,7 @@ See the [usage example](/docs/integrations/memory/astradb_chat_message_history#e
```python
from langchain.globals import set_llm_cache
from langchain_community.cache import AstraDBCache
from langchain_astradb import AstraDBCache
set_llm_cache(AstraDBCache(
api_endpoint=ASTRA_DB_API_ENDPOINT,
@ -71,7 +71,7 @@ Learn more in the [example notebook](/docs/integrations/llms/llm_caching#astra-d
```python
from langchain.globals import set_llm_cache
from langchain_community.cache import
from langchain_astradb import AstraDBSemanticCache
set_llm_cache(AstraDBSemanticCache(
embedding=my_embedding,

@ -9,6 +9,7 @@
```bash
pip install dgml-utils
pip install docugami-langchain
```
## Document Loader
@ -16,5 +17,5 @@ pip install dgml-utils
See a [usage example](/docs/integrations/document_loaders/docugami).
```python
from langchain_community.document_loaders import DocugamiLoader
from docugami_langchain.document_loaders import DocugamiLoader
```

@ -28,17 +28,17 @@
},
"outputs": [],
"source": [
"% pip install --upgrade --quiet flashrank\n",
"% pip install --upgrade --quiet faiss\n",
"%pip install --upgrade --quiet flashrank\n",
"%pip install --upgrade --quiet faiss\n",
"\n",
"# OR (depending on Python version)\n",
"\n",
"% pip install --upgrade --quiet faiss_cpu"
"%pip install --upgrade --quiet faiss_cpu"
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 2,
"metadata": {
"collapsed": false,
"jupyter": {
@ -53,7 +53,10 @@
"def pretty_print_docs(docs):\n",
" print(\n",
" f\"\\n{'-' * 100}\\n\".join(\n",
" [f\"Document {i+1}:\\n\\n\" + d.page_content for i, d in enumerate(docs)]\n",
" [\n",
" f\"Document {i+1}:\\n\\n{d.page_content}\\nMetadata: {d.metadata}\"\n",
" for i, d in enumerate(docs)\n",
" ]\n",
" )\n",
" )"
]
@ -73,7 +76,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {
"collapsed": false,
"jupyter": {
@ -90,7 +93,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 4,
"metadata": {
"collapsed": false,
"jupyter": {
@ -247,14 +250,6 @@
"----------------------------------------------------------------------------------------------------\n",
"Document 15:\n",
"\n",
"My plan to fight inflation will lower your costs and lower the deficit. \n",
"\n",
"17 Nobel laureates in economics say my plan will ease long-term inflationary pressures. Top business leaders and most Americans support my plan. And heres the plan: \n",
"\n",
"First cut the cost of prescription drugs. Just look at insulin. One in ten Americans has diabetes. In Virginia, I met a 13-year-old boy named Joshua Davis.\n",
"----------------------------------------------------------------------------------------------------\n",
"Document 16:\n",
"\n",
"And soon, well strengthen the Violence Against Women Act that I first wrote three decades ago. It is important for us to show the nation that we can come together and do big things. \n",
"\n",
"So tonight Im offering a Unity Agenda for the Nation. Four big things we can do together. \n",
@ -263,15 +258,15 @@
"\n",
"There is so much we can do. Increase funding for prevention, treatment, harm reduction, and recovery.\n",
"----------------------------------------------------------------------------------------------------\n",
"Document 17:\n",
"Document 16:\n",
"\n",
"So lets not abandon our streets. Or choose between safety and equal justice. \n",
"My plan to fight inflation will lower your costs and lower the deficit. \n",
"\n",
"Lets come together to protect our communities, restore trust, and hold law enforcement accountable. \n",
"17 Nobel laureates in economics say my plan will ease long-term inflationary pressures. Top business leaders and most Americans support my plan. And heres the plan: \n",
"\n",
"Thats why the Justice Department required body cameras, banned chokeholds, and restricted no-knock warrants for its officers.\n",
"First cut the cost of prescription drugs. Just look at insulin. One in ten Americans has diabetes. In Virginia, I met a 13-year-old boy named Joshua Davis.\n",
"----------------------------------------------------------------------------------------------------\n",
"Document 18:\n",
"Document 17:\n",
"\n",
"My plan will not only lower costs to give families a fair shot, it will lower the deficit. \n",
"\n",
@ -281,6 +276,14 @@
"\n",
"Were going after the criminals who stole billions in relief money meant for small businesses and millions of Americans.\n",
"----------------------------------------------------------------------------------------------------\n",
"Document 18:\n",
"\n",
"So lets not abandon our streets. Or choose between safety and equal justice. \n",
"\n",
"Lets come together to protect our communities, restore trust, and hold law enforcement accountable. \n",
"\n",
"Thats why the Justice Department required body cameras, banned chokeholds, and restricted no-knock warrants for its officers.\n",
"----------------------------------------------------------------------------------------------------\n",
"Document 19:\n",
"\n",
"I understand. \n",
@ -316,6 +319,8 @@
").load()\n",
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)\n",
"texts = text_splitter.split_documents(documents)\n",
"for idx, text in enumerate(texts):\n",
" text.metadata[\"id\"] = idx\n",
"\n",
"embedding = OpenAIEmbeddings(model=\"text-embedding-ada-002\")\n",
"retriever = FAISS.from_documents(texts, embedding).as_retriever(search_kwargs={\"k\": 20})\n",
@ -340,16 +345,25 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0, 5, 3]\n"
]
}
],
"source": [
"from langchain.retrievers import ContextualCompressionRetriever, FlashrankRerank\n",
"from langchain.retrievers import ContextualCompressionRetriever\n",
"from langchain.retrievers.document_compressors import FlashrankRerank\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(temperature=0)\n",
@ -379,7 +393,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 6,
"metadata": {
"collapsed": false,
"jupyter": {
@ -399,29 +413,27 @@
"----------------------------------------------------------------------------------------------------\n",
"Document 2:\n",
"\n",
"And tonight, Im announcing that the Justice Department will name a chief prosecutor for pandemic fraud. \n",
"\n",
"By the end of this year, the deficit will be down to less than half what it was before I took office. \n",
"\n",
"The only president ever to cut the deficit by more than one trillion dollars in a single year. \n",
"He met the Ukrainian people. \n",
"\n",
"Lowering your costs also means demanding more competition. \n",
"From President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world. \n",
"\n",
"Im a capitalist, but capitalism without competition isnt capitalism. \n",
"Groups of citizens blocking tanks with their bodies. Everyone from students to retirees teachers turned soldiers defending their homeland. \n",
"\n",
"Its exploitation—and it drives up prices.\n",
"In this struggle as President Zelenskyy said in his speech to the European Parliament “Light will win over darkness.” The Ukrainian Ambassador to the United States is here tonight.\n",
"----------------------------------------------------------------------------------------------------\n",
"Document 3:\n",
"\n",
"As Ohio Senator Sherrod Brown says, “Its time to bury the label “Rust Belt.” \n",
"And tonight, Im announcing that the Justice Department will name a chief prosecutor for pandemic fraud. \n",
"\n",
"Its time. \n",
"By the end of this year, the deficit will be down to less than half what it was before I took office. \n",
"\n",
"But with all the bright spots in our economy, record job growth and higher wages, too many families are struggling to keep up with the bills. \n",
"The only president ever to cut the deficit by more than one trillion dollars in a single year. \n",
"\n",
"Inflation is robbing them of the gains they might otherwise feel. \n",
"Lowering your costs also means demanding more competition. \n",
"\n",
"Im a capitalist, but capitalism without competition isnt capitalism. \n",
"\n",
"I get it. Thats why my top priority is getting prices under control.\n"
"Its exploitation—and it drives up prices.\n"
]
}
],
@ -443,7 +455,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 7,
"metadata": {
"collapsed": false,
"jupyter": {
@ -459,7 +471,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 8,
"metadata": {
"collapsed": false,
"jupyter": {
@ -471,10 +483,10 @@
"data": {
"text/plain": [
"{'query': 'What did the president say about Ketanji Brown Jackson',\n",
" 'result': \"The President said that Ketanji Brown Jackson is one of our nation's top legal minds and will continue Justice Breyer's legacy of excellence.\"}"
" 'result': \"The President mentioned that Ketanji Brown Jackson is one of the nation's top legal minds and will continue Justice Breyer's legacy of excellence.\"}"
]
},
"execution_count": 19,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@ -500,7 +512,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.12.2"
}
},
"nbformat": 4,

@ -124,7 +124,7 @@
"outputs": [],
"source": [
"from langchain import hub\n",
"from langchain.agents import AgentExecutor, create_react_agent\n",
"from langchain.agents import AgentExecutor, create_openai_tools_agent\n",
"from langchain_openai import ChatOpenAI"
]
},
@ -135,8 +135,8 @@
"outputs": [],
"source": [
"llm = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
"prompt = hub.pull(\"hwchase17/react\")\n",
"agent = create_react_agent(\n",
"prompt = hub.pull(\"hwchase17/openai-tools-agent\")\n",
"agent = create_openai_tools_agent(\n",
" tools=toolkit.get_tools(),\n",
" llm=llm,\n",
" prompt=prompt,\n",
@ -151,7 +151,9 @@
"outputs": [],
"source": [
"agent_executor.invoke(\n",
" {\"input\": \"Send a greeting to my coworkers in the #general channel.\"}\n",
" {\n",
" \"input\": \"Send a greeting to my coworkers in the #general channel. Note use `channel` as key of channel id, and `message` as key of content to sent in the channel.\"\n",
" }\n",
")"
]
},

@ -0,0 +1,787 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "f63dfcf9-fd9d-4ac1-a0b3-c02d4dce7faf",
"metadata": {},
"source": [
"# Couchbase \n",
"[Couchbase](http://couchbase.com/) is an award-winning distributed NoSQL cloud database that delivers unmatched versatility, performance, scalability, and financial value for all of your cloud, mobile, AI, and edge computing applications. Couchbase embraces AI with coding assistance for developers and vector search for their applications.\n",
"\n",
"Vector Search is a part of the [Full Text Search Service](https://docs.couchbase.com/server/current/learn/services-and-indexes/services/search-service.html) (Search Service) in Couchbase.\n",
"\n",
"This tutorial explains how to use Vector Search in Couchbase. You can work with both [Couchbase Capella](https://www.couchbase.com/products/capella/) and your self-managed Couchbase Server."
]
},
{
"cell_type": "markdown",
"id": "43326be4-4433-4de2-ad42-6eb91a722bad",
"metadata": {},
"source": [
"## Installation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bec8d532-fec7-4dc7-9be3-020aa7bdb01f",
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade --quiet langchain langchain-openai couchbase"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a972cbc-bf59-46eb-9b50-e5dc3a69dcf0",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")"
]
},
{
"cell_type": "markdown",
"id": "acf1b168-622f-465c-a9a5-d27a6d7e7a8f",
"metadata": {},
"source": [
"## Import the Vector Store and Embeddings"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "23ce45ab-bfd2-42e1-b681-514a550f0232",
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.vectorstores import CouchbaseVectorStore\n",
"from langchain_openai import OpenAIEmbeddings"
]
},
{
"cell_type": "markdown",
"id": "3144ba02-1eaa-4449-853e-f034ca5706bf",
"metadata": {},
"source": [
"## Create Couchbase Connection Object\n",
"We create a connection to the Couchbase cluster initially and then pass the cluster object to the Vector Store. \n",
"\n",
"Here, we are connecting using the username and password. You can also connect using any other supported way to your cluster. \n",
"\n",
"For more information on connecting to the Couchbase cluster, please check the [Python SDK documentation](https://docs.couchbase.com/python-sdk/current/hello-world/start-using-sdk.html#connect)."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "52fe583a-12db-4dc2-9281-1174bf1d4e5c",
"metadata": {},
"outputs": [],
"source": [
"COUCHBASE_CONNECTION_STRING = (\n",
" \"couchbase://localhost\" # or \"couchbases://localhost\" if using TLS\n",
")\n",
"DB_USERNAME = \"Administrator\"\n",
"DB_PASSWORD = \"Password\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9986c6b9",
"metadata": {},
"outputs": [],
"source": [
"from datetime import timedelta\n",
"\n",
"from couchbase.auth import PasswordAuthenticator\n",
"from couchbase.cluster import Cluster\n",
"from couchbase.options import ClusterOptions\n",
"\n",
"auth = PasswordAuthenticator(DB_USERNAME, DB_PASSWORD)\n",
"options = ClusterOptions(auth)\n",
"cluster = Cluster(COUCHBASE_CONNECTION_STRING, options)\n",
"\n",
"# Wait until the cluster is ready for use.\n",
"cluster.wait_until_ready(timedelta(seconds=5))"
]
},
{
"cell_type": "markdown",
"id": "90c5dec9-f6cb-41eb-9f30-13cab7b107db",
"metadata": {},
"source": [
"We will now set the bucket, scope, and collection names in the Couchbase cluster that we want to use for Vector Search. \n",
"\n",
"For this example, we are using the default scope & collections."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1b1d0a26-e9d4-4823-9800-9549d24d3d16",
"metadata": {},
"outputs": [],
"source": [
"BUCKET_NAME = \"testing\"\n",
"SCOPE_NAME = \"_default\"\n",
"COLLECTION_NAME = \"_default\"\n",
"SEARCH_INDEX_NAME = \"vector-index\""
]
},
{
"cell_type": "markdown",
"id": "efbac6ff-c2ac-4443-9250-7cc88061346b",
"metadata": {},
"source": [
"For this tutorial, we will use OpenAI embeddings"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "87625579-86d7-4de4-8a4d-cee674a6b676",
"metadata": {},
"outputs": [],
"source": [
"embeddings = OpenAIEmbeddings()"
]
},
{
"cell_type": "markdown",
"id": "3677b4b0-3711-419c-89ff-32ef4d3e3022",
"metadata": {},
"source": [
"## Create the Search Index\n",
"Currently, the Search index needs to be created from the Couchbase Capella or Server UI or using the REST interface. \n",
"\n",
"Let us define a Search index with the name `vector-index` on the testing bucket\n",
"\n",
"For this example, let us use the Import Index feature on the Search Service on the UI. \n",
"\n",
"We are defining an index on the `testing` bucket's `_default` scope on the `_default` collection with the vector field set to `embedding` with 1536 dimensions and the text field set to `text`. We are also indexing and storing all the fields under `metadata` in the document as a dynamic mapping to account for varying document structures. The similarity metric is set to `dot_product`."
]
},
{
"cell_type": "markdown",
"id": "655117ae-9b1f-4139-b437-ca7685975a54",
"metadata": {},
"source": [
"### How to Import an Index to the Full Text Search service?\n",
" - [Couchbase Server](https://docs.couchbase.com/server/current/search/import-search-index.html)\n",
" - Click on Search -> Add Index -> Import\n",
" - Copy the following Index definition in the Import screen\n",
" - Click on Create Index to create the index.\n",
" - [Couchbase Capella](https://docs.couchbase.com/cloud/search/import-search-index.html)\n",
" - Copy the index definition to a new file `index.json`\n",
" - Import the file in Capella using the instructions in the documentation.\n",
" - Click on Create Index to create the index.\n",
" \n"
]
},
{
"cell_type": "markdown",
"id": "f85bc468-d9b8-487d-999a-3b5d2fb78e41",
"metadata": {},
"source": [
"### Index Definition\n",
"```\n",
"{\n",
" \"name\": \"vector-index\",\n",
" \"type\": \"fulltext-index\",\n",
" \"params\": {\n",
" \"doc_config\": {\n",
" \"docid_prefix_delim\": \"\",\n",
" \"docid_regexp\": \"\",\n",
" \"mode\": \"type_field\",\n",
" \"type_field\": \"type\"\n",
" },\n",
" \"mapping\": {\n",
" \"default_analyzer\": \"standard\",\n",
" \"default_datetime_parser\": \"dateTimeOptional\",\n",
" \"default_field\": \"_all\",\n",
" \"default_mapping\": {\n",
" \"dynamic\": true,\n",
" \"enabled\": true,\n",
" \"properties\": {\n",
" \"metadata\": {\n",
" \"dynamic\": true,\n",
" \"enabled\": true\n",
" },\n",
" \"embedding\": {\n",
" \"enabled\": true,\n",
" \"dynamic\": false,\n",
" \"fields\": [\n",
" {\n",
" \"dims\": 1536,\n",
" \"index\": true,\n",
" \"name\": \"embedding\",\n",
" \"similarity\": \"dot_product\",\n",
" \"type\": \"vector\",\n",
" \"vector_index_optimized_for\": \"recall\"\n",
" }\n",
" ]\n",
" },\n",
" \"text\": {\n",
" \"enabled\": true,\n",
" \"dynamic\": false,\n",
" \"fields\": [\n",
" {\n",
" \"index\": true,\n",
" \"name\": \"text\",\n",
" \"store\": true,\n",
" \"type\": \"text\"\n",
" }\n",
" ]\n",
" }\n",
" }\n",
" },\n",
" \"default_type\": \"_default\",\n",
" \"docvalues_dynamic\": false,\n",
" \"index_dynamic\": true,\n",
" \"store_dynamic\": true,\n",
" \"type_field\": \"_type\"\n",
" },\n",
" \"store\": {\n",
" \"indexType\": \"scorch\",\n",
" \"segmentVersion\": 16\n",
" }\n",
" },\n",
" \"sourceType\": \"gocbcore\",\n",
" \"sourceName\": \"testing\",\n",
" \"sourceParams\": {},\n",
" \"planParams\": {\n",
" \"maxPartitionsPerPIndex\": 103,\n",
" \"indexPartitions\": 10,\n",
" \"numReplicas\": 0\n",
" }\n",
"}\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "556dc68c-9089-4390-8dc9-b77051e7fc34",
"metadata": {},
"source": [
"For more details on how to create a Search index with support for Vector fields, please refer to the documentation.\n",
"\n",
"- [Couchbase Capella](https://docs.couchbase.com/cloud/vector-search/create-vector-search-index-ui.html)\n",
" \n",
"- [Couchbase Server](https://docs.couchbase.com/server/current/vector-search/create-vector-search-index-ui.html)"
]
},
{
"cell_type": "markdown",
"id": "75f4037d-e509-4de7-a8d1-63a05de24e9d",
"metadata": {},
"source": [
"## Create Vector Store\n",
"We create the vector store object with the cluster information and the search index name."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "33db4670-76c5-49ba-94d6-a8fa35583058",
"metadata": {},
"outputs": [],
"source": [
"vector_store = CouchbaseVectorStore(\n",
" cluster=cluster,\n",
" bucket_name=BUCKET_NAME,\n",
" scope_name=SCOPE_NAME,\n",
" collection_name=COLLECTION_NAME,\n",
" embedding=embeddings,\n",
" index_name=SEARCH_INDEX_NAME,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "0aa98793-5ac2-4f76-bbba-2d40856c2d58",
"metadata": {},
"source": [
"### Specify the Text & Embeddings Field\n",
"You can optionally specify the text & embeddings field for the document using the `text_key` and `embedding_key` fields.\n",
"```\n",
"vector_store = CouchbaseVectorStore(\n",
" cluster=cluster,\n",
" bucket_name=BUCKET_NAME,\n",
" scope_name=SCOPE_NAME,\n",
" collection_name=COLLECTION_NAME,\n",
" embedding=embeddings,\n",
" index_name=SEARCH_INDEX_NAME,\n",
" text_key=\"text\",\n",
" embedding_key=\"embedding\",\n",
")\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "790dc1ac-0ab8-4cb5-989d-31ca7c241068",
"metadata": {},
"source": [
"## Basic Vector Search Example\n",
"For this example, we are going to load the \"state_of_the_union.txt\" file via the TextLoader, chunk the text into 500 character chunks with no overlaps and index all these chunks into Couchbase.\n",
"\n",
"After the data is indexed, we perform a simple query to find the top 4 chunks that are similar to the query \"What did president say about Ketanji Brown Jackson\".\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "440350df-cbc6-48f7-8009-2e783be18306",
"metadata": {},
"outputs": [],
"source": [
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain_community.document_loaders import TextLoader\n",
"\n",
"loader = TextLoader(\"../../modules/state_of_the_union.txt\")\n",
"documents = loader.load()\n",
"text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=0)\n",
"docs = text_splitter.split_documents(documents)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "9d3b4c7c-abd6-4dfa-ad63-470f16661319",
"metadata": {},
"outputs": [],
"source": [
"vector_store = CouchbaseVectorStore.from_documents(\n",
" documents=docs,\n",
" embedding=embeddings,\n",
" cluster=cluster,\n",
" bucket_name=BUCKET_NAME,\n",
" scope_name=SCOPE_NAME,\n",
" collection_name=COLLECTION_NAME,\n",
" index_name=SEARCH_INDEX_NAME,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "91fdce6c-8f7c-4060-865a-2fd742846664",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"page_content='One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.' metadata={'source': '../../modules/state_of_the_union.txt'}\n"
]
}
],
"source": [
"query = \"What did president say about Ketanji Brown Jackson\"\n",
"results = vector_store.similarity_search(query)\n",
"print(results[0])"
]
},
{
"cell_type": "markdown",
"id": "d9b46c93-65f6-4e4f-87a2-5cebea3b7a6b",
"metadata": {},
"source": [
"## Similarity Search with Score\n",
"You can fetch the scores for the results by calling the `similarity_search_with_score` method."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "24b146b2-55a2-4fe8-8659-3649032f5dc7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"page_content='One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.' metadata={'source': '../../modules/state_of_the_union.txt'}\n",
"Score: 0.8211871385574341\n"
]
}
],
"source": [
"query = \"What did president say about Ketanji Brown Jackson\"\n",
"results = vector_store.similarity_search_with_score(query)\n",
"document, score = results[0]\n",
"print(document)\n",
"print(f\"Score: {score}\")"
]
},
{
"cell_type": "markdown",
"id": "9983e83d-efd0-4b75-80db-150e0694e822",
"metadata": {},
"source": [
"## Specifying Fields to Return\n",
"You can specify the fields to return from the document using `fields` parameter in the searches. These fields are returned as part of the `metadata` object in the returned Document. You can fetch any field that is stored in the Search index. The `text_key` of the document is returned as part of the document's `page_content`.\n",
"\n",
"If you do not specify any fields to be fetched, all the fields stored in the index are returned.\n",
"\n",
"If you want to fetch one of the fields in the metadata, you need to specify it using `.`\n",
"\n",
"For example, to fetch the `source` field in the metadata, you need to specify `metadata.source`.\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "ffa743dc-4e89-405b-ad71-7390338889e6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"page_content='One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.' metadata={'source': '../../modules/state_of_the_union.txt'}\n"
]
}
],
"source": [
"query = \"What did president say about Ketanji Brown Jackson\"\n",
"results = vector_store.similarity_search(query, fields=[\"metadata.source\"])\n",
"print(results[0])"
]
},
{
"cell_type": "markdown",
"id": "a5e45eb2-aa97-45df-bcc5-410e9626e506",
"metadata": {},
"source": [
"## Hybrid Search\n",
"Couchbase allows you to do hybrid searches by combining Vector Search results with searches on non-vector fields of the document like the `metadata` object. \n",
"\n",
"The results will be based on the combination of the results from both Vector Search and the searches supported by Search Service. The scores of each of the component searches are added up to get the total score of the result.\n",
"\n",
"To perform hybrid searches, there is an optional parameter, `search_options` that can be passed to all the similarity searches. \n",
"The different search/query possibilities for the `search_options` can be found [here](https://docs.couchbase.com/server/current/search/search-request-params.html#query-object)."
]
},
{
"cell_type": "markdown",
"id": "a5db3685-1918-4c63-8148-0bb3a71ea677",
"metadata": {},
"source": [
"### Create Diverse Metadata for Hybrid Search\n",
"In order to simulate hybrid search, let us create some random metadata from the existing documents. \n",
"We uniformly add three fields to the metadata, `date` between 2010 & 2020, `rating` between 1 & 5 and `author` set to either John Doe or Jane Doe. "
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "7d2e607d-6bbc-4cef-83e3-b6a28bb269ea",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'author': 'John Doe', 'date': '2016-01-01', 'rating': 2, 'source': '../../modules/state_of_the_union.txt'}\n"
]
}
],
"source": [
"# Adding metadata to documents\n",
"for i, doc in enumerate(docs):\n",
" doc.metadata[\"date\"] = f\"{range(2010, 2020)[i % 10]}-01-01\"\n",
" doc.metadata[\"rating\"] = range(1, 6)[i % 5]\n",
" doc.metadata[\"author\"] = [\"John Doe\", \"Jane Doe\"][i % 2]\n",
"\n",
"vector_store.add_documents(docs)\n",
"\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"results = vector_store.similarity_search(query)\n",
"print(results[0].metadata)"
]
},
{
"cell_type": "markdown",
"id": "6cad893b-3977-4556-ab1d-d12bce68b306",
"metadata": {},
"source": [
"### Example: Search by Exact Value\n",
"We can search for exact matches on a textual field like the author in the `metadata` object."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "dc06ba4a-8a6b-4c55-bb69-95cd92db273f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"page_content='This is personal to me and Jill, to Kamala, and to so many of you. \\n\\nCancer is the #2 cause of death in Americasecond only to heart disease. \\n\\nLast month, I announced our plan to supercharge \\nthe Cancer Moonshot that President Obama asked me to lead six years ago. \\n\\nOur goal is to cut the cancer death rate by at least 50% over the next 25 years, turn more cancers from death sentences into treatable diseases. \\n\\nMore support for patients and families.' metadata={'author': 'John Doe'}\n"
]
}
],
"source": [
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"results = vector_store.similarity_search(\n",
" query,\n",
" search_options={\"query\": {\"field\": \"metadata.author\", \"match\": \"John Doe\"}},\n",
" fields=[\"metadata.author\"],\n",
")\n",
"print(results[0])"
]
},
{
"cell_type": "markdown",
"id": "9106b594-b41e-4329-b98c-9b9f8a34d6f7",
"metadata": {},
"source": [
"### Example: Search by Partial Match\n",
"We can search for partial matches by specifying a fuzziness for the search. This is useful when you want to search for slight variations or misspellings of a search query.\n",
"\n",
"Here, \"Jae\" is close (fuzziness of 1) to \"Jane\"."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "fd4749e6-ef4f-4cb5-95ff-37c4fa8283d8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"page_content='A former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder. Since shes been nominated, shes received a broad range of support—from the Fraternal Order of Police to former judges appointed by Democrats and Republicans. \\n\\nAnd if we are to advance liberty and justice, we need to secure the Border and fix the immigration system.' metadata={'author': 'Jane Doe'}\n"
]
}
],
"source": [
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"results = vector_store.similarity_search(\n",
" query,\n",
" search_options={\n",
" \"query\": {\"field\": \"metadata.author\", \"match\": \"Jae\", \"fuzziness\": 1}\n",
" },\n",
" fields=[\"metadata.author\"],\n",
")\n",
"print(results[0])"
]
},
{
"cell_type": "markdown",
"id": "1bbf9449-6e30-4bd1-9eeb-f3b60952fcab",
"metadata": {},
"source": [
"### Example: Search by Date Range Query\n",
"We can search for documents that are within a date range query on a date field like `metadata.date`."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "b7b47e7d-c32f-4999-bce9-3c3c3cebffd0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"page_content='He will never extinguish their love of freedom. He will never weaken the resolve of the free world. \\n\\nWe meet tonight in an America that has lived through two of the hardest years this nation has ever faced. \\n\\nThe pandemic has been punishing. \\n\\nAnd so many families are living paycheck to paycheck, struggling to keep up with the rising cost of food, gas, housing, and so much more. \\n\\nI understand.' metadata={'author': 'Jane Doe', 'date': '2017-01-01', 'rating': 3, 'source': '../../modules/state_of_the_union.txt'}\n"
]
}
],
"source": [
"query = \"Any mention about independence?\"\n",
"results = vector_store.similarity_search(\n",
" query,\n",
" search_options={\n",
" \"query\": {\n",
" \"start\": \"2016-12-31\",\n",
" \"end\": \"2017-01-02\",\n",
" \"inclusive_start\": True,\n",
" \"inclusive_end\": False,\n",
" \"field\": \"metadata.date\",\n",
" }\n",
" },\n",
")\n",
"print(results[0])"
]
},
{
"cell_type": "markdown",
"id": "a18d4ea2-bfab-4f15-9839-674faf1c6f0d",
"metadata": {},
"source": [
"### Example: Search by Numeric Range Query\n",
"We can search for documents that are within a range for a numeric field like `metadata.rating`."
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "7e8bf7c5-07d1-4c3f-86d7-1fa3a454dc7f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(Document(page_content='He will never extinguish their love of freedom. He will never weaken the resolve of the free world. \\n\\nWe meet tonight in an America that has lived through two of the hardest years this nation has ever faced. \\n\\nThe pandemic has been punishing. \\n\\nAnd so many families are living paycheck to paycheck, struggling to keep up with the rising cost of food, gas, housing, and so much more. \\n\\nI understand.', metadata={'author': 'Jane Doe', 'date': '2017-01-01', 'rating': 3, 'source': '../../modules/state_of_the_union.txt'}), 0.9000703597577832)\n"
]
}
],
"source": [
"query = \"Any mention about independence?\"\n",
"results = vector_store.similarity_search_with_score(\n",
" query,\n",
" search_options={\n",
" \"query\": {\n",
" \"min\": 3,\n",
" \"max\": 5,\n",
" \"inclusive_min\": True,\n",
" \"inclusive_max\": True,\n",
" \"field\": \"metadata.rating\",\n",
" }\n",
" },\n",
")\n",
"print(results[0])"
]
},
{
"cell_type": "markdown",
"id": "0f16bf86-f01c-4a77-8406-275f7313f493",
"metadata": {},
"source": [
"### Example: Combining Multiple Search Queries\n",
"Different search queries can be combined using AND (conjuncts) or OR (disjuncts) operators.\n",
"\n",
"In this example, we are checking for documents with a rating between 3 & 4 and dated between 2015 & 2018."
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "dd0fe7f1-aa40-4c6f-889b-99ad5efcd88b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(Document(page_content='He will never extinguish their love of freedom. He will never weaken the resolve of the free world. \\n\\nWe meet tonight in an America that has lived through two of the hardest years this nation has ever faced. \\n\\nThe pandemic has been punishing. \\n\\nAnd so many families are living paycheck to paycheck, struggling to keep up with the rising cost of food, gas, housing, and so much more. \\n\\nI understand.', metadata={'author': 'Jane Doe', 'date': '2017-01-01', 'rating': 3, 'source': '../../modules/state_of_the_union.txt'}), 1.3598770370389914)\n"
]
}
],
"source": [
"query = \"Any mention about independence?\"\n",
"results = vector_store.similarity_search_with_score(\n",
" query,\n",
" search_options={\n",
" \"query\": {\n",
" \"conjuncts\": [\n",
" {\"min\": 3, \"max\": 4, \"inclusive_max\": True, \"field\": \"metadata.rating\"},\n",
" {\"start\": \"2016-12-31\", \"end\": \"2017-01-02\", \"field\": \"metadata.date\"},\n",
" ]\n",
" }\n",
" },\n",
")\n",
"print(results[0])"
]
},
{
"cell_type": "markdown",
"id": "39258571-3233-45c3-a6ad-5c3c90ea2b1c",
"metadata": {},
"source": [
"### Other Queries\n",
"Similarly, you can use any of the supported Query methods like Geo Distance, Polygon Search, Wildcard, Regular Expressions, etc in the `search_options` parameter. Please refer to the documentation for more details on the available query methods and their syntax.\n",
"\n",
"- [Couchbase Capella](https://docs.couchbase.com/cloud/search/search-request-params.html#query-object)\n",
"- [Couchbase Server](https://docs.couchbase.com/server/current/search/search-request-params.html#query-object)"
]
},
{
"cell_type": "markdown",
"id": "80958c2b-6a67-45e6-b7f0-fd2461d75e0f",
"metadata": {},
"source": [
"# Frequently Asked Questions"
]
},
{
"cell_type": "markdown",
"id": "4f7f9838-cc20-44bc-a72d-06f2cb6c3fca",
"metadata": {},
"source": [
"## Question: Should I create the Search index before creating the CouchbaseVectorStore object?\n",
"Yes, currently you need to create the Search index before creating the `CouchbaseVectoreStore` object.\n"
]
},
{
"cell_type": "markdown",
"id": "3f0dbc1b-9e82-4ec3-9330-6b54de00661e",
"metadata": {},
"source": [
"## Question: I am not seeing all the fields that I specified in my search results. \n",
"\n",
"In Couchbase, we can only return the fields stored in the Search index. Please ensure that the field that you are trying to access in the search results is part of the Search index.\n",
"\n",
"One way to handle this is to index and store a document's fields dynamically in the index. \n",
"\n",
"- In Capella, you need to go to \"Advanced Mode\" then under the chevron \"General Settings\" you can check \"[X] Store Dynamic Fields\" or \"[X] Index Dynamic Fields\"\n",
"- In Couchbase Server, in the Index Editor (not Quick Editor) under the chevron \"Advanced\" you can check \"[X] Store Dynamic Fields\" or \"[X] Index Dynamic Fields\"\n",
"\n",
"Note that these options will increase the size of the index.\n",
"\n",
"For more details on dynamic mappings, please refer to the [documentation](https://docs.couchbase.com/cloud/search/customize-index.html).\n"
]
},
{
"cell_type": "markdown",
"id": "3702977a-2e25-48b6-b662-edd5cb94cdec",
"metadata": {},
"source": [
"## Question: I am unable to see the metadata object in my search results. \n",
"This is most likely due to the `metadata` field in the document not being indexed and/or stored by the Couchbase Search index. In order to index the `metadata` field in the document, you need to add it to the index as a child mapping. \n",
"\n",
"If you select to map all the fields in the mapping, you will be able to search by all metadata fields. Alternatively, to optimize the index, you can select the specific fields inside `metadata` object to be indexed. You can refer to the [docs](https://docs.couchbase.com/cloud/search/customize-index.html) to learn more about indexing child mappings.\n",
"\n",
"Creating Child Mappings\n",
"\n",
"* [Couchbase Capella](https://docs.couchbase.com/cloud/search/create-child-mapping.html)\n",
"* [Couchbase Server](https://docs.couchbase.com/server/current/search/create-child-mapping.html)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -37,9 +37,21 @@
"\n",
"To run this demo we need a running Infinispan instance without authentication and a data file.\n",
"In the next three cells we're going to:\n",
"- download the data file\n",
"- create the configuration\n",
"- run Infinispan in docker\n",
"- download the data file"
"- run Infinispan in docker"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9678d5ce-894c-4e28-bf68-20d45507122f",
"metadata": {},
"outputs": [],
"source": [
"%%bash\n",
"#get an archive of news\n",
"wget https://raw.githubusercontent.com/rigazilla/infinispan-vector/main/bbc_news.csv.gz"
]
},
{
@ -76,18 +88,6 @@
"' > infinispan-noauth.yaml"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9678d5ce-894c-4e28-bf68-20d45507122f",
"metadata": {},
"outputs": [],
"source": [
"%%bash\n",
"#get an archive of news\n",
"wget https://raw.githubusercontent.com/rigazilla/infinispan-vector/main/bbc_news.csv.gz"
]
},
{
"cell_type": "code",
"execution_count": null,
@ -95,7 +95,8 @@
"metadata": {},
"outputs": [],
"source": [
"!docker run -d --name infinispanvs-demo -v $(pwd):/user-config -p 11222:11222 infinispan/server:15.0.0.Dev09 -c /user-config/infinispan-noauth.yaml "
"!docker rm --force infinispanvs-demo\n",
"!docker run -d --name infinispanvs-demo -v $(pwd):/user-config -p 11222:11222 infinispan/server:15.0 -c /user-config/infinispan-noauth.yaml"
]
},
{
@ -133,80 +134,8 @@
"## Setup Infinispan cache\n",
"\n",
"Infinispan is a very flexible key-value store, it can store raw bits as well as complex data type.\n",
"We need to configure it to store data containing embedded vectors.\n",
"\n",
"In the next cells we're going to:\n",
"- create an empty Infinispan VectoreStore\n",
"- deploy a protobuf definition of our data\n",
"- create a cache"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "49668bf1-778b-466d-86fb-41747ed52b74",
"metadata": {},
"outputs": [],
"source": [
"# Creating a langchain_core.VectorStore\n",
"from langchain_community.vectorstores import InfinispanVS\n",
"\n",
"ispnvs = InfinispanVS.from_texts(\n",
" texts={}, embedding=hf, cache_name=\"demo_cache\", entity_name=\"demo_entity\"\n",
")\n",
"ispn = ispnvs.ispn"
]
},
{
"cell_type": "markdown",
"id": "0cedf066-aaab-4185-b049-93eea9b48329",
"metadata": {},
"source": [
"### Protobuf definition\n",
"\n",
"Below there's the protobuf definition of our data type that contains:\n",
"- embedded vector (field 1)\n",
"- text of the news (2)\n",
"- title of the news (3)\n",
"\n",
"As you can see, there are additional annotations in the comments that tell Infinispan that:\n",
"- data type must be indexed (`@Indexed`)\n",
"- field 1 is an embeddeded vector (`@Vector`)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1fa0add0-8317-4667-9b8c-5d91c47f752a",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"# Infinispan supports protobuf schemas\n",
"schema_vector = \"\"\"\n",
"/**\n",
" * @Indexed\n",
" */\n",
"message demo_entity {\n",
"/**\n",
" * @Vector(dimension=384)\n",
" */\n",
"repeated float vector = 1;\n",
"optional string text = 2;\n",
"optional string title = 3;\n",
"}\n",
"\"\"\"\n",
"# Cleanup before deploy a new schema\n",
"ispnvs.schema_delete()\n",
"output = ispnvs.schema_create(schema_vector)\n",
"assert output.status_code == 200\n",
"assert json.loads(output.text)[\"error\"] is None\n",
"# Create the cache\n",
"ispnvs.cache_create()\n",
"# Cleanup old data and index\n",
"ispnvs.cache_clear()\n",
"ispnvs.cache_index_reindex()"
"User has complete freedom in the datagrid configuration, but for simple data type everything is automatically\n",
"configured by the python layer. We take advantage of this feature so we can focus on our application."
]
},
{
@ -216,8 +145,7 @@
"source": [
"## Prepare the data\n",
"\n",
"In this demo we choose to store text,vector and metadata in the same cache, but other options\n",
"are possible: i.e. content can be store somewhere else and vector store could contain only a reference to the actual content."
"In this demo we rely on the default configuration, thus texts, metadatas and vectors in the same cache, but other options are possible: i.e. content can be store somewhere else and vector store could contain only a reference to the actual content."
]
},
{
@ -239,15 +167,12 @@
" metas = []\n",
" embeds = []\n",
" for row in spamreader:\n",
" # first and fifth value are joined to form the content\n",
" # first and fifth values are joined to form the content\n",
" # to be processed\n",
" text = row[0] + \".\" + row[4]\n",
" texts.append(text)\n",
" # Storing meta\n",
" # Store text and title as metadata\n",
" meta = {}\n",
" meta[\"text\"] = row[4]\n",
" meta[\"title\"] = row[0]\n",
" meta = {\"text\": row[4], \"title\": row[0]}\n",
" metas.append(meta)\n",
" i = i + 1\n",
" # Change this to change the number of news you want to load\n",
@ -271,7 +196,10 @@
"outputs": [],
"source": [
"# add texts and fill vector db\n",
"keys = ispnvs.add_texts(texts, metas)"
"\n",
"from langchain_community.vectorstores import InfinispanVS\n",
"\n",
"ispnvs = InfinispanVS.from_texts(texts, hf, metas)"
]
},
{
@ -361,18 +289,6 @@
"print_docs(ispnvs.similarity_search(\"How to stay young\", 5))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "862e4af2-9f8a-4985-90cb-997477901b1e",
"metadata": {},
"outputs": [],
"source": [
"# Clean up\n",
"ispnvs.schema_delete()\n",
"ispnvs.cache_delete()"
]
},
{
"cell_type": "code",
"execution_count": null,
@ -400,7 +316,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
"version": "3.9.18"
}
},
"nbformat": 4,

@ -10,7 +10,7 @@
"Splits the text based on semantic similarity.\n",
"\n",
"Taken from Greg Kamradt's wonderful notebook:\n",
"https://github.com/FullStackRetrieval-com/RetrievalTutorials/blob/main/5_Levels_Of_Text_Splitting.ipynb\n",
"[5_Levels_Of_Text_Splitting](https://github.com/FullStackRetrieval-com/RetrievalTutorials/blob/main/5_Levels_Of_Text_Splitting.ipynb)\n",
"\n",
"All credit to him.\n",
"\n",

@ -49,6 +49,14 @@
"from langchain_text_splitters import CharacterTextSplitter"
]
},
{
"cell_type": "markdown",
"id": "a3ba1d8a",
"metadata": {},
"source": [
"The `.from_tiktoken_encoder()` method takes either `encoding` as an argument (e.g. `cl100k_base`), or the `model_name` (e.g. `gpt-4`). All additional arguments like `chunk_size`, `chunk_overlap`, and `separators` are used to instantiate `CharacterTextSplitter`:"
]
},
{
"cell_type": "code",
"execution_count": 2,
@ -57,7 +65,7 @@
"outputs": [],
"source": [
"text_splitter = CharacterTextSplitter.from_tiktoken_encoder(\n",
" chunk_size=100, chunk_overlap=0\n",
" encoding=\"cl100k_base\", chunk_size=100, chunk_overlap=0\n",
")\n",
"texts = text_splitter.split_text(state_of_the_union)"
]
@ -91,9 +99,31 @@
"id": "de5b6a6e",
"metadata": {},
"source": [
"Note that if we use `CharacterTextSplitter.from_tiktoken_encoder`, text is only split by `CharacterTextSplitter` and `tiktoken` tokenizer is used to merge splits. It means that split can be larger than chunk size measured by `tiktoken` tokenizer. We can use `RecursiveCharacterTextSplitter.from_tiktoken_encoder` to make sure splits are not larger than chunk size of tokens allowed by the language model, where each split will be recursively split if it has a larger size.\n",
"Note that if we use `CharacterTextSplitter.from_tiktoken_encoder`, text is only split by `CharacterTextSplitter` and `tiktoken` tokenizer is used to merge splits. It means that split can be larger than chunk size measured by `tiktoken` tokenizer. We can use `RecursiveCharacterTextSplitter.from_tiktoken_encoder` to make sure splits are not larger than chunk size of tokens allowed by the language model, where each split will be recursively split if it has a larger size:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0262a991",
"metadata": {},
"outputs": [],
"source": [
"from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
"\n",
"We can also load a tiktoken splitter directly, which ensure each split is smaller than chunk size."
"text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(\n",
" model_name=\"gpt-4\",\n",
" chunk_size=100,\n",
" chunk_overlap=0,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "04457e3a",
"metadata": {},
"source": [
"We can also load a tiktoken splitter directly, which will ensure each split is smaller than chunk size."
]
},
{
@ -111,6 +141,14 @@
"print(texts[0])"
]
},
{
"cell_type": "markdown",
"id": "3bc155d0",
"metadata": {},
"source": [
"Some written languages (e.g. Chinese and Japanese) have characters which encode to 2 or more tokens. Using the `TokenTextSplitter` directly can split the tokens for a character between two chunks causing malformed Unicode characters. Use `RecursiveCharacterTextSplitter.from_tiktoken_encoder` or `CharacterTextSplitter.from_tiktoken_encoder` to ensure chunks contain valid Unicode strings."
]
},
{
"cell_type": "markdown",
"id": "55f95f06",

@ -60,7 +60,7 @@
" * document addition by id (`add_documents` method with `ids` argument)\n",
" * delete by id (`delete` method with `ids` argument)\n",
"\n",
"Compatible Vectorstores: `AnalyticDB`, `AstraDB`, `AwaDB`, `Bagel`, `Cassandra`, `Chroma`, `DashVector`, `DatabricksVectorSearch`, `DeepLake`, `Dingo`, `ElasticVectorSearch`, `ElasticsearchStore`, `FAISS`, `HanaDB`, `Milvus`, `MyScale`, `OpenSearchVectorSearch`, `PGVector`, `Pinecone`, `Qdrant`, `Redis`, `Rockset`, `ScaNN`, `SupabaseVectorStore`, `SurrealDBStore`, `TimescaleVector`, `Vald`, `Vearch`, `VespaStore`, `Weaviate`, `ZepVectorStore`.\n",
"Compatible Vectorstores: `AnalyticDB`, `AstraDB`, `AwaDB`, `Bagel`, `Cassandra`, `Chroma`, `CouchbaseVectorStore`, `DashVector`, `DatabricksVectorSearch`, `DeepLake`, `Dingo`, `ElasticVectorSearch`, `ElasticsearchStore`, `FAISS`, `HanaDB`, `Milvus`, `MyScale`, `OpenSearchVectorSearch`, `PGVector`, `Pinecone`, `Qdrant`, `Redis`, `Rockset`, `ScaNN`, `SupabaseVectorStore`, `SurrealDBStore`, `TimescaleVector`, `Vald`, `Vearch`, `VespaStore`, `Weaviate`, `ZepVectorStore`.\n",
" \n",
"## Caution\n",
"\n",

@ -22,10 +22,11 @@
"Caching embeddings can be done using a `CacheBackedEmbeddings`. The cache backed embedder is a wrapper around an embedder that caches\n",
"embeddings in a key-value store. The text is hashed and the hash is used as the key in the cache.\n",
"\n",
"The main supported way to initialized a `CacheBackedEmbeddings` is `from_bytes_store`. This takes in the following parameters:\n",
"The main supported way to initialize a `CacheBackedEmbeddings` is `from_bytes_store`. It takes the following parameters:\n",
"\n",
"- underlying_embedder: The embedder to use for embedding.\n",
"- document_embedding_cache: Any [`ByteStore`](/docs/integrations/stores/) for caching document embeddings.\n",
"- batch_size: (optional, defaults to `None`) The number of documents to embed between store updates.\n",
"- namespace: (optional, defaults to `\"\"`) The namespace to use for document cache. This namespace is used to avoid collisions with other caches. For example, set it to the name of the embedding model used.\n",
"\n",
"**Attention**: Be sure to set the `namespace` parameter to avoid collisions of the same text embedded using different embeddings models."

@ -39,7 +39,7 @@
},
"devDependencies": {
"@babel/eslint-parser": "^7.18.2",
"@langchain/scripts": "^0.0.9",
"@langchain/scripts": "^0.0.10",
"docusaurus-plugin-typedoc": "next",
"dotenv": "^16.4.5",
"eslint": "^8.19.0",

@ -3,5 +3,5 @@ const { checkBrokenLinks } = require("@langchain/scripts/check_broken_links");
checkBrokenLinks("docs", {
timeout: 10000,
whitelist: ["microsoft.com"],
retryFailed: true,
});

@ -0,0 +1,103 @@
/* eslint-disable react/jsx-props-no-spreading */
import React from "react";
import Tabs from "@theme/Tabs";
import TabItem from "@theme/TabItem";
import CodeBlock from "@theme-original/CodeBlock";
function Setup({ apiKeyName, packageName }) {
const apiKeyText = `import getpass
import os
os.environ["${apiKeyName}"] = getpass.getpass()`;
return (
<>
<h5>Install dependencies</h5>
<CodeBlock language="bash">{`pip install -qU ${packageName}`}</CodeBlock>
<h5>Set environment variables</h5>
<CodeBlock language="python">{apiKeyText}</CodeBlock>
</>
);
}
/**
* @param {{ openaiParams?: string, anthropicParams?: string, fireworksParams?: string, mistralParams?: string, googleParams?: string, hideOpenai?: boolean, hideAnthropic?: boolean, hideFireworks?: boolean, hideMistral?: boolean, hideGoogle?: boolean }} props
*/
export default function ChatModelTabs(props) {
const {
openaiParams,
anthropicParams,
fireworksParams,
mistralParams,
googleParams,
hideOpenai,
hideAnthropic,
hideFireworks,
hideMistral,
hideGoogle,
} = props;
const openAIParamsOrDefault = openaiParams ?? `model="gpt-3.5-turbo-0125"`
const anthropicParamsOrDefault = anthropicParams ?? `model="claude-3-sonnet-20240229"`
const fireworksParamsOrDefault = fireworksParams ?? `model="accounts/fireworks/models/mixtral-8x7b-instruct"`
const mistralParamsOrDefault = mistralParams ?? `model="mistral-large-latest"`
const googleParamsOrDefault = googleParams ?? `model="gemini-pro"`
const tabItems = [
{
value: "OpenAI",
label: "OpenAI",
text: `from langchain_openai import ChatOpenAI\n\nmodel = ChatOpenAI(${openAIParamsOrDefault})`,
apiKeyName: "OPENAI_API_KEY",
packageName: "langchain-openai",
default: true,
shouldHide: hideOpenai,
},
{
value: "Anthropic",
label: "Anthropic",
text: `from langchain_anthropic import ChatAnthropic\n\nmodel = ChatAnthropic(${anthropicParamsOrDefault})`,
apiKeyName: "ANTHROPIC_API_KEY",
packageName: "langchain-anthropic",
default: false,
shouldHide: hideAnthropic,
},
{
value: "FireworksAI",
label: "FireworksAI",
text: `from langchain_fireworks import ChatFireworks\n\nmodel = ChatFireworks(${fireworksParamsOrDefault})`,
apiKeyName: "FIREWORKS_API_KEY",
packageName: "langchain-fireworks",
default: false,
shouldHide: hideFireworks,
},
{
value: "MistralAI",
label: "MistralAI",
text: `from langchain_mistralai import ChatMistralAI\n\nmodel = ChatMistralAI(${mistralParamsOrDefault})`,
apiKeyName: "MISTRAL_API_KEY",
packageName: "langchain-mistralai",
default: false,
shouldHide: hideMistral,
},
{
value: "Google",
label: "Google",
text: `from langchain_google_genai import ChatGoogleGenerativeAI\n\nmodel = ChatGoogleGenerativeAI(${googleParamsOrDefault})`,
apiKeyName: "GOOGLE_API_KEY",
packageName: "langchain-google-genai",
default: false,
shouldHide: hideGoogle,
}
]
return (
<Tabs groupId="modelTabs">
{tabItems.filter((tabItem) => !tabItem.shouldHide).map((tabItem) => (
<TabItem value={tabItem.value} label={tabItem.label} default={tabItem.default}>
<Setup apiKeyName={tabItem.apiKeyName} packageName={tabItem.packageName} />
<CodeBlock language="python">{tabItem.text}</CodeBlock>
</TabItem>
))}
</Tabs>
);
}

@ -2505,9 +2505,9 @@ __metadata:
languageName: node
linkType: hard
"@langchain/scripts@npm:^0.0.9":
version: 0.0.9
resolution: "@langchain/scripts@npm:0.0.9"
"@langchain/scripts@npm:^0.0.10":
version: 0.0.10
resolution: "@langchain/scripts@npm:0.0.10"
dependencies:
axios: ^1.6.7
commander: ^11.1.0
@ -2517,7 +2517,7 @@ __metadata:
typescript: <5.2.0
bin:
lc-build: build.js
checksum: 8b0e2e73b84e5997155d8f73208b86fc3178ad72ec013f7b9fc976e52b55ce4b39e6eaf2315f3b841cc4c40e954851655e3cee3b236f557b7421bfad3be83f32
checksum: 2051be819a5fb9863f81c06e5504626377ac922093953c559f84ec9931108a1b72942d5f5b7e22f44dd60ca8cc7a34532d2eea03a0055b96bf2acf300ec454f6
languageName: node
linkType: hard
@ -5922,7 +5922,7 @@ __metadata:
"@docusaurus/preset-classic": 2.4.3
"@docusaurus/remark-plugin-npm2yarn": ^2.4.3
"@docusaurus/theme-mermaid": 2.4.3
"@langchain/scripts": ^0.0.9
"@langchain/scripts": ^0.0.10
"@mdx-js/react": ^1.6.22
"@supabase/supabase-js": ^2.39.7
clsx: ^1.2.1

@ -63,6 +63,7 @@ try:
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from langchain_core._api.deprecation import deprecated
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.llms import LLM, aget_prompts, get_prompts
@ -1394,6 +1395,11 @@ class SQLAlchemyMd5Cache(BaseCache):
ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_cache"
@deprecated(
since="0.0.28",
removal="0.2.0",
alternative_import="langchain_astradb.AstraDBCache",
)
class AstraDBCache(BaseCache):
@staticmethod
def _make_id(prompt: str, llm_string: str) -> str:
@ -1592,6 +1598,11 @@ def _async_lru_cache(maxsize: int = 128, typed: bool = False) -> Callable:
return decorating_function
@deprecated(
since="0.0.28",
removal="0.2.0",
alternative_import="langchain_astradb.AstraDBSemanticCache",
)
class AstraDBSemanticCache(BaseCache):
def __init__(
self,

@ -143,6 +143,44 @@ class LlamaChatContentFormatter(ContentFormatterBase):
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
class MistralChatContentFormatter(LlamaChatContentFormatter):
"""Content formatter for `Mistral`."""
def format_messages_request_payload(
self,
messages: List[BaseMessage],
model_kwargs: Dict,
api_type: AzureMLEndpointApiType,
) -> bytes:
"""Formats the request according to the chosen api"""
chat_messages = [self._convert_message_to_dict(message) for message in messages]
if chat_messages and chat_messages[0]["role"] == "system":
# Mistral OSS models do not explicitly support system prompts, so we have to
# stash in the first user prompt
chat_messages[1]["content"] = (
chat_messages[0]["content"] + "\n\n" + chat_messages[1]["content"]
)
del chat_messages[0]
if api_type == AzureMLEndpointApiType.realtime:
request_payload = json.dumps(
{
"input_data": {
"input_string": chat_messages,
"parameters": model_kwargs,
}
}
)
elif api_type == AzureMLEndpointApiType.serverless:
request_payload = json.dumps({"messages": chat_messages, **model_kwargs})
else:
raise ValueError(
f"`api_type` {api_type} is not supported by this formatter"
)
return str.encode(request_payload)
class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
"""Azure ML Online Endpoint chat models.

@ -6,6 +6,7 @@ from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
import requests
from langchain_core._api.deprecation import deprecated
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, root_validator
@ -26,6 +27,11 @@ DEFAULT_API_ENDPOINT = "https://api.docugami.com/v1preview1"
logger = logging.getLogger(__name__)
@deprecated(
since="0.0.24",
removal="0.2.0",
alternative_import="docugami_langchain.DocugamiLoader",
)
class DocugamiLoader(BaseLoader, BaseModel):
"""Load from `Docugami`.

@ -162,6 +162,7 @@ def _import_databricks() -> Type[BaseLLM]:
return Databricks
# deprecated / only for back compat - do not add to __all__
def _import_databricks_chat() -> Any:
warn_deprecated(
since="0.0.22",
@ -325,6 +326,7 @@ def _import_mlflow() -> Type[BaseLLM]:
return Mlflow
# deprecated / only for back compat - do not add to __all__
def _import_mlflow_chat() -> Any:
warn_deprecated(
since="0.0.22",
@ -631,7 +633,7 @@ def __getattr__(name: str) -> Any:
return _import_aviary()
elif name == "AzureMLOnlineEndpoint":
return _import_azureml_endpoint()
elif name == "Baichuan":
elif name == "BaichuanLLM" or name == "Baichuan":
return _import_baichuan()
elif name == "QianfanLLMEndpoint":
return _import_baidu_qianfan_endpoint()
@ -701,6 +703,8 @@ def __getattr__(name: str) -> Any:
return _import_konko()
elif name == "LlamaCpp":
return _import_llamacpp()
elif name == "Llamafile":
return _import_llamafile()
elif name == "ManifestWrapper":
return _import_manifest()
elif name == "Minimax":
@ -818,6 +822,7 @@ __all__ = [
"Aviary",
"AzureMLOnlineEndpoint",
"AzureOpenAI",
"BaichuanLLM",
"Banana",
"Baseten",
"Beam",
@ -836,8 +841,8 @@ __all__ = [
"Fireworks",
"ForefrontAI",
"Friendli",
"GigaChat",
"GPT4All",
"GigaChat",
"GooglePalm",
"GooseAI",
"GradientLLM",
@ -846,22 +851,26 @@ __all__ = [
"HuggingFacePipeline",
"HuggingFaceTextGenInference",
"HumanInputLLM",
"JavelinAIGateway",
"KoboldApiLLM",
"Konko",
"LlamaCpp",
"TextGen",
"Llamafile",
"ManifestWrapper",
"Minimax",
"Mlflow",
"MlflowAIGateway",
"Modal",
"MosaicML",
"Nebula",
"NIBittensorLLM",
"NLPCloud",
"Nebula",
"OCIGenAI",
"OCIModelDeploymentTGI",
"OCIModelDeploymentVLLM",
"OCIGenAI",
"OctoAIEndpoint",
"Ollama",
"OpaquePrompts",
"OpenAI",
"OpenAIChat",
"OpenLLM",
@ -873,30 +882,29 @@ __all__ = [
"PredictionGuard",
"PromptLayerOpenAI",
"PromptLayerOpenAIChat",
"OpaquePrompts",
"QianfanLLMEndpoint",
"RWKV",
"Replicate",
"SagemakerEndpoint",
"SelfHostedHuggingFaceLLM",
"SelfHostedPipeline",
"SparkLLM",
"StochasticAI",
"TextGen",
"TitanTakeoff",
"TitanTakeoffPro",
"Together",
"Tongyi",
"VertexAI",
"VertexAIModelGarden",
"VLLM",
"VLLMOpenAI",
"VertexAI",
"VertexAIModelGarden",
"VolcEngineMaasLLM",
"WatsonxLLM",
"Writer",
"OctoAIEndpoint",
"Xinference",
"JavelinAIGateway",
"QianfanLLMEndpoint",
"YandexGPT",
"Yuan2",
"VolcEngineMaasLLM",
"SparkLLM",
]
@ -912,6 +920,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"aviary": _import_aviary,
"azure": _import_azure_openai,
"azureml_endpoint": _import_azureml_endpoint,
"baichuan": _import_baichuan,
"bananadev": _import_bananadev,
"baseten": _import_baseten,
"beam": _import_beam,
@ -922,7 +931,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"ctransformers": _import_ctransformers,
"ctranslate2": _import_ctranslate2,
"databricks": _import_databricks,
"databricks-chat": _import_databricks_chat,
"databricks-chat": _import_databricks_chat, # deprecated / only for back compat
"deepinfra": _import_deepinfra,
"deepsparse": _import_deepsparse,
"edenai": _import_edenai,
@ -942,10 +951,11 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"koboldai": _import_koboldai,
"konko": _import_konko,
"llamacpp": _import_llamacpp,
"llamafile": _import_llamafile,
"textgen": _import_textgen,
"minimax": _import_minimax,
"mlflow": _import_mlflow,
"mlflow-chat": _import_mlflow_chat,
"mlflow-chat": _import_mlflow_chat, # deprecated / only for back compat
"mlflow-ai-gateway": _import_mlflow_ai_gateway,
"modal": _import_modal,
"mosaic": _import_mosaicml,

@ -11,12 +11,18 @@ from langchain_core.outputs import Generation, LLMResult
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator, validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
DEFAULT_TIMEOUT = 50
class AzureMLEndpointClient(object):
"""AzureML Managed Endpoint client."""
def __init__(
self, endpoint_url: str, endpoint_api_key: str, deployment_name: str = ""
self,
endpoint_url: str,
endpoint_api_key: str,
deployment_name: str = "",
timeout: int = DEFAULT_TIMEOUT,
) -> None:
"""Initialize the class."""
if not endpoint_api_key or not endpoint_url:
@ -27,6 +33,7 @@ class AzureMLEndpointClient(object):
self.endpoint_url = endpoint_url
self.endpoint_api_key = endpoint_api_key
self.deployment_name = deployment_name
self.timeout = timeout
def call(
self,
@ -47,7 +54,9 @@ class AzureMLEndpointClient(object):
headers["azureml-model-deployment"] = self.deployment_name
req = urllib.request.Request(self.endpoint_url, body, headers)
response = urllib.request.urlopen(req, timeout=kwargs.get("timeout", 50))
response = urllib.request.urlopen(
req, timeout=kwargs.get("timeout", self.timeout)
)
result = response.read()
return result
@ -334,6 +343,9 @@ class AzureMLBaseEndpoint(BaseModel):
"""Deployment Name for Endpoint. NOT REQUIRED to call endpoint. Should be passed
to constructor or specified as env var `AZUREML_DEPLOYMENT_NAME`."""
timeout: int = DEFAULT_TIMEOUT
"""Request timeout for calls to the endpoint"""
http_client: Any = None #: :meta private:
content_formatter: Any = None
@ -361,6 +373,12 @@ class AzureMLBaseEndpoint(BaseModel):
"AZUREML_ENDPOINT_API_TYPE",
AzureMLEndpointApiType.realtime,
)
values["timeout"] = get_from_dict_or_env(
values,
"timeout",
"AZUREML_TIMEOUT",
str(DEFAULT_TIMEOUT),
)
return values
@ -424,12 +442,15 @@ class AzureMLBaseEndpoint(BaseModel):
endpoint_url = values.get("endpoint_url")
endpoint_key = values.get("endpoint_api_key")
deployment_name = values.get("deployment_name")
timeout = values.get("timeout", DEFAULT_TIMEOUT)
http_client = AzureMLEndpointClient(
endpoint_url, # type: ignore
endpoint_key.get_secret_value(), # type: ignore
deployment_name, # type: ignore
timeout, # type: ignore
)
return http_client
@ -442,6 +463,7 @@ class AzureMLOnlineEndpoint(BaseLLM, AzureMLBaseEndpoint):
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
endpoint_api_type=AzureMLApiType.realtime,
endpoint_api_key="my-api-key",
timeout=120,
content_formatter=content_formatter,
)
""" # noqa: E501

@ -1,5 +1,5 @@
import json
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, Union
import aiohttp
import requests
@ -111,6 +111,18 @@ class _OllamaCommon(BaseLanguageModel):
timeout: Optional[int] = None
"""Timeout for the request stream"""
keep_alive: Optional[Union[int, str]] = None
"""How long the model will stay loaded into memory.
The parameter (Default: 5 minutes) can be set to:
1. a duration string in Golang (such as "10m" or "24h");
2. a number in seconds (such as 3600);
3. any negative number which will keep the model loaded \
in memory (e.g. -1 or "-1m");
4. 0 which will unload the model immediately after generating a response;
See the [Ollama documents](https://github.com/ollama/ollama/blob/main/docs/faq.md#how-do-i-keep-a-model-loaded-in-memory-or-make-it-unload-immediately)"""
headers: Optional[dict] = None
"""Additional headers to pass to endpoint (e.g. Authorization, Referer).
This is useful when Ollama is hosted on cloud services that require
@ -141,6 +153,7 @@ class _OllamaCommon(BaseLanguageModel):
},
"system": self.system,
"template": self.template,
"keep_alive": self.keep_alive,
}
@property
@ -462,12 +475,12 @@ class Ollama(BaseLLM, _OllamaCommon):
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
yield chunk
async def _astream(
self,
@ -479,9 +492,9 @@ class Ollama(BaseLLM, _OllamaCommon):
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
yield chunk

@ -365,7 +365,6 @@ class BaseOpenAI(BaseLLM):
if not isinstance(stream_resp, dict):
stream_resp = stream_resp.dict()
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
@ -375,6 +374,7 @@ class BaseOpenAI(BaseLLM):
if chunk.generation_info
else None,
)
yield chunk
async def _astream(
self,

@ -231,9 +231,9 @@ class PaiEasEndpoint(LLM):
# yield text, if any
if text:
res = GenerationChunk(text=text)
yield res
if run_manager:
run_manager.on_llm_new_token(res.text)
yield res
# break if stop sequence found
if stop_seq_found:

@ -177,12 +177,12 @@ class Replicate(LLM):
if not output:
break
if output:
yield GenerationChunk(text=output)
if run_manager:
run_manager.on_llm_new_token(
output,
verbose=self.verbose,
)
yield GenerationChunk(text=output)
if stop_condition_reached:
break

@ -169,9 +169,9 @@ class SparkLLM(LLM):
if "data" not in content:
continue
delta = content["data"]
yield GenerationChunk(text=delta["content"])
if run_manager:
run_manager.on_llm_new_token(delta)
yield GenerationChunk(text=delta["content"])
class _SparkLLMClient:

@ -207,9 +207,9 @@ class TitanTakeoffPro(LLM):
# Yield any remaining content in the buffer.
if buffer:
chunk = GenerationChunk(text=buffer.replace("</s>", ""))
yield chunk
if run_manager:
run_manager.on_llm_new_token(token=chunk.text)
yield chunk
@property
def _identifying_params(self) -> Mapping[str, Any]:

@ -86,7 +86,7 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for getting tables names."""
name: str = "sql_db_list_tables"
description: str = "Input is an empty string, output is a comma separated list of tables in the database."
description: str = "Input is an empty string, output is a comma-separated list of tables in the database."
args_schema: Type[BaseModel] = _ListSQLDataBaseToolInput
def _run(
@ -94,7 +94,7 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
tool_input: str = "",
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the schema for a specific table."""
"""Get a comma-separated list of table names."""
return ", ".join(self.db.get_usable_table_names())

@ -42,6 +42,7 @@ _module_lookup = {
"Clarifai": "langchain_community.vectorstores.clarifai",
"Clickhouse": "langchain_community.vectorstores.clickhouse",
"ClickhouseSettings": "langchain_community.vectorstores.clickhouse",
"CouchbaseVectorStore": "langchain_community.vectorstores.couchbase",
"DashVector": "langchain_community.vectorstores.dashvector",
"DatabricksVectorSearch": "langchain_community.vectorstores.databricks_vector_search", # noqa: E501
"DeepLake": "langchain_community.vectorstores.deeplake",

@ -211,12 +211,48 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
@property
def embeddings(self) -> Embeddings:
"""Provides access to the embedding mechanism used by the Clickhouse instance.
This property allows direct access to the embedding function or model being
used by the Clickhouse instance to convert text documents into embedding vectors
for vector similarity search.
Returns:
The `Embeddings` instance associated with this Clickhouse instance.
"""
return self.embedding_function
def escape_str(self, value: str) -> str:
"""Escape special characters in a string for Clickhouse SQL queries.
This method is used internally to prepare strings for safe insertion
into SQL queries by escaping special characters that might otherwise
interfere with the query syntax.
Args:
value: The string to be escaped.
Returns:
The escaped string, safe for insertion into SQL queries.
"""
return "".join(f"{self.BS}{c}" if c in self.must_escape else c for c in value)
def _build_insert_sql(self, transac: Iterable, column_names: Iterable[str]) -> str:
"""Construct an SQL query for inserting data into the Clickhouse database.
This method formats and constructs an SQL `INSERT` query string using the
provided transaction data and column names. It is utilized internally during
the process of batch insertion of documents and their embeddings into the
database.
Args:
transac: iterable of tuples, representing a row of data to be inserted.
column_names: iterable of strings representing the names of the columns
into which data will be inserted.
Returns:
A string containing the constructed SQL `INSERT` query.
"""
ks = ",".join(column_names)
_data = []
for n in transac:
@ -231,6 +267,17 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
return i_str
def _insert(self, transac: Iterable, column_names: Iterable[str]) -> None:
"""Execute an SQL query to insert data into the Clickhouse database.
This method performs the actual insertion of data into the database by
executing the SQL query constructed by `_build_insert_sql`. It's a critical
step in adding new documents and their associated data into the vector store.
Args:
transac:iterable of tuples, representing a row of data to be inserted.
column_names: An iterable of strings representing the names of the columns
into which data will be inserted.
"""
_insert_query = self._build_insert_sql(transac, column_names)
self.client.command(_insert_query)
@ -345,6 +392,21 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
def _build_query_sql(
self, q_emb: List[float], topk: int, where_str: Optional[str] = None
) -> str:
"""Construct an SQL query for performing a similarity search.
This internal method generates an SQL query for finding the top-k most similar
vectors in the database to a given query vector.It allows for optional filtering
conditions to be applied via a WHERE clause.
Args:
q_emb: The query vector as a list of floats.
topk: The number of top similar items to retrieve.
where_str: opt str representing additional WHERE conditions for the query
Defaults to None.
Returns:
A string containing the SQL query for the similarity search.
"""
q_emb_str = ",".join(map(str, q_emb))
if where_str:
where_str = f"PREWHERE {where_str}"

@ -0,0 +1,617 @@
from __future__ import annotations
import uuid
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
if TYPE_CHECKING:
from couchbase.cluster import Cluster
class CouchbaseVectorStore(VectorStore):
"""`Couchbase Vector Store` vector store.
To use it, you need
- a recent installation of the `couchbase` library
- a Couchbase database with a pre-defined Search index with support for
vector fields
Example:
.. code-block:: python
from langchain_community.vectorstores import CouchbaseVectorStore
from langchain_openai import OpenAIEmbeddings
from couchbase.cluster import Cluster
from couchbase.auth import PasswordAuthenticator
from couchbase.options import ClusterOptions
from datetime import timedelta
auth = PasswordAuthenticator(username, password)
options = ClusterOptions(auth)
connect_string = "couchbases://localhost"
cluster = Cluster(connect_string, options)
# Wait until the cluster is ready for use.
cluster.wait_until_ready(timedelta(seconds=5))
embeddings = OpenAIEmbeddings()
vectorstore = CouchbaseVectorStore(
cluster=cluster,
bucket_name="",
scope_name="",
collection_name="",
embedding=embeddings,
index_name="vector-index",
)
vectorstore.add_texts(["hello", "world"])
results = vectorstore.similarity_search("ola", k=1)
"""
# Default batch size
DEFAULT_BATCH_SIZE = 100
_metadata_key = "metadata"
_default_text_key = "text"
_default_embedding_key = "embedding"
def _check_bucket_exists(self) -> bool:
"""Check if the bucket exists in the linked Couchbase cluster"""
bucket_manager = self._cluster.buckets()
try:
bucket_manager.get_bucket(self._bucket_name)
return True
except Exception:
return False
def _check_scope_and_collection_exists(self) -> bool:
"""Check if the scope and collection exists in the linked Couchbase bucket
Raises a ValueError if either is not found"""
scope_collection_map: Dict[str, Any] = {}
# Get a list of all scopes in the bucket
for scope in self._bucket.collections().get_all_scopes():
scope_collection_map[scope.name] = []
# Get a list of all the collections in the scope
for collection in scope.collections:
scope_collection_map[scope.name].append(collection.name)
# Check if the scope exists
if self._scope_name not in scope_collection_map.keys():
raise ValueError(
f"Scope {self._scope_name} not found in Couchbase "
f"bucket {self._bucket_name}"
)
# Check if the collection exists in the scope
if self._collection_name not in scope_collection_map[self._scope_name]:
raise ValueError(
f"Collection {self._collection_name} not found in scope "
f"{self._scope_name} in Couchbase bucket {self._bucket_name}"
)
return True
def _check_index_exists(self) -> bool:
"""Check if the Search index exists in the linked Couchbase cluster
Raises a ValueError if the index does not exist"""
if self._scoped_index:
all_indexes = [
index.name for index in self._scope.search_indexes().get_all_indexes()
]
if self._index_name not in all_indexes:
raise ValueError(
f"Index {self._index_name} does not exist. "
" Please create the index before searching."
)
else:
all_indexes = [
index.name for index in self._cluster.search_indexes().get_all_indexes()
]
if self._index_name not in all_indexes:
raise ValueError(
f"Index {self._index_name} does not exist. "
" Please create the index before searching."
)
return True
def __init__(
self,
cluster: Cluster,
bucket_name: str,
scope_name: str,
collection_name: str,
embedding: Embeddings,
index_name: str,
*,
text_key: Optional[str] = _default_text_key,
embedding_key: Optional[str] = _default_embedding_key,
scoped_index: bool = True,
) -> None:
"""
Initialize the Couchbase Vector Store.
Args:
cluster (Cluster): couchbase cluster object with active connection.
bucket_name (str): name of bucket to store documents in.
scope_name (str): name of scope in the bucket to store documents in.
collection_name (str): name of collection in the scope to store documents in
embedding (Embeddings): embedding function to use.
index_name (str): name of the Search index to use.
text_key (optional[str]): key in document to use as text.
Set to text by default.
embedding_key (optional[str]): key in document to use for the embeddings.
Set to embedding by default.
scoped_index (optional[bool]): specify whether the index is a scoped index.
Set to True by default.
"""
try:
from couchbase.cluster import Cluster
except ImportError as e:
raise ImportError(
"Could not import couchbase python package. "
"Please install couchbase SDK with `pip install couchbase`."
) from e
if not isinstance(cluster, Cluster):
raise ValueError(
f"cluster should be an instance of couchbase.Cluster, "
f"got {type(cluster)}"
)
self._cluster = cluster
if not embedding:
raise ValueError("Embeddings instance must be provided.")
if not bucket_name:
raise ValueError("bucket_name must be provided.")
if not scope_name:
raise ValueError("scope_name must be provided.")
if not collection_name:
raise ValueError("collection_name must be provided.")
if not index_name:
raise ValueError("index_name must be provided.")
self._bucket_name = bucket_name
self._scope_name = scope_name
self._collection_name = collection_name
self._embedding_function = embedding
self._text_key = text_key
self._embedding_key = embedding_key
self._index_name = index_name
self._scoped_index = scoped_index
# Check if the bucket exists
if not self._check_bucket_exists():
raise ValueError(
f"Bucket {self._bucket_name} does not exist. "
" Please create the bucket before searching."
)
try:
self._bucket = self._cluster.bucket(self._bucket_name)
self._scope = self._bucket.scope(self._scope_name)
self._collection = self._scope.collection(self._collection_name)
except Exception as e:
raise ValueError(
"Error connecting to couchbase. "
"Please check the connection and credentials."
) from e
# Check if the scope and collection exists. Throws ValueError if they don't
try:
self._check_scope_and_collection_exists()
except Exception as e:
raise e
# Check if the index exists. Throws ValueError if it doesn't
try:
self._check_index_exists()
except Exception as e:
raise e
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[Dict[str, Any]]] = None,
ids: Optional[List[str]] = None,
batch_size: Optional[int] = None,
**kwargs: Any,
) -> List[str]:
"""Run texts through the embeddings and persist in vectorstore.
If the document IDs are passed, the existing documents (if any) will be
overwritten with the new ones.
Args:
texts (Iterable[str]): Iterable of strings to add to the vectorstore.
metadatas (Optional[List[Dict]]): Optional list of metadatas associated
with the texts.
ids (Optional[List[str]]): Optional list of ids associated with the texts.
IDs have to be unique strings across the collection.
If it is not specified uuids are generated and used as ids.
batch_size (Optional[int]): Optional batch size for bulk insertions.
Default is 100.
Returns:
List[str]:List of ids from adding the texts into the vectorstore.
"""
from couchbase.exceptions import DocumentExistsException
if not batch_size:
batch_size = self.DEFAULT_BATCH_SIZE
doc_ids: List[str] = []
if ids is None:
ids = [uuid.uuid4().hex for _ in texts]
if metadatas is None:
metadatas = [{} for _ in texts]
embedded_texts = self._embedding_function.embed_documents(list(texts))
documents_to_insert = [
{
id: {
self._text_key: text,
self._embedding_key: vector,
self._metadata_key: metadata,
}
for id, text, vector, metadata in zip(
ids, texts, embedded_texts, metadatas
)
}
]
# Insert in batches
for i in range(0, len(documents_to_insert), batch_size):
batch = documents_to_insert[i : i + batch_size]
try:
result = self._collection.upsert_multi(batch[0])
if result.all_ok:
doc_ids.extend(batch[0].keys())
except DocumentExistsException as e:
raise ValueError(f"Document already exists: {e}")
return doc_ids
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
"""Delete documents from the vector store by ids.
Args:
ids (List[str]): List of IDs of the documents to delete.
batch_size (Optional[int]): Optional batch size for bulk deletions.
Returns:
bool: True if all the documents were deleted successfully, False otherwise.
"""
from couchbase.exceptions import DocumentNotFoundException
if ids is None:
raise ValueError("No document ids provided to delete.")
batch_size = kwargs.get("batch_size", self.DEFAULT_BATCH_SIZE)
deletion_status = True
# Delete in batches
for i in range(0, len(ids), batch_size):
batch = ids[i : i + batch_size]
try:
result = self._collection.remove_multi(batch)
except DocumentNotFoundException as e:
deletion_status = False
raise ValueError(f"Document not found: {e}")
deletion_status &= result.all_ok
return deletion_status
@property
def embeddings(self) -> Embeddings:
"""Return the query embedding object."""
return self._embedding_function
def _format_metadata(self, row_fields: Dict[str, Any]) -> Dict[str, Any]:
"""Helper method to format the metadata from the Couchbase Search API.
Args:
row_fields (Dict[str, Any]): The fields to format.
Returns:
Dict[str, Any]: The formatted metadata.
"""
metadata = {}
for key, value in row_fields.items():
# Couchbase Search returns the metadata key with a prefix
# `metadata.` We remove it to get the original metadata key
if key.startswith(self._metadata_key):
new_key = key.split(self._metadata_key + ".")[-1]
metadata[new_key] = value
else:
metadata[key] = value
return metadata
def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
search_options: Optional[Dict[str, Any]] = {},
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to embedding vector with their scores.
Args:
embedding (List[float]): Embedding vector to look up documents similar to.
k (int): Number of Documents to return.
Defaults to 4.
search_options (Optional[Dict[str, Any]]): Optional search options that are
passed to Couchbase search.
Defaults to empty dictionary.
fields (Optional[List[str]]): Optional list of fields to include in the
metadata of results. Note that these need to be stored in the index.
If nothing is specified, defaults to all the fields stored in the index.
Returns:
List of (Document, score) that are the most similar to the query vector.
"""
import couchbase.search as search
from couchbase.options import SearchOptions
from couchbase.vector_search import VectorQuery, VectorSearch
fields = kwargs.get("fields", ["*"])
# Document text field needs to be returned from the search
if fields != ["*"] and self._text_key not in fields:
fields.append(self._text_key)
search_req = search.SearchRequest.create(
VectorSearch.from_vector_query(
VectorQuery(
self._embedding_key,
embedding,
k,
)
)
)
try:
if self._scoped_index:
search_iter = self._scope.search(
self._index_name,
search_req,
SearchOptions(
limit=k,
fields=fields,
raw=search_options,
),
)
else:
search_iter = self._cluster.search(
index=self._index_name,
request=search_req,
options=SearchOptions(limit=k, fields=fields, raw=search_options),
)
docs_with_score = []
# Parse the results
for row in search_iter.rows():
text = row.fields.pop(self._text_key, "")
# Format the metadata from Couchbase
metadata = self._format_metadata(row.fields)
score = row.score
doc = Document(page_content=text, metadata=metadata)
docs_with_score.append((doc, score))
except Exception as e:
raise ValueError(f"Search failed with error: {e}")
return docs_with_score
def similarity_search(
self,
query: str,
k: int = 4,
search_options: Optional[Dict[str, Any]] = {},
**kwargs: Any,
) -> List[Document]:
"""Return documents most similar to embedding vector with their scores.
Args:
query (str): Query to look up for similar documents
k (int): Number of Documents to return.
Defaults to 4.
search_options (Optional[Dict[str, Any]]): Optional search options that are
passed to Couchbase search.
Defaults to empty dictionary
fields (Optional[List[str]]): Optional list of fields to include in the
metadata of results. Note that these need to be stored in the index.
If nothing is specified, defaults to all the fields stored in the index.
Returns:
List of Documents most similar to the query.
"""
query_embedding = self.embeddings.embed_query(query)
docs_with_scores = self.similarity_search_with_score_by_vector(
query_embedding, k, search_options, **kwargs
)
return [doc for doc, _ in docs_with_scores]
def similarity_search_with_score(
self,
query: str,
k: int = 4,
search_options: Optional[Dict[str, Any]] = {},
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return documents that are most similar to the query with their scores.
Args:
query (str): Query to look up for similar documents
k (int): Number of Documents to return.
Defaults to 4.
search_options (Optional[Dict[str, Any]]): Optional search options that are
passed to Couchbase search.
Defaults to empty dictionary.
fields (Optional[List[str]]): Optional list of fields to include in the
metadata of results. Note that these need to be stored in the index.
If nothing is specified, defaults to text and metadata fields.
Returns:
List of (Document, score) that are most similar to the query.
"""
query_embedding = self.embeddings.embed_query(query)
docs_with_score = self.similarity_search_with_score_by_vector(
query_embedding, k, search_options, **kwargs
)
return docs_with_score
def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
search_options: Optional[Dict[str, Any]] = {},
**kwargs: Any,
) -> List[Document]:
"""Return documents that are most similar to the vector embedding.
Args:
embedding (List[float]): Embedding to look up documents similar to.
k (int): Number of Documents to return.
Defaults to 4.
search_options (Optional[Dict[str, Any]]): Optional search options that are
passed to Couchbase search.
Defaults to empty dictionary.
fields (Optional[List[str]]): Optional list of fields to include in the
metadata of results. Note that these need to be stored in the index.
If nothing is specified, defaults to document text and metadata fields.
Returns:
List of Documents most similar to the query.
"""
docs_with_score = self.similarity_search_with_score_by_vector(
embedding, k, search_options, **kwargs
)
return [doc for doc, _ in docs_with_score]
@classmethod
def _from_kwargs(
cls: Type[CouchbaseVectorStore],
embedding: Embeddings,
**kwargs: Any,
) -> CouchbaseVectorStore:
"""Initialize the Couchbase vector store from keyword arguments for the
vector store.
Args:
embedding: Embedding object to use to embed text.
**kwargs: Keyword arguments to initialize the vector store with.
Accepted arguments are:
- cluster
- bucket_name
- scope_name
- collection_name
- index_name
- text_key
- embedding_key
- scoped_index
"""
cluster = kwargs.get("cluster", None)
bucket_name = kwargs.get("bucket_name", None)
scope_name = kwargs.get("scope_name", None)
collection_name = kwargs.get("collection_name", None)
index_name = kwargs.get("index_name", None)
text_key = kwargs.get("text_key", cls._default_text_key)
embedding_key = kwargs.get("embedding_key", cls._default_embedding_key)
scoped_index = kwargs.get("scoped_index", True)
return cls(
embedding=embedding,
cluster=cluster,
bucket_name=bucket_name,
scope_name=scope_name,
collection_name=collection_name,
index_name=index_name,
text_key=text_key,
embedding_key=embedding_key,
scoped_index=scoped_index,
)
@classmethod
def from_texts(
cls: Type[CouchbaseVectorStore],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[Dict[Any, Any]]] = None,
**kwargs: Any,
) -> CouchbaseVectorStore:
"""Construct a Couchbase vector store from a list of texts.
Example:
.. code-block:: python
from langchain_community.vectorstores import CouchbaseVectorStore
from langchain_openai import OpenAIEmbeddings
from couchbase.cluster import Cluster
from couchbase.auth import PasswordAuthenticator
from couchbase.options import ClusterOptions
from datetime import timedelta
auth = PasswordAuthenticator(username, password)
options = ClusterOptions(auth)
connect_string = "couchbases://localhost"
cluster = Cluster(connect_string, options)
# Wait until the cluster is ready for use.
cluster.wait_until_ready(timedelta(seconds=5))
embeddings = OpenAIEmbeddings()
texts = ["hello", "world"]
vectorstore = CouchbaseVectorStore.from_texts(
texts,
embedding=embeddings,
cluster=cluster,
bucket_name="",
scope_name="",
collection_name="",
index_name="vector-index",
)
Args:
texts (List[str]): list of texts to add to the vector store.
embedding (Embeddings): embedding function to use.
metadatas (optional[List[Dict]): list of metadatas to add to documents.
**kwargs: Keyword arguments used to initialize the vector store with and/or
passed to `add_texts` method. Check the constructor and/or `add_texts`
for the list of accepted arguments.
Returns:
A Couchbase vector store.
"""
vector_store = cls._from_kwargs(embedding, **kwargs)
batch_size = kwargs.get("batch_size", vector_store.DEFAULT_BATCH_SIZE)
ids = kwargs.get("ids", None)
vector_store.add_texts(
texts, metadatas=metadatas, ids=ids, batch_size=batch_size
)
return vector_store

@ -5,14 +5,7 @@ from __future__ import annotations
import json
import logging
import uuid
from typing import (
Any,
Iterable,
List,
Optional,
Tuple,
Type,
)
from typing import Any, Iterable, List, Optional, Tuple, Type, cast
import requests
from langchain_core.documents import Document
@ -25,29 +18,44 @@ logger = logging.getLogger(__name__)
class InfinispanVS(VectorStore):
"""`Infinispan` VectorStore interface.
This class exposes the method to present Infinispan as a
VectorStore. It relies on the Infinispan class (below) which takes care
of the REST interface with the server.
This class exposes the method to present Infinispan as a
VectorStore. It relies on the Infinispan class (below) which takes care
of the REST interface with the server.
Example:
.. code-block:: python
... code-block:: python
from langchain_community.vectorstores import InfinispanVS
from mymodels import RGBEmbeddings
...
vectorDb = InfinispanVS.from_documents(docs,
embedding=RGBEmbeddings(),
output_fields=["texture", "color"],
lambda_key=lambda text,meta: str(meta["_key"]),
lambda_content=lambda item: item["color"])
or an empty InfinispanVS instance can be created if preliminary setup
is required before populating the store
... code-block:: python
from langchain_community.vectorstores import InfinispanVS
from mymodels import RGBEmbeddings
...
ispnVS = InfinispanVS()
# configure Infinispan here
# i.e. create cache and schema
# then populate the store
vectorDb = InfinispanVS.from_documents(docs,
embedding=RGBEmbeddings(),
output_fields: ["texture", "color"],
lambda_key: lambda text,meta: str(meta["_key"]),
lambda_content: lambda item: item["color"]})
"""
def __init__(
self,
embedding: Optional[Embeddings] = None,
ids: Optional[List[str]] = None,
clear_old: Optional[bool] = True,
**kwargs: Any,
):
self.ispn = Infinispan(**kwargs)
@ -65,8 +73,6 @@ class InfinispanVS(VectorStore):
)
self._output_fields = self._configuration.get("output_fields")
self._ids = ids
if clear_old:
self.ispn.cache_clear(self._cache_name)
def _default_metadata(self, item: dict) -> dict:
meta = dict(item)
@ -78,6 +84,43 @@ class InfinispanVS(VectorStore):
def _default_content(self, item: dict[str, Any]) -> Any:
return item.get(self._textfield)
def schema_builder(self, templ: dict, dimension: int) -> str:
metadata_proto_tpl = """
/**
* @Indexed
*/
message %s {
/**
* @Vector(dimension=%d)
*/
repeated float %s = 1;
"""
metadata_proto = metadata_proto_tpl % (
self._entity_name,
dimension,
self._vectorfield,
)
idx = 2
for f, v in templ.items():
if isinstance(v, str):
metadata_proto += "optional string " + f + " = " + str(idx) + ";\n"
elif isinstance(v, int):
metadata_proto += "optional int64 " + f + " = " + str(idx) + ";\n"
elif isinstance(v, float):
metadata_proto += "optional double " + f + " = " + str(idx) + ";\n"
elif isinstance(v, bytes):
metadata_proto += "optional bytes " + f + " = " + str(idx) + ";\n"
elif isinstance(v, bool):
metadata_proto += "optional bool " + f + " = " + str(idx) + ";\n"
else:
raise Exception(
"Unable to build proto schema for metadata. "
"Unhandled type for field: " + f
)
idx += 1
metadata_proto += "}\n"
return metadata_proto
def schema_create(self, proto: str) -> requests.Response:
"""Deploy the schema for the vector db
Args:
@ -143,6 +186,13 @@ class InfinispanVS(VectorStore):
"""
return self.ispn.cache_clear(self._cache_name)
def cache_exists(self) -> bool:
"""Checks if the cache exists
Returns:
true if exists
"""
return self.ispn.cache_exists(self._cache_name)
def cache_index_clear(self) -> requests.Response:
"""Clear the index for the vector db
Returns:
@ -161,10 +211,16 @@ class InfinispanVS(VectorStore):
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
last_vector: Optional[List[float]] = None,
**kwargs: Any,
) -> List[str]:
result = []
embeds = self._embedding.embed_documents(list(texts)) # type: ignore
texts_l = list(texts)
if last_vector:
texts_l.pop()
embeds = self._embedding.embed_documents(texts_l) # type: ignore
if last_vector:
embeds.append(last_vector)
if not metadatas:
metadatas = [{} for _ in texts]
ids = self._ids or [str(uuid.uuid4()) for _ in texts]
@ -266,6 +322,23 @@ class InfinispanVS(VectorStore):
documents.append((doc, hit["score()"]))
return documents
def configure(self, metadata: dict, dimension: int) -> None:
schema = self.schema_builder(metadata, dimension)
output = self.schema_create(schema)
assert output.ok, "Unable to create schema. Already exists? "
"Consider using clear_old=True"
assert json.loads(output.text)["error"] is None
if not self.cache_exists():
output = self.cache_create()
assert output.ok, "Unable to create cache. Already exists? "
"Consider using clear_old=True"
# Ensure index is clean
self.cache_index_clear()
def config_clear(self) -> None:
self.schema_delete()
self.cache_delete()
@classmethod
def from_texts(
cls: Type[InfinispanVS],
@ -273,13 +346,24 @@ class InfinispanVS(VectorStore):
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
clear_old: Optional[bool] = None,
clear_old: Optional[bool] = True,
auto_config: Optional[bool] = True,
**kwargs: Any,
) -> InfinispanVS:
"""Return VectorStore initialized from texts and embeddings."""
infinispanvs = cls(embedding=embedding, ids=ids, clear_old=clear_old, **kwargs)
infinispanvs = cls(embedding=embedding, ids=ids, **kwargs)
if auto_config and len(metadatas or []) > 0:
if clear_old:
infinispanvs.config_clear()
vec = embedding.embed_query(texts[len(texts) - 1])
metadatas = cast(List[dict], metadatas)
infinispanvs.configure(metadatas[0], len(vec))
else:
if clear_old:
infinispanvs.cache_clear()
vec = embedding.embed_query(texts[len(texts) - 1])
if texts:
infinispanvs.add_texts(texts, metadatas)
infinispanvs.add_texts(texts, metadatas, vector=vec)
return infinispanvs
@ -293,7 +377,8 @@ class Infinispan:
create and set up a vector db.
You need a running Infinispan (15+) server without authentication.
You can easily start one, see: https://github.com/rigazilla/infinispan-vector#run-infinispan
You can easily start one, see:
https://github.com/rigazilla/infinispan-vector#run-infinispan
"""
def __init__(self, **kwargs: Any):
@ -473,6 +558,29 @@ class Infinispan:
response = requests.post(api_url, timeout=REST_TIMEOUT)
return response
def cache_exists(self, cache_name: str) -> bool:
"""Check if a cache exists
Args:
cache_name(str): name of the cache.
Returns:
True if cache exists
"""
api_url = (
self._default_node + self._cache_url + "/" + cache_name + "?action=clear"
)
return self.resource_exists(api_url)
@staticmethod
def resource_exists(api_url: str) -> bool:
"""Check if a resource exists
Args:
api_url(str): url of the resource.
Returns:
true if resource exists
"""
response = requests.head(api_url, timeout=REST_TIMEOUT)
return response.ok
def index_clear(self, cache_name: str) -> requests.Response:
"""Clear an index on a cache
Args:

@ -0,0 +1,199 @@
import uuid
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from langchain_community.utils.math import cosine_similarity
from langchain_community.vectorstores.utils import maximal_marginal_relevance
class InMemoryVectorStore(VectorStore):
"""In-memory implementation of VectorStore using a dictionary.
Uses numpy to compute cosine similarity for search.
Args:
embedding: embedding function to use.
"""
def __init__(self, embedding: Embeddings) -> None:
self.store: Dict[str, Dict[str, Any]] = {}
self.embedding = embedding
@property
def embeddings(self) -> Embeddings:
return self.embedding
def delete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
if ids:
for _id in ids:
self.store.pop(_id, None)
async def adelete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
self.delete(ids)
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
ids = []
vectors = self.embedding.embed_documents(list(texts))
for i, text in enumerate(texts):
doc_id = str(uuid.uuid4())
ids.append(doc_id)
self.store[doc_id] = {
"id": doc_id,
"vector": vectors[i],
"text": text,
"metadata": metadatas[i] if metadatas else {},
}
return ids
async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
return self.add_texts(texts, metadatas, **kwargs)
def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
) -> List[Tuple[Document, float]]:
docs_with_similarity = []
for doc in self.store.values():
similarity = float(cosine_similarity([embedding], [doc["vector"]]).item(0))
docs_with_similarity.append(
(
Document(page_content=doc["text"], metadata=doc["metadata"]),
similarity,
)
)
docs_with_similarity.sort(key=lambda x: x[1], reverse=True)
return docs_with_similarity[:k]
def similarity_search_with_score(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
embedding = self.embedding.embed_query(query)
docs = self.similarity_search_with_score_by_vector(
embedding,
k,
)
return docs
async def asimilarity_search_with_score(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]:
return self.similarity_search_with_score(query, k, **kwargs)
def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
**kwargs: Any,
) -> List[Document]:
docs_and_scores = self.similarity_search_with_score_by_vector(
embedding,
k,
)
return [doc for doc, _ in docs_and_scores]
async def asimilarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
return self.similarity_search_by_vector(embedding, k, **kwargs)
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
return [doc for doc, _ in self.similarity_search_with_score(query, k, **kwargs)]
async def asimilarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
return self.similarity_search(query, k, **kwargs)
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
docs_with_similarity = []
for doc in self.store.values():
similarity = float(cosine_similarity([embedding], [doc["vector"]]).item(0))
docs_with_similarity.append(
(
doc,
similarity,
)
)
docs_with_similarity.sort(key=lambda x: x[1], reverse=True)
prefetch_hits = docs_with_similarity[:fetch_k]
mmr_chosen_indices = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
[doc["vector"] for doc, _ in prefetch_hits],
k=k,
lambda_mult=lambda_mult,
)
return [
Document(
page_content=prefetch_hits[idx][0]["text"],
metadata=prefetch_hits[idx][0]["metadata"],
)
for idx in mmr_chosen_indices
]
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
embedding_vector = self.embedding.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding_vector,
k,
fetch_k,
lambda_mult=lambda_mult,
)
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> "InMemoryVectorStore":
store = cls(
embedding=embedding,
)
store.add_texts(texts=texts, metadatas=metadatas)
return store
@classmethod
async def afrom_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> "InMemoryVectorStore":
return cls.from_texts(texts, embedding, metadatas, **kwargs)

@ -467,6 +467,7 @@ class Milvus(VectorStore):
from pymilvus import Collection, utility
from pymilvus.client.types import LoadState
timeout = self.timeout or timeout
if (
isinstance(self.col, Collection)
and self._get_index() is not None
@ -483,7 +484,7 @@ class Milvus(VectorStore):
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
timeout: Optional[int] = None,
timeout: Optional[float] = None,
batch_size: int = 1000,
*,
ids: Optional[List[str]] = None,
@ -504,7 +505,7 @@ class Milvus(VectorStore):
metadatas (Optional[List[dict]]): Metadata dicts attached to each of
the texts. Defaults to None.
should be less than 65535 bytes. Required and work when auto_id is False.
timeout (Optional[int]): Timeout for each batch insert. Defaults
timeout (Optional[float]): Timeout for each batch insert. Defaults
to None.
batch_size (int, optional): Batch size to use for insertion.
Defaults to 1000.
@ -592,6 +593,7 @@ class Milvus(VectorStore):
# Insert into the collection.
try:
res: Collection
timeout = self.timeout or timeout
res = self.col.insert(insert_list, timeout=timeout, **kwargs)
pks.extend(res.primary_keys)
except MilvusException as e:
@ -608,7 +610,7 @@ class Milvus(VectorStore):
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a similarity search against the query string.
@ -629,6 +631,7 @@ class Milvus(VectorStore):
if self.col is None:
logger.debug("No existing collection to search.")
return []
timeout = self.timeout or timeout
res = self.similarity_search_with_score(
query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
@ -640,7 +643,7 @@ class Milvus(VectorStore):
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a similarity search against the query string.
@ -661,6 +664,7 @@ class Milvus(VectorStore):
if self.col is None:
logger.debug("No existing collection to search.")
return []
timeout = self.timeout or timeout
res = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
@ -672,7 +676,7 @@ class Milvus(VectorStore):
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Perform a search on a query string and return results with score.
@ -687,7 +691,7 @@ class Milvus(VectorStore):
param (dict): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
timeout (float, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
@ -700,7 +704,7 @@ class Milvus(VectorStore):
# Embed the query text.
embedding = self.embedding_func.embed_query(query)
timeout = self.timeout or timeout
res = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
@ -712,7 +716,7 @@ class Milvus(VectorStore):
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Perform a search on a query string and return results with score.
@ -727,7 +731,7 @@ class Milvus(VectorStore):
param (dict): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
timeout (float, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
@ -744,7 +748,7 @@ class Milvus(VectorStore):
# Determine result metadata fields with PK.
output_fields = self.fields[:]
output_fields.remove(self._vector_field)
timeout = self.timeout or timeout
# Perform the search.
res = self.col.search(
data=[embedding],
@ -774,7 +778,7 @@ class Milvus(VectorStore):
lambda_mult: float = 0.5,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a search and return results that are reordered by MMR.
@ -791,7 +795,7 @@ class Milvus(VectorStore):
param (dict, optional): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
timeout (float, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
@ -804,7 +808,7 @@ class Milvus(VectorStore):
return []
embedding = self.embedding_func.embed_query(query)
timeout = self.timeout or timeout
return self.max_marginal_relevance_search_by_vector(
embedding=embedding,
k=k,
@ -824,7 +828,7 @@ class Milvus(VectorStore):
lambda_mult: float = 0.5,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a search and return results that are reordered by MMR.
@ -841,7 +845,7 @@ class Milvus(VectorStore):
param (dict, optional): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
timeout (float, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
@ -858,7 +862,7 @@ class Milvus(VectorStore):
# Determine result metadata fields.
output_fields = self.fields[:]
output_fields.remove(self._vector_field)
timeout = self.timeout or timeout
# Perform the search.
res = self.col.search(
data=[embedding],
@ -1049,7 +1053,7 @@ class Milvus(VectorStore):
except MilvusException:
pass
try:
return self.add_documents(documents=documents)
return self.add_documents(documents=documents, **kwargs)
except MilvusException as exc:
logger.error(
"Failed to upsert entities: %s error: %s", self.collection_name, exc

@ -0,0 +1,367 @@
"""Test Couchbase Vector Store functionality"""
import os
import time
from typing import Any
import pytest
from langchain_core.documents import Document
from langchain_community.vectorstores.couchbase import CouchbaseVectorStore
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
)
CONNECTION_STRING = os.getenv("COUCHBASE_CONNECTION_STRING", "")
BUCKET_NAME = os.getenv("COUCHBASE_BUCKET_NAME", "")
SCOPE_NAME = os.getenv("COUCHBASE_SCOPE_NAME", "")
COLLECTION_NAME = os.getenv("COUCHBASE_COLLECTION_NAME", "")
USERNAME = os.getenv("COUCHBASE_USERNAME", "")
PASSWORD = os.getenv("COUCHBASE_PASSWORD", "")
INDEX_NAME = os.getenv("COUCHBASE_INDEX_NAME", "")
SLEEP_DURATION = 1
def set_all_env_vars() -> bool:
return all(
[
CONNECTION_STRING,
BUCKET_NAME,
SCOPE_NAME,
COLLECTION_NAME,
USERNAME,
PASSWORD,
INDEX_NAME,
]
)
def get_cluster() -> Any:
"""Get a couchbase cluster object"""
from datetime import timedelta
from couchbase.auth import PasswordAuthenticator
from couchbase.cluster import Cluster
from couchbase.options import ClusterOptions
auth = PasswordAuthenticator(USERNAME, PASSWORD)
options = ClusterOptions(auth)
connect_string = CONNECTION_STRING
cluster = Cluster(connect_string, options)
# Wait until the cluster is ready for use.
cluster.wait_until_ready(timedelta(seconds=5))
return cluster
@pytest.fixture()
def cluster() -> Any:
"""Get a couchbase cluster object"""
return get_cluster()
def delete_documents(
cluster: Any, bucket_name: str, scope_name: str, collection_name: str
) -> None:
"""Delete all the documents in the collection"""
query = f"DELETE FROM `{bucket_name}`.`{scope_name}`.`{collection_name}`"
cluster.query(query).execute()
@pytest.mark.requires("couchbase")
@pytest.mark.skipif(
not set_all_env_vars(), reason="Missing Couchbase environment variables"
)
class TestCouchbaseVectorStore:
@classmethod
def setup_method(self) -> None:
cluster = get_cluster()
# Delete all the documents in the collection
delete_documents(cluster, BUCKET_NAME, SCOPE_NAME, COLLECTION_NAME)
def test_from_documents(self, cluster: Any) -> None:
"""Test end to end search using a list of documents."""
documents = [
Document(page_content="foo", metadata={"page": 1}),
Document(page_content="bar", metadata={"page": 2}),
Document(page_content="baz", metadata={"page": 3}),
]
vectorstore = CouchbaseVectorStore.from_documents(
documents,
ConsistentFakeEmbeddings(),
cluster=cluster,
bucket_name=BUCKET_NAME,
scope_name=SCOPE_NAME,
collection_name=COLLECTION_NAME,
index_name=INDEX_NAME,
)
# Wait for the documents to be indexed
time.sleep(SLEEP_DURATION)
output = vectorstore.similarity_search("baz", k=1)
assert output[0].page_content == "baz"
assert output[0].metadata["page"] == 3
def test_from_texts(self, cluster: Any) -> None:
"""Test end to end search using a list of texts."""
texts = [
"foo",
"bar",
"baz",
]
vectorstore = CouchbaseVectorStore.from_texts(
texts,
ConsistentFakeEmbeddings(),
cluster=cluster,
index_name=INDEX_NAME,
bucket_name=BUCKET_NAME,
scope_name=SCOPE_NAME,
collection_name=COLLECTION_NAME,
)
# Wait for the documents to be indexed
time.sleep(SLEEP_DURATION)
output = vectorstore.similarity_search("foo", k=1)
assert len(output) == 1
assert output[0].page_content == "foo"
def test_from_texts_with_metadatas(self, cluster: Any) -> None:
"""Test end to end search using a list of texts and metadatas."""
texts = [
"foo",
"bar",
"baz",
]
metadatas = [{"a": 1}, {"b": 2}, {"c": 3}]
vectorstore = CouchbaseVectorStore.from_texts(
texts,
ConsistentFakeEmbeddings(),
metadatas=metadatas,
cluster=cluster,
index_name=INDEX_NAME,
bucket_name=BUCKET_NAME,
scope_name=SCOPE_NAME,
collection_name=COLLECTION_NAME,
)
# Wait for the documents to be indexed
time.sleep(SLEEP_DURATION)
output = vectorstore.similarity_search("baz", k=1)
assert output[0].page_content == "baz"
assert output[0].metadata["c"] == 3
def test_add_texts_with_ids_and_metadatas(self, cluster: Any) -> None:
"""Test end to end search by adding a list of texts, ids and metadatas."""
texts = [
"foo",
"bar",
"baz",
]
ids = ["a", "b", "c"]
metadatas = [{"a": 1}, {"b": 2}, {"c": 3}]
vectorstore = CouchbaseVectorStore(
cluster=cluster,
embedding=ConsistentFakeEmbeddings(),
index_name=INDEX_NAME,
bucket_name=BUCKET_NAME,
scope_name=SCOPE_NAME,
collection_name=COLLECTION_NAME,
)
results = vectorstore.add_texts(
texts,
ids=ids,
metadatas=metadatas,
)
assert results == ids
# Wait for the documents to be indexed
time.sleep(SLEEP_DURATION)
output = vectorstore.similarity_search("foo", k=1)
assert output[0].page_content == "foo"
assert output[0].metadata["a"] == 1
def test_delete_texts_with_ids(self, cluster: Any) -> None:
"""Test deletion of documents by ids."""
texts = [
"foo",
"bar",
"baz",
]
ids = ["a", "b", "c"]
metadatas = [{"a": 1}, {"b": 2}, {"c": 3}]
vectorstore = CouchbaseVectorStore(
cluster=cluster,
embedding=ConsistentFakeEmbeddings(),
index_name=INDEX_NAME,
bucket_name=BUCKET_NAME,
scope_name=SCOPE_NAME,
collection_name=COLLECTION_NAME,
)
results = vectorstore.add_texts(
texts,
ids=ids,
metadatas=metadatas,
)
assert results == ids
assert vectorstore.delete(ids)
# Wait for the documents to be indexed
time.sleep(SLEEP_DURATION)
output = vectorstore.similarity_search("foo", k=1)
assert len(output) == 0
def test_similarity_search_with_scores(self, cluster: Any) -> None:
"""Test similarity search with scores."""
texts = ["foo", "bar", "baz"]
metadatas = [{"a": 1}, {"b": 2}, {"c": 3}]
vectorstore = CouchbaseVectorStore(
cluster=cluster,
embedding=ConsistentFakeEmbeddings(),
index_name=INDEX_NAME,
bucket_name=BUCKET_NAME,
scope_name=SCOPE_NAME,
collection_name=COLLECTION_NAME,
)
vectorstore.add_texts(texts, metadatas=metadatas)
# Wait for the documents to be indexed
time.sleep(SLEEP_DURATION)
output = vectorstore.similarity_search_with_score("foo", k=2)
assert len(output) == 2
assert output[0][0].page_content == "foo"
# check if the scores are sorted
assert output[0][0].metadata["a"] == 1
assert output[0][1] > output[1][1]
def test_similarity_search_by_vector(self, cluster: Any) -> None:
"""Test similarity search by vector."""
texts = ["foo", "bar", "baz"]
metadatas = [{"a": 1}, {"b": 2}, {"c": 3}]
vectorstore = CouchbaseVectorStore(
cluster=cluster,
embedding=ConsistentFakeEmbeddings(),
index_name=INDEX_NAME,
bucket_name=BUCKET_NAME,
scope_name=SCOPE_NAME,
collection_name=COLLECTION_NAME,
)
vectorstore.add_texts(texts, metadatas=metadatas)
# Wait for the documents to be indexed
time.sleep(SLEEP_DURATION)
vector = ConsistentFakeEmbeddings().embed_query("foo")
vector_output = vectorstore.similarity_search_by_vector(vector, k=1)
assert vector_output[0].page_content == "foo"
similarity_output = vectorstore.similarity_search("foo", k=1)
assert similarity_output == vector_output
def test_output_fields(self, cluster: Any) -> None:
"""Test that output fields are set correctly."""
texts = [
"foo",
"bar",
"baz",
]
metadatas = [{"page": 1, "a": 1}, {"page": 2, "b": 2}, {"page": 3, "c": 3}]
vectorstore = CouchbaseVectorStore(
cluster=cluster,
embedding=ConsistentFakeEmbeddings(),
index_name=INDEX_NAME,
bucket_name=BUCKET_NAME,
scope_name=SCOPE_NAME,
collection_name=COLLECTION_NAME,
)
ids = vectorstore.add_texts(texts, metadatas)
assert len(ids) == len(texts)
# Wait for the documents to be indexed
time.sleep(SLEEP_DURATION)
output = vectorstore.similarity_search("foo", k=1, fields=["metadata.page"])
assert output[0].page_content == "foo"
assert output[0].metadata["page"] == 1
assert "a" not in output[0].metadata
def test_hybrid_search(self, cluster: Any) -> None:
"""Test hybrid search."""
texts = [
"foo",
"bar",
"baz",
]
metadatas = [
{"section": "index"},
{"section": "glossary"},
{"section": "appendix"},
]
vectorstore = CouchbaseVectorStore(
cluster=cluster,
embedding=ConsistentFakeEmbeddings(),
index_name=INDEX_NAME,
bucket_name=BUCKET_NAME,
scope_name=SCOPE_NAME,
collection_name=COLLECTION_NAME,
)
vectorstore.add_texts(texts, metadatas=metadatas)
# Wait for the documents to be indexed
time.sleep(SLEEP_DURATION)
result, score = vectorstore.similarity_search_with_score("foo", k=1)[0]
# Wait for the documents to be indexed for hybrid search
time.sleep(SLEEP_DURATION)
hybrid_result, hybrid_score = vectorstore.similarity_search_with_score(
"foo",
k=1,
search_options={"query": {"match": "index", "field": "metadata.section"}},
)[0]
assert result == hybrid_result
assert score <= hybrid_score

@ -1,17 +1,19 @@
"""Test Infinispan functionality."""
from typing import Any, List, Optional
import pytest
from langchain_core.documents import Document
from langchain_community.vectorstores import InfinispanVS
from langchain_community.vectorstores.infinispanvs import InfinispanVS
from tests.integration_tests.vectorstores.fake_embeddings import (
FakeEmbeddings,
fake_texts,
)
def _infinispan_setup() -> None:
ispnvs = InfinispanVS()
def _infinispan_setup_noautoconf() -> None:
ispnvs = InfinispanVS(auto_config=False)
ispnvs.cache_delete()
ispnvs.schema_delete()
proto = """
@ -37,6 +39,7 @@ def _infinispanvs_from_texts(
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
clear_old: Optional[bool] = True,
auto_config: Optional[bool] = False,
**kwargs: Any,
) -> InfinispanVS:
texts = [{"text": t} for t in fake_texts]
@ -50,86 +53,109 @@ def _infinispanvs_from_texts(
metadatas=metadatas,
ids=ids,
clear_old=clear_old,
auto_config=auto_config,
**kwargs,
)
def test_infinispan() -> None:
"""Test end to end construction and search."""
_infinispan_setup()
docsearch = _infinispanvs_from_texts()
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
def test_infinispan_with_metadata() -> None:
"""Test with metadata"""
_infinispan_setup()
meta = []
for _ in range(len(fake_texts)):
meta.append({"label": "test"})
docsearch = _infinispanvs_from_texts(metadatas=meta)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"label": "test"})]
def test_infinispan_with_metadata_with_output_fields() -> None:
"""Test with metadata"""
_infinispan_setup()
metadatas = [{"page": i, "label": "label" + str(i)} for i in range(len(fake_texts))]
c = {"output_fields": ["label", "page", "text"]}
docsearch = _infinispanvs_from_texts(metadatas=metadatas, configuration=c)
output = docsearch.similarity_search("foo", k=1)
assert output == [
Document(page_content="foo", metadata={"label": "label0", "page": 0})
]
def test_infinispanvs_with_id() -> None:
"""Test with ids"""
ids = ["id_" + str(i) for i in range(len(fake_texts))]
docsearch = _infinispanvs_from_texts(ids=ids)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
def test_infinispan_with_score() -> None:
"""Test end to end construction and search with scores and IDs."""
_infinispan_setup()
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _infinispanvs_from_texts(metadatas=metadatas)
output = docsearch.similarity_search_with_score("foo", k=3)
docs = [o[0] for o in output]
scores = [o[1] for o in output]
assert docs == [
Document(page_content="foo", metadata={"page": 0}),
Document(page_content="bar", metadata={"page": 1}),
Document(page_content="baz", metadata={"page": 2}),
]
assert scores[0] >= scores[1] >= scores[2]
def test_infinispan_add_texts() -> None:
"""Test end to end construction and MRR search."""
_infinispan_setup()
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _infinispanvs_from_texts(metadatas=metadatas)
docsearch.add_texts(texts, metadatas)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 6
def test_infinispan_no_clear_old() -> None:
"""Test end to end construction and MRR search."""
_infinispan_setup()
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _infinispanvs_from_texts(metadatas=metadatas)
del docsearch
docsearch = _infinispanvs_from_texts(metadatas=metadatas, clear_old=False)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 6
@pytest.mark.parametrize("autoconfig", [False, True])
class TestBasic:
def test_infinispan(self, autoconfig: bool) -> None:
"""Test end to end construction and search."""
if not autoconfig:
_infinispan_setup_noautoconf()
docsearch = _infinispanvs_from_texts(auto_config=autoconfig)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
def test_infinispan_with_metadata(self, autoconfig: bool) -> None:
"""Test with metadata"""
if not autoconfig:
_infinispan_setup_noautoconf()
meta = []
for _ in range(len(fake_texts)):
meta.append({"label": "test"})
docsearch = _infinispanvs_from_texts(metadatas=meta, auto_config=autoconfig)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"label": "test"})]
def test_infinispan_with_metadata_with_output_fields(
self, autoconfig: bool
) -> None:
"""Test with metadata"""
if not autoconfig:
_infinispan_setup_noautoconf()
metadatas = [
{"page": i, "label": "label" + str(i)} for i in range(len(fake_texts))
]
c = {"output_fields": ["label", "page", "text"]}
docsearch = _infinispanvs_from_texts(
metadatas=metadatas, configuration=c, auto_config=autoconfig
)
output = docsearch.similarity_search("foo", k=1)
assert output == [
Document(page_content="foo", metadata={"label": "label0", "page": 0})
]
def test_infinispanvs_with_id(self, autoconfig: bool) -> None:
"""Test with ids"""
ids = ["id_" + str(i) for i in range(len(fake_texts))]
docsearch = _infinispanvs_from_texts(ids=ids, auto_config=autoconfig)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
def test_infinispan_with_score(self, autoconfig: bool) -> None:
"""Test end to end construction and search with scores and IDs."""
if not autoconfig:
_infinispan_setup_noautoconf()
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _infinispanvs_from_texts(
metadatas=metadatas, auto_config=autoconfig
)
output = docsearch.similarity_search_with_score("foo", k=3)
docs = [o[0] for o in output]
scores = [o[1] for o in output]
assert docs == [
Document(page_content="foo", metadata={"page": 0}),
Document(page_content="bar", metadata={"page": 1}),
Document(page_content="baz", metadata={"page": 2}),
]
assert scores[0] >= scores[1] >= scores[2]
def test_infinispan_add_texts(self, autoconfig: bool) -> None:
"""Test end to end construction and MRR search."""
if not autoconfig:
_infinispan_setup_noautoconf()
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _infinispanvs_from_texts(
metadatas=metadatas, auto_config=autoconfig
)
docsearch.add_texts(texts, metadatas)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 6
def test_infinispan_no_clear_old(self, autoconfig: bool) -> None:
"""Test end to end construction and MRR search."""
if not autoconfig:
_infinispan_setup_noautoconf()
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _infinispanvs_from_texts(
metadatas=metadatas, auto_config=autoconfig
)
del docsearch
try:
docsearch = _infinispanvs_from_texts(
metadatas=metadatas, clear_old=False, auto_config=autoconfig
)
except AssertionError:
if autoconfig:
return
else:
raise
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 6

@ -13,6 +13,7 @@ EXPECT_ALL = [
"Aviary",
"AzureMLOnlineEndpoint",
"AzureOpenAI",
"BaichuanLLM",
"Banana",
"Baseten",
"Beam",
@ -44,9 +45,11 @@ EXPECT_ALL = [
"KoboldApiLLM",
"Konko",
"LlamaCpp",
"Llamafile",
"TextGen",
"ManifestWrapper",
"Minimax",
"Mlflow",
"MlflowAIGateway",
"Modal",
"MosaicML",
@ -77,6 +80,7 @@ EXPECT_ALL = [
"StochasticAI",
"TitanTakeoff",
"TitanTakeoffPro",
"Together",
"Tongyi",
"VertexAI",
"VertexAIModelGarden",

@ -100,6 +100,7 @@ def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None:
"prompt": "Test prompt",
"system": "Test system prompt",
"template": None,
"keep_alive": None,
}
assert stream is True
assert timeout == 300
@ -147,6 +148,7 @@ def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None:
"prompt": "Test prompt",
"system": None,
"template": None,
"keep_alive": None,
}
assert stream is True
assert timeout == 300
@ -178,6 +180,7 @@ def test_handle_kwargs_with_options(monkeypatch: MonkeyPatch) -> None:
"prompt": "Test prompt",
"system": None,
"template": None,
"keep_alive": None,
}
assert stream is True
assert timeout == 300

@ -55,6 +55,7 @@ def test_compatible_vectorstore_documentation() -> None:
"BigQueryVectorSearch",
"Cassandra",
"Chroma",
"CouchbaseVectorStore",
"DashVector",
"DatabricksVectorSearch",
"TiDBVectorStore",

@ -0,0 +1,33 @@
from langchain_core.documents import Document
from langchain_community.vectorstores.inmemory import InMemoryVectorStore
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
)
async def test_inmemory() -> None:
"""Test end to end construction and search."""
store = await InMemoryVectorStore.afrom_texts(
["foo", "bar", "baz"], ConsistentFakeEmbeddings()
)
output = await store.asimilarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
output = await store.asimilarity_search("bar", k=2)
assert output == [Document(page_content="bar"), Document(page_content="baz")]
output2 = await store.asimilarity_search_with_score("bar", k=2)
assert output2[0][1] > output2[1][1]
async def test_inmemory_mmr() -> None:
texts = ["foo", "foo", "fou", "foy"]
docsearch = await InMemoryVectorStore.afrom_texts(texts, ConsistentFakeEmbeddings())
# make sure we can k > docstore size
output = await docsearch.amax_marginal_relevance_search(
"foo", k=10, lambda_mult=0.1
)
assert len(output) == len(texts)
assert output[0] == Document(page_content="foo")
assert output[1] == Document(page_content="foy")

@ -85,6 +85,7 @@ _EXPECTED = [
"VectorStore",
"Yellowbrick",
"NeuralDBVectorStore",
"CouchbaseVectorStore",
]

@ -307,7 +307,46 @@ class ContextSet(RunnableSerializable):
class Context:
"""Context for a runnable."""
"""
Context for a runnable.
The `Context` class provides methods for creating context scopes,
getters, and setters within a runnable. It allows for managing
and accessing contextual information throughout the execution
of a program.
Example:
.. code-block:: python
from langchain_core.beta.runnables.context import Context
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.output_parsers.string import StrOutputParser
from tests.unit_tests.fake.llm import FakeListLLM
chain = (
Context.setter("input")
| {
"context": RunnablePassthrough()
| Context.setter("context"),
"question": RunnablePassthrough(),
}
| PromptTemplate.from_template("{context} {question}")
| FakeListLLM(responses=["hello"])
| StrOutputParser()
| {
"result": RunnablePassthrough(),
"context": Context.getter("context"),
"input": Context.getter("input"),
}
)
# Use the chain
output = chain.invoke("What's your name?")
print(output["result"]) # Output: "hello"
print(output["context"]) # Output: "What's your name?"
print(output["input"]) # Output: "What's your name?
"""
@staticmethod
def create_scope(scope: str, /) -> "PrefixContext":

@ -1183,6 +1183,7 @@ class CallbackManager(BaseCallbackManager):
self,
serialized: Dict[str, Any],
prompts: List[str],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> List[CallbackManagerForLLMRun]:
"""Run when LLM starts running.
@ -1197,8 +1198,9 @@ class CallbackManager(BaseCallbackManager):
prompt as an LLM run.
"""
managers = []
for prompt in prompts:
run_id_ = uuid.uuid4()
for i, prompt in enumerate(prompts):
# Can't have duplicate runs with the same run ID (if provided)
run_id_ = run_id if i == 0 and run_id is not None else uuid.uuid4()
handle_event(
self.handlers,
"on_llm_start",
@ -1231,6 +1233,7 @@ class CallbackManager(BaseCallbackManager):
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> List[CallbackManagerForLLMRun]:
"""Run when LLM starts running.
@ -1247,7 +1250,11 @@ class CallbackManager(BaseCallbackManager):
managers = []
for message_list in messages:
run_id_ = uuid.uuid4()
if run_id is not None:
run_id_ = run_id
run_id = None
else:
run_id_ = uuid.uuid4()
handle_event(
self.handlers,
"on_chat_model_start",
@ -1520,6 +1527,7 @@ class AsyncCallbackManager(BaseCallbackManager):
self,
serialized: Dict[str, Any],
prompts: List[str],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> List[AsyncCallbackManagerForLLMRun]:
"""Run when LLM starts running.
@ -1539,7 +1547,11 @@ class AsyncCallbackManager(BaseCallbackManager):
managers = []
for prompt in prompts:
run_id_ = uuid.uuid4()
if run_id is not None:
run_id_ = run_id
run_id = None
else:
run_id_ = uuid.uuid4()
tasks.append(
ahandle_event(
@ -1577,6 +1589,7 @@ class AsyncCallbackManager(BaseCallbackManager):
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> List[AsyncCallbackManagerForLLMRun]:
"""Run when LLM starts running.
@ -1595,7 +1608,11 @@ class AsyncCallbackManager(BaseCallbackManager):
managers = []
for message_list in messages:
run_id_ = uuid.uuid4()
if run_id is not None:
run_id_ = run_id
run_id = None
else:
run_id_ = uuid.uuid4()
tasks.append(
ahandle_event(

@ -0,0 +1,4 @@
from langchain_core.embeddings.embeddings import Embeddings
from langchain_core.embeddings.fake import DeterministicFakeEmbedding, FakeEmbeddings
__all__ = ["DeterministicFakeEmbedding", "Embeddings", "FakeEmbeddings"]

@ -0,0 +1,52 @@
import hashlib
from typing import List
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel
class FakeEmbeddings(Embeddings, BaseModel):
"""Fake embedding model."""
size: int
"""The size of the embedding vector."""
def _get_embedding(self) -> List[float]:
import numpy as np # type: ignore[import-not-found, import-untyped]
return list(np.random.normal(size=self.size))
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self._get_embedding() for _ in texts]
def embed_query(self, text: str) -> List[float]:
return self._get_embedding()
class DeterministicFakeEmbedding(Embeddings, BaseModel):
"""
Fake embedding model that always returns
the same embedding vector for the same text.
"""
size: int
"""The size of the embedding vector."""
def _get_embedding(self, seed: int) -> List[float]:
import numpy as np # type: ignore[import-not-found, import-untyped]
# set the seed for the random generator
np.random.seed(seed)
return list(np.random.normal(size=self.size))
def _get_seed(self, text: str) -> int:
"""
Get a seed for the random generator, using the hash of the text.
"""
return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self._get_embedding(seed=self._get_seed(_)) for _ in texts]
def embed_query(self, text: str) -> List[float]:
return self._get_embedding(seed=self._get_seed(text))

@ -30,6 +30,13 @@ from langchain_core.language_models.base import (
get_tokenizer,
)
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
from langchain_core.language_models.fake import FakeListLLM, FakeStreamingListLLM
from langchain_core.language_models.fake_chat_models import (
FakeListChatModel,
FakeMessagesListChatModel,
GenericFakeChatModel,
ParrotFakeChatModel,
)
from langchain_core.language_models.llms import LLM, BaseLLM
__all__ = [
@ -42,4 +49,10 @@ __all__ = [
"get_tokenizer",
"LanguageModelOutput",
"LanguageModelLike",
"FakeListLLM",
"FakeStreamingListLLM",
"FakeListChatModel",
"FakeMessagesListChatModel",
"GenericFakeChatModel",
"ParrotFakeChatModel",
]

@ -31,6 +31,7 @@ from langchain_core.runnables import Runnable, RunnableSerializable
from langchain_core.utils import get_pydantic_field_names
if TYPE_CHECKING:
from langchain_core.caches import BaseCache
from langchain_core.callbacks import Callbacks
from langchain_core.outputs import LLMResult
@ -78,8 +79,16 @@ class BaseLanguageModel(
All language model wrappers inherit from BaseLanguageModel.
"""
cache: Optional[bool] = None
"""Whether to cache the response."""
cache: Union[BaseCache, bool, None] = None
"""Whether to cache the response.
* If true, will use the global cache.
* If false, will not use a cache
* If None, will use the global cache if it's set, otherwise no cache.
* If instance of BaseCache, will use the provided cache.
Caching is not currently supported for streaming methods of models.
"""
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True)

@ -2,14 +2,13 @@ from __future__ import annotations
import asyncio
import inspect
import uuid
import warnings
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
@ -20,6 +19,7 @@ from typing import (
)
from langchain_core._api import deprecated
from langchain_core.caches import BaseCache
from langchain_core.callbacks import (
AsyncCallbackManager,
AsyncCallbackManagerForLLMRun,
@ -97,26 +97,6 @@ async def agenerate_from_stream(
)
def _as_async_iterator(sync_iterator: Callable) -> Callable:
"""Convert a sync iterator into an async iterator."""
async def _as_sync_iterator(*args: Any, **kwargs: Any) -> AsyncGenerator:
iterator = await run_in_executor(None, sync_iterator, *args, **kwargs)
done = object()
while True:
item = await run_in_executor(
None,
next,
iterator,
done, # type: ignore[call-arg, arg-type]
)
if item is done:
break
yield item # type: ignore[misc]
return _as_sync_iterator
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Base class for Chat models."""
@ -234,6 +214,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
invocation_params=params,
options=options,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
batch_size=1,
)
generation: Optional[ChatGenerationChunk] = None
@ -267,26 +248,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[BaseMessageChunk]:
if type(self)._astream is not BaseChatModel._astream:
# Then astream is implemented
_stream_implementation = self._astream
elif type(self)._stream is not BaseChatModel._stream:
# Then stream is implemented, so we can create an async iterator from it
# The typing is hard to type correctly with mypy here, so we cast
# and do a type ignore, this code is unit tested and should be fine.
_stream_implementation = cast( # type: ignore
Callable[
[
List[BaseMessage],
Optional[List[str]],
CallbackManagerForLLMRun,
Any,
],
AsyncIterator[ChatGenerationChunk],
],
_as_async_iterator(self._stream),
)
else: # No async or sync stream is implemented, so fall back to ainvoke
if (
type(self)._astream is BaseChatModel._astream
and type(self)._stream is BaseChatModel._stream
):
# No async or sync stream is implemented, so fall back to ainvoke
yield cast(
BaseMessageChunk,
await self.ainvoke(input, config=config, stop=stop, **kwargs),
@ -312,12 +278,17 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
invocation_params=params,
options=options,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
batch_size=1,
)
generation: Optional[ChatGenerationChunk] = None
try:
async for chunk in _stream_implementation(
messages, stop=stop, run_manager=run_manager, **kwargs
async for chunk in self._astream(
messages,
stop=stop,
run_manager=run_manager,
**kwargs,
):
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
yield chunk.message
@ -371,6 +342,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
**kwargs: Any,
) -> LLMResult:
"""Pass a sequence of prompts to the model and return model generations.
@ -415,6 +387,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
invocation_params=params,
options=options,
name=run_name,
run_id=run_id,
batch_size=len(messages),
)
results = []
@ -456,6 +429,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
**kwargs: Any,
) -> LLMResult:
"""Asynchronously pass a sequence of prompts to a model and return generations.
@ -502,6 +476,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
options=options,
name=run_name,
batch_size=len(messages),
run_id=run_id,
)
results = await asyncio.gather(
@ -589,7 +564,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
llm_cache = get_llm_cache()
if isinstance(self.cache, BaseCache):
llm_cache = self.cache
else:
llm_cache = get_llm_cache()
# We should check the cache unless it's explicitly set to False
# A None cache means we should use the default global cache
# if it's configured.
check_cache = self.cache or self.cache is None
if check_cache:
if llm_cache:
@ -611,10 +592,16 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
else:
result = self._generate(messages, stop=stop, **kwargs)
# Add response metadata to each generation
for generation in result.generations:
generation.message.response_metadata = _gen_info_and_msg_metadata(
generation
)
if len(result.generations) == 1 and result.llm_output is not None:
result.generations[0].message.response_metadata = {
**result.llm_output,
**result.generations[0].message.response_metadata,
}
if check_cache and llm_cache:
llm_cache.update(prompt, llm_string, result.generations)
return result
@ -626,7 +613,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
llm_cache = get_llm_cache()
if isinstance(self.cache, BaseCache):
llm_cache = self.cache
else:
llm_cache = get_llm_cache()
# We should check the cache unless it's explicitly set to False
# A None cache means we should use the default global cache
# if it's configured.
check_cache = self.cache or self.cache is None
if check_cache:
if llm_cache:
@ -647,10 +640,17 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
)
else:
result = await self._agenerate(messages, stop=stop, **kwargs)
# Add response metadata to each generation
for generation in result.generations:
generation.message.response_metadata = _gen_info_and_msg_metadata(
generation
)
if len(result.generations) == 1 and result.llm_output is not None:
result.generations[0].message.response_metadata = {
**result.llm_output,
**result.generations[0].message.response_metadata,
}
if check_cache and llm_cache:
await llm_cache.aupdate(prompt, llm_string, result.generations)
return result
@ -691,14 +691,32 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
) -> Iterator[ChatGenerationChunk]:
raise NotImplementedError()
def _astream(
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
raise NotImplementedError()
iterator = await run_in_executor(
None,
self._stream,
messages,
stop,
run_manager.get_sync() if run_manager else None,
**kwargs,
)
done = object()
while True:
item = await run_in_executor(
None,
next,
iterator,
done, # type: ignore[call-arg, arg-type]
)
if item is done:
break
yield item # type: ignore[misc]
@deprecated("0.1.7", alternative="invoke", removal="0.2.0")
def __call__(

@ -2,11 +2,12 @@ import asyncio
import time
from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional
from langchain_core.callbacks.manager import (
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LLM, LanguageModelInput
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.llms import LLM
from langchain_core.runnables import RunnableConfig

@ -1,21 +1,16 @@
"""Fake Chat Model wrapper for testing purposes."""
"""Fake ChatModel for testing purposes."""
import asyncio
import re
import time
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast
from langchain_core.callbacks.manager import (
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
)
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import run_in_executor
class FakeMessagesListChatModel(BaseChatModel):
@ -283,25 +278,6 @@ class GenericFakeChatModel(BaseChatModel):
)
yield chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
"""Stream the output of the model."""
result = await run_in_executor(
None,
self._stream,
messages,
stop=stop,
run_manager=run_manager.get_sync() if run_manager else None,
**kwargs,
)
for chunk in result:
yield chunk
@property
def _llm_type(self) -> str:
return "generic-fake-chat-model"

@ -7,12 +7,12 @@ import functools
import inspect
import json
import logging
import uuid
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Dict,
@ -38,6 +38,7 @@ from tenacity import (
)
from langchain_core._api import deprecated
from langchain_core.caches import BaseCache
from langchain_core.callbacks import (
AsyncCallbackManager,
AsyncCallbackManagerForLLMRun,
@ -114,26 +115,6 @@ def create_base_retry_decorator(
)
def _as_async_iterator(sync_iterator: Callable) -> Callable:
"""Convert a sync iterator into an async iterator."""
async def _as_sync_iterator(*args: Any, **kwargs: Any) -> AsyncGenerator:
iterator = await run_in_executor(None, sync_iterator, *args, **kwargs)
done = object()
while True:
item = await run_in_executor(
None,
next,
iterator,
done, # type: ignore[call-arg, arg-type]
)
if item is done:
break
yield item # type: ignore[misc]
return _as_sync_iterator
def get_prompts(
params: Dict[str, Any], prompts: List[str]
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
@ -271,6 +252,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs,
)
.generations[0][0]
@ -293,6 +275,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs,
)
return llm_result.generations[0][0].text
@ -423,6 +406,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
invocation_params=params,
options=options,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
batch_size=1,
)
generation: Optional[GenerationChunk] = None
@ -455,26 +439,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
if type(self)._astream is not BaseLLM._astream:
# model doesn't implement streaming, so use default implementation
_stream_implementation = self._astream
elif type(self)._stream is not BaseLLM._stream:
# Then stream is implemented, so we can create an async iterator from it
# The typing is hard to type correctly with mypy here, so we cast
# and do a type ignore, this code is unit tested and should be fine.
_stream_implementation = cast( # type: ignore
Callable[
[
str,
Optional[List[str]],
CallbackManagerForLLMRun,
Any,
],
AsyncIterator[GenerationChunk],
],
_as_async_iterator(self._stream),
)
else:
if (
type(self)._astream is BaseLLM._astream
and type(self)._stream is BaseLLM._stream
):
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
return
@ -499,12 +467,16 @@ class BaseLLM(BaseLanguageModel[str], ABC):
invocation_params=params,
options=options,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
batch_size=1,
)
generation: Optional[GenerationChunk] = None
try:
async for chunk in _stream_implementation(
prompt, stop=stop, run_manager=run_manager, **kwargs
async for chunk in self._astream(
prompt,
stop=stop,
run_manager=run_manager,
**kwargs,
):
yield chunk.text
if generation is None:
@ -559,14 +531,32 @@ class BaseLLM(BaseLanguageModel[str], ABC):
) -> Iterator[GenerationChunk]:
raise NotImplementedError()
def _astream(
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
raise NotImplementedError()
iterator = await run_in_executor(
None,
self._stream,
prompt,
stop,
run_manager.get_sync() if run_manager else None,
**kwargs,
)
done = object()
while True:
item = await run_in_executor(
None,
next,
iterator,
done, # type: ignore[call-arg, arg-type]
)
if item is done:
break
yield item # type: ignore[misc]
def generate_prompt(
self,
@ -632,6 +622,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
tags: Optional[Union[List[str], List[List[str]]]] = None,
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
run_name: Optional[Union[str, List[str]]] = None,
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None,
**kwargs: Any,
) -> LLMResult:
"""Pass a sequence of prompts to a model and return generations.
@ -717,7 +708,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
)
] * len(prompts)
run_name_list = [cast(Optional[str], run_name)] * len(prompts)
run_ids_list = self._get_run_ids_list(run_id, prompts)
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
@ -727,6 +718,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
missing_prompt_idxs,
missing_prompts,
) = get_prompts(params, prompts)
if isinstance(self.cache, BaseCache):
raise NotImplementedError(
"Local cache is not yet supported for " "LLMs (only chat models)"
)
disregard_cache = self.cache is not None and not self.cache
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
@ -744,9 +739,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
options=options,
name=run_name,
batch_size=len(prompts),
run_id=run_id_,
)[0]
for callback_manager, prompt, run_name in zip(
callback_managers, prompts, run_name_list
for callback_manager, prompt, run_name, run_id_ in zip(
callback_managers, prompts, run_name_list, run_ids_list
)
]
output = self._generate_helper(
@ -782,6 +778,21 @@ class BaseLLM(BaseLanguageModel[str], ABC):
generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
@staticmethod
def _get_run_ids_list(
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]], prompts: list
) -> list:
if run_id is None:
return [None] * len(prompts)
if isinstance(run_id, list):
if len(run_id) != len(prompts):
raise ValueError(
"Number of manually provided run_id's does not match batch length."
f" {len(run_id)} != {len(prompts)}"
)
return run_id
return [run_id] + [None] * (len(prompts) - 1)
async def _agenerate_helper(
self,
prompts: List[str],
@ -833,6 +844,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
tags: Optional[Union[List[str], List[List[str]]]] = None,
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
run_name: Optional[Union[str, List[str]]] = None,
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None,
**kwargs: Any,
) -> LLMResult:
"""Asynchronously pass a sequence of prompts to a model and return generations.
@ -909,7 +921,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
)
] * len(prompts)
run_name_list = [cast(Optional[str], run_name)] * len(prompts)
run_ids_list = self._get_run_ids_list(run_id, prompts)
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
@ -919,6 +931,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
missing_prompt_idxs,
missing_prompts,
) = await aget_prompts(params, prompts)
if isinstance(self.cache, BaseCache):
raise NotImplementedError(
"Local cache is not yet supported for " "LLMs (only chat models)"
)
disregard_cache = self.cache is not None and not self.cache
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
@ -937,9 +954,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
options=options,
name=run_name,
batch_size=len(prompts),
run_id=run_id_,
)
for callback_manager, prompt, run_name in zip(
callback_managers, prompts, run_name_list
for callback_manager, prompt, run_name, run_id_ in zip(
callback_managers, prompts, run_name_list, run_ids_list
)
]
)

@ -14,7 +14,6 @@
ChatPromptTemplate
""" # noqa: E501
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from langchain_core.messages.ai import AIMessage, AIMessageChunk
from langchain_core.messages.base import (
@ -29,223 +28,15 @@ from langchain_core.messages.function import FunctionMessage, FunctionMessageChu
from langchain_core.messages.human import HumanMessage, HumanMessageChunk
from langchain_core.messages.system import SystemMessage, SystemMessageChunk
from langchain_core.messages.tool import ToolMessage, ToolMessageChunk
AnyMessage = Union[
AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage, ToolMessage
]
def get_buffer_string(
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
) -> str:
"""Convert a sequence of Messages to strings and concatenate them into one string.
Args:
messages: Messages to be converted to strings.
human_prefix: The prefix to prepend to contents of HumanMessages.
ai_prefix: THe prefix to prepend to contents of AIMessages.
Returns:
A single string concatenation of all input messages.
Example:
.. code-block:: python
from langchain_core import AIMessage, HumanMessage
messages = [
HumanMessage(content="Hi, how are you?"),
AIMessage(content="Good, how are you?"),
]
get_buffer_string(messages)
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
"""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
role = human_prefix
elif isinstance(m, AIMessage):
role = ai_prefix
elif isinstance(m, SystemMessage):
role = "System"
elif isinstance(m, FunctionMessage):
role = "Function"
elif isinstance(m, ToolMessage):
role = "Tool"
elif isinstance(m, ChatMessage):
role = m.role
else:
raise ValueError(f"Got unsupported message type: {m}")
message = f"{role}: {m.content}"
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
message += f"{m.additional_kwargs['function_call']}"
string_messages.append(message)
return "\n".join(string_messages)
def _message_from_dict(message: dict) -> BaseMessage:
_type = message["type"]
if _type == "human":
return HumanMessage(**message["data"])
elif _type == "ai":
return AIMessage(**message["data"])
elif _type == "system":
return SystemMessage(**message["data"])
elif _type == "chat":
return ChatMessage(**message["data"])
elif _type == "function":
return FunctionMessage(**message["data"])
elif _type == "tool":
return ToolMessage(**message["data"])
elif _type == "AIMessageChunk":
return AIMessageChunk(**message["data"])
elif _type == "HumanMessageChunk":
return HumanMessageChunk(**message["data"])
elif _type == "FunctionMessageChunk":
return FunctionMessageChunk(**message["data"])
elif _type == "ToolMessageChunk":
return ToolMessageChunk(**message["data"])
elif _type == "SystemMessageChunk":
return SystemMessageChunk(**message["data"])
elif _type == "ChatMessageChunk":
return ChatMessageChunk(**message["data"])
else:
raise ValueError(f"Got unexpected message type: {_type}")
def messages_from_dict(messages: Sequence[dict]) -> List[BaseMessage]:
"""Convert a sequence of messages from dicts to Message objects.
Args:
messages: Sequence of messages (as dicts) to convert.
Returns:
List of messages (BaseMessages).
"""
return [_message_from_dict(m) for m in messages]
def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
"""Convert a message chunk to a message.
Args:
chunk: Message chunk to convert.
Returns:
Message.
"""
if not isinstance(chunk, BaseMessageChunk):
return chunk
# chunk classes always have the equivalent non-chunk class as their first parent
return chunk.__class__.__mro__[1](
**{k: v for k, v in chunk.__dict__.items() if k != "type"}
)
MessageLikeRepresentation = Union[BaseMessage, Tuple[str, str], str, Dict[str, Any]]
def _create_message_from_message_type(
message_type: str,
content: str,
name: Optional[str] = None,
tool_call_id: Optional[str] = None,
**additional_kwargs: Any,
) -> BaseMessage:
"""Create a message from a message type and content string.
Args:
message_type: str the type of the message (e.g., "human", "ai", etc.)
content: str the content string.
Returns:
a message of the appropriate type.
"""
kwargs: Dict[str, Any] = {}
if name is not None:
kwargs["name"] = name
if tool_call_id is not None:
kwargs["tool_call_id"] = tool_call_id
if additional_kwargs:
kwargs["additional_kwargs"] = additional_kwargs # type: ignore[assignment]
if message_type in ("human", "user"):
message: BaseMessage = HumanMessage(content=content, **kwargs)
elif message_type in ("ai", "assistant"):
message = AIMessage(content=content, **kwargs)
elif message_type == "system":
message = SystemMessage(content=content, **kwargs)
elif message_type == "function":
message = FunctionMessage(content=content, **kwargs)
elif message_type == "tool":
message = ToolMessage(content=content, **kwargs)
else:
raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human',"
f" 'user', 'ai', 'assistant', or 'system'."
)
return message
def _convert_to_message(
message: MessageLikeRepresentation,
) -> BaseMessage:
"""Instantiate a message from a variety of message formats.
The message format can be one of the following:
- BaseMessagePromptTemplate
- BaseMessage
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
- dict: a message dict with role and content keys
- string: shorthand for ("human", template); e.g., "{user_input}"
Args:
message: a representation of a message in one of the supported formats
Returns:
an instance of a message or a message template
"""
if isinstance(message, BaseMessage):
_message = message
elif isinstance(message, str):
_message = _create_message_from_message_type("human", message)
elif isinstance(message, tuple):
if len(message) != 2:
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
message_type_str, template = message
_message = _create_message_from_message_type(message_type_str, template)
elif isinstance(message, dict):
msg_kwargs = message.copy()
try:
msg_type = msg_kwargs.pop("role")
msg_content = msg_kwargs.pop("content")
except KeyError:
raise ValueError(
f"Message dict must contain 'role' and 'content' keys, got {message}"
)
_message = _create_message_from_message_type(
msg_type, msg_content, **msg_kwargs
)
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")
return _message
def convert_to_messages(
messages: Sequence[MessageLikeRepresentation],
) -> List[BaseMessage]:
"""Convert a sequence of messages to a list of messages.
Args:
messages: Sequence of messages to convert.
Returns:
List of messages (BaseMessages).
"""
return [_convert_to_message(m) for m in messages]
from langchain_core.messages.utils import (
AnyMessage,
MessageLikeRepresentation,
_message_from_dict,
convert_to_messages,
get_buffer_string,
message_chunk_to_message,
messages_from_dict,
)
__all__ = [
"AIMessage",
@ -259,15 +50,17 @@ __all__ = [
"FunctionMessageChunk",
"HumanMessage",
"HumanMessageChunk",
"MessageLikeRepresentation",
"SystemMessage",
"SystemMessageChunk",
"ToolMessage",
"ToolMessageChunk",
"_message_from_dict",
"convert_to_messages",
"get_buffer_string",
"merge_content",
"message_chunk_to_message",
"message_to_dict",
"messages_from_dict",
"messages_to_dict",
"message_to_dict",
"merge_content",
]

@ -0,0 +1,228 @@
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from langchain_core.messages.ai import AIMessage, AIMessageChunk
from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
)
from langchain_core.messages.chat import ChatMessage, ChatMessageChunk
from langchain_core.messages.function import FunctionMessage, FunctionMessageChunk
from langchain_core.messages.human import HumanMessage, HumanMessageChunk
from langchain_core.messages.system import SystemMessage, SystemMessageChunk
from langchain_core.messages.tool import ToolMessage, ToolMessageChunk
AnyMessage = Union[
AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage, ToolMessage
]
def get_buffer_string(
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
) -> str:
"""Convert a sequence of Messages to strings and concatenate them into one string.
Args:
messages: Messages to be converted to strings.
human_prefix: The prefix to prepend to contents of HumanMessages.
ai_prefix: THe prefix to prepend to contents of AIMessages.
Returns:
A single string concatenation of all input messages.
Example:
.. code-block:: python
from langchain_core import AIMessage, HumanMessage
messages = [
HumanMessage(content="Hi, how are you?"),
AIMessage(content="Good, how are you?"),
]
get_buffer_string(messages)
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
"""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
role = human_prefix
elif isinstance(m, AIMessage):
role = ai_prefix
elif isinstance(m, SystemMessage):
role = "System"
elif isinstance(m, FunctionMessage):
role = "Function"
elif isinstance(m, ToolMessage):
role = "Tool"
elif isinstance(m, ChatMessage):
role = m.role
else:
raise ValueError(f"Got unsupported message type: {m}")
message = f"{role}: {m.content}"
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
message += f"{m.additional_kwargs['function_call']}"
string_messages.append(message)
return "\n".join(string_messages)
def _message_from_dict(message: dict) -> BaseMessage:
_type = message["type"]
if _type == "human":
return HumanMessage(**message["data"])
elif _type == "ai":
return AIMessage(**message["data"])
elif _type == "system":
return SystemMessage(**message["data"])
elif _type == "chat":
return ChatMessage(**message["data"])
elif _type == "function":
return FunctionMessage(**message["data"])
elif _type == "tool":
return ToolMessage(**message["data"])
elif _type == "AIMessageChunk":
return AIMessageChunk(**message["data"])
elif _type == "HumanMessageChunk":
return HumanMessageChunk(**message["data"])
elif _type == "FunctionMessageChunk":
return FunctionMessageChunk(**message["data"])
elif _type == "ToolMessageChunk":
return ToolMessageChunk(**message["data"])
elif _type == "SystemMessageChunk":
return SystemMessageChunk(**message["data"])
elif _type == "ChatMessageChunk":
return ChatMessageChunk(**message["data"])
else:
raise ValueError(f"Got unexpected message type: {_type}")
def messages_from_dict(messages: Sequence[dict]) -> List[BaseMessage]:
"""Convert a sequence of messages from dicts to Message objects.
Args:
messages: Sequence of messages (as dicts) to convert.
Returns:
List of messages (BaseMessages).
"""
return [_message_from_dict(m) for m in messages]
def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
"""Convert a message chunk to a message.
Args:
chunk: Message chunk to convert.
Returns:
Message.
"""
if not isinstance(chunk, BaseMessageChunk):
return chunk
# chunk classes always have the equivalent non-chunk class as their first parent
return chunk.__class__.__mro__[1](
**{k: v for k, v in chunk.__dict__.items() if k != "type"}
)
MessageLikeRepresentation = Union[BaseMessage, Tuple[str, str], str, Dict[str, Any]]
def _create_message_from_message_type(
message_type: str,
content: str,
name: Optional[str] = None,
tool_call_id: Optional[str] = None,
**additional_kwargs: Any,
) -> BaseMessage:
"""Create a message from a message type and content string.
Args:
message_type: str the type of the message (e.g., "human", "ai", etc.)
content: str the content string.
Returns:
a message of the appropriate type.
"""
kwargs: Dict[str, Any] = {}
if name is not None:
kwargs["name"] = name
if tool_call_id is not None:
kwargs["tool_call_id"] = tool_call_id
if additional_kwargs:
kwargs["additional_kwargs"] = additional_kwargs # type: ignore[assignment]
if message_type in ("human", "user"):
message: BaseMessage = HumanMessage(content=content, **kwargs)
elif message_type in ("ai", "assistant"):
message = AIMessage(content=content, **kwargs)
elif message_type == "system":
message = SystemMessage(content=content, **kwargs)
elif message_type == "function":
message = FunctionMessage(content=content, **kwargs)
elif message_type == "tool":
message = ToolMessage(content=content, **kwargs)
else:
raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human',"
f" 'user', 'ai', 'assistant', or 'system'."
)
return message
def _convert_to_message(
message: MessageLikeRepresentation,
) -> BaseMessage:
"""Instantiate a message from a variety of message formats.
The message format can be one of the following:
- BaseMessagePromptTemplate
- BaseMessage
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
- dict: a message dict with role and content keys
- string: shorthand for ("human", template); e.g., "{user_input}"
Args:
message: a representation of a message in one of the supported formats
Returns:
an instance of a message or a message template
"""
if isinstance(message, BaseMessage):
_message = message
elif isinstance(message, str):
_message = _create_message_from_message_type("human", message)
elif isinstance(message, tuple):
if len(message) != 2:
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
message_type_str, template = message
_message = _create_message_from_message_type(message_type_str, template)
elif isinstance(message, dict):
msg_kwargs = message.copy()
try:
msg_type = msg_kwargs.pop("role")
msg_content = msg_kwargs.pop("content")
except KeyError:
raise ValueError(
f"Message dict must contain 'role' and 'content' keys, got {message}"
)
_message = _create_message_from_message_type(
msg_type, msg_content, **msg_kwargs
)
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")
return _message
def convert_to_messages(
messages: Sequence[MessageLikeRepresentation],
) -> List[BaseMessage]:
"""Convert a sequence of messages to a list of messages.
Args:
messages: Sequence of messages to convert.
Returns:
List of messages (BaseMessages).
"""
return [_convert_to_message(m) for m in messages]

@ -2,6 +2,7 @@ import re
import xml.etree.ElementTree as ET
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser
from langchain_core.runnables.utils import AddableDict
@ -44,13 +45,13 @@ class XMLOutputParser(BaseTransformOutputParser):
text = encoding_match.group(2)
text = text.strip()
if (text.startswith("<") or text.startswith("\n<")) and (
text.endswith(">") or text.endswith(">\n")
):
try:
root = ET.fromstring(text)
return self._root_to_dict(root)
else:
raise ValueError(f"Could not parse output: {text}")
except ET.ParseError as e:
msg = f"Failed to parse XML format from completion {text}. Got: {e}"
raise OutputParserException(msg, llm_output=text) from e
def _transform(
self, input: Iterator[Union[str, BaseMessage]]

@ -89,10 +89,15 @@ class BasePromptTemplate(
def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue:
if not isinstance(inner_input, dict):
raise TypeError(
f"Expected mapping type as input to {self.__class__.__name__}. "
f"Received {type(inner_input)}."
)
if len(self.input_variables) == 1:
var_name = self.input_variables[0]
inner_input = {var_name: inner_input}
else:
raise TypeError(
f"Expected mapping type as input to {self.__class__.__name__}. "
f"Received {type(inner_input)}."
)
missing = set(self.input_variables).difference(inner_input)
if missing:
raise KeyError(

@ -574,11 +574,51 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
("human", "{user_input}"),
])
messages = template.format_messages(
name="Bob",
user_input="What is your name?"
prompt_value = template.invoke(
{
"name": "Bob",
"user_input": "What is your name?"
}
)
"""
# Output:
# ChatPromptValue(
# messages=[
# SystemMessage(content='You are a helpful AI bot. Your name is Bob.'),
# HumanMessage(content='Hello, how are you doing?'),
# AIMessage(content="I'm doing well, thanks!"),
# HumanMessage(content='What is your name?')
# ]
#)
Single-variable template:
If your prompt has only a single input variable (i.e., 1 instance of "{variable_nams}"),
and you invoke the template with a non-dict object, the prompt template will
inject the provided argument into that variable location.
.. code-block:: python
from langchain_core.prompts import ChatPromptTemplate
template = ChatPromptTemplate.from_messages([
("system", "You are a helpful AI bot. Your name is Carl."),
("human", "{user_input}"),
])
prompt_value = template.invoke("Hello, there!")
# Equivalent to
# prompt_value = template.invoke({"user_input": "Hello, there!"})
# Output:
# ChatPromptValue(
# messages=[
# SystemMessage(content='You are a helpful AI bot. Your name is Carl.'),
# HumanMessage(content='Hello, there!'),
# ]
# )
""" # noqa: E501
input_variables: List[str]
"""List of input variables in template messages. Used for validation."""

@ -230,6 +230,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
dumpd(self),
query,
name=run_name,
run_id=kwargs.pop("run_id", None),
)
try:
_kwargs = kwargs if self._expects_other_args else {}
@ -286,6 +287,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
dumpd(self),
query,
name=run_name,
run_id=kwargs.pop("run_id", None),
)
try:
_kwargs = kwargs if self._expects_other_args else {}

@ -1344,6 +1344,35 @@ class Runnable(Generic[Input, Output], ABC):
) -> Runnable[Input, Output]:
"""Create a new Runnable that retries the original runnable on exceptions.
Example:
.. code-block:: python
from langchain_core.runnables import RunnableLambda
count = 0
def _lambda(x: int) -> None:
global count
count = count + 1
if x == 1:
raise ValueError("x is 1")
else:
pass
runnable = RunnableLambda(_lambda)
try:
runnable.with_retry(
stop_after_attempt=2,
retry_if_exception_type=(ValueError,),
).invoke(1)
except ValueError:
pass
assert (count == 2)
Args:
retry_if_exception_type: A tuple of exception types to retry on
wait_exponential_jitter: Whether to add jitter to the wait time
@ -1448,6 +1477,7 @@ class Runnable(Generic[Input, Output], ABC):
input,
run_type=run_type,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
@ -1495,6 +1525,7 @@ class Runnable(Generic[Input, Output], ABC):
input,
run_type=run_type,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
@ -1547,6 +1578,7 @@ class Runnable(Generic[Input, Output], ABC):
input,
run_type=run_type,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
for callback_manager, input, config in zip(
callback_managers, input, configs
@ -1619,6 +1651,7 @@ class Runnable(Generic[Input, Output], ABC):
input,
run_type=run_type,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
for callback_manager, input, config in zip(
callback_managers, input, configs
@ -1694,6 +1727,7 @@ class Runnable(Generic[Input, Output], ABC):
{"input": ""},
run_type=run_type,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
@ -1781,6 +1815,7 @@ class Runnable(Generic[Input, Output], ABC):
{"input": ""},
run_type=run_type,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
@ -2262,7 +2297,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") or self.get_name()
dumpd(self),
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# invoke all steps in sequence
@ -2296,7 +2334,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") or self.get_name()
dumpd(self),
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# invoke all steps in sequence
@ -2354,6 +2395,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
dumpd(self),
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
for cm, input, config in zip(callback_managers, inputs, configs)
]
@ -2478,6 +2520,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
dumpd(self),
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
for cm, input, config in zip(callback_managers, inputs, configs)
)
@ -2885,7 +2928,10 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") or self.get_name()
dumpd(self),
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# gather results from all steps
@ -2925,7 +2971,10 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") or self.get_name()
dumpd(self),
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# gather results from all steps

@ -183,6 +183,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
try:
@ -231,6 +232,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
try:
for idx, branch in enumerate(self.branches):
@ -282,6 +284,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
final_output: Optional[Output] = None
final_output_supported = True
@ -356,6 +359,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
final_output: Optional[Output] = None
final_output_supported = True

@ -1,6 +1,8 @@
from __future__ import annotations
import asyncio
import uuid
import warnings
from concurrent.futures import Executor, Future, ThreadPoolExecutor
from contextlib import contextmanager
from contextvars import ContextVar, copy_context
@ -95,6 +97,12 @@ class RunnableConfig(TypedDict, total=False):
configurable.
"""
run_id: Optional[uuid.UUID]
"""
Unique identifier for the tracer run for this call. If not provided, a new UUID
will be generated.
"""
var_child_runnable_config = ContextVar(
"child_runnable_config", default=RunnableConfig()
@ -116,6 +124,7 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
metadata={},
callbacks=None,
recursion_limit=25,
run_id=None,
)
if var_config := var_child_runnable_config.get():
empty.update(
@ -158,11 +167,21 @@ def get_config_list(
f"but got {len(config)} configs for {length} inputs"
)
return (
list(map(ensure_config, config))
if isinstance(config, list)
else [ensure_config(config) for _ in range(length)]
)
if isinstance(config, list):
return list(map(ensure_config, config))
if length > 1 and isinstance(config, dict) and config.get("run_id") is not None:
warnings.warn(
"Provided run_id be used only for the first element of the batch.",
category=RuntimeWarning,
)
subsequent = cast(
RunnableConfig, {k: v for k, v in config.items() if k != "run_id"}
)
return [
ensure_config(subsequent) if i else ensure_config(config)
for i in range(length)
]
return [ensure_config(config) for i in range(length)]
def patch_config(
@ -199,6 +218,8 @@ def patch_config(
config["callbacks"] = callbacks
if "run_name" in config:
del config["run_name"]
if "run_id" in config:
del config["run_id"]
if recursion_limit is not None:
config["recursion_limit"] = recursion_limit
if max_concurrency is not None:

@ -156,7 +156,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
first_error = None
last_error = None
@ -200,7 +203,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
first_error = None
@ -270,6 +276,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
dumpd(self),
input if isinstance(input, dict) else {"input": input},
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
for cm, input, config in zip(callback_managers, inputs, configs)
]
@ -362,6 +369,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
for cm, input, config in zip(callback_managers, inputs, configs)
)
@ -436,7 +444,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
first_error = None
last_error = None
@ -493,7 +504,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
first_error = None
last_error = None

@ -0,0 +1,15 @@
# from langchain_core.runnables.base import RunnableBinding
# class RunnableLearnable(RunnableBinding):
# def __init__(self, *args, **kwargs):
# super().__init__(*args, **kwargs)
# self.parameters = []
# def backward(self):
# for param in self.parameters:
# param.backward()
# def update(self, optimizer):
# for param in self.parameters:
# optimizer.update(param)

@ -610,8 +610,30 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
"""
Runnable that picks keys from Dict[str, Any] inputs.
"""Runnable that picks keys from Dict[str, Any] inputs.
RunnablePick class represents a runnable that selectively picks keys from a
dictionary input. It allows you to specify one or more keys to extract
from the input dictionary. It returns a new dictionary containing only
the selected keys.
Example :
.. code-block:: python
from langchain_core.runnables.passthrough import RunnablePick
input_data = {
'name': 'John',
'age': 30,
'city': 'New York',
'country': 'USA'
}
runnable = RunnablePick(keys=['name', 'age'])
output_data = runnable.invoke(input_data)
print(output_data) # Output: {'name': 'John', 'age': 30}
"""
keys: Union[str, List[str]]

@ -20,6 +20,7 @@ tool for the job.
from __future__ import annotations
import inspect
import uuid
import warnings
from abc import abstractmethod, ABC
from inspect import signature
@ -243,6 +244,7 @@ class ChildTool(BaseTool):
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs,
)
@ -259,6 +261,7 @@ class ChildTool(BaseTool):
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs,
)
@ -339,6 +342,7 @@ class ChildTool(BaseTool):
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
**kwargs: Any,
) -> Any:
"""Run the tool."""
@ -362,6 +366,7 @@ class ChildTool(BaseTool):
tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color,
name=run_name,
run_id=run_id,
# Inputs by definition should always be dicts.
# For now, it's unclear whether this assumption is ever violated,
# but if it is we will send a `None` value to the callback instead
@ -430,6 +435,7 @@ class ChildTool(BaseTool):
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
**kwargs: Any,
) -> Any:
"""Run the tool asynchronously."""
@ -453,6 +459,7 @@ class ChildTool(BaseTool):
color=start_color,
name=run_name,
inputs=tool_input,
run_id=run_id,
**kwargs,
)
try:

@ -1,4 +1,5 @@
"""Methods for creating function specs in the style of OpenAI Functions"""
from __future__ import annotations
import inspect
@ -51,13 +52,16 @@ class ToolDescription(TypedDict):
function: FunctionDescription
def _rm_titles(kv: dict) -> dict:
def _rm_titles(kv: dict, prev_key: str = "") -> dict:
new_kv = {}
for k, v in kv.items():
if k == "title":
continue
if isinstance(v, dict) and prev_key == "properties" and "title" in v.keys():
new_kv[k] = _rm_titles(v, k)
else:
continue
elif isinstance(v, dict):
new_kv[k] = _rm_titles(v)
new_kv[k] = _rm_titles(v, k)
else:
new_kv[k] = v
return new_kv

@ -165,8 +165,16 @@ class Tee(Generic[T]):
safetee = Tee
def batch_iterate(size: int, iterable: Iterable[T]) -> Iterator[List[T]]:
"""Utility batching function."""
def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[List[T]]:
"""Utility batching function.
Args:
size: The size of the batch. If None, returns a single batch.
iterable: The iterable to batch.
Returns:
An iterator over the batches.
"""
it = iter(iterable)
while True:
chunk = list(islice(it, size))

109
libs/core/poetry.lock generated

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
[[package]]
name = "annotated-types"
@ -679,13 +679,13 @@ testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs
[[package]]
name = "importlib-resources"
version = "6.1.3"
version = "6.3.1"
description = "Read resources from Python packages"
optional = false
python-versions = ">=3.8"
files = [
{file = "importlib_resources-6.1.3-py3-none-any.whl", hash = "sha256:4c0269e3580fe2634d364b39b38b961540a7738c02cb984e98add8b4221d793d"},
{file = "importlib_resources-6.1.3.tar.gz", hash = "sha256:56fb4525197b78544a3354ea27793952ab93f935bb4bf746b846bb1015020f2b"},
{file = "importlib_resources-6.3.1-py3-none-any.whl", hash = "sha256:4811639ca7fa830abdb8e9ca0a104dc6ad13de691d9fe0d3173a71304f068159"},
{file = "importlib_resources-6.3.1.tar.gz", hash = "sha256:29a3d16556e330c3c8fb8202118c5ff41241cc34cbfb25989bbad226d99b7995"},
]
[package.dependencies]
@ -851,18 +851,15 @@ i18n = ["Babel (>=2.7)"]
[[package]]
name = "json5"
version = "0.9.22"
version = "0.9.24"
description = "A Python implementation of the JSON5 data format."
optional = false
python-versions = ">=3.8"
files = [
{file = "json5-0.9.22-py3-none-any.whl", hash = "sha256:6621007c70897652f8b5d03885f732771c48d1925591ad989aa80c7e0e5ad32f"},
{file = "json5-0.9.22.tar.gz", hash = "sha256:b729bde7650b2196a35903a597d2b704b8fdf8648bfb67368cfb79f1174a17bd"},
{file = "json5-0.9.24-py3-none-any.whl", hash = "sha256:4ca101fd5c7cb47960c055ef8f4d0e31e15a7c6c48c3b6f1473fc83b6c462a13"},
{file = "json5-0.9.24.tar.gz", hash = "sha256:0c638399421da959a20952782800e5c1a78c14e08e1dc9738fa10d8ec14d58c8"},
]
[package.extras]
dev = ["hypothesis"]
[[package]]
name = "jsonpatch"
version = "1.33"
@ -1118,13 +1115,13 @@ test = ["jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-jupyter[server] (>
[[package]]
name = "jupyterlab"
version = "4.1.4"
version = "4.1.5"
description = "JupyterLab computational environment"
optional = false
python-versions = ">=3.8"
files = [
{file = "jupyterlab-4.1.4-py3-none-any.whl", hash = "sha256:f92c3f2b12b88efcf767205f49be9b2f86b85544f9c4f342bb5e9904a16cf931"},
{file = "jupyterlab-4.1.4.tar.gz", hash = "sha256:e03c82c124ad8a0892e498b9dde79c50868b2c267819aca3f55ce47c57ebeb1d"},
{file = "jupyterlab-4.1.5-py3-none-any.whl", hash = "sha256:3bc843382a25e1ab7bc31d9e39295a9f0463626692b7995597709c0ab236ab2c"},
{file = "jupyterlab-4.1.5.tar.gz", hash = "sha256:c9ad75290cb10bfaff3624bf3fbb852319b4cce4c456613f8ebbaa98d03524db"},
]
[package.dependencies]
@ -1219,13 +1216,13 @@ url = "../text-splitters"
[[package]]
name = "langsmith"
version = "0.1.23"
version = "0.1.27"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false
python-versions = ">=3.8.1,<4.0"
files = [
{file = "langsmith-0.1.23-py3-none-any.whl", hash = "sha256:69984268b9867cb31b875965b3f86b6f56ba17dd5454d487d3a1a999bdaeea69"},
{file = "langsmith-0.1.23.tar.gz", hash = "sha256:327c66ec0de8c1bc57bfa47bbc70a29ef749e97c3e5571b9baf754d1e0644220"},
{file = "langsmith-0.1.27-py3-none-any.whl", hash = "sha256:d223176952b1525c958189ab1b894f5bd9891ec9177222f7a978aeee4bf1cc95"},
{file = "langsmith-0.1.27.tar.gz", hash = "sha256:e0a339d976362051adf3fdbc43fcc7c00bb4615a401321ad7e556bd2dab556c0"},
]
[package.dependencies]
@ -1387,13 +1384,13 @@ files = [
[[package]]
name = "nbclient"
version = "0.9.1"
version = "0.10.0"
description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor."
optional = false
python-versions = ">=3.8.0"
files = [
{file = "nbclient-0.9.1-py3-none-any.whl", hash = "sha256:2c50a866e8dd6c5f655de47d2e252c82d2ebe978574e760ac229f5950593a434"},
{file = "nbclient-0.9.1.tar.gz", hash = "sha256:4f7b78c6c2a380e228f8a3bb469b847cb24e5b8ad6fda410691b5621e05ce5a2"},
{file = "nbclient-0.10.0-py3-none-any.whl", hash = "sha256:f13e3529332a1f1f81d82a53210322476a168bb7090a0289c795fe9cc11c9d3f"},
{file = "nbclient-0.10.0.tar.gz", hash = "sha256:4b3f1b7dba531e498449c4db4f53da339c91d449dc11e9af3a43b4eb5c5abb09"},
]
[package.dependencies]
@ -1447,13 +1444,13 @@ webpdf = ["playwright"]
[[package]]
name = "nbformat"
version = "5.10.2"
version = "5.10.3"
description = "The Jupyter Notebook format"
optional = false
python-versions = ">=3.8"
files = [
{file = "nbformat-5.10.2-py3-none-any.whl", hash = "sha256:7381189a0d537586b3f18bae5dbad347d7dd0a7cf0276b09cdcd5c24d38edd99"},
{file = "nbformat-5.10.2.tar.gz", hash = "sha256:c535b20a0d4310167bf4d12ad31eccfb0dc61e6392d6f8c570ab5b45a06a49a3"},
{file = "nbformat-5.10.3-py3-none-any.whl", hash = "sha256:d9476ca28676799af85385f409b49d95e199951477a159a576ef2a675151e5e8"},
{file = "nbformat-5.10.3.tar.gz", hash = "sha256:60ed5e910ef7c6264b87d644f276b1b49e24011930deef54605188ddeb211685"},
]
[package.dependencies]
@ -1479,13 +1476,13 @@ files = [
[[package]]
name = "notebook"
version = "7.1.1"
version = "7.1.2"
description = "Jupyter Notebook - A web-based notebook environment for interactive computing"
optional = false
python-versions = ">=3.8"
files = [
{file = "notebook-7.1.1-py3-none-any.whl", hash = "sha256:197d8e0595acabf4005851c8716e952a81b405f7aefb648067a761fbde267ce7"},
{file = "notebook-7.1.1.tar.gz", hash = "sha256:818e7420fa21f402e726afb9f02df7f3c10f294c02e383ed19852866c316108b"},
{file = "notebook-7.1.2-py3-none-any.whl", hash = "sha256:fc6c24b9aef18d0cd57157c9c47e95833b9b0bdc599652639acf0bdb61dc7d5f"},
{file = "notebook-7.1.2.tar.gz", hash = "sha256:efc2c80043909e0faa17fce9e9b37c059c03af0ec99a4d4db84cb21d9d2e936a"},
]
[package.dependencies]
@ -1517,6 +1514,43 @@ jupyter-server = ">=1.8,<3"
[package.extras]
test = ["pytest", "pytest-console-scripts", "pytest-jupyter", "pytest-tornasync"]
[[package]]
name = "numpy"
version = "1.24.4"
description = "Fundamental package for array computing in Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"},
{file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"},
{file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79fc682a374c4a8ed08b331bef9c5f582585d1048fa6d80bc6c35bc384eee9b4"},
{file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ffe43c74893dbf38c2b0a1f5428760a1a9c98285553c89e12d70a96a7f3a4d6"},
{file = "numpy-1.24.4-cp310-cp310-win32.whl", hash = "sha256:4c21decb6ea94057331e111a5bed9a79d335658c27ce2adb580fb4d54f2ad9bc"},
{file = "numpy-1.24.4-cp310-cp310-win_amd64.whl", hash = "sha256:b4bea75e47d9586d31e892a7401f76e909712a0fd510f58f5337bea9572c571e"},
{file = "numpy-1.24.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f136bab9c2cfd8da131132c2cf6cc27331dd6fae65f95f69dcd4ae3c3639c810"},
{file = "numpy-1.24.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e2926dac25b313635e4d6cf4dc4e51c8c0ebfed60b801c799ffc4c32bf3d1254"},
{file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:222e40d0e2548690405b0b3c7b21d1169117391c2e82c378467ef9ab4c8f0da7"},
{file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7215847ce88a85ce39baf9e89070cb860c98fdddacbaa6c0da3ffb31b3350bd5"},
{file = "numpy-1.24.4-cp311-cp311-win32.whl", hash = "sha256:4979217d7de511a8d57f4b4b5b2b965f707768440c17cb70fbf254c4b225238d"},
{file = "numpy-1.24.4-cp311-cp311-win_amd64.whl", hash = "sha256:b7b1fc9864d7d39e28f41d089bfd6353cb5f27ecd9905348c24187a768c79694"},
{file = "numpy-1.24.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1452241c290f3e2a312c137a9999cdbf63f78864d63c79039bda65ee86943f61"},
{file = "numpy-1.24.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:04640dab83f7c6c85abf9cd729c5b65f1ebd0ccf9de90b270cd61935eef0197f"},
{file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5425b114831d1e77e4b5d812b69d11d962e104095a5b9c3b641a218abcc050e"},
{file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd80e219fd4c71fc3699fc1dadac5dcf4fd882bfc6f7ec53d30fa197b8ee22dc"},
{file = "numpy-1.24.4-cp38-cp38-win32.whl", hash = "sha256:4602244f345453db537be5314d3983dbf5834a9701b7723ec28923e2889e0bb2"},
{file = "numpy-1.24.4-cp38-cp38-win_amd64.whl", hash = "sha256:692f2e0f55794943c5bfff12b3f56f99af76f902fc47487bdfe97856de51a706"},
{file = "numpy-1.24.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2541312fbf09977f3b3ad449c4e5f4bb55d0dbf79226d7724211acc905049400"},
{file = "numpy-1.24.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9667575fb6d13c95f1b36aca12c5ee3356bf001b714fc354eb5465ce1609e62f"},
{file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3a86ed21e4f87050382c7bc96571755193c4c1392490744ac73d660e8f564a9"},
{file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d11efb4dbecbdf22508d55e48d9c8384db795e1b7b51ea735289ff96613ff74d"},
{file = "numpy-1.24.4-cp39-cp39-win32.whl", hash = "sha256:6620c0acd41dbcb368610bb2f4d83145674040025e5536954782467100aa8835"},
{file = "numpy-1.24.4-cp39-cp39-win_amd64.whl", hash = "sha256:befe2bf740fd8373cf56149a5c23a0f601e82869598d41f8e188a0e9869926f8"},
{file = "numpy-1.24.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:31f13e25b4e304632a4619d0e0777662c2ffea99fcae2029556b17d8ff958aef"},
{file = "numpy-1.24.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95f7ac6540e95bc440ad77f56e520da5bf877f87dca58bd095288dce8940532a"},
{file = "numpy-1.24.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e98f220aa76ca2a977fe435f5b04d7b3470c0a2e6312907b37ba6068f26787f2"},
{file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"},
]
[[package]]
name = "orjson"
version = "3.9.15"
@ -2111,7 +2145,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -2295,13 +2328,13 @@ test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"]
[[package]]
name = "referencing"
version = "0.33.0"
version = "0.34.0"
description = "JSON Referencing + Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "referencing-0.33.0-py3-none-any.whl", hash = "sha256:39240f2ecc770258f28b642dd47fd74bc8b02484de54e1882b74b35ebd779bd5"},
{file = "referencing-0.33.0.tar.gz", hash = "sha256:c775fedf74bc0f9189c2a3be1c12fd03e8c23f4d371dce795df44e06c5b412f7"},
{file = "referencing-0.34.0-py3-none-any.whl", hash = "sha256:d53ae300ceddd3169f1ffa9caf2cb7b769e92657e4fafb23d34b93679116dfd4"},
{file = "referencing-0.34.0.tar.gz", hash = "sha256:5773bd84ef41799a5a8ca72dc34590c041eb01bf9aa02632b4a973fb0181a844"},
]
[package.dependencies]
@ -2731,13 +2764,13 @@ files = [
[[package]]
name = "types-python-dateutil"
version = "2.8.19.20240311"
version = "2.9.0.20240316"
description = "Typing stubs for python-dateutil"
optional = false
python-versions = ">=3.8"
files = [
{file = "types-python-dateutil-2.8.19.20240311.tar.gz", hash = "sha256:51178227bbd4cbec35dc9adffbf59d832f20e09842d7dcb8c73b169b8780b7cb"},
{file = "types_python_dateutil-2.8.19.20240311-py3-none-any.whl", hash = "sha256:ef813da0809aca76472ca88807addbeea98b19339aebe56159ae2f4b4f70857a"},
{file = "types-python-dateutil-2.9.0.20240316.tar.gz", hash = "sha256:5d2f2e240b86905e40944dd787db6da9263f0deabef1076ddaed797351ec0202"},
{file = "types_python_dateutil-2.9.0.20240316-py3-none-any.whl", hash = "sha256:6b8cb66d960771ce5ff974e9dd45e38facb81718cc1e208b10b1baccbfdbee3b"},
]
[[package]]
@ -2914,18 +2947,18 @@ files = [
[[package]]
name = "zipp"
version = "3.17.0"
version = "3.18.1"
description = "Backport of pathlib-compatible object wrapper for zip files"
optional = false
python-versions = ">=3.8"
files = [
{file = "zipp-3.17.0-py3-none-any.whl", hash = "sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31"},
{file = "zipp-3.17.0.tar.gz", hash = "sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0"},
{file = "zipp-3.18.1-py3-none-any.whl", hash = "sha256:206f5a15f2af3dbaee80769fb7dc6f249695e940acca08dfb2a4769fe61e538b"},
{file = "zipp-3.18.1.tar.gz", hash = "sha256:2884ed22e7d8961de1c9a05142eb69a247f120291bc0206a00a7642f09b5b715"},
]
[package.extras]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"]
[extras]
extended-testing = ["jinja2"]
@ -2933,4 +2966,4 @@ extended-testing = ["jinja2"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "9d6e9c9613b31dbbe35772bf8d8d5aaba637228de7abbf4a7b271971c2a81ba9"
content-hash = "ca611429e3dd84ce6dac7ef69d7d9b4da78bf467356946e37016b821e5fe752e"

@ -60,6 +60,7 @@ pytest-asyncio = "^0.21.1"
grandalf = "^0.8"
pytest-profiling = "^1.7.0"
responses = "^0.25.0"
numpy = "^1.24.0"
[tool.poetry.group.test_integration]

@ -0,0 +1,16 @@
from langchain_core.embeddings import DeterministicFakeEmbedding
def test_deterministic_fake_embeddings() -> None:
"""
Test that the deterministic fake embeddings return the same
embedding vector for the same text.
"""
fake = DeterministicFakeEmbedding(size=10)
text = "Hello world!"
assert fake.embed_query(text) == fake.embed_query(text)
assert fake.embed_query(text) != fake.embed_query("Goodbye world!")
assert fake.embed_documents([text, text]) == fake.embed_documents([text, text])
assert fake.embed_documents([text, text]) != fake.embed_documents(
[text, "Goodbye world!"]
)

@ -4,10 +4,10 @@ from typing import Any, Dict, List, Optional, Union
from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.fake.chat_model import GenericFakeChatModel, ParrotFakeChatModel
def test_generic_fake_chat_model_invoke() -> None:

@ -5,7 +5,7 @@ from typing import Any, AsyncIterator, Iterator, List, Optional
import pytest
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel
from langchain_core.language_models import BaseChatModel, FakeListChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
@ -21,7 +21,6 @@ from tests.unit_tests.fake.callbacks import (
FakeAsyncCallbackHandler,
FakeCallbackHandler,
)
from tests.unit_tests.fake.chat_model import FakeListChatModel
@pytest.fixture

@ -0,0 +1,268 @@
"""Module tests interaction of chat model with caching abstraction.."""
from typing import Any, Dict, Optional, Tuple
import pytest
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.globals import set_llm_cache
from langchain_core.language_models.fake_chat_models import (
FakeListChatModel,
GenericFakeChatModel,
)
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration
class InMemoryCache(BaseCache):
"""In-memory cache used for testing purposes."""
def __init__(self) -> None:
"""Initialize with empty cache."""
self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {}
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
return self._cache.get((prompt, llm_string), None)
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
self._cache[(prompt, llm_string)] = return_val
def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
self._cache = {}
def test_local_cache_sync() -> None:
"""Test that the local cache is being populated but not the global one."""
global_cache = InMemoryCache()
local_cache = InMemoryCache()
try:
set_llm_cache(global_cache)
chat_model = FakeListChatModel(
cache=local_cache, responses=["hello", "goodbye"]
)
assert chat_model.invoke("How are you?").content == "hello"
# If the cache works we should get the same response since
# the prompt is the same
assert chat_model.invoke("How are you?").content == "hello"
# The global cache should be empty
assert global_cache._cache == {}
# The local cache should be populated
assert len(local_cache._cache) == 1
llm_result = list(local_cache._cache.values())
chat_generation = llm_result[0][0]
assert isinstance(chat_generation, ChatGeneration)
assert chat_generation.message.content == "hello"
# Verify that another prompt will trigger the call to the model
assert chat_model.invoke("meow?").content == "goodbye"
# The global cache should be empty
assert global_cache._cache == {}
# The local cache should be populated
assert len(local_cache._cache) == 2
finally:
set_llm_cache(None)
async def test_local_cache_async() -> None:
# Use MockCache as the cache
global_cache = InMemoryCache()
local_cache = InMemoryCache()
try:
set_llm_cache(global_cache)
chat_model = FakeListChatModel(
cache=local_cache, responses=["hello", "goodbye"]
)
assert (await chat_model.ainvoke("How are you?")).content == "hello"
# If the cache works we should get the same response since
# the prompt is the same
assert (await chat_model.ainvoke("How are you?")).content == "hello"
# The global cache should be empty
assert global_cache._cache == {}
# The local cache should be populated
assert len(local_cache._cache) == 1
llm_result = list(local_cache._cache.values())
chat_generation = llm_result[0][0]
assert isinstance(chat_generation, ChatGeneration)
assert chat_generation.message.content == "hello"
# Verify that another prompt will trigger the call to the model
assert chat_model.invoke("meow?").content == "goodbye"
# The global cache should be empty
assert global_cache._cache == {}
# The local cache should be populated
assert len(local_cache._cache) == 2
finally:
set_llm_cache(None)
def test_global_cache_sync() -> None:
"""Test that the global cache gets populated when cache = True."""
global_cache = InMemoryCache()
try:
set_llm_cache(global_cache)
chat_model = FakeListChatModel(
cache=True, responses=["hello", "goodbye", "meow", "woof"]
)
assert (chat_model.invoke("How are you?")).content == "hello"
# If the cache works we should get the same response since
# the prompt is the same
assert (chat_model.invoke("How are you?")).content == "hello"
# The global cache should be populated
assert len(global_cache._cache) == 1
llm_result = list(global_cache._cache.values())
chat_generation = llm_result[0][0]
assert isinstance(chat_generation, ChatGeneration)
assert chat_generation.message.content == "hello"
# Verify that another prompt will trigger the call to the model
assert chat_model.invoke("nice").content == "goodbye"
# The local cache should be populated
assert len(global_cache._cache) == 2
finally:
set_llm_cache(None)
async def test_global_cache_async() -> None:
"""Test that the global cache gets populated when cache = True."""
global_cache = InMemoryCache()
try:
set_llm_cache(global_cache)
chat_model = FakeListChatModel(
cache=True, responses=["hello", "goodbye", "meow", "woof"]
)
assert (await chat_model.ainvoke("How are you?")).content == "hello"
# If the cache works we should get the same response since
# the prompt is the same
assert (await chat_model.ainvoke("How are you?")).content == "hello"
# The global cache should be populated
assert len(global_cache._cache) == 1
llm_result = list(global_cache._cache.values())
chat_generation = llm_result[0][0]
assert isinstance(chat_generation, ChatGeneration)
assert chat_generation.message.content == "hello"
# Verify that another prompt will trigger the call to the model
assert chat_model.invoke("nice").content == "goodbye"
# The local cache should be populated
assert len(global_cache._cache) == 2
finally:
set_llm_cache(None)
def test_no_cache_sync() -> None:
global_cache = InMemoryCache()
try:
set_llm_cache(global_cache)
chat_model = FakeListChatModel(
cache=False, responses=["hello", "goodbye"]
) # Set cache=False
assert (chat_model.invoke("How are you?")).content == "hello"
# The global cache should not be populated since cache=False
# so we should get the second response
assert (chat_model.invoke("How are you?")).content == "goodbye"
# The global cache should not be populated since cache=False
assert len(global_cache._cache) == 0
finally:
set_llm_cache(None)
async def test_no_cache_async() -> None:
global_cache = InMemoryCache()
try:
set_llm_cache(global_cache)
chat_model = FakeListChatModel(
cache=False, responses=["hello", "goodbye"]
) # Set cache=False
assert (await chat_model.ainvoke("How are you?")).content == "hello"
# The global cache should not be populated since cache=False
# so we should get the second response
assert (await chat_model.ainvoke("How are you?")).content == "goodbye"
# The global cache should not be populated since cache=False
assert len(global_cache._cache) == 0
finally:
set_llm_cache(None)
async def test_global_cache_abatch() -> None:
global_cache = InMemoryCache()
try:
set_llm_cache(global_cache)
chat_model = FakeListChatModel(
cache=True, responses=["hello", "goodbye", "meow", "woof"]
)
results = await chat_model.abatch(["first prompt", "second prompt"])
assert results[0].content == "hello"
assert results[1].content == "goodbye"
# Now try with the same prompt
results = await chat_model.abatch(["first prompt", "first prompt"])
assert results[0].content == "hello"
assert results[1].content == "hello"
## RACE CONDITION -- note behavior is different from sync
# Now, reset cache and test the race condition
# For now we just hard-code the result, if this changes
# we can investigate further
global_cache = InMemoryCache()
set_llm_cache(global_cache)
assert global_cache._cache == {}
results = await chat_model.abatch(["prompt", "prompt"])
# suspecting that tasks will be scheduled and executed in order
# if this ever fails, we can relax to a set comparison
# Cache misses likely guaranteed?
assert results[0].content == "meow"
assert results[1].content == "woof"
finally:
set_llm_cache(None)
def test_global_cache_batch() -> None:
global_cache = InMemoryCache()
try:
set_llm_cache(global_cache)
chat_model = FakeListChatModel(
cache=True, responses=["hello", "goodbye", "meow", "woof"]
)
results = chat_model.batch(["first prompt", "second prompt"])
# These may be in any order
assert {results[0].content, results[1].content} == {"hello", "goodbye"}
# Now try with the same prompt
results = chat_model.batch(["first prompt", "first prompt"])
# These could be either "hello" or "goodbye" and should be identical
assert results[0].content == results[1].content
assert {results[0].content, results[1].content}.issubset({"hello", "goodbye"})
## RACE CONDITION -- note behavior is different from async
# Now, reset cache and test the race condition
# For now we just hard-code the result, if this changes
# we can investigate further
global_cache = InMemoryCache()
set_llm_cache(global_cache)
assert global_cache._cache == {}
results = chat_model.batch(
[
"prompt",
"prompt",
]
)
assert {results[0].content, results[1].content} == {"meow"}
finally:
set_llm_cache(None)
@pytest.mark.xfail(reason="Abstraction does not support caching for streaming yet.")
def test_global_cache_stream() -> None:
"""Test streaming."""
global_cache = InMemoryCache()
try:
set_llm_cache(global_cache)
messages = [
AIMessage(content="hello world"),
AIMessage(content="goodbye world"),
]
model = GenericFakeChatModel(messages=iter(messages), cache=True)
chunks = [chunk for chunk in model.stream("some input")]
assert len(chunks) == 3
# Assert that streaming information gets cached
assert global_cache._cache != {}
finally:
set_llm_cache(None)

@ -6,7 +6,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.language_models import BaseLLM, FakeListLLM, FakeStreamingListLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.callbacks import (
@ -14,7 +14,6 @@ from tests.unit_tests.fake.callbacks import (
FakeAsyncCallbackHandler,
FakeCallbackHandler,
)
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM
def test_batch() -> None:

@ -11,6 +11,12 @@ EXPECTED_ALL = [
"LanguageModelLike",
"get_tokenizer",
"LanguageModelLike",
"FakeMessagesListChatModel",
"FakeListChatModel",
"GenericFakeChatModel",
"FakeStreamingListLLM",
"FakeListLLM",
"ParrotFakeChatModel",
]

@ -1,6 +1,8 @@
from langchain_core.messages import __all__
EXPECTED_ALL = [
"MessageLikeRepresentation",
"_message_from_dict",
"AIMessage",
"AIMessageChunk",
"AnyMessage",
@ -18,11 +20,11 @@ EXPECTED_ALL = [
"ToolMessageChunk",
"convert_to_messages",
"get_buffer_string",
"merge_content",
"message_chunk_to_message",
"message_to_dict",
"messages_from_dict",
"messages_to_dict",
"message_to_dict",
"merge_content",
]

@ -2,13 +2,13 @@
from typing import List
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import GenericFakeChatModel
from langchain_core.messages import AIMessage
from langchain_core.output_parsers import (
BaseGenerationOutputParser,
BaseTransformOutputParser,
)
from langchain_core.outputs import ChatGeneration, Generation
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
def test_base_generation_parser() -> None:

@ -1,6 +1,7 @@
"""Test XMLOutputParser"""
import pytest
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.xml import XMLOutputParser
DEF_RESULT_ENCODING = """<?xml version="1.0" encoding="UTF-8"?>
@ -59,6 +60,6 @@ def test_xml_output_parser_fail(result: str) -> None:
xml_parser = XMLOutputParser()
with pytest.raises(ValueError) as e:
with pytest.raises(OutputParserException) as e:
xml_parser.parse(result)
assert "Could not parse output" in str(e)
assert "Failed to parse" in str(e)

@ -533,3 +533,16 @@ def test_chat_prompt_message_placeholder_partial() -> None:
assert prompt.format_messages() == []
prompt = prompt.partial(history=[("system", "foo")])
assert prompt.format_messages() == [SystemMessage(content="foo")]
def test_messages_prompt_accepts_list() -> None:
prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")])
value = prompt.invoke([("user", "Hi there")]) # type: ignore
assert value.to_messages() == [HumanMessage(content="Hi there")]
# Assert still raises a nice error
prompt = ChatPromptTemplate.from_messages(
[("system", "You are a {foo}"), MessagesPlaceholder("history")]
)
with pytest.raises(TypeError):
prompt.invoke([("user", "Hi there")]) # type: ignore

@ -2,12 +2,12 @@ from functools import partial
from inspect import isclass
from typing import Any, Dict, Type, Union, cast
from langchain_core.language_models import FakeListChatModel
from langchain_core.load.dump import dumps
from langchain_core.load.load import loads
from langchain_core.prompts.structured import StructuredPrompt
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableLambda
from tests.unit_tests.fake.chat_model import FakeListChatModel
def _fake_runnable(

@ -606,10 +606,9 @@
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"langchain_core",
"language_models",
"fake",
"llm",
"FakeListLLM"
],
"repr": "FakeListLLM(responses=['foo'], i=1)",
@ -1083,10 +1082,9 @@
"type": "runnable",
"data": {
"id": [
"tests",
"unit_tests",
"langchain_core",
"language_models",
"fake",
"llm",
"FakeListLLM"
],
"name": "FakeListLLM"
@ -1150,10 +1148,9 @@
"type": "runnable",
"data": {
"id": [
"tests",
"unit_tests",
"langchain_core",
"language_models",
"fake",
"llm",
"FakeListLLM"
],
"name": "FakeListLLM"
@ -1694,10 +1691,9 @@
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"langchain_core",
"language_models",
"fake",
"llm",
"FakeListLLM"
],
"repr": "FakeListLLM(responses=['bar'])",
@ -2171,10 +2167,9 @@
"type": "runnable",
"data": {
"id": [
"tests",
"unit_tests",
"langchain_core",
"language_models",
"fake",
"llm",
"FakeListLLM"
],
"name": "FakeListLLM"
@ -2238,10 +2233,9 @@
"type": "runnable",
"data": {
"id": [
"tests",
"unit_tests",
"langchain_core",
"language_models",
"fake",
"llm",
"FakeListLLM"
],
"name": "FakeListLLM"
@ -2713,10 +2707,9 @@
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"langchain_core",
"language_models",
"fake",
"llm",
"FakeListLLM"
],
"repr": "FakeListLLM(responses=['foo'], i=1)",
@ -3190,10 +3183,9 @@
"type": "runnable",
"data": {
"id": [
"tests",
"unit_tests",
"langchain_core",
"language_models",
"fake",
"llm",
"FakeListLLM"
],
"name": "FakeListLLM"
@ -3225,10 +3217,9 @@
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"langchain_core",
"language_models",
"fake",
"llm",
"FakeListLLM"
],
"repr": "FakeListLLM(responses=['bar'])",
@ -3702,10 +3693,9 @@
"type": "runnable",
"data": {
"id": [
"tests",
"unit_tests",
"langchain_core",
"language_models",
"fake",
"llm",
"FakeListLLM"
],
"name": "FakeListLLM"
@ -4263,10 +4253,9 @@
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"langchain_core",
"language_models",
"fake",
"llm",
"FakeListLLM"
],
"repr": "FakeListLLM(responses=['foo'], i=1)",
@ -4740,10 +4729,9 @@
"type": "runnable",
"data": {
"id": [
"tests",
"unit_tests",
"langchain_core",
"language_models",
"fake",
"llm",
"FakeListLLM"
],
"name": "FakeListLLM"
@ -4775,10 +4763,9 @@
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"langchain_core",
"language_models",
"fake",
"llm",
"FakeListLLM"
],
"repr": "FakeListLLM(responses=['baz'], i=1)",
@ -5252,10 +5239,9 @@
"type": "runnable",
"data": {
"id": [
"tests",
"unit_tests",
"langchain_core",
"language_models",
"fake",
"llm",
"FakeListLLM"
],
"name": "FakeListLLM"
@ -5286,10 +5272,9 @@
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"langchain_core",
"language_models",
"fake",
"llm",
"FakeListLLM"
],
"repr": "FakeListLLM(responses=['bar'])",
@ -5763,10 +5748,9 @@
"type": "runnable",
"data": {
"id": [
"tests",
"unit_tests",
"langchain_core",
"language_models",
"fake",
"llm",
"FakeListLLM"
],
"name": "FakeListLLM"

File diff suppressed because one or more lines are too long

@ -3,13 +3,13 @@ from typing import Any, Callable, List, NamedTuple, Union
import pytest
from langchain_core.beta.runnables.context import Context
from langchain_core.language_models import FakeListLLM, FakeStreamingListLLM
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.prompt_values import StringPromptValue
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.runnables.base import Runnable, RunnableLambda
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.utils import aadd, add
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM
class _TestCase(NamedTuple):

@ -4,6 +4,7 @@ from typing import Any, AsyncIterator, Iterator
import pytest
from syrupy import SnapshotAssertion
from langchain_core.language_models import FakeListLLM
from langchain_core.load import dumps
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import (
@ -14,7 +15,6 @@ from langchain_core.runnables import (
RunnablePassthrough,
RunnableWithFallbacks,
)
from tests.unit_tests.fake.llm import FakeListLLM
@pytest.fixture()

@ -1,11 +1,11 @@
from syrupy import SnapshotAssertion
from langchain_core.language_models import FakeListLLM
from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.output_parsers.xml import XMLOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.runnables.base import Runnable
from tests.unit_tests.fake.llm import FakeListLLM
def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None:
@ -54,7 +54,7 @@ def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
"id": 2,
"type": "runnable",
"data": {
"id": ["tests", "unit_tests", "fake", "llm", "FakeListLLM"],
"id": ["langchain_core", "language_models", "fake", "FakeListLLM"],
"name": "FakeListLLM",
},
},
@ -136,7 +136,7 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
"id": 2,
"type": "runnable",
"data": {
"id": ["tests", "unit_tests", "fake", "llm", "FakeListLLM"],
"id": ["langchain_core", "language_models", "fake", "FakeListLLM"],
"name": "FakeListLLM",
},
},

@ -1,4 +1,5 @@
import sys
import uuid
from functools import partial
from operator import itemgetter
from typing import (
@ -28,6 +29,11 @@ from langchain_core.callbacks.manager import (
trace_as_chain_group,
)
from langchain_core.documents import Document
from langchain_core.language_models import (
FakeListChatModel,
FakeListLLM,
FakeStreamingListLLM,
)
from langchain_core.load import dumpd, dumps
from langchain_core.messages import (
AIMessage,
@ -80,8 +86,6 @@ from langchain_core.tracers import (
RunLogPatch,
)
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.chat_model import FakeListChatModel
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM
class FakeTracer(BaseTracer):
@ -133,6 +137,22 @@ class FakeTracer(BaseTracer):
self.runs.append(self._copy_run(run))
def flattened_runs(self) -> List[Run]:
q = [] + self.runs
result = []
while q:
parent = q.pop()
result.append(parent)
if parent.child_runs:
q.extend(parent.child_runs)
return result
@property
def run_ids(self) -> List[Optional[uuid.UUID]]:
runs = self.flattened_runs()
uuids_map = {v: k for k, v in self.uuids_map.items()}
return [uuids_map.get(r.id) for r in runs]
class FakeRunnable(Runnable[str, int]):
def invoke(
@ -1364,6 +1384,7 @@ async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None:
recursion_limit=25,
configurable={"hello": "there"},
metadata={"hello": "there", "bye": "now"},
run_id=None,
),
)
spy.reset_mock()
@ -1505,6 +1526,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
tags=["c"],
callbacks=None,
recursion_limit=5,
run_id=None,
),
),
mocker.call(
@ -1514,6 +1536,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
tags=["c"],
callbacks=None,
recursion_limit=5,
run_id=None,
),
),
]
@ -1539,6 +1562,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
tags=["c"],
callbacks=None,
recursion_limit=5,
run_id=None,
),
)
second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld")
@ -1549,6 +1573,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
tags=["c"],
callbacks=None,
recursion_limit=5,
run_id=None,
),
)
@ -1617,6 +1642,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
tags=[],
callbacks=None,
recursion_limit=25,
run_id=None,
),
),
mocker.call(
@ -1626,6 +1652,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
tags=[],
callbacks=None,
recursion_limit=25,
run_id=None,
),
),
]
@ -4819,27 +4846,45 @@ async def test_runnable_gen_context_config() -> None:
}
tracer = FakeTracer()
assert runnable.invoke(None, {"callbacks": [tracer]}) == 6
run_id = uuid.uuid4()
assert runnable.invoke(None, {"callbacks": [tracer], "run_id": run_id}) == 6
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
run_ids = tracer.run_ids
assert run_id in run_ids
assert len(run_ids) == len(set(run_ids))
tracer.runs.clear()
assert list(runnable.stream(None)) == [1, 2, 3]
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
tracer = FakeTracer()
assert list(runnable.stream(None, {"callbacks": [tracer]})) == [1, 2, 3]
run_id = uuid.uuid4()
assert list(runnable.stream(None, {"callbacks": [tracer], "run_id": run_id})) == [
1,
2,
3,
]
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
run_ids = tracer.run_ids
assert run_id in run_ids
assert len(run_ids) == len(set(run_ids))
tracer.runs.clear()
tracer = FakeTracer()
assert runnable.batch([None, None], {"callbacks": [tracer]}) == [6, 6]
run_id = uuid.uuid4()
with pytest.warns(RuntimeWarning):
assert runnable.batch(
[None, None], {"callbacks": [tracer], "run_id": run_id}
) == [6, 6]
assert len(tracer.runs) == 2
assert tracer.runs[0].outputs == {"output": 6}
assert tracer.runs[1].outputs == {"output": 6}
@ -4862,19 +4907,30 @@ async def test_runnable_gen_context_config() -> None:
arunnable = RunnableGenerator(agen)
tracer = FakeTracer()
assert await arunnable.ainvoke(None, {"callbacks": [tracer]}) == 6
run_id = uuid.uuid4()
assert await arunnable.ainvoke(None, {"callbacks": [tracer], "run_id": run_id}) == 6
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
run_ids = tracer.run_ids
assert run_id in run_ids
assert len(run_ids) == len(set(run_ids))
tracer.runs.clear()
assert [p async for p in arunnable.astream(None)] == [1, 2, 3]
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
tracer = FakeTracer()
assert [p async for p in arunnable.astream(None, {"callbacks": [tracer]})] == [
run_id = uuid.uuid4()
assert [
p
async for p in arunnable.astream(
None, {"callbacks": [tracer], "run_id": run_id}
)
] == [
1,
2,
3,
@ -4884,9 +4940,16 @@ async def test_runnable_gen_context_config() -> None:
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
run_ids = tracer.run_ids
assert run_id in run_ids
assert len(run_ids) == len(set(run_ids))
tracer = FakeTracer()
assert await arunnable.abatch([None, None], {"callbacks": [tracer]}) == [6, 6]
run_id = uuid.uuid4()
with pytest.warns(RuntimeWarning):
assert await arunnable.abatch(
[None, None], {"callbacks": [tracer], "run_id": run_id}
) == [6, 6]
assert len(tracer.runs) == 2
assert tracer.runs[0].outputs == {"output": 6}
assert tracer.runs[1].outputs == {"output": 6}

@ -7,6 +7,7 @@ import pytest
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.documents import Document
from langchain_core.language_models import FakeStreamingListLLM, GenericFakeChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
@ -26,8 +27,6 @@ from langchain_core.runnables import (
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import tool
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
from tests.unit_tests.fake.llm import FakeStreamingListLLM
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]:
@ -313,6 +312,68 @@ async def test_event_stream_with_lambdas_from_lambda() -> None:
]
async def test_astream_events_from_model() -> None:
"""Test the output of a model."""
infinite_cycle = cycle(
[AIMessage(content="hello world!"), AIMessage(content="goodbye world!")]
)
# When streaming GenericFakeChatModel breaks AIMessage into chunks based on spaces
model = (
GenericFakeChatModel(messages=infinite_cycle)
.with_config(
{
"metadata": {"a": "b"},
"tags": ["my_model"],
"run_name": "my_model",
}
)
.bind(stop="<stop_token>")
)
events = await _collect_events(model.astream_events("hello", version="v1"))
assert events == [
{
"data": {"input": "hello"},
"event": "on_chat_model_start",
"metadata": {"a": "b"},
"name": "my_model",
"run_id": "",
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="hello")},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
"run_id": "",
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content=" ")},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
"run_id": "",
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="world!")},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
"run_id": "",
"tags": ["my_model"],
},
{
"data": {"output": AIMessageChunk(content="hello world!")},
"event": "on_chat_model_end",
"metadata": {"a": "b"},
"name": "my_model",
"run_id": "",
"tags": ["my_model"],
},
]
async def test_event_stream_with_simple_chain() -> None:
"""Test as event stream."""
template = ChatPromptTemplate.from_messages(

@ -2,8 +2,8 @@
import uuid
from langchain_core.language_models import FakeListLLM
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.llm import FakeListLLM
def test_collect_runs() -> None:

@ -0,0 +1,199 @@
import pytest
from langchain_core.utils.function_calling import _rm_titles
output1 = {
"type": "object",
"properties": {
"people": {
"description": "List of info about people",
"type": "array",
"items": {
"description": "Information about a person.",
"type": "object",
"properties": {
"name": {"type": "string"},
"title": {"description": "person's age", "type": "integer"},
},
"required": ["name"],
},
}
},
"required": ["people"],
}
schema1 = {
"type": "object",
"properties": {
"people": {
"title": "People",
"description": "List of info about people",
"type": "array",
"items": {
"title": "Person",
"description": "Information about a person.",
"type": "object",
"properties": {
"name": {"title": "Name", "type": "string"},
"title": {
"title": "Title",
"description": "person's age",
"type": "integer",
},
},
"required": ["name"],
},
}
},
"required": ["people"],
}
output2 = {
"type": "object",
"properties": {
"title": {
"description": "List of info about people",
"type": "array",
"items": {
"description": "Information about a person.",
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"description": "person's age", "type": "integer"},
},
"required": ["name"],
},
}
},
"required": ["title"],
}
schema2 = {
"type": "object",
"properties": {
"title": {
"title": "Title",
"description": "List of info about people",
"type": "array",
"items": {
"title": "Person",
"description": "Information about a person.",
"type": "object",
"properties": {
"name": {"title": "Name", "type": "string"},
"age": {
"title": "Age",
"description": "person's age",
"type": "integer",
},
},
"required": ["name"],
},
}
},
"required": ["title"],
}
output3 = {
"type": "object",
"properties": {
"title": {
"description": "List of info about people",
"type": "array",
"items": {
"description": "Information about a person.",
"type": "object",
"properties": {
"title": {"type": "string"},
"type": {"description": "person's age", "type": "integer"},
},
"required": ["title"],
},
}
},
"required": ["title"],
}
schema3 = {
"type": "object",
"properties": {
"title": {
"title": "Title",
"description": "List of info about people",
"type": "array",
"items": {
"title": "Person",
"description": "Information about a person.",
"type": "object",
"properties": {
"title": {"title": "Title", "type": "string"},
"type": {
"title": "Type",
"description": "person's age",
"type": "integer",
},
},
"required": ["title"],
},
}
},
"required": ["title"],
}
output4 = {
"type": "object",
"properties": {
"properties": {
"description": "Information to extract",
"type": "object",
"properties": {
"title": {
"description": "Information about papers mentioned.",
"type": "object",
"properties": {
"title": {"type": "string"},
"author": {"type": "string"},
},
"required": ["title"],
}
},
"required": ["title"],
}
},
"required": ["properties"],
}
schema4 = {
"type": "object",
"properties": {
"properties": {
"title": "Info",
"description": "Information to extract",
"type": "object",
"properties": {
"title": {
"title": "Paper",
"description": "Information about papers mentioned.",
"type": "object",
"properties": {
"title": {"title": "Title", "type": "string"},
"author": {"title": "Author", "type": "string"},
},
"required": ["title"],
}
},
"required": ["title"],
}
},
"required": ["properties"],
}
@pytest.mark.parametrize(
"schema, output",
[(schema1, output1), (schema2, output2), (schema3, output3), (schema4, output4)],
)
def test_rm_titles(schema: dict, output: dict) -> None:
assert _rm_titles(schema) == output

@ -95,7 +95,7 @@ class SemanticChunker(BaseDocumentTransformer):
"""Split the text based on semantic similarity.
Taken from Greg Kamradt's wonderful notebook:
https://github.com/FullStackRetrieval-com/RetrievalTutorials/blob/main/5_Levels_Of_Text_Splitting.ipynb
https://github.com/FullStackRetrieval-com/RetrievalTutorials/blob/main/tutorials/LevelsOfTextSplitting/5_Levels_Of_Text_Splitting.ipynb
All credits to him.
@ -106,6 +106,7 @@ class SemanticChunker(BaseDocumentTransformer):
def __init__(
self,
embeddings: Embeddings,
buffer_size: int = 1,
add_start_index: bool = False,
breakpoint_threshold_type: BreakpointThresholdType = "percentile",
breakpoint_threshold_amount: Optional[float] = None,
@ -113,6 +114,7 @@ class SemanticChunker(BaseDocumentTransformer):
):
self._add_start_index = add_start_index
self.embeddings = embeddings
self.buffer_size = buffer_size
self.breakpoint_threshold_type = breakpoint_threshold_type
self.number_of_chunks = number_of_chunks
if breakpoint_threshold_amount is None:
@ -173,7 +175,7 @@ class SemanticChunker(BaseDocumentTransformer):
_sentences = [
{"sentence": x, "index": i} for i, x in enumerate(single_sentences_list)
]
sentences = combine_sentences(_sentences)
sentences = combine_sentences(_sentences, self.buffer_size)
embeddings = self.embeddings.embed_documents(
[x["combined_sentence"] for x in sentences]
)

@ -46,5 +46,8 @@ COPY libs/core ../core
# Copy the community library for installation
COPY libs/community/ ../community/
# Copy the text-splitters library for installation
COPY libs/text-splitters/ ../text-splitters/
# Install the Poetry dependencies (this layer will be cached as long as the dependencies don't change)
RUN poetry install --no-interaction --no-ansi --with dev,test,docs

@ -171,27 +171,29 @@ def __getattr__(name: str) -> Any:
elif name == "FewShotPromptTemplate":
from langchain_core.prompts import FewShotPromptTemplate
_warn_on_import(name, replacement="langchain.prompts.FewShotPromptTemplate")
_warn_on_import(
name, replacement="langchain_core.prompts.FewShotPromptTemplate"
)
return FewShotPromptTemplate
elif name == "Prompt":
from langchain.prompts import Prompt
from langchain_core.prompts import PromptTemplate
_warn_on_import(name, replacement="langchain.prompts.Prompt")
_warn_on_import(name, replacement="langchain_core.prompts.PromptTemplate")
return Prompt
# it's renamed as prompt template anyways
# this is just for backwards compat
return PromptTemplate
elif name == "PromptTemplate":
from langchain_core.prompts import PromptTemplate
_warn_on_import(name, replacement="langchain.prompts.PromptTemplate")
_warn_on_import(name, replacement="langchain_core.prompts.PromptTemplate")
return PromptTemplate
elif name == "BasePromptTemplate":
from langchain_core.prompts import BasePromptTemplate
_warn_on_import(
name, replacement="langchain.schema.prompt_template.BasePromptTemplate"
)
_warn_on_import(name, replacement="langchain_core.prompts.BasePromptTemplate")
return BasePromptTemplate
elif name == "ArxivAPIWrapper":

@ -1,4 +1,4 @@
from typing import Sequence
from typing import List, Sequence, Union
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.chat import ChatPromptTemplate
@ -15,7 +15,7 @@ def create_json_chat_agent(
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt: ChatPromptTemplate,
stop_sequence: bool = True,
stop_sequence: Union[bool, List[str]] = True,
tools_renderer: ToolsRenderer = render_text_description,
) -> Runnable:
"""Create an agent that uses JSON to format its logic, build for Chat Models.
@ -24,7 +24,11 @@ def create_json_chat_agent(
llm: LLM to use as the agent.
tools: Tools this agent has access to.
prompt: The prompt to use. See Prompt section below for more.
stop_sequence: Adds a stop token of "Observation:" to avoid hallucinates.
stop_sequence: bool or list of str.
If True, adds a stop token of "Observation:" to avoid hallucinates.
If False, does not add a stop token.
If a list of str, uses the provided list as the stop tokens.
Default is True. You may to set this to False if the LLM you are using
does not support stop sequences.
tools_renderer: This controls how the tools are converted into a string and
@ -158,7 +162,8 @@ def create_json_chat_agent(
tool_names=", ".join([t.name for t in tools]),
)
if stop_sequence:
llm_to_use = llm.bind(stop=["\nObservation"])
stop = ["\nObservation"] if stop_sequence is True else stop_sequence
llm_to_use = llm.bind(stop=stop)
else:
llm_to_use = llm

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save