mirror of https://github.com/hwchase17/langchain
Merge branch 'master' into bagatur/0.2
commit
0495ca0d10
@ -0,0 +1,103 @@
|
||||
/* eslint-disable react/jsx-props-no-spreading */
|
||||
import React from "react";
|
||||
import Tabs from "@theme/Tabs";
|
||||
import TabItem from "@theme/TabItem";
|
||||
import CodeBlock from "@theme-original/CodeBlock";
|
||||
|
||||
function Setup({ apiKeyName, packageName }) {
|
||||
const apiKeyText = `import getpass
|
||||
import os
|
||||
|
||||
os.environ["${apiKeyName}"] = getpass.getpass()`;
|
||||
return (
|
||||
<>
|
||||
<h5>Install dependencies</h5>
|
||||
<CodeBlock language="bash">{`pip install -qU ${packageName}`}</CodeBlock>
|
||||
<h5>Set environment variables</h5>
|
||||
<CodeBlock language="python">{apiKeyText}</CodeBlock>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {{ openaiParams?: string, anthropicParams?: string, fireworksParams?: string, mistralParams?: string, googleParams?: string, hideOpenai?: boolean, hideAnthropic?: boolean, hideFireworks?: boolean, hideMistral?: boolean, hideGoogle?: boolean }} props
|
||||
*/
|
||||
export default function ChatModelTabs(props) {
|
||||
const {
|
||||
openaiParams,
|
||||
anthropicParams,
|
||||
fireworksParams,
|
||||
mistralParams,
|
||||
googleParams,
|
||||
hideOpenai,
|
||||
hideAnthropic,
|
||||
hideFireworks,
|
||||
hideMistral,
|
||||
hideGoogle,
|
||||
} = props;
|
||||
|
||||
const openAIParamsOrDefault = openaiParams ?? `model="gpt-3.5-turbo-0125"`
|
||||
const anthropicParamsOrDefault = anthropicParams ?? `model="claude-3-sonnet-20240229"`
|
||||
const fireworksParamsOrDefault = fireworksParams ?? `model="accounts/fireworks/models/mixtral-8x7b-instruct"`
|
||||
const mistralParamsOrDefault = mistralParams ?? `model="mistral-large-latest"`
|
||||
const googleParamsOrDefault = googleParams ?? `model="gemini-pro"`
|
||||
|
||||
const tabItems = [
|
||||
{
|
||||
value: "OpenAI",
|
||||
label: "OpenAI",
|
||||
text: `from langchain_openai import ChatOpenAI\n\nmodel = ChatOpenAI(${openAIParamsOrDefault})`,
|
||||
apiKeyName: "OPENAI_API_KEY",
|
||||
packageName: "langchain-openai",
|
||||
default: true,
|
||||
shouldHide: hideOpenai,
|
||||
},
|
||||
{
|
||||
value: "Anthropic",
|
||||
label: "Anthropic",
|
||||
text: `from langchain_anthropic import ChatAnthropic\n\nmodel = ChatAnthropic(${anthropicParamsOrDefault})`,
|
||||
apiKeyName: "ANTHROPIC_API_KEY",
|
||||
packageName: "langchain-anthropic",
|
||||
default: false,
|
||||
shouldHide: hideAnthropic,
|
||||
},
|
||||
{
|
||||
value: "FireworksAI",
|
||||
label: "FireworksAI",
|
||||
text: `from langchain_fireworks import ChatFireworks\n\nmodel = ChatFireworks(${fireworksParamsOrDefault})`,
|
||||
apiKeyName: "FIREWORKS_API_KEY",
|
||||
packageName: "langchain-fireworks",
|
||||
default: false,
|
||||
shouldHide: hideFireworks,
|
||||
},
|
||||
{
|
||||
value: "MistralAI",
|
||||
label: "MistralAI",
|
||||
text: `from langchain_mistralai import ChatMistralAI\n\nmodel = ChatMistralAI(${mistralParamsOrDefault})`,
|
||||
apiKeyName: "MISTRAL_API_KEY",
|
||||
packageName: "langchain-mistralai",
|
||||
default: false,
|
||||
shouldHide: hideMistral,
|
||||
},
|
||||
{
|
||||
value: "Google",
|
||||
label: "Google",
|
||||
text: `from langchain_google_genai import ChatGoogleGenerativeAI\n\nmodel = ChatGoogleGenerativeAI(${googleParamsOrDefault})`,
|
||||
apiKeyName: "GOOGLE_API_KEY",
|
||||
packageName: "langchain-google-genai",
|
||||
default: false,
|
||||
shouldHide: hideGoogle,
|
||||
}
|
||||
]
|
||||
|
||||
return (
|
||||
<Tabs groupId="modelTabs">
|
||||
{tabItems.filter((tabItem) => !tabItem.shouldHide).map((tabItem) => (
|
||||
<TabItem value={tabItem.value} label={tabItem.label} default={tabItem.default}>
|
||||
<Setup apiKeyName={tabItem.apiKeyName} packageName={tabItem.packageName} />
|
||||
<CodeBlock language="python">{tabItem.text}</CodeBlock>
|
||||
</TabItem>
|
||||
))}
|
||||
</Tabs>
|
||||
);
|
||||
}
|
@ -0,0 +1,617 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from couchbase.cluster import Cluster
|
||||
|
||||
|
||||
class CouchbaseVectorStore(VectorStore):
|
||||
"""`Couchbase Vector Store` vector store.
|
||||
|
||||
To use it, you need
|
||||
- a recent installation of the `couchbase` library
|
||||
- a Couchbase database with a pre-defined Search index with support for
|
||||
vector fields
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.vectorstores import CouchbaseVectorStore
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
from couchbase.cluster import Cluster
|
||||
from couchbase.auth import PasswordAuthenticator
|
||||
from couchbase.options import ClusterOptions
|
||||
from datetime import timedelta
|
||||
|
||||
auth = PasswordAuthenticator(username, password)
|
||||
options = ClusterOptions(auth)
|
||||
connect_string = "couchbases://localhost"
|
||||
cluster = Cluster(connect_string, options)
|
||||
|
||||
# Wait until the cluster is ready for use.
|
||||
cluster.wait_until_ready(timedelta(seconds=5))
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
|
||||
vectorstore = CouchbaseVectorStore(
|
||||
cluster=cluster,
|
||||
bucket_name="",
|
||||
scope_name="",
|
||||
collection_name="",
|
||||
embedding=embeddings,
|
||||
index_name="vector-index",
|
||||
)
|
||||
|
||||
vectorstore.add_texts(["hello", "world"])
|
||||
results = vectorstore.similarity_search("ola", k=1)
|
||||
"""
|
||||
|
||||
# Default batch size
|
||||
DEFAULT_BATCH_SIZE = 100
|
||||
_metadata_key = "metadata"
|
||||
_default_text_key = "text"
|
||||
_default_embedding_key = "embedding"
|
||||
|
||||
def _check_bucket_exists(self) -> bool:
|
||||
"""Check if the bucket exists in the linked Couchbase cluster"""
|
||||
bucket_manager = self._cluster.buckets()
|
||||
try:
|
||||
bucket_manager.get_bucket(self._bucket_name)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _check_scope_and_collection_exists(self) -> bool:
|
||||
"""Check if the scope and collection exists in the linked Couchbase bucket
|
||||
Raises a ValueError if either is not found"""
|
||||
scope_collection_map: Dict[str, Any] = {}
|
||||
|
||||
# Get a list of all scopes in the bucket
|
||||
for scope in self._bucket.collections().get_all_scopes():
|
||||
scope_collection_map[scope.name] = []
|
||||
|
||||
# Get a list of all the collections in the scope
|
||||
for collection in scope.collections:
|
||||
scope_collection_map[scope.name].append(collection.name)
|
||||
|
||||
# Check if the scope exists
|
||||
if self._scope_name not in scope_collection_map.keys():
|
||||
raise ValueError(
|
||||
f"Scope {self._scope_name} not found in Couchbase "
|
||||
f"bucket {self._bucket_name}"
|
||||
)
|
||||
|
||||
# Check if the collection exists in the scope
|
||||
if self._collection_name not in scope_collection_map[self._scope_name]:
|
||||
raise ValueError(
|
||||
f"Collection {self._collection_name} not found in scope "
|
||||
f"{self._scope_name} in Couchbase bucket {self._bucket_name}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def _check_index_exists(self) -> bool:
|
||||
"""Check if the Search index exists in the linked Couchbase cluster
|
||||
Raises a ValueError if the index does not exist"""
|
||||
if self._scoped_index:
|
||||
all_indexes = [
|
||||
index.name for index in self._scope.search_indexes().get_all_indexes()
|
||||
]
|
||||
if self._index_name not in all_indexes:
|
||||
raise ValueError(
|
||||
f"Index {self._index_name} does not exist. "
|
||||
" Please create the index before searching."
|
||||
)
|
||||
else:
|
||||
all_indexes = [
|
||||
index.name for index in self._cluster.search_indexes().get_all_indexes()
|
||||
]
|
||||
if self._index_name not in all_indexes:
|
||||
raise ValueError(
|
||||
f"Index {self._index_name} does not exist. "
|
||||
" Please create the index before searching."
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cluster: Cluster,
|
||||
bucket_name: str,
|
||||
scope_name: str,
|
||||
collection_name: str,
|
||||
embedding: Embeddings,
|
||||
index_name: str,
|
||||
*,
|
||||
text_key: Optional[str] = _default_text_key,
|
||||
embedding_key: Optional[str] = _default_embedding_key,
|
||||
scoped_index: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Couchbase Vector Store.
|
||||
|
||||
Args:
|
||||
|
||||
cluster (Cluster): couchbase cluster object with active connection.
|
||||
bucket_name (str): name of bucket to store documents in.
|
||||
scope_name (str): name of scope in the bucket to store documents in.
|
||||
collection_name (str): name of collection in the scope to store documents in
|
||||
embedding (Embeddings): embedding function to use.
|
||||
index_name (str): name of the Search index to use.
|
||||
text_key (optional[str]): key in document to use as text.
|
||||
Set to text by default.
|
||||
embedding_key (optional[str]): key in document to use for the embeddings.
|
||||
Set to embedding by default.
|
||||
scoped_index (optional[bool]): specify whether the index is a scoped index.
|
||||
Set to True by default.
|
||||
"""
|
||||
try:
|
||||
from couchbase.cluster import Cluster
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import couchbase python package. "
|
||||
"Please install couchbase SDK with `pip install couchbase`."
|
||||
) from e
|
||||
|
||||
if not isinstance(cluster, Cluster):
|
||||
raise ValueError(
|
||||
f"cluster should be an instance of couchbase.Cluster, "
|
||||
f"got {type(cluster)}"
|
||||
)
|
||||
|
||||
self._cluster = cluster
|
||||
|
||||
if not embedding:
|
||||
raise ValueError("Embeddings instance must be provided.")
|
||||
|
||||
if not bucket_name:
|
||||
raise ValueError("bucket_name must be provided.")
|
||||
|
||||
if not scope_name:
|
||||
raise ValueError("scope_name must be provided.")
|
||||
|
||||
if not collection_name:
|
||||
raise ValueError("collection_name must be provided.")
|
||||
|
||||
if not index_name:
|
||||
raise ValueError("index_name must be provided.")
|
||||
|
||||
self._bucket_name = bucket_name
|
||||
self._scope_name = scope_name
|
||||
self._collection_name = collection_name
|
||||
self._embedding_function = embedding
|
||||
self._text_key = text_key
|
||||
self._embedding_key = embedding_key
|
||||
self._index_name = index_name
|
||||
self._scoped_index = scoped_index
|
||||
|
||||
# Check if the bucket exists
|
||||
if not self._check_bucket_exists():
|
||||
raise ValueError(
|
||||
f"Bucket {self._bucket_name} does not exist. "
|
||||
" Please create the bucket before searching."
|
||||
)
|
||||
|
||||
try:
|
||||
self._bucket = self._cluster.bucket(self._bucket_name)
|
||||
self._scope = self._bucket.scope(self._scope_name)
|
||||
self._collection = self._scope.collection(self._collection_name)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Error connecting to couchbase. "
|
||||
"Please check the connection and credentials."
|
||||
) from e
|
||||
|
||||
# Check if the scope and collection exists. Throws ValueError if they don't
|
||||
try:
|
||||
self._check_scope_and_collection_exists()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
# Check if the index exists. Throws ValueError if it doesn't
|
||||
try:
|
||||
self._check_index_exists()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run texts through the embeddings and persist in vectorstore.
|
||||
|
||||
If the document IDs are passed, the existing documents (if any) will be
|
||||
overwritten with the new ones.
|
||||
|
||||
Args:
|
||||
texts (Iterable[str]): Iterable of strings to add to the vectorstore.
|
||||
metadatas (Optional[List[Dict]]): Optional list of metadatas associated
|
||||
with the texts.
|
||||
ids (Optional[List[str]]): Optional list of ids associated with the texts.
|
||||
IDs have to be unique strings across the collection.
|
||||
If it is not specified uuids are generated and used as ids.
|
||||
batch_size (Optional[int]): Optional batch size for bulk insertions.
|
||||
Default is 100.
|
||||
|
||||
Returns:
|
||||
List[str]:List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
from couchbase.exceptions import DocumentExistsException
|
||||
|
||||
if not batch_size:
|
||||
batch_size = self.DEFAULT_BATCH_SIZE
|
||||
doc_ids: List[str] = []
|
||||
|
||||
if ids is None:
|
||||
ids = [uuid.uuid4().hex for _ in texts]
|
||||
|
||||
if metadatas is None:
|
||||
metadatas = [{} for _ in texts]
|
||||
|
||||
embedded_texts = self._embedding_function.embed_documents(list(texts))
|
||||
|
||||
documents_to_insert = [
|
||||
{
|
||||
id: {
|
||||
self._text_key: text,
|
||||
self._embedding_key: vector,
|
||||
self._metadata_key: metadata,
|
||||
}
|
||||
for id, text, vector, metadata in zip(
|
||||
ids, texts, embedded_texts, metadatas
|
||||
)
|
||||
}
|
||||
]
|
||||
|
||||
# Insert in batches
|
||||
for i in range(0, len(documents_to_insert), batch_size):
|
||||
batch = documents_to_insert[i : i + batch_size]
|
||||
try:
|
||||
result = self._collection.upsert_multi(batch[0])
|
||||
if result.all_ok:
|
||||
doc_ids.extend(batch[0].keys())
|
||||
except DocumentExistsException as e:
|
||||
raise ValueError(f"Document already exists: {e}")
|
||||
|
||||
return doc_ids
|
||||
|
||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
|
||||
"""Delete documents from the vector store by ids.
|
||||
|
||||
Args:
|
||||
ids (List[str]): List of IDs of the documents to delete.
|
||||
batch_size (Optional[int]): Optional batch size for bulk deletions.
|
||||
|
||||
Returns:
|
||||
bool: True if all the documents were deleted successfully, False otherwise.
|
||||
|
||||
"""
|
||||
from couchbase.exceptions import DocumentNotFoundException
|
||||
|
||||
if ids is None:
|
||||
raise ValueError("No document ids provided to delete.")
|
||||
|
||||
batch_size = kwargs.get("batch_size", self.DEFAULT_BATCH_SIZE)
|
||||
deletion_status = True
|
||||
|
||||
# Delete in batches
|
||||
for i in range(0, len(ids), batch_size):
|
||||
batch = ids[i : i + batch_size]
|
||||
try:
|
||||
result = self._collection.remove_multi(batch)
|
||||
except DocumentNotFoundException as e:
|
||||
deletion_status = False
|
||||
raise ValueError(f"Document not found: {e}")
|
||||
|
||||
deletion_status &= result.all_ok
|
||||
|
||||
return deletion_status
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
"""Return the query embedding object."""
|
||||
return self._embedding_function
|
||||
|
||||
def _format_metadata(self, row_fields: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Helper method to format the metadata from the Couchbase Search API.
|
||||
Args:
|
||||
row_fields (Dict[str, Any]): The fields to format.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The formatted metadata.
|
||||
"""
|
||||
metadata = {}
|
||||
for key, value in row_fields.items():
|
||||
# Couchbase Search returns the metadata key with a prefix
|
||||
# `metadata.` We remove it to get the original metadata key
|
||||
if key.startswith(self._metadata_key):
|
||||
new_key = key.split(self._metadata_key + ".")[-1]
|
||||
metadata[new_key] = value
|
||||
else:
|
||||
metadata[key] = value
|
||||
|
||||
return metadata
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
search_options: Optional[Dict[str, Any]] = {},
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to embedding vector with their scores.
|
||||
|
||||
Args:
|
||||
embedding (List[float]): Embedding vector to look up documents similar to.
|
||||
k (int): Number of Documents to return.
|
||||
Defaults to 4.
|
||||
search_options (Optional[Dict[str, Any]]): Optional search options that are
|
||||
passed to Couchbase search.
|
||||
Defaults to empty dictionary.
|
||||
fields (Optional[List[str]]): Optional list of fields to include in the
|
||||
metadata of results. Note that these need to be stored in the index.
|
||||
If nothing is specified, defaults to all the fields stored in the index.
|
||||
|
||||
Returns:
|
||||
List of (Document, score) that are the most similar to the query vector.
|
||||
"""
|
||||
import couchbase.search as search
|
||||
from couchbase.options import SearchOptions
|
||||
from couchbase.vector_search import VectorQuery, VectorSearch
|
||||
|
||||
fields = kwargs.get("fields", ["*"])
|
||||
|
||||
# Document text field needs to be returned from the search
|
||||
if fields != ["*"] and self._text_key not in fields:
|
||||
fields.append(self._text_key)
|
||||
|
||||
search_req = search.SearchRequest.create(
|
||||
VectorSearch.from_vector_query(
|
||||
VectorQuery(
|
||||
self._embedding_key,
|
||||
embedding,
|
||||
k,
|
||||
)
|
||||
)
|
||||
)
|
||||
try:
|
||||
if self._scoped_index:
|
||||
search_iter = self._scope.search(
|
||||
self._index_name,
|
||||
search_req,
|
||||
SearchOptions(
|
||||
limit=k,
|
||||
fields=fields,
|
||||
raw=search_options,
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
search_iter = self._cluster.search(
|
||||
index=self._index_name,
|
||||
request=search_req,
|
||||
options=SearchOptions(limit=k, fields=fields, raw=search_options),
|
||||
)
|
||||
|
||||
docs_with_score = []
|
||||
|
||||
# Parse the results
|
||||
for row in search_iter.rows():
|
||||
text = row.fields.pop(self._text_key, "")
|
||||
|
||||
# Format the metadata from Couchbase
|
||||
metadata = self._format_metadata(row.fields)
|
||||
|
||||
score = row.score
|
||||
doc = Document(page_content=text, metadata=metadata)
|
||||
docs_with_score.append((doc, score))
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Search failed with error: {e}")
|
||||
|
||||
return docs_with_score
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
search_options: Optional[Dict[str, Any]] = {},
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return documents most similar to embedding vector with their scores.
|
||||
|
||||
Args:
|
||||
query (str): Query to look up for similar documents
|
||||
k (int): Number of Documents to return.
|
||||
Defaults to 4.
|
||||
search_options (Optional[Dict[str, Any]]): Optional search options that are
|
||||
passed to Couchbase search.
|
||||
Defaults to empty dictionary
|
||||
fields (Optional[List[str]]): Optional list of fields to include in the
|
||||
metadata of results. Note that these need to be stored in the index.
|
||||
If nothing is specified, defaults to all the fields stored in the index.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
query_embedding = self.embeddings.embed_query(query)
|
||||
docs_with_scores = self.similarity_search_with_score_by_vector(
|
||||
query_embedding, k, search_options, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_with_scores]
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
search_options: Optional[Dict[str, Any]] = {},
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return documents that are most similar to the query with their scores.
|
||||
|
||||
Args:
|
||||
query (str): Query to look up for similar documents
|
||||
k (int): Number of Documents to return.
|
||||
Defaults to 4.
|
||||
search_options (Optional[Dict[str, Any]]): Optional search options that are
|
||||
passed to Couchbase search.
|
||||
Defaults to empty dictionary.
|
||||
fields (Optional[List[str]]): Optional list of fields to include in the
|
||||
metadata of results. Note that these need to be stored in the index.
|
||||
If nothing is specified, defaults to text and metadata fields.
|
||||
|
||||
Returns:
|
||||
List of (Document, score) that are most similar to the query.
|
||||
"""
|
||||
query_embedding = self.embeddings.embed_query(query)
|
||||
docs_with_score = self.similarity_search_with_score_by_vector(
|
||||
query_embedding, k, search_options, **kwargs
|
||||
)
|
||||
return docs_with_score
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
search_options: Optional[Dict[str, Any]] = {},
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return documents that are most similar to the vector embedding.
|
||||
|
||||
Args:
|
||||
embedding (List[float]): Embedding to look up documents similar to.
|
||||
k (int): Number of Documents to return.
|
||||
Defaults to 4.
|
||||
search_options (Optional[Dict[str, Any]]): Optional search options that are
|
||||
passed to Couchbase search.
|
||||
Defaults to empty dictionary.
|
||||
fields (Optional[List[str]]): Optional list of fields to include in the
|
||||
metadata of results. Note that these need to be stored in the index.
|
||||
If nothing is specified, defaults to document text and metadata fields.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
docs_with_score = self.similarity_search_with_score_by_vector(
|
||||
embedding, k, search_options, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_with_score]
|
||||
|
||||
@classmethod
|
||||
def _from_kwargs(
|
||||
cls: Type[CouchbaseVectorStore],
|
||||
embedding: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> CouchbaseVectorStore:
|
||||
"""Initialize the Couchbase vector store from keyword arguments for the
|
||||
vector store.
|
||||
|
||||
Args:
|
||||
embedding: Embedding object to use to embed text.
|
||||
**kwargs: Keyword arguments to initialize the vector store with.
|
||||
Accepted arguments are:
|
||||
- cluster
|
||||
- bucket_name
|
||||
- scope_name
|
||||
- collection_name
|
||||
- index_name
|
||||
- text_key
|
||||
- embedding_key
|
||||
- scoped_index
|
||||
|
||||
"""
|
||||
cluster = kwargs.get("cluster", None)
|
||||
bucket_name = kwargs.get("bucket_name", None)
|
||||
scope_name = kwargs.get("scope_name", None)
|
||||
collection_name = kwargs.get("collection_name", None)
|
||||
index_name = kwargs.get("index_name", None)
|
||||
text_key = kwargs.get("text_key", cls._default_text_key)
|
||||
embedding_key = kwargs.get("embedding_key", cls._default_embedding_key)
|
||||
scoped_index = kwargs.get("scoped_index", True)
|
||||
|
||||
return cls(
|
||||
embedding=embedding,
|
||||
cluster=cluster,
|
||||
bucket_name=bucket_name,
|
||||
scope_name=scope_name,
|
||||
collection_name=collection_name,
|
||||
index_name=index_name,
|
||||
text_key=text_key,
|
||||
embedding_key=embedding_key,
|
||||
scoped_index=scoped_index,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type[CouchbaseVectorStore],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[Dict[Any, Any]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> CouchbaseVectorStore:
|
||||
"""Construct a Couchbase vector store from a list of texts.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.vectorstores import CouchbaseVectorStore
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
from couchbase.cluster import Cluster
|
||||
from couchbase.auth import PasswordAuthenticator
|
||||
from couchbase.options import ClusterOptions
|
||||
from datetime import timedelta
|
||||
|
||||
auth = PasswordAuthenticator(username, password)
|
||||
options = ClusterOptions(auth)
|
||||
connect_string = "couchbases://localhost"
|
||||
cluster = Cluster(connect_string, options)
|
||||
|
||||
# Wait until the cluster is ready for use.
|
||||
cluster.wait_until_ready(timedelta(seconds=5))
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
|
||||
texts = ["hello", "world"]
|
||||
|
||||
vectorstore = CouchbaseVectorStore.from_texts(
|
||||
texts,
|
||||
embedding=embeddings,
|
||||
cluster=cluster,
|
||||
bucket_name="",
|
||||
scope_name="",
|
||||
collection_name="",
|
||||
index_name="vector-index",
|
||||
)
|
||||
|
||||
Args:
|
||||
texts (List[str]): list of texts to add to the vector store.
|
||||
embedding (Embeddings): embedding function to use.
|
||||
metadatas (optional[List[Dict]): list of metadatas to add to documents.
|
||||
**kwargs: Keyword arguments used to initialize the vector store with and/or
|
||||
passed to `add_texts` method. Check the constructor and/or `add_texts`
|
||||
for the list of accepted arguments.
|
||||
|
||||
Returns:
|
||||
A Couchbase vector store.
|
||||
|
||||
"""
|
||||
vector_store = cls._from_kwargs(embedding, **kwargs)
|
||||
batch_size = kwargs.get("batch_size", vector_store.DEFAULT_BATCH_SIZE)
|
||||
ids = kwargs.get("ids", None)
|
||||
vector_store.add_texts(
|
||||
texts, metadatas=metadatas, ids=ids, batch_size=batch_size
|
||||
)
|
||||
|
||||
return vector_store
|
@ -0,0 +1,199 @@
|
||||
import uuid
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.utils.math import cosine_similarity
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
|
||||
class InMemoryVectorStore(VectorStore):
|
||||
"""In-memory implementation of VectorStore using a dictionary.
|
||||
Uses numpy to compute cosine similarity for search.
|
||||
|
||||
Args:
|
||||
embedding: embedding function to use.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding: Embeddings) -> None:
|
||||
self.store: Dict[str, Dict[str, Any]] = {}
|
||||
self.embedding = embedding
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding
|
||||
|
||||
def delete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
|
||||
if ids:
|
||||
for _id in ids:
|
||||
self.store.pop(_id, None)
|
||||
|
||||
async def adelete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
|
||||
self.delete(ids)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
ids = []
|
||||
vectors = self.embedding.embed_documents(list(texts))
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
doc_id = str(uuid.uuid4())
|
||||
ids.append(doc_id)
|
||||
self.store[doc_id] = {
|
||||
"id": doc_id,
|
||||
"vector": vectors[i],
|
||||
"text": text,
|
||||
"metadata": metadatas[i] if metadatas else {},
|
||||
}
|
||||
return ids
|
||||
|
||||
async def aadd_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
return self.add_texts(texts, metadatas, **kwargs)
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
docs_with_similarity = []
|
||||
for doc in self.store.values():
|
||||
similarity = float(cosine_similarity([embedding], [doc["vector"]]).item(0))
|
||||
docs_with_similarity.append(
|
||||
(
|
||||
Document(page_content=doc["text"], metadata=doc["metadata"]),
|
||||
similarity,
|
||||
)
|
||||
)
|
||||
docs_with_similarity.sort(key=lambda x: x[1], reverse=True)
|
||||
return docs_with_similarity[:k]
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
embedding = self.embedding.embed_query(query)
|
||||
docs = self.similarity_search_with_score_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
)
|
||||
return docs
|
||||
|
||||
async def asimilarity_search_with_score(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
return self.similarity_search_with_score(query, k, **kwargs)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
docs_and_scores = self.similarity_search_with_score_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
async def asimilarity_search_by_vector(
|
||||
self, embedding: List[float], k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
return self.similarity_search_by_vector(embedding, k, **kwargs)
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
return [doc for doc, _ in self.similarity_search_with_score(query, k, **kwargs)]
|
||||
|
||||
async def asimilarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
return self.similarity_search(query, k, **kwargs)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
docs_with_similarity = []
|
||||
for doc in self.store.values():
|
||||
similarity = float(cosine_similarity([embedding], [doc["vector"]]).item(0))
|
||||
docs_with_similarity.append(
|
||||
(
|
||||
doc,
|
||||
similarity,
|
||||
)
|
||||
)
|
||||
docs_with_similarity.sort(key=lambda x: x[1], reverse=True)
|
||||
prefetch_hits = docs_with_similarity[:fetch_k]
|
||||
|
||||
mmr_chosen_indices = maximal_marginal_relevance(
|
||||
np.array(embedding, dtype=np.float32),
|
||||
[doc["vector"] for doc, _ in prefetch_hits],
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
return [
|
||||
Document(
|
||||
page_content=prefetch_hits[idx][0]["text"],
|
||||
metadata=prefetch_hits[idx][0]["metadata"],
|
||||
)
|
||||
for idx in mmr_chosen_indices
|
||||
]
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
embedding_vector = self.embedding.embed_query(query)
|
||||
return self.max_marginal_relevance_search_by_vector(
|
||||
embedding_vector,
|
||||
k,
|
||||
fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> "InMemoryVectorStore":
|
||||
store = cls(
|
||||
embedding=embedding,
|
||||
)
|
||||
store.add_texts(texts=texts, metadatas=metadatas)
|
||||
return store
|
||||
|
||||
@classmethod
|
||||
async def afrom_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> "InMemoryVectorStore":
|
||||
return cls.from_texts(texts, embedding, metadatas, **kwargs)
|
@ -0,0 +1,367 @@
|
||||
"""Test Couchbase Vector Store functionality"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.vectorstores.couchbase import CouchbaseVectorStore
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
)
|
||||
|
||||
CONNECTION_STRING = os.getenv("COUCHBASE_CONNECTION_STRING", "")
|
||||
BUCKET_NAME = os.getenv("COUCHBASE_BUCKET_NAME", "")
|
||||
SCOPE_NAME = os.getenv("COUCHBASE_SCOPE_NAME", "")
|
||||
COLLECTION_NAME = os.getenv("COUCHBASE_COLLECTION_NAME", "")
|
||||
USERNAME = os.getenv("COUCHBASE_USERNAME", "")
|
||||
PASSWORD = os.getenv("COUCHBASE_PASSWORD", "")
|
||||
INDEX_NAME = os.getenv("COUCHBASE_INDEX_NAME", "")
|
||||
SLEEP_DURATION = 1
|
||||
|
||||
|
||||
def set_all_env_vars() -> bool:
|
||||
return all(
|
||||
[
|
||||
CONNECTION_STRING,
|
||||
BUCKET_NAME,
|
||||
SCOPE_NAME,
|
||||
COLLECTION_NAME,
|
||||
USERNAME,
|
||||
PASSWORD,
|
||||
INDEX_NAME,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_cluster() -> Any:
|
||||
"""Get a couchbase cluster object"""
|
||||
from datetime import timedelta
|
||||
|
||||
from couchbase.auth import PasswordAuthenticator
|
||||
from couchbase.cluster import Cluster
|
||||
from couchbase.options import ClusterOptions
|
||||
|
||||
auth = PasswordAuthenticator(USERNAME, PASSWORD)
|
||||
options = ClusterOptions(auth)
|
||||
connect_string = CONNECTION_STRING
|
||||
cluster = Cluster(connect_string, options)
|
||||
|
||||
# Wait until the cluster is ready for use.
|
||||
cluster.wait_until_ready(timedelta(seconds=5))
|
||||
|
||||
return cluster
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def cluster() -> Any:
|
||||
"""Get a couchbase cluster object"""
|
||||
return get_cluster()
|
||||
|
||||
|
||||
def delete_documents(
|
||||
cluster: Any, bucket_name: str, scope_name: str, collection_name: str
|
||||
) -> None:
|
||||
"""Delete all the documents in the collection"""
|
||||
query = f"DELETE FROM `{bucket_name}`.`{scope_name}`.`{collection_name}`"
|
||||
cluster.query(query).execute()
|
||||
|
||||
|
||||
@pytest.mark.requires("couchbase")
|
||||
@pytest.mark.skipif(
|
||||
not set_all_env_vars(), reason="Missing Couchbase environment variables"
|
||||
)
|
||||
class TestCouchbaseVectorStore:
|
||||
@classmethod
|
||||
def setup_method(self) -> None:
|
||||
cluster = get_cluster()
|
||||
# Delete all the documents in the collection
|
||||
delete_documents(cluster, BUCKET_NAME, SCOPE_NAME, COLLECTION_NAME)
|
||||
|
||||
def test_from_documents(self, cluster: Any) -> None:
|
||||
"""Test end to end search using a list of documents."""
|
||||
|
||||
documents = [
|
||||
Document(page_content="foo", metadata={"page": 1}),
|
||||
Document(page_content="bar", metadata={"page": 2}),
|
||||
Document(page_content="baz", metadata={"page": 3}),
|
||||
]
|
||||
|
||||
vectorstore = CouchbaseVectorStore.from_documents(
|
||||
documents,
|
||||
ConsistentFakeEmbeddings(),
|
||||
cluster=cluster,
|
||||
bucket_name=BUCKET_NAME,
|
||||
scope_name=SCOPE_NAME,
|
||||
collection_name=COLLECTION_NAME,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
|
||||
# Wait for the documents to be indexed
|
||||
time.sleep(SLEEP_DURATION)
|
||||
|
||||
output = vectorstore.similarity_search("baz", k=1)
|
||||
assert output[0].page_content == "baz"
|
||||
assert output[0].metadata["page"] == 3
|
||||
|
||||
def test_from_texts(self, cluster: Any) -> None:
|
||||
"""Test end to end search using a list of texts."""
|
||||
|
||||
texts = [
|
||||
"foo",
|
||||
"bar",
|
||||
"baz",
|
||||
]
|
||||
|
||||
vectorstore = CouchbaseVectorStore.from_texts(
|
||||
texts,
|
||||
ConsistentFakeEmbeddings(),
|
||||
cluster=cluster,
|
||||
index_name=INDEX_NAME,
|
||||
bucket_name=BUCKET_NAME,
|
||||
scope_name=SCOPE_NAME,
|
||||
collection_name=COLLECTION_NAME,
|
||||
)
|
||||
|
||||
# Wait for the documents to be indexed
|
||||
time.sleep(SLEEP_DURATION)
|
||||
|
||||
output = vectorstore.similarity_search("foo", k=1)
|
||||
assert len(output) == 1
|
||||
assert output[0].page_content == "foo"
|
||||
|
||||
def test_from_texts_with_metadatas(self, cluster: Any) -> None:
|
||||
"""Test end to end search using a list of texts and metadatas."""
|
||||
|
||||
texts = [
|
||||
"foo",
|
||||
"bar",
|
||||
"baz",
|
||||
]
|
||||
|
||||
metadatas = [{"a": 1}, {"b": 2}, {"c": 3}]
|
||||
|
||||
vectorstore = CouchbaseVectorStore.from_texts(
|
||||
texts,
|
||||
ConsistentFakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
cluster=cluster,
|
||||
index_name=INDEX_NAME,
|
||||
bucket_name=BUCKET_NAME,
|
||||
scope_name=SCOPE_NAME,
|
||||
collection_name=COLLECTION_NAME,
|
||||
)
|
||||
|
||||
# Wait for the documents to be indexed
|
||||
time.sleep(SLEEP_DURATION)
|
||||
|
||||
output = vectorstore.similarity_search("baz", k=1)
|
||||
assert output[0].page_content == "baz"
|
||||
assert output[0].metadata["c"] == 3
|
||||
|
||||
def test_add_texts_with_ids_and_metadatas(self, cluster: Any) -> None:
|
||||
"""Test end to end search by adding a list of texts, ids and metadatas."""
|
||||
|
||||
texts = [
|
||||
"foo",
|
||||
"bar",
|
||||
"baz",
|
||||
]
|
||||
|
||||
ids = ["a", "b", "c"]
|
||||
|
||||
metadatas = [{"a": 1}, {"b": 2}, {"c": 3}]
|
||||
|
||||
vectorstore = CouchbaseVectorStore(
|
||||
cluster=cluster,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_name=INDEX_NAME,
|
||||
bucket_name=BUCKET_NAME,
|
||||
scope_name=SCOPE_NAME,
|
||||
collection_name=COLLECTION_NAME,
|
||||
)
|
||||
|
||||
results = vectorstore.add_texts(
|
||||
texts,
|
||||
ids=ids,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
assert results == ids
|
||||
|
||||
# Wait for the documents to be indexed
|
||||
time.sleep(SLEEP_DURATION)
|
||||
|
||||
output = vectorstore.similarity_search("foo", k=1)
|
||||
assert output[0].page_content == "foo"
|
||||
assert output[0].metadata["a"] == 1
|
||||
|
||||
def test_delete_texts_with_ids(self, cluster: Any) -> None:
|
||||
"""Test deletion of documents by ids."""
|
||||
texts = [
|
||||
"foo",
|
||||
"bar",
|
||||
"baz",
|
||||
]
|
||||
|
||||
ids = ["a", "b", "c"]
|
||||
|
||||
metadatas = [{"a": 1}, {"b": 2}, {"c": 3}]
|
||||
|
||||
vectorstore = CouchbaseVectorStore(
|
||||
cluster=cluster,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_name=INDEX_NAME,
|
||||
bucket_name=BUCKET_NAME,
|
||||
scope_name=SCOPE_NAME,
|
||||
collection_name=COLLECTION_NAME,
|
||||
)
|
||||
|
||||
results = vectorstore.add_texts(
|
||||
texts,
|
||||
ids=ids,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
assert results == ids
|
||||
assert vectorstore.delete(ids)
|
||||
|
||||
# Wait for the documents to be indexed
|
||||
time.sleep(SLEEP_DURATION)
|
||||
|
||||
output = vectorstore.similarity_search("foo", k=1)
|
||||
assert len(output) == 0
|
||||
|
||||
def test_similarity_search_with_scores(self, cluster: Any) -> None:
|
||||
"""Test similarity search with scores."""
|
||||
|
||||
texts = ["foo", "bar", "baz"]
|
||||
|
||||
metadatas = [{"a": 1}, {"b": 2}, {"c": 3}]
|
||||
|
||||
vectorstore = CouchbaseVectorStore(
|
||||
cluster=cluster,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_name=INDEX_NAME,
|
||||
bucket_name=BUCKET_NAME,
|
||||
scope_name=SCOPE_NAME,
|
||||
collection_name=COLLECTION_NAME,
|
||||
)
|
||||
|
||||
vectorstore.add_texts(texts, metadatas=metadatas)
|
||||
|
||||
# Wait for the documents to be indexed
|
||||
time.sleep(SLEEP_DURATION)
|
||||
|
||||
output = vectorstore.similarity_search_with_score("foo", k=2)
|
||||
|
||||
assert len(output) == 2
|
||||
assert output[0][0].page_content == "foo"
|
||||
|
||||
# check if the scores are sorted
|
||||
assert output[0][0].metadata["a"] == 1
|
||||
assert output[0][1] > output[1][1]
|
||||
|
||||
def test_similarity_search_by_vector(self, cluster: Any) -> None:
|
||||
"""Test similarity search by vector."""
|
||||
|
||||
texts = ["foo", "bar", "baz"]
|
||||
|
||||
metadatas = [{"a": 1}, {"b": 2}, {"c": 3}]
|
||||
|
||||
vectorstore = CouchbaseVectorStore(
|
||||
cluster=cluster,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_name=INDEX_NAME,
|
||||
bucket_name=BUCKET_NAME,
|
||||
scope_name=SCOPE_NAME,
|
||||
collection_name=COLLECTION_NAME,
|
||||
)
|
||||
|
||||
vectorstore.add_texts(texts, metadatas=metadatas)
|
||||
|
||||
# Wait for the documents to be indexed
|
||||
time.sleep(SLEEP_DURATION)
|
||||
|
||||
vector = ConsistentFakeEmbeddings().embed_query("foo")
|
||||
vector_output = vectorstore.similarity_search_by_vector(vector, k=1)
|
||||
|
||||
assert vector_output[0].page_content == "foo"
|
||||
|
||||
similarity_output = vectorstore.similarity_search("foo", k=1)
|
||||
|
||||
assert similarity_output == vector_output
|
||||
|
||||
def test_output_fields(self, cluster: Any) -> None:
|
||||
"""Test that output fields are set correctly."""
|
||||
|
||||
texts = [
|
||||
"foo",
|
||||
"bar",
|
||||
"baz",
|
||||
]
|
||||
|
||||
metadatas = [{"page": 1, "a": 1}, {"page": 2, "b": 2}, {"page": 3, "c": 3}]
|
||||
|
||||
vectorstore = CouchbaseVectorStore(
|
||||
cluster=cluster,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_name=INDEX_NAME,
|
||||
bucket_name=BUCKET_NAME,
|
||||
scope_name=SCOPE_NAME,
|
||||
collection_name=COLLECTION_NAME,
|
||||
)
|
||||
|
||||
ids = vectorstore.add_texts(texts, metadatas)
|
||||
assert len(ids) == len(texts)
|
||||
|
||||
# Wait for the documents to be indexed
|
||||
time.sleep(SLEEP_DURATION)
|
||||
|
||||
output = vectorstore.similarity_search("foo", k=1, fields=["metadata.page"])
|
||||
assert output[0].page_content == "foo"
|
||||
assert output[0].metadata["page"] == 1
|
||||
assert "a" not in output[0].metadata
|
||||
|
||||
def test_hybrid_search(self, cluster: Any) -> None:
|
||||
"""Test hybrid search."""
|
||||
|
||||
texts = [
|
||||
"foo",
|
||||
"bar",
|
||||
"baz",
|
||||
]
|
||||
|
||||
metadatas = [
|
||||
{"section": "index"},
|
||||
{"section": "glossary"},
|
||||
{"section": "appendix"},
|
||||
]
|
||||
|
||||
vectorstore = CouchbaseVectorStore(
|
||||
cluster=cluster,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_name=INDEX_NAME,
|
||||
bucket_name=BUCKET_NAME,
|
||||
scope_name=SCOPE_NAME,
|
||||
collection_name=COLLECTION_NAME,
|
||||
)
|
||||
|
||||
vectorstore.add_texts(texts, metadatas=metadatas)
|
||||
|
||||
# Wait for the documents to be indexed
|
||||
time.sleep(SLEEP_DURATION)
|
||||
|
||||
result, score = vectorstore.similarity_search_with_score("foo", k=1)[0]
|
||||
|
||||
# Wait for the documents to be indexed for hybrid search
|
||||
time.sleep(SLEEP_DURATION)
|
||||
|
||||
hybrid_result, hybrid_score = vectorstore.similarity_search_with_score(
|
||||
"foo",
|
||||
k=1,
|
||||
search_options={"query": {"match": "index", "field": "metadata.section"}},
|
||||
)[0]
|
||||
|
||||
assert result == hybrid_result
|
||||
assert score <= hybrid_score
|
@ -0,0 +1,33 @@
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.vectorstores.inmemory import InMemoryVectorStore
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
)
|
||||
|
||||
|
||||
async def test_inmemory() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
store = await InMemoryVectorStore.afrom_texts(
|
||||
["foo", "bar", "baz"], ConsistentFakeEmbeddings()
|
||||
)
|
||||
output = await store.asimilarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
output = await store.asimilarity_search("bar", k=2)
|
||||
assert output == [Document(page_content="bar"), Document(page_content="baz")]
|
||||
|
||||
output2 = await store.asimilarity_search_with_score("bar", k=2)
|
||||
assert output2[0][1] > output2[1][1]
|
||||
|
||||
|
||||
async def test_inmemory_mmr() -> None:
|
||||
texts = ["foo", "foo", "fou", "foy"]
|
||||
docsearch = await InMemoryVectorStore.afrom_texts(texts, ConsistentFakeEmbeddings())
|
||||
# make sure we can k > docstore size
|
||||
output = await docsearch.amax_marginal_relevance_search(
|
||||
"foo", k=10, lambda_mult=0.1
|
||||
)
|
||||
assert len(output) == len(texts)
|
||||
assert output[0] == Document(page_content="foo")
|
||||
assert output[1] == Document(page_content="foy")
|
@ -0,0 +1,4 @@
|
||||
from langchain_core.embeddings.embeddings import Embeddings
|
||||
from langchain_core.embeddings.fake import DeterministicFakeEmbedding, FakeEmbeddings
|
||||
|
||||
__all__ = ["DeterministicFakeEmbedding", "Embeddings", "FakeEmbeddings"]
|
@ -0,0 +1,52 @@
|
||||
import hashlib
|
||||
from typing import List
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class FakeEmbeddings(Embeddings, BaseModel):
|
||||
"""Fake embedding model."""
|
||||
|
||||
size: int
|
||||
"""The size of the embedding vector."""
|
||||
|
||||
def _get_embedding(self) -> List[float]:
|
||||
import numpy as np # type: ignore[import-not-found, import-untyped]
|
||||
|
||||
return list(np.random.normal(size=self.size))
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return [self._get_embedding() for _ in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self._get_embedding()
|
||||
|
||||
|
||||
class DeterministicFakeEmbedding(Embeddings, BaseModel):
|
||||
"""
|
||||
Fake embedding model that always returns
|
||||
the same embedding vector for the same text.
|
||||
"""
|
||||
|
||||
size: int
|
||||
"""The size of the embedding vector."""
|
||||
|
||||
def _get_embedding(self, seed: int) -> List[float]:
|
||||
import numpy as np # type: ignore[import-not-found, import-untyped]
|
||||
|
||||
# set the seed for the random generator
|
||||
np.random.seed(seed)
|
||||
return list(np.random.normal(size=self.size))
|
||||
|
||||
def _get_seed(self, text: str) -> int:
|
||||
"""
|
||||
Get a seed for the random generator, using the hash of the text.
|
||||
"""
|
||||
return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return [self._get_embedding(seed=self._get_seed(_)) for _ in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self._get_embedding(seed=self._get_seed(text))
|
@ -0,0 +1,228 @@
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from langchain_core.messages.ai import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages.base import (
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
)
|
||||
from langchain_core.messages.chat import ChatMessage, ChatMessageChunk
|
||||
from langchain_core.messages.function import FunctionMessage, FunctionMessageChunk
|
||||
from langchain_core.messages.human import HumanMessage, HumanMessageChunk
|
||||
from langchain_core.messages.system import SystemMessage, SystemMessageChunk
|
||||
from langchain_core.messages.tool import ToolMessage, ToolMessageChunk
|
||||
|
||||
AnyMessage = Union[
|
||||
AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage, ToolMessage
|
||||
]
|
||||
|
||||
|
||||
def get_buffer_string(
|
||||
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
||||
) -> str:
|
||||
"""Convert a sequence of Messages to strings and concatenate them into one string.
|
||||
|
||||
Args:
|
||||
messages: Messages to be converted to strings.
|
||||
human_prefix: The prefix to prepend to contents of HumanMessages.
|
||||
ai_prefix: THe prefix to prepend to contents of AIMessages.
|
||||
|
||||
Returns:
|
||||
A single string concatenation of all input messages.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core import AIMessage, HumanMessage
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="Hi, how are you?"),
|
||||
AIMessage(content="Good, how are you?"),
|
||||
]
|
||||
get_buffer_string(messages)
|
||||
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
|
||||
"""
|
||||
string_messages = []
|
||||
for m in messages:
|
||||
if isinstance(m, HumanMessage):
|
||||
role = human_prefix
|
||||
elif isinstance(m, AIMessage):
|
||||
role = ai_prefix
|
||||
elif isinstance(m, SystemMessage):
|
||||
role = "System"
|
||||
elif isinstance(m, FunctionMessage):
|
||||
role = "Function"
|
||||
elif isinstance(m, ToolMessage):
|
||||
role = "Tool"
|
||||
elif isinstance(m, ChatMessage):
|
||||
role = m.role
|
||||
else:
|
||||
raise ValueError(f"Got unsupported message type: {m}")
|
||||
message = f"{role}: {m.content}"
|
||||
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
|
||||
message += f"{m.additional_kwargs['function_call']}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
|
||||
|
||||
def _message_from_dict(message: dict) -> BaseMessage:
|
||||
_type = message["type"]
|
||||
if _type == "human":
|
||||
return HumanMessage(**message["data"])
|
||||
elif _type == "ai":
|
||||
return AIMessage(**message["data"])
|
||||
elif _type == "system":
|
||||
return SystemMessage(**message["data"])
|
||||
elif _type == "chat":
|
||||
return ChatMessage(**message["data"])
|
||||
elif _type == "function":
|
||||
return FunctionMessage(**message["data"])
|
||||
elif _type == "tool":
|
||||
return ToolMessage(**message["data"])
|
||||
elif _type == "AIMessageChunk":
|
||||
return AIMessageChunk(**message["data"])
|
||||
elif _type == "HumanMessageChunk":
|
||||
return HumanMessageChunk(**message["data"])
|
||||
elif _type == "FunctionMessageChunk":
|
||||
return FunctionMessageChunk(**message["data"])
|
||||
elif _type == "ToolMessageChunk":
|
||||
return ToolMessageChunk(**message["data"])
|
||||
elif _type == "SystemMessageChunk":
|
||||
return SystemMessageChunk(**message["data"])
|
||||
elif _type == "ChatMessageChunk":
|
||||
return ChatMessageChunk(**message["data"])
|
||||
else:
|
||||
raise ValueError(f"Got unexpected message type: {_type}")
|
||||
|
||||
|
||||
def messages_from_dict(messages: Sequence[dict]) -> List[BaseMessage]:
|
||||
"""Convert a sequence of messages from dicts to Message objects.
|
||||
|
||||
Args:
|
||||
messages: Sequence of messages (as dicts) to convert.
|
||||
|
||||
Returns:
|
||||
List of messages (BaseMessages).
|
||||
"""
|
||||
return [_message_from_dict(m) for m in messages]
|
||||
|
||||
|
||||
def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
|
||||
"""Convert a message chunk to a message.
|
||||
|
||||
Args:
|
||||
chunk: Message chunk to convert.
|
||||
|
||||
Returns:
|
||||
Message.
|
||||
"""
|
||||
if not isinstance(chunk, BaseMessageChunk):
|
||||
return chunk
|
||||
# chunk classes always have the equivalent non-chunk class as their first parent
|
||||
return chunk.__class__.__mro__[1](
|
||||
**{k: v for k, v in chunk.__dict__.items() if k != "type"}
|
||||
)
|
||||
|
||||
|
||||
MessageLikeRepresentation = Union[BaseMessage, Tuple[str, str], str, Dict[str, Any]]
|
||||
|
||||
|
||||
def _create_message_from_message_type(
|
||||
message_type: str,
|
||||
content: str,
|
||||
name: Optional[str] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
**additional_kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
"""Create a message from a message type and content string.
|
||||
|
||||
Args:
|
||||
message_type: str the type of the message (e.g., "human", "ai", etc.)
|
||||
content: str the content string.
|
||||
|
||||
Returns:
|
||||
a message of the appropriate type.
|
||||
"""
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if name is not None:
|
||||
kwargs["name"] = name
|
||||
if tool_call_id is not None:
|
||||
kwargs["tool_call_id"] = tool_call_id
|
||||
if additional_kwargs:
|
||||
kwargs["additional_kwargs"] = additional_kwargs # type: ignore[assignment]
|
||||
if message_type in ("human", "user"):
|
||||
message: BaseMessage = HumanMessage(content=content, **kwargs)
|
||||
elif message_type in ("ai", "assistant"):
|
||||
message = AIMessage(content=content, **kwargs)
|
||||
elif message_type == "system":
|
||||
message = SystemMessage(content=content, **kwargs)
|
||||
elif message_type == "function":
|
||||
message = FunctionMessage(content=content, **kwargs)
|
||||
elif message_type == "tool":
|
||||
message = ToolMessage(content=content, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected message type: {message_type}. Use one of 'human',"
|
||||
f" 'user', 'ai', 'assistant', or 'system'."
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
def _convert_to_message(
|
||||
message: MessageLikeRepresentation,
|
||||
) -> BaseMessage:
|
||||
"""Instantiate a message from a variety of message formats.
|
||||
|
||||
The message format can be one of the following:
|
||||
|
||||
- BaseMessagePromptTemplate
|
||||
- BaseMessage
|
||||
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
|
||||
- dict: a message dict with role and content keys
|
||||
- string: shorthand for ("human", template); e.g., "{user_input}"
|
||||
|
||||
Args:
|
||||
message: a representation of a message in one of the supported formats
|
||||
|
||||
Returns:
|
||||
an instance of a message or a message template
|
||||
"""
|
||||
if isinstance(message, BaseMessage):
|
||||
_message = message
|
||||
elif isinstance(message, str):
|
||||
_message = _create_message_from_message_type("human", message)
|
||||
elif isinstance(message, tuple):
|
||||
if len(message) != 2:
|
||||
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
|
||||
message_type_str, template = message
|
||||
_message = _create_message_from_message_type(message_type_str, template)
|
||||
elif isinstance(message, dict):
|
||||
msg_kwargs = message.copy()
|
||||
try:
|
||||
msg_type = msg_kwargs.pop("role")
|
||||
msg_content = msg_kwargs.pop("content")
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Message dict must contain 'role' and 'content' keys, got {message}"
|
||||
)
|
||||
_message = _create_message_from_message_type(
|
||||
msg_type, msg_content, **msg_kwargs
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported message type: {type(message)}")
|
||||
|
||||
return _message
|
||||
|
||||
|
||||
def convert_to_messages(
|
||||
messages: Sequence[MessageLikeRepresentation],
|
||||
) -> List[BaseMessage]:
|
||||
"""Convert a sequence of messages to a list of messages.
|
||||
|
||||
Args:
|
||||
messages: Sequence of messages to convert.
|
||||
|
||||
Returns:
|
||||
List of messages (BaseMessages).
|
||||
"""
|
||||
return [_convert_to_message(m) for m in messages]
|
@ -0,0 +1,15 @@
|
||||
# from langchain_core.runnables.base import RunnableBinding
|
||||
|
||||
|
||||
# class RunnableLearnable(RunnableBinding):
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# super().__init__(*args, **kwargs)
|
||||
# self.parameters = []
|
||||
|
||||
# def backward(self):
|
||||
# for param in self.parameters:
|
||||
# param.backward()
|
||||
|
||||
# def update(self, optimizer):
|
||||
# for param in self.parameters:
|
||||
# optimizer.update(param)
|
@ -0,0 +1,16 @@
|
||||
from langchain_core.embeddings import DeterministicFakeEmbedding
|
||||
|
||||
|
||||
def test_deterministic_fake_embeddings() -> None:
|
||||
"""
|
||||
Test that the deterministic fake embeddings return the same
|
||||
embedding vector for the same text.
|
||||
"""
|
||||
fake = DeterministicFakeEmbedding(size=10)
|
||||
text = "Hello world!"
|
||||
assert fake.embed_query(text) == fake.embed_query(text)
|
||||
assert fake.embed_query(text) != fake.embed_query("Goodbye world!")
|
||||
assert fake.embed_documents([text, text]) == fake.embed_documents([text, text])
|
||||
assert fake.embed_documents([text, text]) != fake.embed_documents(
|
||||
[text, "Goodbye world!"]
|
||||
)
|
@ -0,0 +1,268 @@
|
||||
"""Module tests interaction of chat model with caching abstraction.."""
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
|
||||
from langchain_core.globals import set_llm_cache
|
||||
from langchain_core.language_models.fake_chat_models import (
|
||||
FakeListChatModel,
|
||||
GenericFakeChatModel,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
|
||||
|
||||
class InMemoryCache(BaseCache):
|
||||
"""In-memory cache used for testing purposes."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize with empty cache."""
|
||||
self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {}
|
||||
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
return self._cache.get((prompt, llm_string), None)
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
self._cache[(prompt, llm_string)] = return_val
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache."""
|
||||
self._cache = {}
|
||||
|
||||
|
||||
def test_local_cache_sync() -> None:
|
||||
"""Test that the local cache is being populated but not the global one."""
|
||||
global_cache = InMemoryCache()
|
||||
local_cache = InMemoryCache()
|
||||
try:
|
||||
set_llm_cache(global_cache)
|
||||
chat_model = FakeListChatModel(
|
||||
cache=local_cache, responses=["hello", "goodbye"]
|
||||
)
|
||||
assert chat_model.invoke("How are you?").content == "hello"
|
||||
# If the cache works we should get the same response since
|
||||
# the prompt is the same
|
||||
assert chat_model.invoke("How are you?").content == "hello"
|
||||
# The global cache should be empty
|
||||
assert global_cache._cache == {}
|
||||
# The local cache should be populated
|
||||
assert len(local_cache._cache) == 1
|
||||
llm_result = list(local_cache._cache.values())
|
||||
chat_generation = llm_result[0][0]
|
||||
assert isinstance(chat_generation, ChatGeneration)
|
||||
assert chat_generation.message.content == "hello"
|
||||
# Verify that another prompt will trigger the call to the model
|
||||
assert chat_model.invoke("meow?").content == "goodbye"
|
||||
# The global cache should be empty
|
||||
assert global_cache._cache == {}
|
||||
# The local cache should be populated
|
||||
assert len(local_cache._cache) == 2
|
||||
finally:
|
||||
set_llm_cache(None)
|
||||
|
||||
|
||||
async def test_local_cache_async() -> None:
|
||||
# Use MockCache as the cache
|
||||
global_cache = InMemoryCache()
|
||||
local_cache = InMemoryCache()
|
||||
try:
|
||||
set_llm_cache(global_cache)
|
||||
chat_model = FakeListChatModel(
|
||||
cache=local_cache, responses=["hello", "goodbye"]
|
||||
)
|
||||
assert (await chat_model.ainvoke("How are you?")).content == "hello"
|
||||
# If the cache works we should get the same response since
|
||||
# the prompt is the same
|
||||
assert (await chat_model.ainvoke("How are you?")).content == "hello"
|
||||
# The global cache should be empty
|
||||
assert global_cache._cache == {}
|
||||
# The local cache should be populated
|
||||
assert len(local_cache._cache) == 1
|
||||
llm_result = list(local_cache._cache.values())
|
||||
chat_generation = llm_result[0][0]
|
||||
assert isinstance(chat_generation, ChatGeneration)
|
||||
assert chat_generation.message.content == "hello"
|
||||
# Verify that another prompt will trigger the call to the model
|
||||
assert chat_model.invoke("meow?").content == "goodbye"
|
||||
# The global cache should be empty
|
||||
assert global_cache._cache == {}
|
||||
# The local cache should be populated
|
||||
assert len(local_cache._cache) == 2
|
||||
finally:
|
||||
set_llm_cache(None)
|
||||
|
||||
|
||||
def test_global_cache_sync() -> None:
|
||||
"""Test that the global cache gets populated when cache = True."""
|
||||
global_cache = InMemoryCache()
|
||||
try:
|
||||
set_llm_cache(global_cache)
|
||||
chat_model = FakeListChatModel(
|
||||
cache=True, responses=["hello", "goodbye", "meow", "woof"]
|
||||
)
|
||||
assert (chat_model.invoke("How are you?")).content == "hello"
|
||||
# If the cache works we should get the same response since
|
||||
# the prompt is the same
|
||||
assert (chat_model.invoke("How are you?")).content == "hello"
|
||||
# The global cache should be populated
|
||||
assert len(global_cache._cache) == 1
|
||||
llm_result = list(global_cache._cache.values())
|
||||
chat_generation = llm_result[0][0]
|
||||
assert isinstance(chat_generation, ChatGeneration)
|
||||
assert chat_generation.message.content == "hello"
|
||||
# Verify that another prompt will trigger the call to the model
|
||||
assert chat_model.invoke("nice").content == "goodbye"
|
||||
# The local cache should be populated
|
||||
assert len(global_cache._cache) == 2
|
||||
finally:
|
||||
set_llm_cache(None)
|
||||
|
||||
|
||||
async def test_global_cache_async() -> None:
|
||||
"""Test that the global cache gets populated when cache = True."""
|
||||
global_cache = InMemoryCache()
|
||||
try:
|
||||
set_llm_cache(global_cache)
|
||||
chat_model = FakeListChatModel(
|
||||
cache=True, responses=["hello", "goodbye", "meow", "woof"]
|
||||
)
|
||||
assert (await chat_model.ainvoke("How are you?")).content == "hello"
|
||||
# If the cache works we should get the same response since
|
||||
# the prompt is the same
|
||||
assert (await chat_model.ainvoke("How are you?")).content == "hello"
|
||||
# The global cache should be populated
|
||||
assert len(global_cache._cache) == 1
|
||||
llm_result = list(global_cache._cache.values())
|
||||
chat_generation = llm_result[0][0]
|
||||
assert isinstance(chat_generation, ChatGeneration)
|
||||
assert chat_generation.message.content == "hello"
|
||||
# Verify that another prompt will trigger the call to the model
|
||||
assert chat_model.invoke("nice").content == "goodbye"
|
||||
# The local cache should be populated
|
||||
assert len(global_cache._cache) == 2
|
||||
finally:
|
||||
set_llm_cache(None)
|
||||
|
||||
|
||||
def test_no_cache_sync() -> None:
|
||||
global_cache = InMemoryCache()
|
||||
try:
|
||||
set_llm_cache(global_cache)
|
||||
chat_model = FakeListChatModel(
|
||||
cache=False, responses=["hello", "goodbye"]
|
||||
) # Set cache=False
|
||||
assert (chat_model.invoke("How are you?")).content == "hello"
|
||||
# The global cache should not be populated since cache=False
|
||||
# so we should get the second response
|
||||
assert (chat_model.invoke("How are you?")).content == "goodbye"
|
||||
# The global cache should not be populated since cache=False
|
||||
assert len(global_cache._cache) == 0
|
||||
finally:
|
||||
set_llm_cache(None)
|
||||
|
||||
|
||||
async def test_no_cache_async() -> None:
|
||||
global_cache = InMemoryCache()
|
||||
try:
|
||||
set_llm_cache(global_cache)
|
||||
chat_model = FakeListChatModel(
|
||||
cache=False, responses=["hello", "goodbye"]
|
||||
) # Set cache=False
|
||||
assert (await chat_model.ainvoke("How are you?")).content == "hello"
|
||||
# The global cache should not be populated since cache=False
|
||||
# so we should get the second response
|
||||
assert (await chat_model.ainvoke("How are you?")).content == "goodbye"
|
||||
# The global cache should not be populated since cache=False
|
||||
assert len(global_cache._cache) == 0
|
||||
finally:
|
||||
set_llm_cache(None)
|
||||
|
||||
|
||||
async def test_global_cache_abatch() -> None:
|
||||
global_cache = InMemoryCache()
|
||||
try:
|
||||
set_llm_cache(global_cache)
|
||||
chat_model = FakeListChatModel(
|
||||
cache=True, responses=["hello", "goodbye", "meow", "woof"]
|
||||
)
|
||||
results = await chat_model.abatch(["first prompt", "second prompt"])
|
||||
assert results[0].content == "hello"
|
||||
assert results[1].content == "goodbye"
|
||||
|
||||
# Now try with the same prompt
|
||||
results = await chat_model.abatch(["first prompt", "first prompt"])
|
||||
assert results[0].content == "hello"
|
||||
assert results[1].content == "hello"
|
||||
|
||||
## RACE CONDITION -- note behavior is different from sync
|
||||
# Now, reset cache and test the race condition
|
||||
# For now we just hard-code the result, if this changes
|
||||
# we can investigate further
|
||||
global_cache = InMemoryCache()
|
||||
set_llm_cache(global_cache)
|
||||
assert global_cache._cache == {}
|
||||
results = await chat_model.abatch(["prompt", "prompt"])
|
||||
# suspecting that tasks will be scheduled and executed in order
|
||||
# if this ever fails, we can relax to a set comparison
|
||||
# Cache misses likely guaranteed?
|
||||
assert results[0].content == "meow"
|
||||
assert results[1].content == "woof"
|
||||
finally:
|
||||
set_llm_cache(None)
|
||||
|
||||
|
||||
def test_global_cache_batch() -> None:
|
||||
global_cache = InMemoryCache()
|
||||
try:
|
||||
set_llm_cache(global_cache)
|
||||
chat_model = FakeListChatModel(
|
||||
cache=True, responses=["hello", "goodbye", "meow", "woof"]
|
||||
)
|
||||
results = chat_model.batch(["first prompt", "second prompt"])
|
||||
# These may be in any order
|
||||
assert {results[0].content, results[1].content} == {"hello", "goodbye"}
|
||||
|
||||
# Now try with the same prompt
|
||||
results = chat_model.batch(["first prompt", "first prompt"])
|
||||
# These could be either "hello" or "goodbye" and should be identical
|
||||
assert results[0].content == results[1].content
|
||||
assert {results[0].content, results[1].content}.issubset({"hello", "goodbye"})
|
||||
|
||||
## RACE CONDITION -- note behavior is different from async
|
||||
# Now, reset cache and test the race condition
|
||||
# For now we just hard-code the result, if this changes
|
||||
# we can investigate further
|
||||
global_cache = InMemoryCache()
|
||||
set_llm_cache(global_cache)
|
||||
assert global_cache._cache == {}
|
||||
results = chat_model.batch(
|
||||
[
|
||||
"prompt",
|
||||
"prompt",
|
||||
]
|
||||
)
|
||||
assert {results[0].content, results[1].content} == {"meow"}
|
||||
finally:
|
||||
set_llm_cache(None)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Abstraction does not support caching for streaming yet.")
|
||||
def test_global_cache_stream() -> None:
|
||||
"""Test streaming."""
|
||||
global_cache = InMemoryCache()
|
||||
try:
|
||||
set_llm_cache(global_cache)
|
||||
messages = [
|
||||
AIMessage(content="hello world"),
|
||||
AIMessage(content="goodbye world"),
|
||||
]
|
||||
model = GenericFakeChatModel(messages=iter(messages), cache=True)
|
||||
chunks = [chunk for chunk in model.stream("some input")]
|
||||
assert len(chunks) == 3
|
||||
# Assert that streaming information gets cached
|
||||
assert global_cache._cache != {}
|
||||
finally:
|
||||
set_llm_cache(None)
|
File diff suppressed because one or more lines are too long
@ -0,0 +1,199 @@
|
||||
import pytest
|
||||
|
||||
from langchain_core.utils.function_calling import _rm_titles
|
||||
|
||||
output1 = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"people": {
|
||||
"description": "List of info about people",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"description": "Information about a person.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"title": {"description": "person's age", "type": "integer"},
|
||||
},
|
||||
"required": ["name"],
|
||||
},
|
||||
}
|
||||
},
|
||||
"required": ["people"],
|
||||
}
|
||||
|
||||
schema1 = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"people": {
|
||||
"title": "People",
|
||||
"description": "List of info about people",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"title": "Person",
|
||||
"description": "Information about a person.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"title": "Name", "type": "string"},
|
||||
"title": {
|
||||
"title": "Title",
|
||||
"description": "person's age",
|
||||
"type": "integer",
|
||||
},
|
||||
},
|
||||
"required": ["name"],
|
||||
},
|
||||
}
|
||||
},
|
||||
"required": ["people"],
|
||||
}
|
||||
|
||||
output2 = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {
|
||||
"description": "List of info about people",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"description": "Information about a person.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"description": "person's age", "type": "integer"},
|
||||
},
|
||||
"required": ["name"],
|
||||
},
|
||||
}
|
||||
},
|
||||
"required": ["title"],
|
||||
}
|
||||
|
||||
schema2 = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {
|
||||
"title": "Title",
|
||||
"description": "List of info about people",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"title": "Person",
|
||||
"description": "Information about a person.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"title": "Name", "type": "string"},
|
||||
"age": {
|
||||
"title": "Age",
|
||||
"description": "person's age",
|
||||
"type": "integer",
|
||||
},
|
||||
},
|
||||
"required": ["name"],
|
||||
},
|
||||
}
|
||||
},
|
||||
"required": ["title"],
|
||||
}
|
||||
|
||||
|
||||
output3 = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {
|
||||
"description": "List of info about people",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"description": "Information about a person.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"type": {"description": "person's age", "type": "integer"},
|
||||
},
|
||||
"required": ["title"],
|
||||
},
|
||||
}
|
||||
},
|
||||
"required": ["title"],
|
||||
}
|
||||
|
||||
schema3 = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {
|
||||
"title": "Title",
|
||||
"description": "List of info about people",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"title": "Person",
|
||||
"description": "Information about a person.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"title": "Title", "type": "string"},
|
||||
"type": {
|
||||
"title": "Type",
|
||||
"description": "person's age",
|
||||
"type": "integer",
|
||||
},
|
||||
},
|
||||
"required": ["title"],
|
||||
},
|
||||
}
|
||||
},
|
||||
"required": ["title"],
|
||||
}
|
||||
|
||||
|
||||
output4 = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"properties": {
|
||||
"description": "Information to extract",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {
|
||||
"description": "Information about papers mentioned.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"author": {"type": "string"},
|
||||
},
|
||||
"required": ["title"],
|
||||
}
|
||||
},
|
||||
"required": ["title"],
|
||||
}
|
||||
},
|
||||
"required": ["properties"],
|
||||
}
|
||||
|
||||
schema4 = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"properties": {
|
||||
"title": "Info",
|
||||
"description": "Information to extract",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {
|
||||
"title": "Paper",
|
||||
"description": "Information about papers mentioned.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"title": "Title", "type": "string"},
|
||||
"author": {"title": "Author", "type": "string"},
|
||||
},
|
||||
"required": ["title"],
|
||||
}
|
||||
},
|
||||
"required": ["title"],
|
||||
}
|
||||
},
|
||||
"required": ["properties"],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"schema, output",
|
||||
[(schema1, output1), (schema2, output2), (schema3, output3), (schema4, output4)],
|
||||
)
|
||||
def test_rm_titles(schema: dict, output: dict) -> None:
|
||||
assert _rm_titles(schema) == output
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue