mirror of https://github.com/hwchase17/langchain
Merge bfb9f16e13
into 242eeb537f
commit
2ecd440f00
@ -0,0 +1,636 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"# Milvus Hybrid Search\n",
|
||||
"\n",
|
||||
"> [Milvus](https://milvus.io/docs) is an open-source vector database built to power embedding similarity search and AI applications. Milvus makes unstructured data search more accessible, and provides a consistent user experience regardless of the deployment environment.\n",
|
||||
"\n",
|
||||
"This notebook goes over how to use the Milvus Hybrid Search retriever, which combines the strengths of both dense and sparse vector search.\n",
|
||||
"\n",
|
||||
"For more reference please go to [Milvus Multi-Vector Search](https://milvus.io/docs/multi-vector-search.md)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Prerequisites\n",
|
||||
"### Install dependencies\n",
|
||||
"You need to prepare to install the following dependencies\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
},
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install --upgrade --quiet pymilvus[model] langchain-milvus langchain-openai"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"Import necessary modules and classes"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
},
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pymilvus import (\n",
|
||||
" Collection,\n",
|
||||
" CollectionSchema,\n",
|
||||
" DataType,\n",
|
||||
" FieldSchema,\n",
|
||||
" WeightedRanker,\n",
|
||||
" connections,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_core.output_parsers import StrOutputParser\n",
|
||||
"from langchain_core.prompts import PromptTemplate\n",
|
||||
"from langchain_core.runnables import RunnablePassthrough\n",
|
||||
"from langchain_milvus.retrievers import MilvusCollectionHybridSearchRetriever\n",
|
||||
"from langchain_milvus.utils.sparse import BM25SparseEmbedding\n",
|
||||
"from langchain_openai import ChatOpenAI, OpenAIEmbeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"### Start the Milvus service\n",
|
||||
"\n",
|
||||
"Please refer to the [Milvus documentation](https://milvus.io/docs/install_standalone-docker.md) to start the Milvus service.\n",
|
||||
"\n",
|
||||
"After starting milvus, you need to specify your milvus connection URI.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
},
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"CONNECTION_URI = \"http://localhost:19530\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"### Prepare OpenAI API Key\n",
|
||||
"\n",
|
||||
"Please refer to the [OpenAI documentation](https://platform.openai.com/account/api-keys) to obtain your OpenAI API key, and set it as an environment variable.\n",
|
||||
"\n",
|
||||
"```shell\n",
|
||||
"export OPENAI_API_KEY=<your_api_key>\n",
|
||||
"```\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"## Prepare data and Load\n",
|
||||
"### Prepare dense and sparse embedding functions\n",
|
||||
"\n",
|
||||
" Let us fictionalize 10 fake descriptions of novels. In actual production, it may be a large amount of text data."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"texts = [\n",
|
||||
" \"In 'The Whispering Walls' by Ava Moreno, a young journalist named Sophia uncovers a decades-old conspiracy hidden within the crumbling walls of an ancient mansion, where the whispers of the past threaten to destroy her own sanity.\",\n",
|
||||
" \"In 'The Last Refuge' by Ethan Blackwood, a group of survivors must band together to escape a post-apocalyptic wasteland, where the last remnants of humanity cling to life in a desperate bid for survival.\",\n",
|
||||
" \"In 'The Memory Thief' by Lila Rose, a charismatic thief with the ability to steal and manipulate memories is hired by a mysterious client to pull off a daring heist, but soon finds themselves trapped in a web of deceit and betrayal.\",\n",
|
||||
" \"In 'The City of Echoes' by Julian Saint Clair, a brilliant detective must navigate a labyrinthine metropolis where time is currency, and the rich can live forever, but at a terrible cost to the poor.\",\n",
|
||||
" \"In 'The Starlight Serenade' by Ruby Flynn, a shy astronomer discovers a mysterious melody emanating from a distant star, which leads her on a journey to uncover the secrets of the universe and her own heart.\",\n",
|
||||
" \"In 'The Shadow Weaver' by Piper Redding, a young orphan discovers she has the ability to weave powerful illusions, but soon finds herself at the center of a deadly game of cat and mouse between rival factions vying for control of the mystical arts.\",\n",
|
||||
" \"In 'The Lost Expedition' by Caspian Grey, a team of explorers ventures into the heart of the Amazon rainforest in search of a lost city, but soon finds themselves hunted by a ruthless treasure hunter and the treacherous jungle itself.\",\n",
|
||||
" \"In 'The Clockwork Kingdom' by Augusta Wynter, a brilliant inventor discovers a hidden world of clockwork machines and ancient magic, where a rebellion is brewing against the tyrannical ruler of the land.\",\n",
|
||||
" \"In 'The Phantom Pilgrim' by Rowan Welles, a charismatic smuggler is hired by a mysterious organization to transport a valuable artifact across a war-torn continent, but soon finds themselves pursued by deadly assassins and rival factions.\",\n",
|
||||
" \"In 'The Dreamwalker's Journey' by Lyra Snow, a young dreamwalker discovers she has the ability to enter people's dreams, but soon finds herself trapped in a surreal world of nightmares and illusions, where the boundaries between reality and fantasy blur.\",\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We will use the [OpenAI Embedding](https://platform.openai.com/docs/guides/embeddings) to generate dense vectors, and the [BM25 algorithm](https://en.wikipedia.org/wiki/Okapi_BM25) to generate sparse vectors.\n",
|
||||
"\n",
|
||||
"Initialize dense embedding function and get dimension"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"1536"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dense_embedding_func = OpenAIEmbeddings()\n",
|
||||
"dense_dim = len(dense_embedding_func.embed_query(texts[1]))\n",
|
||||
"dense_dim"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Initialize sparse embedding function.\n",
|
||||
"\n",
|
||||
"Note that the output of sparse embedding is a set of sparse vectors, which represents the index and weight of the keywords of the input text."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{0: 0.4270424944042204,\n",
|
||||
" 21: 1.845826690498331,\n",
|
||||
" 22: 1.845826690498331,\n",
|
||||
" 23: 1.845826690498331,\n",
|
||||
" 24: 1.845826690498331,\n",
|
||||
" 25: 1.845826690498331,\n",
|
||||
" 26: 1.845826690498331,\n",
|
||||
" 27: 1.2237754316221157,\n",
|
||||
" 28: 1.845826690498331,\n",
|
||||
" 29: 1.845826690498331,\n",
|
||||
" 30: 1.845826690498331,\n",
|
||||
" 31: 1.845826690498331,\n",
|
||||
" 32: 1.845826690498331,\n",
|
||||
" 33: 1.845826690498331,\n",
|
||||
" 34: 1.845826690498331,\n",
|
||||
" 35: 1.845826690498331,\n",
|
||||
" 36: 1.845826690498331,\n",
|
||||
" 37: 1.845826690498331,\n",
|
||||
" 38: 1.845826690498331,\n",
|
||||
" 39: 1.845826690498331}"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sparse_embedding_func = BM25SparseEmbedding(corpus=texts)\n",
|
||||
"sparse_embedding_func.embed_query(texts[1])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create Milvus Collection and load data\n",
|
||||
"\n",
|
||||
"Initialize connection URI and establish connection"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"connections.connect(uri=CONNECTION_URI)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Define field names and their data types"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pk_field = \"doc_id\"\n",
|
||||
"dense_field = \"dense_vector\"\n",
|
||||
"sparse_field = \"sparse_vector\"\n",
|
||||
"text_field = \"text\"\n",
|
||||
"fields = [\n",
|
||||
" FieldSchema(\n",
|
||||
" name=pk_field,\n",
|
||||
" dtype=DataType.VARCHAR,\n",
|
||||
" is_primary=True,\n",
|
||||
" auto_id=True,\n",
|
||||
" max_length=100,\n",
|
||||
" ),\n",
|
||||
" FieldSchema(name=dense_field, dtype=DataType.FLOAT_VECTOR, dim=dense_dim),\n",
|
||||
" FieldSchema(name=sparse_field, dtype=DataType.SPARSE_FLOAT_VECTOR),\n",
|
||||
" FieldSchema(name=text_field, dtype=DataType.VARCHAR, max_length=65_535),\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Create a collection with the defined schema"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"schema = CollectionSchema(fields=fields, enable_dynamic_field=False)\n",
|
||||
"collection = Collection(\n",
|
||||
" name=\"IntroductionToTheNovels\", schema=schema, consistency_level=\"Strong\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Define index for dense and sparse vectors"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dense_index = {\"index_type\": \"FLAT\", \"metric_type\": \"IP\"}\n",
|
||||
"collection.create_index(\"dense_vector\", dense_index)\n",
|
||||
"sparse_index = {\"index_type\": \"SPARSE_INVERTED_INDEX\", \"metric_type\": \"IP\"}\n",
|
||||
"collection.create_index(\"sparse_vector\", sparse_index)\n",
|
||||
"collection.flush()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Insert entities into the collection and load the collection"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"entities = []\n",
|
||||
"for text in texts:\n",
|
||||
" entity = {\n",
|
||||
" dense_field: dense_embedding_func.embed_documents([text])[0],\n",
|
||||
" sparse_field: sparse_embedding_func.embed_documents([text])[0],\n",
|
||||
" text_field: text,\n",
|
||||
" }\n",
|
||||
" entities.append(entity)\n",
|
||||
"collection.insert(entities)\n",
|
||||
"collection.load()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Build RAG chain with Retriever\n",
|
||||
"### Create the Retriever\n",
|
||||
"\n",
|
||||
"Define search parameters for sparse and dense fields, and create a retriever"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sparse_search_params = {\"metric_type\": \"IP\"}\n",
|
||||
"dense_search_params = {\"metric_type\": \"IP\", \"params\": {}}\n",
|
||||
"retriever = MilvusCollectionHybridSearchRetriever(\n",
|
||||
" collection=collection,\n",
|
||||
" rerank=WeightedRanker(0.5, 0.5),\n",
|
||||
" anns_fields=[dense_field, sparse_field],\n",
|
||||
" field_embeddings=[dense_embedding_func, sparse_embedding_func],\n",
|
||||
" field_search_params=[dense_search_params, sparse_search_params],\n",
|
||||
" top_k=3,\n",
|
||||
" text_field=text_field,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"In the input parameters of this Retriever, we use a dense embedding and a sparse embedding to perform hybrid search on the two fields of this Collection, and use WeightedRanker for reranking. Finally, 3 top-K Documents will be returned."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content=\"In 'The Lost Expedition' by Caspian Grey, a team of explorers ventures into the heart of the Amazon rainforest in search of a lost city, but soon finds themselves hunted by a ruthless treasure hunter and the treacherous jungle itself.\", metadata={'doc_id': '449281835035545843'}),\n",
|
||||
" Document(page_content=\"In 'The Phantom Pilgrim' by Rowan Welles, a charismatic smuggler is hired by a mysterious organization to transport a valuable artifact across a war-torn continent, but soon finds themselves pursued by deadly assassins and rival factions.\", metadata={'doc_id': '449281835035545845'}),\n",
|
||||
" Document(page_content=\"In 'The Dreamwalker's Journey' by Lyra Snow, a young dreamwalker discovers she has the ability to enter people's dreams, but soon finds herself trapped in a surreal world of nightmares and illusions, where the boundaries between reality and fantasy blur.\", metadata={'doc_id': '449281835035545846'})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"retriever.invoke(\"What are the story about ventures?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Build the RAG chain\n",
|
||||
"\n",
|
||||
"Initialize ChatOpenAI and define a prompt template"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = ChatOpenAI()\n",
|
||||
"\n",
|
||||
"PROMPT_TEMPLATE = \"\"\"\n",
|
||||
"Human: You are an AI assistant, and provides answers to questions by using fact based and statistical information when possible.\n",
|
||||
"Use the following pieces of information to provide a concise answer to the question enclosed in <question> tags.\n",
|
||||
"\n",
|
||||
"<context>\n",
|
||||
"{context}\n",
|
||||
"</context>\n",
|
||||
"\n",
|
||||
"<question>\n",
|
||||
"{question}\n",
|
||||
"</question>\n",
|
||||
"\n",
|
||||
"Assistant:\"\"\"\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate(\n",
|
||||
" template=PROMPT_TEMPLATE, input_variables=[\"context\", \"question\"]\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"Define a function for formatting documents"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
},
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def format_docs(docs):\n",
|
||||
" return \"\\n\\n\".join(doc.page_content for doc in docs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"Define a chain using the retriever and other components"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
},
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"rag_chain = (\n",
|
||||
" {\"context\": retriever | format_docs, \"question\": RunnablePassthrough()}\n",
|
||||
" | prompt\n",
|
||||
" | llm\n",
|
||||
" | StrOutputParser()\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"Perform a query using the defined chain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
},
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"Lila Rose has written 'The Memory Thief,' which follows a charismatic thief with the ability to steal and manipulate memories as they navigate a daring heist and a web of deceit and betrayal.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"rag_chain.invoke(\"What novels has Lila written and what are their contents?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"Drop the collection"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
},
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"collection.drop()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
@ -0,0 +1 @@
|
||||
__pycache__
|
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
@ -0,0 +1,57 @@
|
||||
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
integration_test integration_tests: TEST_FILE=tests/integration_tests/
|
||||
|
||||
test tests integration_test integration_tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
MYPY_CACHE=.mypy_cache
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/milvus --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
lint_package: PYTHON_FILES=langchain_milvus
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
poetry run ruff .
|
||||
poetry run ruff format $(PYTHON_FILES) --diff
|
||||
poetry run ruff --select I $(PYTHON_FILES)
|
||||
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
poetry run ruff format $(PYTHON_FILES)
|
||||
poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
check_imports: $(shell find langchain_milvus -name '*.py')
|
||||
poetry run python ./scripts/check_imports.py $^
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'check_imports - check imports'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
@ -0,0 +1,42 @@
|
||||
# langchain-milvus
|
||||
|
||||
This is a library integration with [Milvus](https://milvus.io/) and [Zilliz Cloud](https://zilliz.com/cloud).
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install -U langchain-milvus
|
||||
```
|
||||
|
||||
## Milvus vector database
|
||||
|
||||
See a [usage example](https://python.langchain.com/docs/integrations/vectorstores/milvus/)
|
||||
|
||||
```python
|
||||
from langchain_milvus import Milvus
|
||||
```
|
||||
|
||||
## Milvus hybrid search
|
||||
|
||||
See a [usage example](https://python.langchain.com/docs/integrations/retrievers/milvus_hybrid_search/).
|
||||
|
||||
```python
|
||||
from langchain_milvus import MilvusCollectionHybridSearchRetriever
|
||||
```
|
||||
|
||||
|
||||
## Zilliz Cloud vector database
|
||||
|
||||
See a [usage example](https://python.langchain.com/docs/integrations/vectorstores/zilliz/).
|
||||
|
||||
```python
|
||||
from langchain_milvus import Zilliz
|
||||
```
|
||||
|
||||
## Zilliz Cloud Pipeline Retriever
|
||||
|
||||
See a [usage example](https://python.langchain.com/docs/integrations/retrievers/zilliz_cloud_pipeline).
|
||||
|
||||
```python
|
||||
from langchain_milvus import ZillizCloudPipelineRetriever
|
||||
```
|
@ -0,0 +1,12 @@
|
||||
from langchain_milvus.retrievers import (
|
||||
MilvusCollectionHybridSearchRetriever,
|
||||
ZillizCloudPipelineRetriever,
|
||||
)
|
||||
from langchain_milvus.vectorstores import Milvus, Zilliz
|
||||
|
||||
__all__ = [
|
||||
"Milvus",
|
||||
"Zilliz",
|
||||
"ZillizCloudPipelineRetriever",
|
||||
"MilvusCollectionHybridSearchRetriever",
|
||||
]
|
@ -0,0 +1,8 @@
|
||||
from langchain_milvus.retrievers.milvus_hybrid_search import (
|
||||
MilvusCollectionHybridSearchRetriever,
|
||||
)
|
||||
from langchain_milvus.retrievers.zilliz_cloud_pipeline_retriever import (
|
||||
ZillizCloudPipelineRetriever,
|
||||
)
|
||||
|
||||
__all__ = ["ZillizCloudPipelineRetriever", "MilvusCollectionHybridSearchRetriever"]
|
@ -0,0 +1,160 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pymilvus import AnnSearchRequest, Collection
|
||||
from pymilvus.client.abstract import BaseRanker, SearchResult # type: ignore
|
||||
|
||||
from langchain_milvus.utils.sparse import BaseSparseEmbedding
|
||||
|
||||
|
||||
class MilvusCollectionHybridSearchRetriever(BaseRetriever):
|
||||
"""This is a hybrid search retriever
|
||||
that uses Milvus Collection to retrieve documents based on multiple fields.
|
||||
For more information, please refer to:
|
||||
https://milvus.io/docs/release_notes.md#Multi-Embedding---Hybrid-Search
|
||||
"""
|
||||
|
||||
collection: Collection
|
||||
"""Milvus Collection object."""
|
||||
rerank: BaseRanker
|
||||
"""Milvus ranker object. Such as WeightedRanker or RRFRanker."""
|
||||
anns_fields: List[str]
|
||||
"""The names of vector fields that are used for ANNS search."""
|
||||
field_embeddings: List[Union[Embeddings, BaseSparseEmbedding]]
|
||||
"""The embedding functions of each vector fields,
|
||||
which can be either Embeddings or BaseSparseEmbedding."""
|
||||
field_search_params: Optional[List[Dict]] = None
|
||||
"""The search parameters of each vector fields.
|
||||
If not specified, the default search parameters will be used."""
|
||||
field_limits: Optional[List[int]] = None
|
||||
"""Limit number of results for each ANNS field.
|
||||
If not specified, the default top_k will be used."""
|
||||
field_exprs: Optional[List[Optional[str]]] = None
|
||||
"""The boolean expression for filtering the search results."""
|
||||
top_k: int = 4
|
||||
"""Final top-K number of documents to retrieve."""
|
||||
text_field: str = "text"
|
||||
"""The text field name,
|
||||
which will be used as the `page_content` of a `Document` object."""
|
||||
output_fields: Optional[List[str]] = None
|
||||
"""Final output fields of the documents.
|
||||
If not specified, all fields except the vector fields will be used as output fields,
|
||||
which will be the `metadata` of a `Document` object."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# If some parameters are not specified, set default values
|
||||
if self.field_search_params is None:
|
||||
default_search_params = {
|
||||
"metric_type": "L2",
|
||||
"params": {"nprobe": 10},
|
||||
}
|
||||
self.field_search_params = [default_search_params] * len(self.anns_fields)
|
||||
if self.field_limits is None:
|
||||
self.field_limits = [self.top_k] * len(self.anns_fields)
|
||||
if self.field_exprs is None:
|
||||
self.field_exprs = [None] * len(self.anns_fields)
|
||||
|
||||
# Check the fields
|
||||
self._validate_fields_num()
|
||||
self.output_fields = self._get_output_fields()
|
||||
self._validate_fields_name()
|
||||
|
||||
# Load collection
|
||||
self.collection.load()
|
||||
|
||||
def _validate_fields_num(self) -> None:
|
||||
assert (
|
||||
len(self.anns_fields) >= 2
|
||||
), "At least two fields are required for hybrid search."
|
||||
lengths = [len(self.anns_fields)]
|
||||
if self.field_limits is not None:
|
||||
lengths.append(len(self.field_limits))
|
||||
if self.field_exprs is not None:
|
||||
lengths.append(len(self.field_exprs))
|
||||
|
||||
if not all(length == lengths[0] for length in lengths):
|
||||
raise ValueError("All field-related lists must have the same length.")
|
||||
|
||||
if len(self.field_search_params) != len(self.anns_fields): # type: ignore[arg-type]
|
||||
raise ValueError(
|
||||
"field_search_params must have the same length as anns_fields."
|
||||
)
|
||||
|
||||
def _validate_fields_name(self) -> None:
|
||||
collection_fields = [x.name for x in self.collection.schema.fields]
|
||||
for field in self.anns_fields:
|
||||
assert (
|
||||
field in collection_fields
|
||||
), f"{field} is not a valid field in the collection."
|
||||
assert (
|
||||
self.text_field in collection_fields
|
||||
), f"{self.text_field} is not a valid field in the collection."
|
||||
for field in self.output_fields: # type: ignore[union-attr]
|
||||
assert (
|
||||
field in collection_fields
|
||||
), f"{field} is not a valid field in the collection."
|
||||
|
||||
def _get_output_fields(self) -> List[str]:
|
||||
if self.output_fields:
|
||||
return self.output_fields
|
||||
output_fields = [x.name for x in self.collection.schema.fields]
|
||||
for field in self.anns_fields:
|
||||
if field in output_fields:
|
||||
output_fields.remove(field)
|
||||
if self.text_field not in output_fields:
|
||||
output_fields.append(self.text_field)
|
||||
return output_fields
|
||||
|
||||
def _build_ann_search_requests(self, query: str) -> List[AnnSearchRequest]:
|
||||
search_requests = []
|
||||
for ann_field, embedding, param, limit, expr in zip(
|
||||
self.anns_fields,
|
||||
self.field_embeddings,
|
||||
self.field_search_params, # type: ignore[arg-type]
|
||||
self.field_limits, # type: ignore[arg-type]
|
||||
self.field_exprs, # type: ignore[arg-type]
|
||||
):
|
||||
request = AnnSearchRequest(
|
||||
data=[embedding.embed_query(query)],
|
||||
anns_field=ann_field,
|
||||
param=param,
|
||||
limit=limit,
|
||||
expr=expr,
|
||||
)
|
||||
search_requests.append(request)
|
||||
return search_requests
|
||||
|
||||
def _parse_document(self, data: dict) -> Document:
|
||||
return Document(
|
||||
page_content=data.pop(self.text_field),
|
||||
metadata=data,
|
||||
)
|
||||
|
||||
def _process_search_result(
|
||||
self, search_results: List[SearchResult]
|
||||
) -> List[Document]:
|
||||
documents = []
|
||||
for result in search_results[0]:
|
||||
data = {x: result.entity.get(x) for x in self.output_fields} # type: ignore[union-attr]
|
||||
doc = self._parse_document(data)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
requests = self._build_ann_search_requests(query)
|
||||
search_result = self.collection.hybrid_search(
|
||||
requests, self.rerank, limit=self.top_k, output_fields=self.output_fields
|
||||
)
|
||||
documents = self._process_search_result(search_result)
|
||||
return documents
|
@ -0,0 +1,215 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
class ZillizCloudPipelineRetriever(BaseRetriever):
|
||||
"""`Zilliz Cloud Pipeline` retriever
|
||||
|
||||
Args:
|
||||
pipeline_ids (dict): A dictionary of pipeline ids.
|
||||
Valid keys: "ingestion", "search", "deletion".
|
||||
token (str): Zilliz Cloud's token. Defaults to "".
|
||||
cloud_region (str='gcp-us-west1'): The region of Zilliz Cloud's cluster.
|
||||
Defaults to 'gcp-us-west1'.
|
||||
"""
|
||||
|
||||
pipeline_ids: Dict
|
||||
token: str = ""
|
||||
cloud_region: str = "gcp-us-west1"
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 10,
|
||||
offset: int = 0,
|
||||
output_fields: List = [],
|
||||
filter: str = "",
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Get documents relevant to a query.
|
||||
|
||||
Args:
|
||||
query (str): String to find relevant documents for
|
||||
top_k (int=10): The number of results. Defaults to 10.
|
||||
offset (int=0): The number of records to skip in the search result.
|
||||
Defaults to 0.
|
||||
output_fields (list=[]): The extra fields to present in output.
|
||||
filter (str=""): The Milvus expression to filter search results.
|
||||
Defaults to "".
|
||||
run_manager (CallBackManagerForRetrieverRun): The callbacks handler to use.
|
||||
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
if "search" in self.pipeline_ids:
|
||||
search_pipe_id = self.pipeline_ids.get("search")
|
||||
else:
|
||||
raise Exception(
|
||||
"A search pipeline id must be provided in pipeline_ids to "
|
||||
"get relevant documents."
|
||||
)
|
||||
domain = (
|
||||
f"https://controller.api.{self.cloud_region}.zillizcloud.com/v1/pipelines"
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.token}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{domain}/{search_pipe_id}/run"
|
||||
|
||||
params = {
|
||||
"data": {"query_text": query},
|
||||
"params": {
|
||||
"limit": top_k,
|
||||
"offset": offset,
|
||||
"outputFields": output_fields,
|
||||
"filter": filter,
|
||||
},
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=params)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(response.text)
|
||||
response_dict = response.json()
|
||||
if response_dict["code"] != 200:
|
||||
raise RuntimeError(response_dict)
|
||||
response_data = response_dict["data"]
|
||||
search_results = response_data["result"]
|
||||
return [
|
||||
Document(
|
||||
page_content=result.pop("text")
|
||||
if "text" in result
|
||||
else result.pop("chunk_text"),
|
||||
metadata=result,
|
||||
)
|
||||
for result in search_results
|
||||
]
|
||||
|
||||
def add_texts(
|
||||
self, texts: List[str], metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Add documents to store.
|
||||
Only supported by a text ingestion pipeline in Zilliz Cloud.
|
||||
|
||||
Args:
|
||||
texts (List[str]): A list of text strings.
|
||||
metadata (Dict[str, Any]): A key-value dictionary of metadata will
|
||||
be inserted as preserved fields required by ingestion pipeline.
|
||||
Defaults to None.
|
||||
"""
|
||||
if "ingestion" in self.pipeline_ids:
|
||||
ingeset_pipe_id = self.pipeline_ids.get("ingestion")
|
||||
else:
|
||||
raise Exception(
|
||||
"An ingestion pipeline id must be provided in pipeline_ids to"
|
||||
" add documents."
|
||||
)
|
||||
domain = (
|
||||
f"https://controller.api.{self.cloud_region}.zillizcloud.com/v1/pipelines"
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.token}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{domain}/{ingeset_pipe_id}/run"
|
||||
|
||||
metadata = {} if metadata is None else metadata
|
||||
params = {"data": {"text_list": texts}}
|
||||
params["data"].update(metadata)
|
||||
|
||||
response = requests.post(url, headers=headers, json=params)
|
||||
if response.status_code != 200:
|
||||
raise Exception(response.text)
|
||||
response_dict = response.json()
|
||||
if response_dict["code"] != 200:
|
||||
raise Exception(response_dict)
|
||||
response_data = response_dict["data"]
|
||||
return response_data
|
||||
|
||||
def add_doc_url(
|
||||
self, doc_url: str, metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Add a document from url.
|
||||
Only supported by a document ingestion pipeline in Zilliz Cloud.
|
||||
|
||||
Args:
|
||||
doc_url: A document url.
|
||||
metadata (Dict[str, Any]): A key-value dictionary of metadata will
|
||||
be inserted as preserved fields required by ingestion pipeline.
|
||||
Defaults to None.
|
||||
"""
|
||||
if "ingestion" in self.pipeline_ids:
|
||||
ingest_pipe_id = self.pipeline_ids.get("ingestion")
|
||||
else:
|
||||
raise Exception(
|
||||
"An ingestion pipeline id must be provided in pipeline_ids to "
|
||||
"add documents."
|
||||
)
|
||||
domain = (
|
||||
f"https://controller.api.{self.cloud_region}.zillizcloud.com/v1/pipelines"
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.token}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{domain}/{ingest_pipe_id}/run"
|
||||
|
||||
params = {"data": {"doc_url": doc_url}}
|
||||
metadata = {} if metadata is None else metadata
|
||||
params["data"].update(metadata)
|
||||
|
||||
response = requests.post(url, headers=headers, json=params)
|
||||
if response.status_code != 200:
|
||||
raise Exception(response.text)
|
||||
response_dict = response.json()
|
||||
if response_dict["code"] != 200:
|
||||
raise Exception(response_dict)
|
||||
response_data = response_dict["data"]
|
||||
return response_data
|
||||
|
||||
def delete(self, key: str, value: Any) -> Dict:
|
||||
"""
|
||||
Delete documents. Only supported by a deletion pipeline in Zilliz Cloud.
|
||||
|
||||
Args:
|
||||
key: input name to run the deletion pipeline
|
||||
value: input value to run deletion pipeline
|
||||
"""
|
||||
if "deletion" in self.pipeline_ids:
|
||||
deletion_pipe_id = self.pipeline_ids.get("deletion")
|
||||
else:
|
||||
raise Exception(
|
||||
"A deletion pipeline id must be provided in pipeline_ids to "
|
||||
"add documents."
|
||||
)
|
||||
domain = (
|
||||
f"https://controller.api.{self.cloud_region}.zillizcloud.com/v1/pipelines"
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.token}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{domain}/{deletion_pipe_id}/run"
|
||||
|
||||
params = {"data": {key: value}}
|
||||
|
||||
response = requests.post(url, headers=headers, json=params)
|
||||
if response.status_code != 200:
|
||||
raise Exception(response.text)
|
||||
response_dict = response.json()
|
||||
if response_dict["code"] != 200:
|
||||
raise Exception(response_dict)
|
||||
response_data = response_dict["data"]
|
||||
return response_data
|
@ -0,0 +1,54 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List
|
||||
|
||||
from scipy.sparse import csr_array # type: ignore
|
||||
|
||||
|
||||
class BaseSparseEmbedding(ABC):
|
||||
"""Interface for Sparse embedding models.
|
||||
You can inherit from it and implement your custom sparse embedding model.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, query: str) -> Dict[int, float]:
|
||||
"""Embed query text."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, texts: List[str]) -> List[Dict[int, float]]:
|
||||
"""Embed search docs."""
|
||||
|
||||
|
||||
class BM25SparseEmbedding(BaseSparseEmbedding):
|
||||
"""This is a class that inherits BaseSparseEmbedding
|
||||
and implements a sparse vector embedding model based on BM25.
|
||||
This class uses the BM25 model in Milvus model to implement sparse vector embedding.
|
||||
This model requires pymilvus[model] to be installed.
|
||||
`pip install pymilvus[model]`
|
||||
For more information please refer to:
|
||||
https://milvus.io/docs/embed-with-bm25.md
|
||||
"""
|
||||
|
||||
def __init__(self, corpus: List[str], language: str = "en"):
|
||||
from pymilvus.model.sparse import BM25EmbeddingFunction # type: ignore
|
||||
from pymilvus.model.sparse.bm25.tokenizers import ( # type: ignore
|
||||
build_default_analyzer,
|
||||
)
|
||||
|
||||
self.analyzer = build_default_analyzer(language=language)
|
||||
self.bm25_ef = BM25EmbeddingFunction(self.analyzer, num_workers=1)
|
||||
self.bm25_ef.fit(corpus)
|
||||
|
||||
def embed_query(self, text: str) -> Dict[int, float]:
|
||||
return self._sparse_to_dict(self.bm25_ef.encode_queries([text]))
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[Dict[int, float]]:
|
||||
sparse_arrays = self.bm25_ef.encode_documents(texts)
|
||||
return [self._sparse_to_dict(sparse_array) for sparse_array in sparse_arrays]
|
||||
|
||||
def _sparse_to_dict(self, sparse_array: csr_array) -> Dict[int, float]:
|
||||
row_indices, col_indices = sparse_array.nonzero()
|
||||
non_zero_values = sparse_array.data
|
||||
result_dict = {}
|
||||
for col_index, value in zip(col_indices, non_zero_values):
|
||||
result_dict[col_index] = value
|
||||
return result_dict
|
@ -0,0 +1,7 @@
|
||||
from langchain_milvus.vectorstores.milvus import Milvus
|
||||
from langchain_milvus.vectorstores.zilliz import Zilliz
|
||||
|
||||
__all__ = [
|
||||
"Milvus",
|
||||
"Zilliz",
|
||||
]
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,196 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from langchain_milvus.vectorstores.milvus import Milvus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Zilliz(Milvus):
|
||||
"""`Zilliz` vector store.
|
||||
|
||||
You need to have `pymilvus` installed and a
|
||||
running Zilliz database.
|
||||
|
||||
See the following documentation for how to run a Zilliz instance:
|
||||
https://docs.zilliz.com/docs/create-cluster
|
||||
|
||||
|
||||
IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA.
|
||||
|
||||
Args:
|
||||
embedding_function (Embeddings): Function used to embed the text.
|
||||
collection_name (str): Which Zilliz collection to use. Defaults to
|
||||
"LangChainCollection".
|
||||
connection_args (Optional[dict[str, any]]): The connection args used for
|
||||
this class comes in the form of a dict.
|
||||
consistency_level (str): The consistency level to use for a collection.
|
||||
Defaults to "Session".
|
||||
index_params (Optional[dict]): Which index params to use. Defaults to
|
||||
HNSW/AUTOINDEX depending on service.
|
||||
search_params (Optional[dict]): Which search params to use. Defaults to
|
||||
default of index.
|
||||
drop_old (Optional[bool]): Whether to drop the current collection. Defaults
|
||||
to False.
|
||||
auto_id (bool): Whether to enable auto id for primary key. Defaults to False.
|
||||
If False, you needs to provide text ids (string less than 65535 bytes).
|
||||
If True, Milvus will generate unique integers as primary keys.
|
||||
|
||||
The connection args used for this class comes in the form of a dict,
|
||||
here are a few of the options:
|
||||
address (str): The actual address of Zilliz
|
||||
instance. Example address: "localhost:19530"
|
||||
uri (str): The uri of Zilliz instance. Example uri:
|
||||
"https://in03-ba4234asae.api.gcp-us-west1.zillizcloud.com",
|
||||
host (str): The host of Zilliz instance. Default at "localhost",
|
||||
PyMilvus will fill in the default host if only port is provided.
|
||||
port (str/int): The port of Zilliz instance. Default at 19530, PyMilvus
|
||||
will fill in the default port if only host is provided.
|
||||
user (str): Use which user to connect to Zilliz instance. If user and
|
||||
password are provided, we will add related header in every RPC call.
|
||||
password (str): Required when user is provided. The password
|
||||
corresponding to the user.
|
||||
token (str): API key, for serverless clusters which can be used as
|
||||
replacements for user and password.
|
||||
secure (bool): Default is false. If set to true, tls will be enabled.
|
||||
client_key_path (str): If use tls two-way authentication, need to
|
||||
write the client.key path.
|
||||
client_pem_path (str): If use tls two-way authentication, need to
|
||||
write the client.pem path.
|
||||
ca_pem_path (str): If use tls two-way authentication, need to write
|
||||
the ca.pem path.
|
||||
server_pem_path (str): If use tls one-way authentication, need to
|
||||
write the server.pem path.
|
||||
server_name (str): If use tls, need to write the common name.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.vectorstores import Zilliz
|
||||
from langchain_community.embeddings import OpenAIEmbeddings
|
||||
|
||||
embedding = OpenAIEmbeddings()
|
||||
# Connect to a Zilliz instance
|
||||
milvus_store = Milvus(
|
||||
embedding_function = embedding,
|
||||
collection_name = "LangChainCollection",
|
||||
connection_args = {
|
||||
"uri": "https://in03-ba4234asae.api.gcp-us-west1.zillizcloud.com",
|
||||
"user": "temp",
|
||||
"password": "temp",
|
||||
"token": "temp", # API key as replacements for user and password
|
||||
"secure": True
|
||||
}
|
||||
drop_old: True,
|
||||
)
|
||||
|
||||
Raises:
|
||||
ValueError: If the pymilvus python package is not installed.
|
||||
"""
|
||||
|
||||
def _create_index(self) -> None:
|
||||
"""Create a index on the collection"""
|
||||
from pymilvus import Collection, MilvusException
|
||||
|
||||
if isinstance(self.col, Collection) and self._get_index() is None:
|
||||
try:
|
||||
# If no index params, use a default AutoIndex based one
|
||||
if self.index_params is None:
|
||||
self.index_params = {
|
||||
"metric_type": "L2",
|
||||
"index_type": "AUTOINDEX",
|
||||
"params": {},
|
||||
}
|
||||
|
||||
try:
|
||||
self.col.create_index(
|
||||
self._vector_field,
|
||||
index_params=self.index_params,
|
||||
using=self.alias,
|
||||
)
|
||||
|
||||
# If default did not work, most likely Milvus self-hosted
|
||||
except MilvusException:
|
||||
# Use HNSW based index
|
||||
self.index_params = {
|
||||
"metric_type": "L2",
|
||||
"index_type": "HNSW",
|
||||
"params": {"M": 8, "efConstruction": 64},
|
||||
}
|
||||
self.col.create_index(
|
||||
self._vector_field,
|
||||
index_params=self.index_params,
|
||||
using=self.alias,
|
||||
)
|
||||
logger.debug(
|
||||
"Successfully created an index on collection: %s",
|
||||
self.collection_name,
|
||||
)
|
||||
|
||||
except MilvusException as e:
|
||||
logger.error(
|
||||
"Failed to create an index on collection: %s", self.collection_name
|
||||
)
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
collection_name: str = "LangChainCollection",
|
||||
connection_args: Optional[Dict[str, Any]] = None,
|
||||
consistency_level: str = "Session",
|
||||
index_params: Optional[dict] = None,
|
||||
search_params: Optional[dict] = None,
|
||||
drop_old: bool = False,
|
||||
*,
|
||||
ids: Optional[List[str]] = None,
|
||||
auto_id: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Zilliz:
|
||||
"""Create a Zilliz collection, indexes it with HNSW, and insert data.
|
||||
|
||||
Args:
|
||||
texts (List[str]): Text data.
|
||||
embedding (Embeddings): Embedding function.
|
||||
metadatas (Optional[List[dict]]): Metadata for each text if it exists.
|
||||
Defaults to None.
|
||||
collection_name (str, optional): Collection name to use. Defaults to
|
||||
"LangChainCollection".
|
||||
connection_args (dict[str, Any], optional): Connection args to use. Defaults
|
||||
to DEFAULT_MILVUS_CONNECTION.
|
||||
consistency_level (str, optional): Which consistency level to use. Defaults
|
||||
to "Session".
|
||||
index_params (Optional[dict], optional): Which index_params to use.
|
||||
Defaults to None.
|
||||
search_params (Optional[dict], optional): Which search params to use.
|
||||
Defaults to None.
|
||||
drop_old (Optional[bool], optional): Whether to drop the collection with
|
||||
that name if it exists. Defaults to False.
|
||||
ids (Optional[List[str]]): List of text ids.
|
||||
auto_id (bool): Whether to enable auto id for primary key. Defaults to
|
||||
False. If False, you needs to provide text ids (string less than 65535
|
||||
bytes). If True, Milvus will generate unique integers as primary keys.
|
||||
|
||||
Returns:
|
||||
Zilliz: Zilliz Vector Store
|
||||
"""
|
||||
vector_db = cls(
|
||||
embedding_function=embedding,
|
||||
collection_name=collection_name,
|
||||
connection_args=connection_args or {},
|
||||
consistency_level=consistency_level,
|
||||
index_params=index_params,
|
||||
search_params=search_params,
|
||||
drop_old=drop_old,
|
||||
auto_id=auto_id,
|
||||
**kwargs,
|
||||
)
|
||||
vector_db.add_texts(texts=texts, metadatas=metadatas, ids=ids)
|
||||
return vector_db
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,101 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-milvus"
|
||||
version = "0.1.0"
|
||||
description = "An integration package connecting Milvus and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/langchain-ai/langchain"
|
||||
license = "MIT"
|
||||
|
||||
[tool.poetry.urls]
|
||||
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/milvus"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = ">=0.0.12"
|
||||
pymilvus = "^2.4.2"
|
||||
scipy = "^1.7"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
freezegun = "^1.2.2"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.1.5"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^0.991"
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
types-requests = "^2"
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
|
||||
[tool.ruff]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"T201", # print
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["pymilvus"]
|
||||
ignore_missing_imports = "True"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [
|
||||
"tests/*",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# --strict-markers will raise errors on unknown marks.
|
||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||
#
|
||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||
# --strict-config any warnings encountered while parsing the `pytest`
|
||||
# section of the configuration file raise errors.
|
||||
#
|
||||
# https://github.com/tophat/syrupy
|
||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library",
|
||||
"asyncio: mark tests as requiring asyncio",
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
@ -0,0 +1,17 @@
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = sys.argv[1:]
|
||||
has_failure = False
|
||||
for file in files:
|
||||
try:
|
||||
SourceFileLoader("x", file).load_module()
|
||||
except Exception:
|
||||
has_faillure = True
|
||||
print(file) # noqa: T201
|
||||
traceback.print_exc()
|
||||
print() # noqa: T201
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||
# in tracked files within a Git repository.
|
||||
#
|
||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||
|
||||
# Check if a path argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Usage: $0 /path/to/repository"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
repository_path="$1"
|
||||
|
||||
# Search for lines matching the pattern within the specified repository
|
||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||
|
||||
# Check if any matching lines were found
|
||||
if [ -n "$result" ]; then
|
||||
echo "ERROR: The following lines need to be updated:"
|
||||
echo "$result"
|
||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||
echo "For example, replace 'from pydantic import BaseModel'"
|
||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||
exit 1
|
||||
fi
|
@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Initialize a variable to keep track of errors
|
||||
errors=0
|
||||
|
||||
# make sure not importing from langchain or langchain_experimental
|
||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
@ -0,0 +1,99 @@
|
||||
import random
|
||||
import unittest
|
||||
|
||||
from pymilvus import (
|
||||
Collection,
|
||||
CollectionSchema,
|
||||
DataType,
|
||||
FieldSchema,
|
||||
WeightedRanker,
|
||||
connections,
|
||||
)
|
||||
|
||||
from langchain_milvus.retrievers import MilvusCollectionHybridSearchRetriever
|
||||
from tests.integration_tests.utils import FakeEmbeddings
|
||||
|
||||
#
|
||||
# To run this test properly, please start a Milvus server with the following command:
|
||||
#
|
||||
# ```shell
|
||||
# wget https://raw.githubusercontent.com/milvus-io/milvus/master/scripts/standalone_embed.sh
|
||||
# bash standalone_embed.sh start
|
||||
# ```
|
||||
#
|
||||
# Here is the reference:
|
||||
# https://milvus.io/docs/install_standalone-docker.md
|
||||
#
|
||||
|
||||
|
||||
class TestMilvusHybridSearch(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.connection_uri = (
|
||||
"http://localhost:19530" # Replace with your Milvus server IP
|
||||
)
|
||||
self.insert_data_using_orm()
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.collection.drop()
|
||||
|
||||
def insert_data_using_orm(self) -> None:
|
||||
connections.connect(uri=self.connection_uri)
|
||||
dim = len(FakeEmbeddings().embed_query("foo"))
|
||||
fields = [
|
||||
FieldSchema(name="film_id", dtype=DataType.INT64, is_primary=True),
|
||||
FieldSchema(
|
||||
name="filmVector", dtype=DataType.FLOAT_VECTOR, dim=dim
|
||||
), # Vector field for film vectors
|
||||
FieldSchema(
|
||||
name="posterVector", dtype=DataType.FLOAT_VECTOR, dim=dim
|
||||
), # Vector field for poster vectors
|
||||
FieldSchema(
|
||||
name="film_description", dtype=DataType.VARCHAR, max_length=65_535
|
||||
),
|
||||
]
|
||||
|
||||
schema = CollectionSchema(fields=fields, enable_dynamic_field=False)
|
||||
|
||||
self.collection = Collection(name="film_information", schema=schema)
|
||||
index_params = {
|
||||
"metric_type": "L2",
|
||||
"index_type": "IVF_FLAT",
|
||||
"params": {"nlist": 128},
|
||||
}
|
||||
|
||||
self.collection.create_index("filmVector", index_params)
|
||||
self.collection.create_index("posterVector", index_params)
|
||||
|
||||
entities = []
|
||||
|
||||
for _ in range(1000):
|
||||
# generate random values for each field in the schema
|
||||
film_id = random.randint(1, 1000)
|
||||
film_vector = [random.random() for _ in range(dim)]
|
||||
poster_vector = [random.random() for _ in range(dim)]
|
||||
|
||||
# creat a dictionary for each entity
|
||||
entity = {
|
||||
"film_id": film_id,
|
||||
"filmVector": film_vector,
|
||||
"posterVector": poster_vector,
|
||||
"film_description": "test_description",
|
||||
}
|
||||
|
||||
# add the entity to the list
|
||||
entities.append(entity)
|
||||
|
||||
self.collection.insert(entities)
|
||||
self.collection.load()
|
||||
|
||||
def test_retriever(self) -> None:
|
||||
retriever = MilvusCollectionHybridSearchRetriever(
|
||||
collection=self.collection,
|
||||
rerank=WeightedRanker(0.5, 0.5),
|
||||
anns_fields=["filmVector", "posterVector"],
|
||||
field_embeddings=[FakeEmbeddings(), FakeEmbeddings()],
|
||||
top_k=5,
|
||||
text_field="film_description",
|
||||
)
|
||||
res_documents = retriever.invoke("foo")
|
||||
assert len(res_documents) == 5
|
@ -0,0 +1,7 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
@ -0,0 +1,40 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
fake_texts = ["foo", "bar", "baz"]
|
||||
|
||||
|
||||
class FakeEmbeddings(Embeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Return simple embeddings.
|
||||
Embeddings encode each text as its index."""
|
||||
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Return constant query embeddings.
|
||||
Embeddings are identical to embed_documents(texts)[0].
|
||||
Distance to each text will be that text's index,
|
||||
as it was passed to embed_documents."""
|
||||
return [float(1.0)] * 9 + [float(0.0)]
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
return self.embed_query(text)
|
||||
|
||||
|
||||
def assert_docs_equal_without_pk(
|
||||
docs1: List[Document], docs2: List[Document], pk_field: str = "pk"
|
||||
) -> None:
|
||||
"""Assert two lists of Documents are equal, ignoring the primary key field."""
|
||||
assert len(docs1) == len(docs2)
|
||||
for doc1, doc2 in zip(docs1, docs2):
|
||||
assert doc1.page_content == doc2.page_content
|
||||
doc1.metadata.pop(pk_field, None)
|
||||
doc2.metadata.pop(pk_field, None)
|
||||
assert doc1.metadata == doc2.metadata
|
@ -0,0 +1,183 @@
|
||||
"""Test Milvus functionality."""
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_milvus.vectorstores import Milvus
|
||||
from tests.integration_tests.utils import (
|
||||
FakeEmbeddings,
|
||||
assert_docs_equal_without_pk,
|
||||
fake_texts,
|
||||
)
|
||||
|
||||
#
|
||||
# To run this test properly, please start a Milvus server with the following command:
|
||||
#
|
||||
# ```shell
|
||||
# wget https://raw.githubusercontent.com/milvus-io/milvus/master/scripts/standalone_embed.sh
|
||||
# bash standalone_embed.sh start
|
||||
# ```
|
||||
#
|
||||
# Here is the reference:
|
||||
# https://milvus.io/docs/install_standalone-docker.md
|
||||
#
|
||||
|
||||
|
||||
def _milvus_from_texts(
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
drop: bool = True,
|
||||
) -> Milvus:
|
||||
return Milvus.from_texts(
|
||||
fake_texts,
|
||||
FakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
connection_args={"uri": "http://127.0.0.1:19530"},
|
||||
drop_old=drop,
|
||||
)
|
||||
|
||||
|
||||
def _get_pks(expr: str, docsearch: Milvus) -> List[Any]:
|
||||
return docsearch.get_pks(expr) # type: ignore[return-value]
|
||||
|
||||
|
||||
def test_milvus() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
docsearch = _milvus_from_texts()
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert_docs_equal_without_pk(output, [Document(page_content="foo")])
|
||||
|
||||
|
||||
def test_milvus_with_metadata() -> None:
|
||||
"""Test with metadata"""
|
||||
docsearch = _milvus_from_texts(metadatas=[{"label": "test"}] * len(fake_texts))
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert_docs_equal_without_pk(
|
||||
output, [Document(page_content="foo", metadata={"label": "test"})]
|
||||
)
|
||||
|
||||
|
||||
def test_milvus_with_id() -> None:
|
||||
"""Test with ids"""
|
||||
ids = ["id_" + str(i) for i in range(len(fake_texts))]
|
||||
docsearch = _milvus_from_texts(ids=ids)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert_docs_equal_without_pk(output, [Document(page_content="foo")])
|
||||
|
||||
output = docsearch.delete(ids=ids)
|
||||
assert output.delete_count == len(fake_texts) # type: ignore[attr-defined]
|
||||
|
||||
try:
|
||||
ids = ["dup_id" for _ in fake_texts]
|
||||
_milvus_from_texts(ids=ids)
|
||||
except Exception as e:
|
||||
assert isinstance(e, AssertionError)
|
||||
|
||||
|
||||
def test_milvus_with_score() -> None:
|
||||
"""Test end to end construction and search with scores and IDs."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _milvus_from_texts(metadatas=metadatas)
|
||||
output = docsearch.similarity_search_with_score("foo", k=3)
|
||||
docs = [o[0] for o in output]
|
||||
scores = [o[1] for o in output]
|
||||
assert_docs_equal_without_pk(
|
||||
docs,
|
||||
[
|
||||
Document(page_content="foo", metadata={"page": 0}),
|
||||
Document(page_content="bar", metadata={"page": 1}),
|
||||
Document(page_content="baz", metadata={"page": 2}),
|
||||
],
|
||||
)
|
||||
assert scores[0] < scores[1] < scores[2]
|
||||
|
||||
|
||||
def test_milvus_max_marginal_relevance_search() -> None:
|
||||
"""Test end to end construction and MRR search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _milvus_from_texts(metadatas=metadatas)
|
||||
output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3)
|
||||
assert_docs_equal_without_pk(
|
||||
output,
|
||||
[
|
||||
Document(page_content="foo", metadata={"page": 0}),
|
||||
Document(page_content="baz", metadata={"page": 2}),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_milvus_add_extra() -> None:
|
||||
"""Test end to end construction and MRR search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _milvus_from_texts(metadatas=metadatas)
|
||||
|
||||
docsearch.add_texts(texts, metadatas)
|
||||
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 6
|
||||
|
||||
|
||||
def test_milvus_no_drop() -> None:
|
||||
"""Test end to end construction and MRR search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _milvus_from_texts(metadatas=metadatas)
|
||||
del docsearch
|
||||
|
||||
docsearch = _milvus_from_texts(metadatas=metadatas, drop=False)
|
||||
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 6
|
||||
|
||||
|
||||
def test_milvus_get_pks() -> None:
|
||||
"""Test end to end construction and get pks with expr"""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"id": i} for i in range(len(texts))]
|
||||
docsearch = _milvus_from_texts(metadatas=metadatas)
|
||||
expr = "id in [1,2]"
|
||||
output = _get_pks(expr, docsearch)
|
||||
assert len(output) == 2
|
||||
|
||||
|
||||
def test_milvus_delete_entities() -> None:
|
||||
"""Test end to end construction and delete entities"""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"id": i} for i in range(len(texts))]
|
||||
docsearch = _milvus_from_texts(metadatas=metadatas)
|
||||
expr = "id in [1,2]"
|
||||
pks = _get_pks(expr, docsearch)
|
||||
result = docsearch.delete(pks)
|
||||
assert result.delete_count == 2 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_milvus_upsert_entities() -> None:
|
||||
"""Test end to end construction and upsert entities"""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"id": i} for i in range(len(texts))]
|
||||
docsearch = _milvus_from_texts(metadatas=metadatas)
|
||||
expr = "id in [1,2]"
|
||||
pks = _get_pks(expr, docsearch)
|
||||
documents = [
|
||||
Document(page_content="test_1", metadata={"id": 1}),
|
||||
Document(page_content="test_2", metadata={"id": 3}),
|
||||
]
|
||||
ids = docsearch.upsert(pks, documents)
|
||||
assert len(ids) == 2 # type: ignore[arg-type]
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# test_milvus()
|
||||
# test_milvus_with_metadata()
|
||||
# test_milvus_with_id()
|
||||
# test_milvus_with_score()
|
||||
# test_milvus_max_marginal_relevance_search()
|
||||
# test_milvus_add_extra()
|
||||
# test_milvus_no_drop()
|
||||
# test_milvus_get_pks()
|
||||
# test_milvus_delete_entities()
|
||||
# test_milvus_upsert_entities()
|
@ -0,0 +1,108 @@
|
||||
"""Test Zilliz functionality."""
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_milvus.vectorstores import Zilliz
|
||||
from tests.integration_tests.utils import (
|
||||
FakeEmbeddings,
|
||||
assert_docs_equal_without_pk,
|
||||
fake_texts,
|
||||
)
|
||||
|
||||
#
|
||||
# To run this test properly, you need to log in [Zilliz](https://zilliz.com/cloud),
|
||||
# and set `ZILLIZ_CLOUD_URI` and `ZILLIZ_CLOUD_API_KEY` environment variables.
|
||||
#
|
||||
|
||||
|
||||
def _zilliz_from_texts(
|
||||
metadatas: Optional[List[dict]] = None, drop: bool = True
|
||||
) -> Zilliz:
|
||||
return Zilliz.from_texts(
|
||||
fake_texts,
|
||||
FakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
connection_args={
|
||||
"uri": os.getenv("ZILLIZ_CLOUD_URI"),
|
||||
"token": os.getenv("ZILLIZ_CLOUD_API_KEY"),
|
||||
"secure": True,
|
||||
},
|
||||
drop_old=drop,
|
||||
auto_id=True,
|
||||
)
|
||||
|
||||
|
||||
def test_zilliz() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
docsearch = _zilliz_from_texts()
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert_docs_equal_without_pk(output, [Document(page_content="foo")])
|
||||
|
||||
|
||||
def test_zilliz_with_score() -> None:
|
||||
"""Test end to end construction and search with scores and IDs."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _zilliz_from_texts(metadatas=metadatas)
|
||||
output = docsearch.similarity_search_with_score("foo", k=3)
|
||||
docs = [o[0] for o in output]
|
||||
scores = [o[1] for o in output]
|
||||
assert_docs_equal_without_pk(
|
||||
docs,
|
||||
[
|
||||
Document(page_content="foo", metadata={"page": 0}),
|
||||
Document(page_content="bar", metadata={"page": 1}),
|
||||
Document(page_content="baz", metadata={"page": 2}),
|
||||
],
|
||||
)
|
||||
assert scores[0] < scores[1] < scores[2]
|
||||
|
||||
|
||||
def test_zilliz_max_marginal_relevance_search() -> None:
|
||||
"""Test end to end construction and MRR search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _zilliz_from_texts(metadatas=metadatas)
|
||||
output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3)
|
||||
assert_docs_equal_without_pk(
|
||||
output,
|
||||
[
|
||||
Document(page_content="foo", metadata={"page": 0}),
|
||||
Document(page_content="baz", metadata={"page": 2}),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_zilliz_add_extra() -> None:
|
||||
"""Test end to end construction and MRR search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _zilliz_from_texts(metadatas=metadatas)
|
||||
|
||||
docsearch.add_texts(texts, metadatas)
|
||||
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 6
|
||||
|
||||
|
||||
def test_zilliz_no_drop() -> None:
|
||||
"""Test end to end construction and MRR search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _zilliz_from_texts(metadatas=metadatas)
|
||||
del docsearch
|
||||
|
||||
docsearch = _zilliz_from_texts(metadatas=metadatas, drop=False)
|
||||
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 6
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# test_zilliz()
|
||||
# test_zilliz_with_score()
|
||||
# test_zilliz_max_marginal_relevance_search()
|
||||
# test_zilliz_add_extra()
|
||||
# test_zilliz_no_drop()
|
@ -0,0 +1,12 @@
|
||||
from langchain_milvus import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"Milvus",
|
||||
"MilvusCollectionHybridSearchRetriever",
|
||||
"Zilliz",
|
||||
"ZillizCloudPipelineRetriever",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
@ -0,0 +1,17 @@
|
||||
import os
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest.mock import Mock
|
||||
|
||||
from langchain_milvus.vectorstores import Milvus
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test integration milvus initialization."""
|
||||
embedding = Mock()
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
Milvus(
|
||||
embedding_function=embedding,
|
||||
connection_args={
|
||||
"uri": os.path.join(tmp_dir, "milvus.db"),
|
||||
},
|
||||
)
|
@ -0,0 +1 @@
|
||||
__pycache__
|
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
@ -0,0 +1,68 @@
|
||||
# rag-milvus
|
||||
|
||||
This template performs RAG using Milvus and OpenAI.
|
||||
|
||||
## Environment Setup
|
||||
|
||||
Start the milvus server instance, and get the host ip and port.
|
||||
|
||||
Set the `OPENAI_API_KEY` environment variable to access the OpenAI models.
|
||||
|
||||
## Usage
|
||||
|
||||
To use this package, you should first have the LangChain CLI installed:
|
||||
|
||||
```shell
|
||||
pip install -U langchain-cli
|
||||
```
|
||||
|
||||
To create a new LangChain project and install this as the only package, you can do:
|
||||
|
||||
```shell
|
||||
langchain app new my-app --package rag-milvus
|
||||
```
|
||||
|
||||
If you want to add this to an existing project, you can just run:
|
||||
|
||||
```shell
|
||||
langchain app add rag-milvus
|
||||
```
|
||||
|
||||
And add the following code to your `server.py` file:
|
||||
```python
|
||||
from rag_milvus import chain as rag_milvus_chain
|
||||
|
||||
add_routes(app, rag_milvus_chain, path="/rag-milvus")
|
||||
```
|
||||
|
||||
(Optional) Let's now configure LangSmith.
|
||||
LangSmith will help us trace, monitor and debug LangChain applications.
|
||||
You can sign up for LangSmith [here](https://smith.langchain.com/).
|
||||
If you don't have access, you can skip this section
|
||||
|
||||
|
||||
```shell
|
||||
export LANGCHAIN_TRACING_V2=true
|
||||
export LANGCHAIN_API_KEY=<your-api-key>
|
||||
export LANGCHAIN_PROJECT=<your-project> # if not specified, defaults to "default"
|
||||
```
|
||||
|
||||
If you are inside this directory, then you can spin up a LangServe instance directly by:
|
||||
|
||||
```shell
|
||||
langchain serve
|
||||
```
|
||||
|
||||
This will start the FastAPI app with a server is running locally at
|
||||
[http://localhost:8000](http://localhost:8000)
|
||||
|
||||
We can see all templates at [http://127.0.0.1:8000/docs](http://127.0.0.1:8000/docs)
|
||||
We can access the playground at [http://127.0.0.1:8000/rag-milvus/playground](http://127.0.0.1:8000/rag-milvus/playground)
|
||||
|
||||
We can access the template from code with:
|
||||
|
||||
```python
|
||||
from langserve.client import RemoteRunnable
|
||||
|
||||
runnable = RemoteRunnable("http://localhost:8000/rag-milvus")
|
||||
```
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,34 @@
|
||||
[tool.poetry]
|
||||
name = "rag-milvus"
|
||||
version = "0.1.0"
|
||||
description = "RAG using Milvus"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain = "^0.1"
|
||||
langchain-core = "^0.1"
|
||||
langchain-openai = "^0.1"
|
||||
langchain-community = "^0.0.30"
|
||||
pymilvus = "^2.4"
|
||||
scipy = "^1.9"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-cli = ">=0.0.4"
|
||||
fastapi = "^0.104.0"
|
||||
sse-starlette = "^1.6.5"
|
||||
|
||||
[tool.langserve]
|
||||
export_module = "rag_milvus"
|
||||
export_attr = "chain"
|
||||
|
||||
[tool.templates-hub]
|
||||
use-case = "rag"
|
||||
author = "LangChain"
|
||||
integrations = ["OpenAI", "Milvus"]
|
||||
tags = ["vectordbs"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
@ -0,0 +1,3 @@
|
||||
from rag_milvus.chain import chain
|
||||
|
||||
__all__ = ["chain"]
|
@ -0,0 +1,69 @@
|
||||
from langchain_community.vectorstores import Milvus
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
|
||||
# Example for document loading (from url), splitting, and creating vectorstore
|
||||
|
||||
"""
|
||||
# Load
|
||||
from langchain_community.document_loaders import WebBaseLoader
|
||||
|
||||
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
|
||||
data = loader.load()
|
||||
|
||||
# Split
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
|
||||
all_splits = text_splitter.split_documents(data)
|
||||
|
||||
# Add to vectorDB
|
||||
vectorstore = Milvus.from_documents(documents=all_splits,
|
||||
collection_name="rag_milvus",
|
||||
embedding=OpenAIEmbeddings(),
|
||||
drop_old=True,
|
||||
)
|
||||
retriever = vectorstore.as_retriever()
|
||||
"""
|
||||
|
||||
# Embed a single document as a test
|
||||
vectorstore = Milvus.from_texts(
|
||||
["harrison worked at kensho"],
|
||||
collection_name="rag_milvus",
|
||||
embedding=OpenAIEmbeddings(),
|
||||
drop_old=True,
|
||||
connection_args={
|
||||
"uri": "http://127.0.0.1:19530",
|
||||
},
|
||||
)
|
||||
retriever = vectorstore.as_retriever()
|
||||
|
||||
# RAG prompt
|
||||
template = """Answer the question based only on the following context:
|
||||
{context}
|
||||
|
||||
Question: {question}
|
||||
"""
|
||||
prompt = ChatPromptTemplate.from_template(template)
|
||||
|
||||
# LLM
|
||||
model = ChatOpenAI()
|
||||
|
||||
# RAG chain
|
||||
chain = (
|
||||
RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
|
||||
| prompt
|
||||
| model
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
|
||||
# Add typing for input
|
||||
class Question(BaseModel):
|
||||
__root__: str
|
||||
|
||||
|
||||
chain = chain.with_types(input_type=Question)
|
Loading…
Reference in New Issue