multiple: langchain 0.2 in master (#21191)

0.2rc 

migrations

- [x] Move memory
- [x] Move remaining retrievers
- [x] graph_qa chains
- [x] some dependency from evaluation code potentially on math utils
- [x] Move openapi chain from `langchain.chains.api.openapi` to
`langchain_community.chains.openapi`
- [x] Migrate `langchain.chains.ernie_functions` to
`langchain_community.chains.ernie_functions`
- [x] migrate `langchain/chains/llm_requests.py` to
`langchain_community.chains.llm_requests`
- [x] Moving `langchain_community.cross_enoders.base:BaseCrossEncoder`
->
`langchain_community.retrievers.document_compressors.cross_encoder:BaseCrossEncoder`
(namespace not ideal, but it needs to be moved to `langchain` to avoid
circular deps)
- [x] unit tests langchain -- add pytest.mark.community to some unit
tests that will stay in langchain
- [x] unit tests community -- move unit tests that depend on community
to community
- [x] mv integration tests that depend on community to community
- [x] mypy checks

Other todo

- [x] Make deprecation warnings not noisy (need to use warn deprecated
and check that things are implemented properly)
- [x] Update deprecation messages with timeline for code removal (likely
we actually won't be removing things until 0.4 release) -- will give
people more time to transition their code.
- [ ] Add information to deprecation warning to show users how to
migrate their code base using langchain-cli
- [ ] Remove any unnecessary requirements in langchain (e.g., is
SQLALchemy required?)

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
pull/21442/head
Eugene Yurtsev 2 weeks ago committed by GitHub
parent 6b392d6d12
commit f92006de3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -193,7 +193,7 @@ def create_sql_agent(
]
prompt = ChatPromptTemplate.from_messages(messages)
agent = RunnableAgent(
runnable=create_openai_functions_agent(llm, tools, prompt),
runnable=create_openai_functions_agent(llm, tools, prompt), # type: ignore
input_keys_arg=["input"],
return_keys_arg=["output"],
**kwargs,
@ -208,10 +208,10 @@ def create_sql_agent(
]
prompt = ChatPromptTemplate.from_messages(messages)
if agent_type == "openai-tools":
runnable = create_openai_tools_agent(llm, tools, prompt)
runnable = create_openai_tools_agent(llm, tools, prompt) # type: ignore
else:
runnable = create_tool_calling_agent(llm, tools, prompt)
agent = RunnableMultiActionAgent(
runnable = create_tool_calling_agent(llm, tools, prompt) # type: ignore
agent = RunnableMultiActionAgent( # type: ignore[assignment]
runnable=runnable,
input_keys_arg=["input"],
return_keys_arg=["output"],

@ -0,0 +1,17 @@
from langchain.chains.ernie_functions.base import (
convert_to_ernie_function,
create_ernie_fn_chain,
create_ernie_fn_runnable,
create_structured_output_chain,
create_structured_output_runnable,
get_ernie_output_parser,
)
__all__ = [
"convert_to_ernie_function",
"create_structured_output_chain",
"create_ernie_fn_chain",
"create_structured_output_runnable",
"create_ernie_fn_runnable",
"get_ernie_output_parser",
]

@ -0,0 +1,551 @@
"""Methods for creating chains that use Ernie function-calling APIs."""
import inspect
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
from langchain.chains import LLMChain
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import (
BaseGenerationOutputParser,
BaseLLMOutputParser,
BaseOutputParser,
)
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable
from langchain_community.output_parsers.ernie_functions import (
JsonOutputFunctionsParser,
PydanticAttrOutputFunctionsParser,
PydanticOutputFunctionsParser,
)
from langchain_community.utils.ernie_functions import convert_pydantic_to_ernie_function
PYTHON_TO_JSON_TYPES = {
"str": "string",
"int": "number",
"float": "number",
"bool": "boolean",
}
def _get_python_function_name(function: Callable) -> str:
"""Get the name of a Python function."""
return function.__name__
def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function.
Assumes the function docstring follows Google Python style guide.
"""
docstring = inspect.getdoc(function)
if docstring:
docstring_blocks = docstring.split("\n\n")
descriptors = []
args_block = None
past_descriptors = False
for block in docstring_blocks:
if block.startswith("Args:"):
args_block = block
break
elif block.startswith("Returns:") or block.startswith("Example:"):
# Don't break in case Args come after
past_descriptors = True
elif not past_descriptors:
descriptors.append(block)
else:
continue
description = " ".join(descriptors)
else:
description = ""
args_block = None
arg_descriptions = {}
if args_block:
arg = None
for line in args_block.split("\n")[1:]:
if ":" in line:
arg, desc = line.split(":")
arg_descriptions[arg.strip()] = desc.strip()
elif arg:
arg_descriptions[arg.strip()] += " " + line.strip()
return description, arg_descriptions
def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -> dict:
"""Get JsonSchema describing a Python functions arguments.
Assumes all function arguments are of primitive types (int, float, str, bool) or
are subclasses of pydantic.BaseModel.
"""
properties = {}
annotations = inspect.getfullargspec(function).annotations
for arg, arg_type in annotations.items():
if arg == "return":
continue
if isinstance(arg_type, type) and issubclass(arg_type, BaseModel):
# Mypy error:
# "type" has no attribute "schema"
properties[arg] = arg_type.schema() # type: ignore[attr-defined]
elif arg_type.__name__ in PYTHON_TO_JSON_TYPES:
properties[arg] = {"type": PYTHON_TO_JSON_TYPES[arg_type.__name__]}
if arg in arg_descriptions:
if arg not in properties:
properties[arg] = {}
properties[arg]["description"] = arg_descriptions[arg]
return properties
def _get_python_function_required_args(function: Callable) -> List[str]:
"""Get the required arguments for a Python function."""
spec = inspect.getfullargspec(function)
required = spec.args[: -len(spec.defaults)] if spec.defaults else spec.args
required += [k for k in spec.kwonlyargs if k not in (spec.kwonlydefaults or {})]
is_class = type(function) is type
if is_class and required[0] == "self":
required = required[1:]
return required
def convert_python_function_to_ernie_function(
function: Callable,
) -> Dict[str, Any]:
"""Convert a Python function to an Ernie function-calling API compatible dict.
Assumes the Python function has type hints and a docstring with a description. If
the docstring has Google Python style argument descriptions, these will be
included as well.
"""
description, arg_descriptions = _parse_python_function_docstring(function)
return {
"name": _get_python_function_name(function),
"description": description,
"parameters": {
"type": "object",
"properties": _get_python_function_arguments(function, arg_descriptions),
"required": _get_python_function_required_args(function),
},
}
def convert_to_ernie_function(
function: Union[Dict[str, Any], Type[BaseModel], Callable],
) -> Dict[str, Any]:
"""Convert a raw function/class to an Ernie function.
Args:
function: Either a dictionary, a pydantic.BaseModel class, or a Python function.
If a dictionary is passed in, it is assumed to already be a valid Ernie
function.
Returns:
A dict version of the passed in function which is compatible with the
Ernie function-calling API.
"""
if isinstance(function, dict):
return function
elif isinstance(function, type) and issubclass(function, BaseModel):
return cast(Dict, convert_pydantic_to_ernie_function(function))
elif callable(function):
return convert_python_function_to_ernie_function(function)
else:
raise ValueError(
f"Unsupported function type {type(function)}. Functions must be passed in"
f" as Dict, pydantic.BaseModel, or Callable."
)
def get_ernie_output_parser(
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
"""Get the appropriate function output parser given the user functions.
Args:
functions: Sequence where element is a dictionary, a pydantic.BaseModel class,
or a Python function. If a dictionary is passed in, it is assumed to
already be a valid Ernie function.
Returns:
A PydanticOutputFunctionsParser if functions are Pydantic classes, otherwise
a JsonOutputFunctionsParser. If there's only one function and it is
not a Pydantic class, then the output parser will automatically extract
only the function arguments and not the function name.
"""
function_names = [convert_to_ernie_function(f)["name"] for f in functions]
if isinstance(functions[0], type) and issubclass(functions[0], BaseModel):
if len(functions) > 1:
pydantic_schema: Union[Dict, Type[BaseModel]] = {
name: fn for name, fn in zip(function_names, functions)
}
else:
pydantic_schema = functions[0]
output_parser: Union[
BaseOutputParser, BaseGenerationOutputParser
] = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
else:
output_parser = JsonOutputFunctionsParser(args_only=len(functions) <= 1)
return output_parser
def create_ernie_fn_runnable(
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
llm: Runnable,
prompt: BasePromptTemplate,
*,
output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None,
**kwargs: Any,
) -> Runnable:
"""Create a runnable sequence that uses Ernie functions.
Args:
functions: A sequence of either dictionaries, pydantic.BaseModels classes, or
Python functions. If dictionaries are passed in, they are assumed to
already be a valid Ernie functions. If only a single
function is passed in, then it will be enforced that the model use that
function. pydantic.BaseModels and Python functions should have docstrings
describing what the function does. For best results, pydantic.BaseModels
should have descriptions of the parameters and Python functions should have
Google Python style args descriptions in the docstring. Additionally,
Python functions should only use primitive types (str, int, float, bool) or
pydantic.BaseModels for arguments.
llm: Language model to use, assumed to support the Ernie function-calling API.
prompt: BasePromptTemplate to pass to the model.
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
will be inferred from the function types. If pydantic.BaseModels are passed
in, then the OutputParser will try to parse outputs using those. Otherwise
model outputs will simply be parsed as JSON. If multiple functions are
passed in and they are not pydantic.BaseModels, the chain output will
include both the name of the function that was returned and the arguments
to pass to the function.
Returns:
A runnable sequence that will pass in the given functions to the model when run.
Example:
.. code-block:: python
from typing import Optional
from langchain.chains.ernie_functions import create_ernie_fn_chain
from langchain_community.chat_models import ErnieBotChat
from langchain_core.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel, Field
class RecordPerson(BaseModel):
\"\"\"Record some identifying information about a person.\"\"\"
name: str = Field(..., description="The person's name")
age: int = Field(..., description="The person's age")
fav_food: Optional[str] = Field(None, description="The person's favorite food")
class RecordDog(BaseModel):
\"\"\"Record some identifying information about a dog.\"\"\"
name: str = Field(..., description="The dog's name")
color: str = Field(..., description="The dog's color")
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
llm = ErnieBotChat(model_name="ERNIE-Bot-4")
prompt = ChatPromptTemplate.from_messages(
[
("user", "Make calls to the relevant function to record the entities in the following input: {input}"),
("assistant", "OK!"),
("user", "Tip: Make sure to answer in the correct format"),
]
)
chain = create_ernie_fn_runnable([RecordPerson, RecordDog], llm, prompt)
chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
# -> RecordDog(name="Harry", color="brown", fav_food="chicken")
""" # noqa: E501
if not functions:
raise ValueError("Need to pass in at least one function. Received zero.")
ernie_functions = [convert_to_ernie_function(f) for f in functions]
llm_kwargs: Dict[str, Any] = {"functions": ernie_functions, **kwargs}
if len(ernie_functions) == 1:
llm_kwargs["function_call"] = {"name": ernie_functions[0]["name"]}
output_parser = output_parser or get_ernie_output_parser(functions)
return prompt | llm.bind(**llm_kwargs) | output_parser
def create_structured_output_runnable(
output_schema: Union[Dict[str, Any], Type[BaseModel]],
llm: Runnable,
prompt: BasePromptTemplate,
*,
output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None,
**kwargs: Any,
) -> Runnable:
"""Create a runnable that uses an Ernie function to get a structured output.
Args:
output_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary
is passed in, it's assumed to already be a valid JsonSchema.
For best results, pydantic.BaseModels should have docstrings describing what
the schema represents and descriptions for the parameters.
llm: Language model to use, assumed to support the Ernie function-calling API.
prompt: BasePromptTemplate to pass to the model.
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
will be inferred from the function types. If pydantic.BaseModels are passed
in, then the OutputParser will try to parse outputs using those. Otherwise
model outputs will simply be parsed as JSON.
Returns:
A runnable sequence that will pass the given function to the model when run.
Example:
.. code-block:: python
from typing import Optional
from langchain.chains.ernie_functions import create_structured_output_chain
from langchain_community.chat_models import ErnieBotChat
from langchain_core.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel, Field
class Dog(BaseModel):
\"\"\"Identifying information about a dog.\"\"\"
name: str = Field(..., description="The dog's name")
color: str = Field(..., description="The dog's color")
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
llm = ErnieBotChat(model_name="ERNIE-Bot-4")
prompt = ChatPromptTemplate.from_messages(
[
("user", "Use the given format to extract information from the following input: {input}"),
("assistant", "OK!"),
("user", "Tip: Make sure to answer in the correct format"),
]
)
chain = create_structured_output_chain(Dog, llm, prompt)
chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
# -> Dog(name="Harry", color="brown", fav_food="chicken")
""" # noqa: E501
if isinstance(output_schema, dict):
function: Any = {
"name": "output_formatter",
"description": (
"Output formatter. Should always be used to format your response to the"
" user."
),
"parameters": output_schema,
}
else:
class _OutputFormatter(BaseModel):
"""Output formatter. Should always be used to format your response to the user.""" # noqa: E501
output: output_schema # type: ignore
function = _OutputFormatter
output_parser = output_parser or PydanticAttrOutputFunctionsParser(
pydantic_schema=_OutputFormatter, attr_name="output"
)
return create_ernie_fn_runnable(
[function],
llm,
prompt,
output_parser=output_parser,
**kwargs,
)
""" --- Legacy --- """
def create_ernie_fn_chain(
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
llm: BaseLanguageModel,
prompt: BasePromptTemplate,
*,
output_key: str = "function",
output_parser: Optional[BaseLLMOutputParser] = None,
**kwargs: Any,
) -> LLMChain:
"""[Legacy] Create an LLM chain that uses Ernie functions.
Args:
functions: A sequence of either dictionaries, pydantic.BaseModels classes, or
Python functions. If dictionaries are passed in, they are assumed to
already be a valid Ernie functions. If only a single
function is passed in, then it will be enforced that the model use that
function. pydantic.BaseModels and Python functions should have docstrings
describing what the function does. For best results, pydantic.BaseModels
should have descriptions of the parameters and Python functions should have
Google Python style args descriptions in the docstring. Additionally,
Python functions should only use primitive types (str, int, float, bool) or
pydantic.BaseModels for arguments.
llm: Language model to use, assumed to support the Ernie function-calling API.
prompt: BasePromptTemplate to pass to the model.
output_key: The key to use when returning the output in LLMChain.__call__.
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
will be inferred from the function types. If pydantic.BaseModels are passed
in, then the OutputParser will try to parse outputs using those. Otherwise
model outputs will simply be parsed as JSON. If multiple functions are
passed in and they are not pydantic.BaseModels, the chain output will
include both the name of the function that was returned and the arguments
to pass to the function.
Returns:
An LLMChain that will pass in the given functions to the model when run.
Example:
.. code-block:: python
from typing import Optional
from langchain.chains.ernie_functions import create_ernie_fn_chain
from langchain_community.chat_models import ErnieBotChat
from langchain_core.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel, Field
class RecordPerson(BaseModel):
\"\"\"Record some identifying information about a person.\"\"\"
name: str = Field(..., description="The person's name")
age: int = Field(..., description="The person's age")
fav_food: Optional[str] = Field(None, description="The person's favorite food")
class RecordDog(BaseModel):
\"\"\"Record some identifying information about a dog.\"\"\"
name: str = Field(..., description="The dog's name")
color: str = Field(..., description="The dog's color")
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
llm = ErnieBotChat(model_name="ERNIE-Bot-4")
prompt = ChatPromptTemplate.from_messages(
[
("user", "Make calls to the relevant function to record the entities in the following input: {input}"),
("assistant", "OK!"),
("user", "Tip: Make sure to answer in the correct format"),
]
)
chain = create_ernie_fn_chain([RecordPerson, RecordDog], llm, prompt)
chain.run("Harry was a chubby brown beagle who loved chicken")
# -> RecordDog(name="Harry", color="brown", fav_food="chicken")
""" # noqa: E501
if not functions:
raise ValueError("Need to pass in at least one function. Received zero.")
ernie_functions = [convert_to_ernie_function(f) for f in functions]
output_parser = output_parser or get_ernie_output_parser(functions)
llm_kwargs: Dict[str, Any] = {
"functions": ernie_functions,
}
if len(ernie_functions) == 1:
llm_kwargs["function_call"] = {"name": ernie_functions[0]["name"]}
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
output_parser=output_parser,
llm_kwargs=llm_kwargs,
output_key=output_key,
**kwargs,
)
return llm_chain
def create_structured_output_chain(
output_schema: Union[Dict[str, Any], Type[BaseModel]],
llm: BaseLanguageModel,
prompt: BasePromptTemplate,
*,
output_key: str = "function",
output_parser: Optional[BaseLLMOutputParser] = None,
**kwargs: Any,
) -> LLMChain:
"""[Legacy] Create an LLMChain that uses an Ernie function to get a structured output.
Args:
output_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary
is passed in, it's assumed to already be a valid JsonSchema.
For best results, pydantic.BaseModels should have docstrings describing what
the schema represents and descriptions for the parameters.
llm: Language model to use, assumed to support the Ernie function-calling API.
prompt: BasePromptTemplate to pass to the model.
output_key: The key to use when returning the output in LLMChain.__call__.
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
will be inferred from the function types. If pydantic.BaseModels are passed
in, then the OutputParser will try to parse outputs using those. Otherwise
model outputs will simply be parsed as JSON.
Returns:
An LLMChain that will pass the given function to the model.
Example:
.. code-block:: python
from typing import Optional
from langchain.chains.ernie_functions import create_structured_output_chain
from langchain_community.chat_models import ErnieBotChat
from langchain_core.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel, Field
class Dog(BaseModel):
\"\"\"Identifying information about a dog.\"\"\"
name: str = Field(..., description="The dog's name")
color: str = Field(..., description="The dog's color")
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
llm = ErnieBotChat(model_name="ERNIE-Bot-4")
prompt = ChatPromptTemplate.from_messages(
[
("user", "Use the given format to extract information from the following input: {input}"),
("assistant", "OK!"),
("user", "Tip: Make sure to answer in the correct format"),
]
)
chain = create_structured_output_chain(Dog, llm, prompt)
chain.run("Harry was a chubby brown beagle who loved chicken")
# -> Dog(name="Harry", color="brown", fav_food="chicken")
""" # noqa: E501
if isinstance(output_schema, dict):
function: Any = {
"name": "output_formatter",
"description": (
"Output formatter. Should always be used to format your response to the"
" user."
),
"parameters": output_schema,
}
else:
class _OutputFormatter(BaseModel):
"""Output formatter. Should always be used to format your response to the user.""" # noqa: E501
output: output_schema # type: ignore
function = _OutputFormatter
output_parser = output_parser or PydanticAttrOutputFunctionsParser(
pydantic_schema=_OutputFormatter, attr_name="output"
)
return create_ernie_fn_chain(
[function],
llm,
prompt,
output_key=output_key,
output_parser=output_parser,
**kwargs,
)

@ -0,0 +1 @@
"""Question answering over a knowledge graph."""

@ -0,0 +1,241 @@
"""Question answering over a graph."""
from __future__ import annotations
import re
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_community.chains.graph_qa.prompts import (
AQL_FIX_PROMPT,
AQL_GENERATION_PROMPT,
AQL_QA_PROMPT,
)
from langchain_community.graphs.arangodb_graph import ArangoGraph
class ArangoGraphQAChain(Chain):
"""Chain for question-answering against a graph by generating AQL statements.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
graph: ArangoGraph = Field(exclude=True)
aql_generation_chain: LLMChain
aql_fix_chain: LLMChain
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
# Specifies the maximum number of AQL Query Results to return
top_k: int = 10
# Specifies the set of AQL Query Examples that promote few-shot-learning
aql_examples: str = ""
# Specify whether to return the AQL Query in the output dictionary
return_aql_query: bool = False
# Specify whether to return the AQL JSON Result in the output dictionary
return_aql_result: bool = False
# Specify the maximum amount of AQL Generation attempts that should be made
max_aql_generation_attempts: int = 3
@property
def input_keys(self) -> List[str]:
return [self.input_key]
@property
def output_keys(self) -> List[str]:
return [self.output_key]
@property
def _chain_type(self) -> str:
return "graph_aql_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
qa_prompt: BasePromptTemplate = AQL_QA_PROMPT,
aql_generation_prompt: BasePromptTemplate = AQL_GENERATION_PROMPT,
aql_fix_prompt: BasePromptTemplate = AQL_FIX_PROMPT,
**kwargs: Any,
) -> ArangoGraphQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
aql_generation_chain = LLMChain(llm=llm, prompt=aql_generation_prompt)
aql_fix_chain = LLMChain(llm=llm, prompt=aql_fix_prompt)
return cls(
qa_chain=qa_chain,
aql_generation_chain=aql_generation_chain,
aql_fix_chain=aql_fix_chain,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""
Generate an AQL statement from user input, use it retrieve a response
from an ArangoDB Database instance, and respond to the user input
in natural language.
Users can modify the following ArangoGraphQAChain Class Variables:
:var top_k: The maximum number of AQL Query Results to return
:type top_k: int
:var aql_examples: A set of AQL Query Examples that are passed to
the AQL Generation Prompt Template to promote few-shot-learning.
Defaults to an empty string.
:type aql_examples: str
:var return_aql_query: Whether to return the AQL Query in the
output dictionary. Defaults to False.
:type return_aql_query: bool
:var return_aql_result: Whether to return the AQL Query in the
output dictionary. Defaults to False
:type return_aql_result: bool
:var max_aql_generation_attempts: The maximum amount of AQL
Generation attempts to be made prior to raising the last
AQL Query Execution Error. Defaults to 3.
:type max_aql_generation_attempts: int
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
user_input = inputs[self.input_key]
#########################
# Generate AQL Query #
aql_generation_output = self.aql_generation_chain.run(
{
"adb_schema": self.graph.schema,
"aql_examples": self.aql_examples,
"user_input": user_input,
},
callbacks=callbacks,
)
#########################
aql_query = ""
aql_error = ""
aql_result = None
aql_generation_attempt = 1
while (
aql_result is None
and aql_generation_attempt < self.max_aql_generation_attempts + 1
):
#####################
# Extract AQL Query #
pattern = r"```(?i:aql)?(.*?)```"
matches = re.findall(pattern, aql_generation_output, re.DOTALL)
if not matches:
_run_manager.on_text(
"Invalid Response: ", end="\n", verbose=self.verbose
)
_run_manager.on_text(
aql_generation_output, color="red", end="\n", verbose=self.verbose
)
raise ValueError(f"Response is Invalid: {aql_generation_output}")
aql_query = matches[0]
#####################
_run_manager.on_text(
f"AQL Query ({aql_generation_attempt}):", verbose=self.verbose
)
_run_manager.on_text(
aql_query, color="green", end="\n", verbose=self.verbose
)
#####################
# Execute AQL Query #
from arango import AQLQueryExecuteError
try:
aql_result = self.graph.query(aql_query, self.top_k)
except AQLQueryExecuteError as e:
aql_error = e.error_message
_run_manager.on_text(
"AQL Query Execution Error: ", end="\n", verbose=self.verbose
)
_run_manager.on_text(
aql_error, color="yellow", end="\n\n", verbose=self.verbose
)
########################
# Retry AQL Generation #
aql_generation_output = self.aql_fix_chain.run(
{
"adb_schema": self.graph.schema,
"aql_query": aql_query,
"aql_error": aql_error,
},
callbacks=callbacks,
)
########################
#####################
aql_generation_attempt += 1
if aql_result is None:
m = f"""
Maximum amount of AQL Query Generation attempts reached.
Unable to execute the AQL Query due to the following error:
{aql_error}
"""
raise ValueError(m)
_run_manager.on_text("AQL Result:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(aql_result), color="green", end="\n", verbose=self.verbose
)
########################
# Interpret AQL Result #
result = self.qa_chain(
{
"adb_schema": self.graph.schema,
"user_input": user_input,
"aql_query": aql_query,
"aql_result": aql_result,
},
callbacks=callbacks,
)
########################
# Return results #
result = {self.output_key: result[self.qa_chain.output_key]}
if self.return_aql_query:
result["aql_query"] = aql_query
if self.return_aql_result:
result["aql_result"] = aql_result
return result

@ -0,0 +1,103 @@
"""Question answering over a graph."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_community.chains.graph_qa.prompts import (
ENTITY_EXTRACTION_PROMPT,
GRAPH_QA_PROMPT,
)
from langchain_community.graphs.networkx_graph import NetworkxEntityGraph, get_entities
class GraphQAChain(Chain):
"""Chain for question-answering against a graph.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
graph: NetworkxEntityGraph = Field(exclude=True)
entity_extraction_chain: LLMChain
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
qa_prompt: BasePromptTemplate = GRAPH_QA_PROMPT,
entity_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT,
**kwargs: Any,
) -> GraphQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
entity_chain = LLMChain(llm=llm, prompt=entity_prompt)
return cls(
qa_chain=qa_chain,
entity_extraction_chain=entity_chain,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Extract entities, look up info and answer question."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
entity_string = self.entity_extraction_chain.run(question)
_run_manager.on_text("Entities Extracted:", end="\n", verbose=self.verbose)
_run_manager.on_text(
entity_string, color="green", end="\n", verbose=self.verbose
)
entities = get_entities(entity_string)
context = ""
all_triplets = []
for entity in entities:
all_triplets.extend(self.graph.get_entity_knowledge(entity))
context = "\n".join(all_triplets)
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(context, color="green", end="\n", verbose=self.verbose)
result = self.qa_chain(
{"question": question, "context": context},
callbacks=_run_manager.get_child(),
)
return {self.output_key: result[self.qa_chain.output_key]}

@ -0,0 +1,298 @@
"""Question answering over a graph."""
from __future__ import annotations
import re
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_community.chains.graph_qa.cypher_utils import (
CypherQueryCorrector,
Schema,
)
from langchain_community.chains.graph_qa.prompts import (
CYPHER_GENERATION_PROMPT,
CYPHER_QA_PROMPT,
)
from langchain_community.graphs.graph_store import GraphStore
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
def extract_cypher(text: str) -> str:
"""Extract Cypher code from a text.
Args:
text: Text to extract Cypher code from.
Returns:
Cypher code extracted from the text.
"""
# The pattern to find Cypher code enclosed in triple backticks
pattern = r"```(.*?)```"
# Find all matches in the input text
matches = re.findall(pattern, text, re.DOTALL)
return matches[0] if matches else text
def construct_schema(
structured_schema: Dict[str, Any],
include_types: List[str],
exclude_types: List[str],
) -> str:
"""Filter the schema based on included or excluded types"""
def filter_func(x: str) -> bool:
return x in include_types if include_types else x not in exclude_types
filtered_schema: Dict[str, Any] = {
"node_props": {
k: v
for k, v in structured_schema.get("node_props", {}).items()
if filter_func(k)
},
"rel_props": {
k: v
for k, v in structured_schema.get("rel_props", {}).items()
if filter_func(k)
},
"relationships": [
r
for r in structured_schema.get("relationships", [])
if all(filter_func(r[t]) for t in ["start", "end", "type"])
],
}
# Format node properties
formatted_node_props = []
for label, properties in filtered_schema["node_props"].items():
props_str = ", ".join(
[f"{prop['property']}: {prop['type']}" for prop in properties]
)
formatted_node_props.append(f"{label} {{{props_str}}}")
# Format relationship properties
formatted_rel_props = []
for rel_type, properties in filtered_schema["rel_props"].items():
props_str = ", ".join(
[f"{prop['property']}: {prop['type']}" for prop in properties]
)
formatted_rel_props.append(f"{rel_type} {{{props_str}}}")
# Format relationships
formatted_rels = [
f"(:{el['start']})-[:{el['type']}]->(:{el['end']})"
for el in filtered_schema["relationships"]
]
return "\n".join(
[
"Node properties are the following:",
",".join(formatted_node_props),
"Relationship properties are the following:",
",".join(formatted_rel_props),
"The relationships are the following:",
",".join(formatted_rels),
]
)
class GraphCypherQAChain(Chain):
"""Chain for question-answering against a graph by generating Cypher statements.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
graph: GraphStore = Field(exclude=True)
cypher_generation_chain: LLMChain
qa_chain: LLMChain
graph_schema: str
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
top_k: int = 10
"""Number of results to return from the query"""
return_intermediate_steps: bool = False
"""Whether or not to return the intermediate steps along with the final answer."""
return_direct: bool = False
"""Whether or not to return the result of querying the graph directly."""
cypher_query_corrector: Optional[CypherQueryCorrector] = None
"""Optional cypher validation tool"""
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@property
def _chain_type(self) -> str:
return "graph_cypher_chain"
@classmethod
def from_llm(
cls,
llm: Optional[BaseLanguageModel] = None,
*,
qa_prompt: Optional[BasePromptTemplate] = None,
cypher_prompt: Optional[BasePromptTemplate] = None,
cypher_llm: Optional[BaseLanguageModel] = None,
qa_llm: Optional[BaseLanguageModel] = None,
exclude_types: List[str] = [],
include_types: List[str] = [],
validate_cypher: bool = False,
qa_llm_kwargs: Optional[Dict[str, Any]] = None,
cypher_llm_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> GraphCypherQAChain:
"""Initialize from LLM."""
if not cypher_llm and not llm:
raise ValueError("Either `llm` or `cypher_llm` parameters must be provided")
if not qa_llm and not llm:
raise ValueError("Either `llm` or `qa_llm` parameters must be provided")
if cypher_llm and qa_llm and llm:
raise ValueError(
"You can specify up to two of 'cypher_llm', 'qa_llm'"
", and 'llm', but not all three simultaneously."
)
if cypher_prompt and cypher_llm_kwargs:
raise ValueError(
"Specifying cypher_prompt and cypher_llm_kwargs together is"
" not allowed. Please pass prompt via cypher_llm_kwargs."
)
if qa_prompt and qa_llm_kwargs:
raise ValueError(
"Specifying qa_prompt and qa_llm_kwargs together is"
" not allowed. Please pass prompt via qa_llm_kwargs."
)
use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {}
use_cypher_llm_kwargs = (
cypher_llm_kwargs if cypher_llm_kwargs is not None else {}
)
if "prompt" not in use_qa_llm_kwargs:
use_qa_llm_kwargs["prompt"] = (
qa_prompt if qa_prompt is not None else CYPHER_QA_PROMPT
)
if "prompt" not in use_cypher_llm_kwargs:
use_cypher_llm_kwargs["prompt"] = (
cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT
)
qa_chain = LLMChain(llm=qa_llm or llm, **use_qa_llm_kwargs) # type: ignore[arg-type]
cypher_generation_chain = LLMChain(
llm=cypher_llm or llm, # type: ignore[arg-type]
**use_cypher_llm_kwargs, # type: ignore[arg-type]
)
if exclude_types and include_types:
raise ValueError(
"Either `exclude_types` or `include_types` "
"can be provided, but not both"
)
graph_schema = construct_schema(
kwargs["graph"].get_structured_schema, include_types, exclude_types
)
cypher_query_corrector = None
if validate_cypher:
corrector_schema = [
Schema(el["start"], el["type"], el["end"])
for el in kwargs["graph"].structured_schema.get("relationships")
]
cypher_query_corrector = CypherQueryCorrector(corrector_schema)
return cls(
graph_schema=graph_schema,
qa_chain=qa_chain,
cypher_generation_chain=cypher_generation_chain,
cypher_query_corrector=cypher_query_corrector,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Generate Cypher statement, use it to look up in db and answer question."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
question = inputs[self.input_key]
intermediate_steps: List = []
generated_cypher = self.cypher_generation_chain.run(
{"question": question, "schema": self.graph_schema}, callbacks=callbacks
)
# Extract Cypher code if it is wrapped in backticks
generated_cypher = extract_cypher(generated_cypher)
# Correct Cypher query if enabled
if self.cypher_query_corrector:
generated_cypher = self.cypher_query_corrector(generated_cypher)
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_cypher, color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"query": generated_cypher})
# Retrieve and limit the number of results
# Generated Cypher be null if query corrector identifies invalid schema
if generated_cypher:
context = self.graph.query(generated_cypher)[: self.top_k]
else:
context = []
if self.return_direct:
final_result = context
else:
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"context": context})
result = self.qa_chain(
{"question": question, "context": context},
callbacks=callbacks,
)
final_result = result[self.qa_chain.output_key]
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result

@ -0,0 +1,260 @@
import re
from collections import namedtuple
from typing import Any, Dict, List, Optional, Tuple
Schema = namedtuple("Schema", ["left_node", "relation", "right_node"])
class CypherQueryCorrector:
"""
Used to correct relationship direction in generated Cypher statements.
This code is copied from the winner's submission to the Cypher competition:
https://github.com/sakusaku-rich/cypher-direction-competition
"""
property_pattern = re.compile(r"\{.+?\}")
node_pattern = re.compile(r"\(.+?\)")
path_pattern = re.compile(
r"(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))(<?-)(\[.*?\])?(->?)(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))"
)
node_relation_node_pattern = re.compile(
r"(\()+(?P<left_node>[^()]*?)\)(?P<relation>.*?)\((?P<right_node>[^()]*?)(\))+"
)
relation_type_pattern = re.compile(r":(?P<relation_type>.+?)?(\{.+\})?]")
def __init__(self, schemas: List[Schema]):
"""
Args:
schemas: list of schemas
"""
self.schemas = schemas
def clean_node(self, node: str) -> str:
"""
Args:
node: node in string format
"""
node = re.sub(self.property_pattern, "", node)
node = node.replace("(", "")
node = node.replace(")", "")
node = node.strip()
return node
def detect_node_variables(self, query: str) -> Dict[str, List[str]]:
"""
Args:
query: cypher query
"""
nodes = re.findall(self.node_pattern, query)
nodes = [self.clean_node(node) for node in nodes]
res: Dict[str, Any] = {}
for node in nodes:
parts = node.split(":")
if parts == "":
continue
variable = parts[0]
if variable not in res:
res[variable] = []
res[variable] += parts[1:]
return res
def extract_paths(self, query: str) -> "List[str]":
"""
Args:
query: cypher query
"""
paths = []
idx = 0
while matched := self.path_pattern.findall(query[idx:]):
matched = matched[0]
matched = [
m for i, m in enumerate(matched) if i not in [1, len(matched) - 1]
]
path = "".join(matched)
idx = query.find(path) + len(path) - len(matched[-1])
paths.append(path)
return paths
def judge_direction(self, relation: str) -> str:
"""
Args:
relation: relation in string format
"""
direction = "BIDIRECTIONAL"
if relation[0] == "<":
direction = "INCOMING"
if relation[-1] == ">":
direction = "OUTGOING"
return direction
def extract_node_variable(self, part: str) -> Optional[str]:
"""
Args:
part: node in string format
"""
part = part.lstrip("(").rstrip(")")
idx = part.find(":")
if idx != -1:
part = part[:idx]
return None if part == "" else part
def detect_labels(
self, str_node: str, node_variable_dict: Dict[str, Any]
) -> List[str]:
"""
Args:
str_node: node in string format
node_variable_dict: dictionary of node variables
"""
splitted_node = str_node.split(":")
variable = splitted_node[0]
labels = []
if variable in node_variable_dict:
labels = node_variable_dict[variable]
elif variable == "" and len(splitted_node) > 1:
labels = splitted_node[1:]
return labels
def verify_schema(
self,
from_node_labels: List[str],
relation_types: List[str],
to_node_labels: List[str],
) -> bool:
"""
Args:
from_node_labels: labels of the from node
relation_type: type of the relation
to_node_labels: labels of the to node
"""
valid_schemas = self.schemas
if from_node_labels != []:
from_node_labels = [label.strip("`") for label in from_node_labels]
valid_schemas = [
schema for schema in valid_schemas if schema[0] in from_node_labels
]
if to_node_labels != []:
to_node_labels = [label.strip("`") for label in to_node_labels]
valid_schemas = [
schema for schema in valid_schemas if schema[2] in to_node_labels
]
if relation_types != []:
relation_types = [type.strip("`") for type in relation_types]
valid_schemas = [
schema for schema in valid_schemas if schema[1] in relation_types
]
return valid_schemas != []
def detect_relation_types(self, str_relation: str) -> Tuple[str, List[str]]:
"""
Args:
str_relation: relation in string format
"""
relation_direction = self.judge_direction(str_relation)
relation_type = self.relation_type_pattern.search(str_relation)
if relation_type is None or relation_type.group("relation_type") is None:
return relation_direction, []
relation_types = [
t.strip().strip("!")
for t in relation_type.group("relation_type").split("|")
]
return relation_direction, relation_types
def correct_query(self, query: str) -> str:
"""
Args:
query: cypher query
"""
node_variable_dict = self.detect_node_variables(query)
paths = self.extract_paths(query)
for path in paths:
original_path = path
start_idx = 0
while start_idx < len(path):
match_res = re.match(self.node_relation_node_pattern, path[start_idx:])
if match_res is None:
break
start_idx += match_res.start()
match_dict = match_res.groupdict()
left_node_labels = self.detect_labels(
match_dict["left_node"], node_variable_dict
)
right_node_labels = self.detect_labels(
match_dict["right_node"], node_variable_dict
)
end_idx = (
start_idx
+ 4
+ len(match_dict["left_node"])
+ len(match_dict["relation"])
+ len(match_dict["right_node"])
)
original_partial_path = original_path[start_idx : end_idx + 1]
relation_direction, relation_types = self.detect_relation_types(
match_dict["relation"]
)
if relation_types != [] and "".join(relation_types).find("*") != -1:
start_idx += (
len(match_dict["left_node"]) + len(match_dict["relation"]) + 2
)
continue
if relation_direction == "OUTGOING":
is_legal = self.verify_schema(
left_node_labels, relation_types, right_node_labels
)
if not is_legal:
is_legal = self.verify_schema(
right_node_labels, relation_types, left_node_labels
)
if is_legal:
corrected_relation = "<" + match_dict["relation"][:-1]
corrected_partial_path = original_partial_path.replace(
match_dict["relation"], corrected_relation
)
query = query.replace(
original_partial_path, corrected_partial_path
)
else:
return ""
elif relation_direction == "INCOMING":
is_legal = self.verify_schema(
right_node_labels, relation_types, left_node_labels
)
if not is_legal:
is_legal = self.verify_schema(
left_node_labels, relation_types, right_node_labels
)
if is_legal:
corrected_relation = match_dict["relation"][1:] + ">"
corrected_partial_path = original_partial_path.replace(
match_dict["relation"], corrected_relation
)
query = query.replace(
original_partial_path, corrected_partial_path
)
else:
return ""
else:
is_legal = self.verify_schema(
left_node_labels, relation_types, right_node_labels
)
is_legal |= self.verify_schema(
right_node_labels, relation_types, left_node_labels
)
if not is_legal:
return ""
start_idx += (
len(match_dict["left_node"]) + len(match_dict["relation"]) + 2
)
return query
def __call__(self, query: str) -> str:
"""Correct the query to make it valid. If
Args:
query: cypher query
"""
return self.correct_query(query)

@ -0,0 +1,157 @@
"""Question answering over a graph."""
from __future__ import annotations
import re
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_community.chains.graph_qa.prompts import (
CYPHER_GENERATION_PROMPT,
CYPHER_QA_PROMPT,
)
from langchain_community.graphs import FalkorDBGraph
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
def extract_cypher(text: str) -> str:
"""
Extract Cypher code from a text.
Args:
text: Text to extract Cypher code from.
Returns:
Cypher code extracted from the text.
"""
# The pattern to find Cypher code enclosed in triple backticks
pattern = r"```(.*?)```"
# Find all matches in the input text
matches = re.findall(pattern, text, re.DOTALL)
return matches[0] if matches else text
class FalkorDBQAChain(Chain):
"""Chain for question-answering against a graph by generating Cypher statements.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
graph: FalkorDBGraph = Field(exclude=True)
cypher_generation_chain: LLMChain
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
top_k: int = 10
"""Number of results to return from the query"""
return_intermediate_steps: bool = False
"""Whether or not to return the intermediate steps along with the final answer."""
return_direct: bool = False
"""Whether or not to return the result of querying the graph directly."""
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@property
def _chain_type(self) -> str:
return "graph_cypher_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
cypher_prompt: BasePromptTemplate = CYPHER_GENERATION_PROMPT,
**kwargs: Any,
) -> FalkorDBQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt)
return cls(
qa_chain=qa_chain,
cypher_generation_chain=cypher_generation_chain,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Generate Cypher statement, use it to look up in db and answer question."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
question = inputs[self.input_key]
intermediate_steps: List = []
generated_cypher = self.cypher_generation_chain.run(
{"question": question, "schema": self.graph.schema}, callbacks=callbacks
)
# Extract Cypher code if it is wrapped in backticks
generated_cypher = extract_cypher(generated_cypher)
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_cypher, color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"query": generated_cypher})
# Retrieve and limit the number of results
context = self.graph.query(generated_cypher)[: self.top_k]
if self.return_direct:
final_result = context
else:
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"context": context})
result = self.qa_chain(
{"question": question, "context": context},
callbacks=callbacks,
)
final_result = result[self.qa_chain.output_key]
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result

@ -0,0 +1,221 @@
"""Question answering over a graph."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks.manager import CallbackManager, CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_community.chains.graph_qa.prompts import (
CYPHER_QA_PROMPT,
GRAPHDB_SPARQL_FIX_TEMPLATE,
GREMLIN_GENERATION_PROMPT,
)
from langchain_community.graphs import GremlinGraph
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
def extract_gremlin(text: str) -> str:
"""Extract Gremlin code from a text.
Args:
text: Text to extract Gremlin code from.
Returns:
Gremlin code extracted from the text.
"""
text = text.replace("`", "")
if text.startswith("gremlin"):
text = text[len("gremlin") :]
return text.replace("\n", "")
class GremlinQAChain(Chain):
"""Chain for question-answering against a graph by generating gremlin statements.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
graph: GremlinGraph = Field(exclude=True)
gremlin_generation_chain: LLMChain
qa_chain: LLMChain
gremlin_fix_chain: LLMChain
max_fix_retries: int = 3
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
top_k: int = 100
return_direct: bool = False
return_intermediate_steps: bool = False
@property
def input_keys(self) -> List[str]:
"""Input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
gremlin_fix_prompt: BasePromptTemplate = PromptTemplate(
input_variables=["error_message", "generated_sparql", "schema"],
template=GRAPHDB_SPARQL_FIX_TEMPLATE.replace("SPARQL", "Gremlin").replace(
"in Turtle format", ""
),
),
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
gremlin_prompt: BasePromptTemplate = GREMLIN_GENERATION_PROMPT,
**kwargs: Any,
) -> GremlinQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
gremlin_generation_chain = LLMChain(llm=llm, prompt=gremlin_prompt)
gremlinl_fix_chain = LLMChain(llm=llm, prompt=gremlin_fix_prompt)
return cls(
qa_chain=qa_chain,
gremlin_generation_chain=gremlin_generation_chain,
gremlin_fix_chain=gremlinl_fix_chain,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Generate gremlin statement, use it to look up in db and answer question."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
question = inputs[self.input_key]
intermediate_steps: List = []
chain_response = self.gremlin_generation_chain.invoke(
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
)
generated_gremlin = extract_gremlin(
chain_response[self.gremlin_generation_chain.output_key]
)
_run_manager.on_text("Generated gremlin:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_gremlin, color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"query": generated_gremlin})
if generated_gremlin:
context = self.execute_with_retry(
_run_manager, callbacks, generated_gremlin
)[: self.top_k]
else:
context = []
if self.return_direct:
final_result = context
else:
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"context": context})
result = self.qa_chain.invoke(
{"question": question, "context": context},
callbacks=callbacks,
)
final_result = result[self.qa_chain.output_key]
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result
def execute_query(self, query: str) -> List[Any]:
try:
return self.graph.query(query)
except Exception as e:
if hasattr(e, "status_message"):
raise ValueError(e.status_message)
else:
raise ValueError(str(e))
def execute_with_retry(
self,
_run_manager: CallbackManagerForChainRun,
callbacks: CallbackManager,
generated_gremlin: str,
) -> List[Any]:
try:
return self.execute_query(generated_gremlin)
except Exception as e:
retries = 0
error_message = str(e)
self.log_invalid_query(_run_manager, generated_gremlin, error_message)
while retries < self.max_fix_retries:
try:
fix_chain_result = self.gremlin_fix_chain.invoke(
{
"error_message": error_message,
# we are borrowing template from sparql
"generated_sparql": generated_gremlin,
"schema": self.schema,
},
callbacks=callbacks,
)
fixed_gremlin = fix_chain_result[self.gremlin_fix_chain.output_key]
return self.execute_query(fixed_gremlin)
except Exception as e:
retries += 1
parse_exception = str(e)
self.log_invalid_query(_run_manager, fixed_gremlin, parse_exception)
raise ValueError("The generated Gremlin query is invalid.")
def log_invalid_query(
self,
_run_manager: CallbackManagerForChainRun,
generated_query: str,
error_message: str,
) -> None:
_run_manager.on_text("Invalid Gremlin query: ", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_query, color="red", end="\n", verbose=self.verbose
)
_run_manager.on_text(
"Gremlin Query Parse Error: ", end="\n", verbose=self.verbose
)
_run_manager.on_text(
error_message, color="red", end="\n\n", verbose=self.verbose
)

@ -0,0 +1,106 @@
"""Question answering over a graph."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_community.chains.graph_qa.prompts import (
CYPHER_QA_PROMPT,
GREMLIN_GENERATION_PROMPT,
)
from langchain_community.graphs.hugegraph import HugeGraph
class HugeGraphQAChain(Chain):
"""Chain for question-answering against a graph by generating gremlin statements.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
graph: HugeGraph = Field(exclude=True)
gremlin_generation_chain: LLMChain
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
gremlin_prompt: BasePromptTemplate = GREMLIN_GENERATION_PROMPT,
**kwargs: Any,
) -> HugeGraphQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
gremlin_generation_chain = LLMChain(llm=llm, prompt=gremlin_prompt)
return cls(
qa_chain=qa_chain,
gremlin_generation_chain=gremlin_generation_chain,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Generate gremlin statement, use it to look up in db and answer question."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
question = inputs[self.input_key]
generated_gremlin = self.gremlin_generation_chain.run(
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
)
_run_manager.on_text("Generated gremlin:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_gremlin, color="green", end="\n", verbose=self.verbose
)
context = self.graph.query(generated_gremlin)
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
result = self.qa_chain(
{"question": question, "context": context},
callbacks=callbacks,
)
return {self.output_key: result[self.qa_chain.output_key]}

@ -0,0 +1,143 @@
"""Question answering over a graph."""
from __future__ import annotations
import re
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_community.chains.graph_qa.prompts import (
CYPHER_QA_PROMPT,
KUZU_GENERATION_PROMPT,
)
from langchain_community.graphs.kuzu_graph import KuzuGraph
def remove_prefix(text: str, prefix: str) -> str:
"""Remove a prefix from a text.
Args:
text: Text to remove the prefix from.
prefix: Prefix to remove from the text.
Returns:
Text with the prefix removed.
"""
if text.startswith(prefix):
return text[len(prefix) :]
return text
def extract_cypher(text: str) -> str:
"""Extract Cypher code from a text.
Args:
text: Text to extract Cypher code from.
Returns:
Cypher code extracted from the text.
"""
# The pattern to find Cypher code enclosed in triple backticks
pattern = r"```(.*?)```"
# Find all matches in the input text
matches = re.findall(pattern, text, re.DOTALL)
return matches[0] if matches else text
class KuzuQAChain(Chain):
"""Question-answering against a graph by generating Cypher statements for Kùzu.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
graph: KuzuGraph = Field(exclude=True)
cypher_generation_chain: LLMChain
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
cypher_prompt: BasePromptTemplate = KUZU_GENERATION_PROMPT,
**kwargs: Any,
) -> KuzuQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt)
return cls(
qa_chain=qa_chain,
cypher_generation_chain=cypher_generation_chain,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Generate Cypher statement, use it to look up in db and answer question."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
question = inputs[self.input_key]
generated_cypher = self.cypher_generation_chain.run(
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
)
# Extract Cypher code if it is wrapped in triple backticks
# with the language marker "cypher"
generated_cypher = remove_prefix(extract_cypher(generated_cypher), "cypher")
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_cypher, color="green", end="\n", verbose=self.verbose
)
context = self.graph.query(generated_cypher)
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
result = self.qa_chain(
{"question": question, "context": context},
callbacks=callbacks,
)
return {self.output_key: result[self.qa_chain.output_key]}

@ -0,0 +1,106 @@
"""Question answering over a graph."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_community.chains.graph_qa.prompts import (
CYPHER_QA_PROMPT,
NGQL_GENERATION_PROMPT,
)
from langchain_community.graphs.nebula_graph import NebulaGraph
class NebulaGraphQAChain(Chain):
"""Chain for question-answering against a graph by generating nGQL statements.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
graph: NebulaGraph = Field(exclude=True)
ngql_generation_chain: LLMChain
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
ngql_prompt: BasePromptTemplate = NGQL_GENERATION_PROMPT,
**kwargs: Any,
) -> NebulaGraphQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
ngql_generation_chain = LLMChain(llm=llm, prompt=ngql_prompt)
return cls(
qa_chain=qa_chain,
ngql_generation_chain=ngql_generation_chain,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Generate nGQL statement, use it to look up in db and answer question."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
question = inputs[self.input_key]
generated_ngql = self.ngql_generation_chain.run(
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
)
_run_manager.on_text("Generated nGQL:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_ngql, color="green", end="\n", verbose=self.verbose
)
context = self.graph.query(generated_ngql)
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
result = self.qa_chain(
{"question": question, "context": context},
callbacks=callbacks,
)
return {self.output_key: result[self.qa_chain.output_key]}

@ -0,0 +1,217 @@
from __future__ import annotations
import re
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.prompt_selector import ConditionalPromptSelector
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_community.chains.graph_qa.prompts import (
CYPHER_QA_PROMPT,
NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT,
)
from langchain_community.graphs import BaseNeptuneGraph
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
def trim_query(query: str) -> str:
"""Trim the query to only include Cypher keywords."""
keywords = (
"CALL",
"CREATE",
"DELETE",
"DETACH",
"LIMIT",
"MATCH",
"MERGE",
"OPTIONAL",
"ORDER",
"REMOVE",
"RETURN",
"SET",
"SKIP",
"UNWIND",
"WITH",
"WHERE",
"//",
)
lines = query.split("\n")
new_query = ""
for line in lines:
if line.strip().upper().startswith(keywords):
new_query += line + "\n"
return new_query
def extract_cypher(text: str) -> str:
"""Extract Cypher code from text using Regex."""
# The pattern to find Cypher code enclosed in triple backticks
pattern = r"```(.*?)```"
# Find all matches in the input text
matches = re.findall(pattern, text, re.DOTALL)
return matches[0] if matches else text
def use_simple_prompt(llm: BaseLanguageModel) -> bool:
"""Decides whether to use the simple prompt"""
if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore
return True
# Bedrock anthropic
if hasattr(llm, "model_id") and "anthropic" in llm.model_id: # type: ignore
return True
return False
PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
conditionals=[(use_simple_prompt, NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT)],
)
class NeptuneOpenCypherQAChain(Chain):
"""Chain for question-answering against a Neptune graph
by generating openCypher statements.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
Example:
.. code-block:: python
chain = NeptuneOpenCypherQAChain.from_llm(
llm=llm,
graph=graph
)
response = chain.run(query)
"""
graph: BaseNeptuneGraph = Field(exclude=True)
cypher_generation_chain: LLMChain
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
top_k: int = 10
return_intermediate_steps: bool = False
"""Whether or not to return the intermediate steps along with the final answer."""
return_direct: bool = False
"""Whether or not to return the result of querying the graph directly."""
extra_instructions: Optional[str] = None
"""Extra instructions by the appended to the query generation prompt."""
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
cypher_prompt: Optional[BasePromptTemplate] = None,
extra_instructions: Optional[str] = None,
**kwargs: Any,
) -> NeptuneOpenCypherQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
_cypher_prompt = cypher_prompt or PROMPT_SELECTOR.get_prompt(llm)
cypher_generation_chain = LLMChain(llm=llm, prompt=_cypher_prompt)
return cls(
qa_chain=qa_chain,
cypher_generation_chain=cypher_generation_chain,
extra_instructions=extra_instructions,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Generate Cypher statement, use it to look up in db and answer question."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
question = inputs[self.input_key]
intermediate_steps: List = []
generated_cypher = self.cypher_generation_chain.run(
{
"question": question,
"schema": self.graph.get_schema,
"extra_instructions": self.extra_instructions or "",
},
callbacks=callbacks,
)
# Extract Cypher code if it is wrapped in backticks
generated_cypher = extract_cypher(generated_cypher)
generated_cypher = trim_query(generated_cypher)
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_cypher, color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"query": generated_cypher})
context = self.graph.query(generated_cypher)
if self.return_direct:
final_result = context
else:
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"context": context})
result = self.qa_chain(
{"question": question, "context": context},
callbacks=callbacks,
)
final_result = result[self.qa_chain.output_key]
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result

@ -0,0 +1,204 @@
"""
Question answering over an RDF or OWL graph using SPARQL.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_community.chains.graph_qa.prompts import SPARQL_QA_PROMPT
from langchain_community.graphs import NeptuneRdfGraph
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
SPARQL_GENERATION_TEMPLATE = """
Task: Generate a SPARQL SELECT statement for querying a graph database.
For instance, to find all email addresses of John Doe, the following
query in backticks would be suitable:
```
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
SELECT ?email
WHERE {{
?person foaf:name "John Doe" .
?person foaf:mbox ?email .
}}
```
Instructions:
Use only the node types and properties provided in the schema.
Do not use any node types and properties that are not explicitly provided.
Include all necessary prefixes.
Examples:
Schema:
{schema}
Note: Be as concise as possible.
Do not include any explanations or apologies in your responses.
Do not respond to any questions that ask for anything else than
for you to construct a SPARQL query.
Do not include any text except the SPARQL query generated.
The question is:
{prompt}"""
SPARQL_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "prompt"], template=SPARQL_GENERATION_TEMPLATE
)
def extract_sparql(query: str) -> str:
"""Extract SPARQL code from a text.
Args:
query: Text to extract SPARQL code from.
Returns:
SPARQL code extracted from the text.
"""
query = query.strip()
querytoks = query.split("```")
if len(querytoks) == 3:
query = querytoks[1]
if query.startswith("sparql"):
query = query[6:]
elif query.startswith("<sparql>") and query.endswith("</sparql>"):
query = query[8:-9]
return query
class NeptuneSparqlQAChain(Chain):
"""Chain for question-answering against a Neptune graph
by generating SPARQL statements.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
Example:
.. code-block:: python
chain = NeptuneSparqlQAChain.from_llm(
llm=llm,
graph=graph
)
response = chain.invoke(query)
"""
graph: NeptuneRdfGraph = Field(exclude=True)
sparql_generation_chain: LLMChain
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
top_k: int = 10
return_intermediate_steps: bool = False
"""Whether or not to return the intermediate steps along with the final answer."""
return_direct: bool = False
"""Whether or not to return the result of querying the graph directly."""
extra_instructions: Optional[str] = None
"""Extra instructions by the appended to the query generation prompt."""
@property
def input_keys(self) -> List[str]:
return [self.input_key]
@property
def output_keys(self) -> List[str]:
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT,
sparql_prompt: BasePromptTemplate = SPARQL_GENERATION_PROMPT,
examples: Optional[str] = None,
**kwargs: Any,
) -> NeptuneSparqlQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
template_to_use = SPARQL_GENERATION_TEMPLATE
if examples:
template_to_use = template_to_use.replace(
"Examples:", "Examples: " + examples
)
sparql_prompt = PromptTemplate(
input_variables=["schema", "prompt"], template=template_to_use
)
sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_prompt)
return cls( # type: ignore[call-arg]
qa_chain=qa_chain,
sparql_generation_chain=sparql_generation_chain,
examples=examples,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""
Generate SPARQL query, use it to retrieve a response from the gdb and answer
the question.
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
prompt = inputs[self.input_key]
intermediate_steps: List = []
generated_sparql = self.sparql_generation_chain.run(
{"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks
)
# Extract SPARQL
generated_sparql = extract_sparql(generated_sparql)
_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_sparql, color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"query": generated_sparql})
context = self.graph.query(generated_sparql)
if self.return_direct:
final_result = context
else:
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"context": context})
result = self.qa_chain(
{"prompt": prompt, "context": context},
callbacks=callbacks,
)
final_result = result[self.qa_chain.output_key]
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result

@ -0,0 +1,190 @@
"""Question answering over a graph."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional
if TYPE_CHECKING:
import rdflib
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks.manager import CallbackManager, CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_community.chains.graph_qa.prompts import (
GRAPHDB_QA_PROMPT,
GRAPHDB_SPARQL_FIX_PROMPT,
GRAPHDB_SPARQL_GENERATION_PROMPT,
)
from langchain_community.graphs import OntotextGraphDBGraph
class OntotextGraphDBQAChain(Chain):
"""Question-answering against Ontotext GraphDB
https://graphdb.ontotext.com/ by generating SPARQL queries.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
graph: OntotextGraphDBGraph = Field(exclude=True)
sparql_generation_chain: LLMChain
sparql_fix_chain: LLMChain
max_fix_retries: int
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
@property
def input_keys(self) -> List[str]:
return [self.input_key]
@property
def output_keys(self) -> List[str]:
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
sparql_generation_prompt: BasePromptTemplate = GRAPHDB_SPARQL_GENERATION_PROMPT,
sparql_fix_prompt: BasePromptTemplate = GRAPHDB_SPARQL_FIX_PROMPT,
max_fix_retries: int = 5,
qa_prompt: BasePromptTemplate = GRAPHDB_QA_PROMPT,
**kwargs: Any,
) -> OntotextGraphDBQAChain:
"""Initialize from LLM."""
sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_generation_prompt)
sparql_fix_chain = LLMChain(llm=llm, prompt=sparql_fix_prompt)
max_fix_retries = max_fix_retries
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
return cls(
qa_chain=qa_chain,
sparql_generation_chain=sparql_generation_chain,
sparql_fix_chain=sparql_fix_chain,
max_fix_retries=max_fix_retries,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""
Generate a SPARQL query, use it to retrieve a response from GraphDB and answer
the question.
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
prompt = inputs[self.input_key]
ontology_schema = self.graph.get_schema
sparql_generation_chain_result = self.sparql_generation_chain.invoke(
{"prompt": prompt, "schema": ontology_schema}, callbacks=callbacks
)
generated_sparql = sparql_generation_chain_result[
self.sparql_generation_chain.output_key
]
generated_sparql = self._get_prepared_sparql_query(
_run_manager, callbacks, generated_sparql, ontology_schema
)
query_results = self._execute_query(generated_sparql)
qa_chain_result = self.qa_chain.invoke(
{"prompt": prompt, "context": query_results}, callbacks=callbacks
)
result = qa_chain_result[self.qa_chain.output_key]
return {self.output_key: result}
def _get_prepared_sparql_query(
self,
_run_manager: CallbackManagerForChainRun,
callbacks: CallbackManager,
generated_sparql: str,
ontology_schema: str,
) -> str:
try:
return self._prepare_sparql_query(_run_manager, generated_sparql)
except Exception as e:
retries = 0
error_message = str(e)
self._log_invalid_sparql_query(
_run_manager, generated_sparql, error_message
)
while retries < self.max_fix_retries:
try:
sparql_fix_chain_result = self.sparql_fix_chain.invoke(
{
"error_message": error_message,
"generated_sparql": generated_sparql,
"schema": ontology_schema,
},
callbacks=callbacks,
)
generated_sparql = sparql_fix_chain_result[
self.sparql_fix_chain.output_key
]
return self._prepare_sparql_query(_run_manager, generated_sparql)
except Exception as e:
retries += 1
parse_exception = str(e)
self._log_invalid_sparql_query(
_run_manager, generated_sparql, parse_exception
)
raise ValueError("The generated SPARQL query is invalid.")
def _prepare_sparql_query(
self, _run_manager: CallbackManagerForChainRun, generated_sparql: str
) -> str:
from rdflib.plugins.sparql import prepareQuery
prepareQuery(generated_sparql)
self._log_prepared_sparql_query(_run_manager, generated_sparql)
return generated_sparql
def _log_prepared_sparql_query(
self, _run_manager: CallbackManagerForChainRun, generated_query: str
) -> None:
_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_query, color="green", end="\n", verbose=self.verbose
)
def _log_invalid_sparql_query(
self,
_run_manager: CallbackManagerForChainRun,
generated_query: str,
error_message: str,
) -> None:
_run_manager.on_text("Invalid SPARQL query: ", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_query, color="red", end="\n", verbose=self.verbose
)
_run_manager.on_text(
"SPARQL Query Parse Error: ", end="\n", verbose=self.verbose
)
_run_manager.on_text(
error_message, color="red", end="\n\n", verbose=self.verbose
)
def _execute_query(self, query: str) -> List[rdflib.query.ResultRow]:
try:
return self.graph.query(query)
except Exception:
raise ValueError("Failed to execute the generated SPARQL query.")

@ -0,0 +1,415 @@
# flake8: noqa
from langchain_core.prompts.prompt import PromptTemplate
_DEFAULT_ENTITY_EXTRACTION_TEMPLATE = """Extract all entities from the following text. As a guideline, a proper noun is generally capitalized. You should definitely extract all names and places.
Return the output as a single comma-separated list, or NONE if there is nothing of note to return.
EXAMPLE
i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff.
Output: Langchain
END OF EXAMPLE
EXAMPLE
i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff. I'm working with Sam.
Output: Langchain, Sam
END OF EXAMPLE
Begin!
{input}
Output:"""
ENTITY_EXTRACTION_PROMPT = PromptTemplate(
input_variables=["input"], template=_DEFAULT_ENTITY_EXTRACTION_TEMPLATE
)
_DEFAULT_GRAPH_QA_TEMPLATE = """Use the following knowledge triplets to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
Helpful Answer:"""
GRAPH_QA_PROMPT = PromptTemplate(
template=_DEFAULT_GRAPH_QA_TEMPLATE, input_variables=["context", "question"]
)
CYPHER_GENERATION_TEMPLATE = """Task:Generate Cypher statement to query a graph database.
Instructions:
Use only the provided relationship types and properties in the schema.
Do not use any other relationship types or properties that are not provided.
Schema:
{schema}
Note: Do not include any explanations or apologies in your responses.
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
Do not include any text except the generated Cypher statement.
The question is:
{question}"""
CYPHER_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE
)
NEBULAGRAPH_EXTRA_INSTRUCTIONS = """
Instructions:
First, generate cypher then convert it to NebulaGraph Cypher dialect(rather than standard):
1. it requires explicit label specification only when referring to node properties: v.`Foo`.name
2. note explicit label specification is not needed for edge properties, so it's e.name instead of e.`Bar`.name
3. it uses double equals sign for comparison: `==` rather than `=`
For instance:
```diff
< MATCH (p:person)-[e:directed]->(m:movie) WHERE m.name = 'The Godfather II'
< RETURN p.name, e.year, m.name;
---
> MATCH (p:`person`)-[e:directed]->(m:`movie`) WHERE m.`movie`.`name` == 'The Godfather II'
> RETURN p.`person`.`name`, e.year, m.`movie`.`name`;
```\n"""
NGQL_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
"Generate Cypher", "Generate NebulaGraph Cypher"
).replace("Instructions:", NEBULAGRAPH_EXTRA_INSTRUCTIONS)
NGQL_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "question"], template=NGQL_GENERATION_TEMPLATE
)
KUZU_EXTRA_INSTRUCTIONS = """
Instructions:
Generate the Kùzu dialect of Cypher with the following rules in mind:
1. Do not use a `WHERE EXISTS` clause to check the existence of a property.
2. Do not omit the relationship pattern. Always use `()-[]->()` instead of `()->()`.
3. Do not include any notes or comments even if the statement does not produce the expected result.
```\n"""
KUZU_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
"Generate Cypher", "Generate Kùzu Cypher"
).replace("Instructions:", KUZU_EXTRA_INSTRUCTIONS)
KUZU_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "question"], template=KUZU_GENERATION_TEMPLATE
)
GREMLIN_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace("Cypher", "Gremlin")
GREMLIN_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "question"], template=GREMLIN_GENERATION_TEMPLATE
)
CYPHER_QA_TEMPLATE = """You are an assistant that helps to form nice and human understandable answers.
The information part contains the provided information that you must use to construct an answer.
The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
Make the answer sound as a response to the question. Do not mention that you based the result on the given information.
Here is an example:
Question: Which managers own Neo4j stocks?
Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC]
Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks.
Follow this example when generating answers.
If the provided information is empty, say that you don't know the answer.
Information:
{context}
Question: {question}
Helpful Answer:"""
CYPHER_QA_PROMPT = PromptTemplate(
input_variables=["context", "question"], template=CYPHER_QA_TEMPLATE
)
SPARQL_INTENT_TEMPLATE = """Task: Identify the intent of a prompt and return the appropriate SPARQL query type.
You are an assistant that distinguishes different types of prompts and returns the corresponding SPARQL query types.
Consider only the following query types:
* SELECT: this query type corresponds to questions
* UPDATE: this query type corresponds to all requests for deleting, inserting, or changing triples
Note: Be as concise as possible.
Do not include any explanations or apologies in your responses.
Do not respond to any questions that ask for anything else than for you to identify a SPARQL query type.
Do not include any unnecessary whitespaces or any text except the query type, i.e., either return 'SELECT' or 'UPDATE'.
The prompt is:
{prompt}
Helpful Answer:"""
SPARQL_INTENT_PROMPT = PromptTemplate(
input_variables=["prompt"], template=SPARQL_INTENT_TEMPLATE
)
SPARQL_GENERATION_SELECT_TEMPLATE = """Task: Generate a SPARQL SELECT statement for querying a graph database.
For instance, to find all email addresses of John Doe, the following query in backticks would be suitable:
```
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
SELECT ?email
WHERE {{
?person foaf:name "John Doe" .
?person foaf:mbox ?email .
}}
```
Instructions:
Use only the node types and properties provided in the schema.
Do not use any node types and properties that are not explicitly provided.
Include all necessary prefixes.
Schema:
{schema}
Note: Be as concise as possible.
Do not include any explanations or apologies in your responses.
Do not respond to any questions that ask for anything else than for you to construct a SPARQL query.
Do not include any text except the SPARQL query generated.
The question is:
{prompt}"""
SPARQL_GENERATION_SELECT_PROMPT = PromptTemplate(
input_variables=["schema", "prompt"], template=SPARQL_GENERATION_SELECT_TEMPLATE
)
SPARQL_GENERATION_UPDATE_TEMPLATE = """Task: Generate a SPARQL UPDATE statement for updating a graph database.
For instance, to add 'jane.doe@foo.bar' as a new email address for Jane Doe, the following query in backticks would be suitable:
```
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
INSERT {{
?person foaf:mbox <mailto:jane.doe@foo.bar> .
}}
WHERE {{
?person foaf:name "Jane Doe" .
}}
```
Instructions:
Make the query as short as possible and avoid adding unnecessary triples.
Use only the node types and properties provided in the schema.
Do not use any node types and properties that are not explicitly provided.
Include all necessary prefixes.
Schema:
{schema}
Note: Be as concise as possible.
Do not include any explanations or apologies in your responses.
Do not respond to any questions that ask for anything else than for you to construct a SPARQL query.
Return only the generated SPARQL query, nothing else.
The information to be inserted is:
{prompt}"""
SPARQL_GENERATION_UPDATE_PROMPT = PromptTemplate(
input_variables=["schema", "prompt"], template=SPARQL_GENERATION_UPDATE_TEMPLATE
)
SPARQL_QA_TEMPLATE = """Task: Generate a natural language response from the results of a SPARQL query.
You are an assistant that creates well-written and human understandable answers.
The information part contains the information provided, which you can use to construct an answer.
The information provided is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
Make your response sound like the information is coming from an AI assistant, but don't add any information.
Information:
{context}
Question: {prompt}
Helpful Answer:"""
SPARQL_QA_PROMPT = PromptTemplate(
input_variables=["context", "prompt"], template=SPARQL_QA_TEMPLATE
)
GRAPHDB_SPARQL_GENERATION_TEMPLATE = """
Write a SPARQL SELECT query for querying a graph database.
The ontology schema delimited by triple backticks in Turtle format is:
```
{schema}
```
Use only the classes and properties provided in the schema to construct the SPARQL query.
Do not use any classes or properties that are not explicitly provided in the SPARQL query.
Include all necessary prefixes.
Do not include any explanations or apologies in your responses.
Do not wrap the query in backticks.
Do not include any text except the SPARQL query generated.
The question delimited by triple backticks is:
```
{prompt}
```
"""
GRAPHDB_SPARQL_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "prompt"],
template=GRAPHDB_SPARQL_GENERATION_TEMPLATE,
)
GRAPHDB_SPARQL_FIX_TEMPLATE = """
This following SPARQL query delimited by triple backticks
```
{generated_sparql}
```
is not valid.
The error delimited by triple backticks is
```
{error_message}
```
Give me a correct version of the SPARQL query.
Do not change the logic of the query.
Do not include any explanations or apologies in your responses.
Do not wrap the query in backticks.
Do not include any text except the SPARQL query generated.
The ontology schema delimited by triple backticks in Turtle format is:
```
{schema}
```
"""
GRAPHDB_SPARQL_FIX_PROMPT = PromptTemplate(
input_variables=["error_message", "generated_sparql", "schema"],
template=GRAPHDB_SPARQL_FIX_TEMPLATE,
)
GRAPHDB_QA_TEMPLATE = """Task: Generate a natural language response from the results of a SPARQL query.
You are an assistant that creates well-written and human understandable answers.
The information part contains the information provided, which you can use to construct an answer.
The information provided is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
Make your response sound like the information is coming from an AI assistant, but don't add any information.
Don't use internal knowledge to answer the question, just say you don't know if no information is available.
Information:
{context}
Question: {prompt}
Helpful Answer:"""
GRAPHDB_QA_PROMPT = PromptTemplate(
input_variables=["context", "prompt"], template=GRAPHDB_QA_TEMPLATE
)
AQL_GENERATION_TEMPLATE = """Task: Generate an ArangoDB Query Language (AQL) query from a User Input.
You are an ArangoDB Query Language (AQL) expert responsible for translating a `User Input` into an ArangoDB Query Language (AQL) query.
You are given an `ArangoDB Schema`. It is a JSON Object containing:
1. `Graph Schema`: Lists all Graphs within the ArangoDB Database Instance, along with their Edge Relationships.
2. `Collection Schema`: Lists all Collections within the ArangoDB Database Instance, along with their document/edge properties and a document/edge example.
You may also be given a set of `AQL Query Examples` to help you create the `AQL Query`. If provided, the `AQL Query Examples` should be used as a reference, similar to how `ArangoDB Schema` should be used.
Things you should do:
- Think step by step.
- Rely on `ArangoDB Schema` and `AQL Query Examples` (if provided) to generate the query.
- Begin the `AQL Query` by the `WITH` AQL keyword to specify all of the ArangoDB Collections required.
- Return the `AQL Query` wrapped in 3 backticks (```).
- Use only the provided relationship types and properties in the `ArangoDB Schema` and any `AQL Query Examples` queries.
- Only answer to requests related to generating an AQL Query.
- If a request is unrelated to generating AQL Query, say that you cannot help the user.
Things you should not do:
- Do not use any properties/relationships that can't be inferred from the `ArangoDB Schema` or the `AQL Query Examples`.
- Do not include any text except the generated AQL Query.
- Do not provide explanations or apologies in your responses.
- Do not generate an AQL Query that removes or deletes any data.
Under no circumstance should you generate an AQL Query that deletes any data whatsoever.
ArangoDB Schema:
{adb_schema}
AQL Query Examples (Optional):
{aql_examples}
User Input:
{user_input}
AQL Query:
"""
AQL_GENERATION_PROMPT = PromptTemplate(
input_variables=["adb_schema", "aql_examples", "user_input"],
template=AQL_GENERATION_TEMPLATE,
)
AQL_FIX_TEMPLATE = """Task: Address the ArangoDB Query Language (AQL) error message of an ArangoDB Query Language query.
You are an ArangoDB Query Language (AQL) expert responsible for correcting the provided `AQL Query` based on the provided `AQL Error`.
The `AQL Error` explains why the `AQL Query` could not be executed in the database.
The `AQL Error` may also contain the position of the error relative to the total number of lines of the `AQL Query`.
For example, 'error X at position 2:5' denotes that the error X occurs on line 2, column 5 of the `AQL Query`.
You are also given the `ArangoDB Schema`. It is a JSON Object containing:
1. `Graph Schema`: Lists all Graphs within the ArangoDB Database Instance, along with their Edge Relationships.
2. `Collection Schema`: Lists all Collections within the ArangoDB Database Instance, along with their document/edge properties and a document/edge example.
You will output the `Corrected AQL Query` wrapped in 3 backticks (```). Do not include any text except the Corrected AQL Query.
Remember to think step by step.
ArangoDB Schema:
{adb_schema}
AQL Query:
{aql_query}
AQL Error:
{aql_error}
Corrected AQL Query:
"""
AQL_FIX_PROMPT = PromptTemplate(
input_variables=[
"adb_schema",
"aql_query",
"aql_error",
],
template=AQL_FIX_TEMPLATE,
)
AQL_QA_TEMPLATE = """Task: Generate a natural language `Summary` from the results of an ArangoDB Query Language query.
You are an ArangoDB Query Language (AQL) expert responsible for creating a well-written `Summary` from the `User Input` and associated `AQL Result`.
A user has executed an ArangoDB Query Language query, which has returned the AQL Result in JSON format.
You are responsible for creating an `Summary` based on the AQL Result.
You are given the following information:
- `ArangoDB Schema`: contains a schema representation of the user's ArangoDB Database.
- `User Input`: the original question/request of the user, which has been translated into an AQL Query.
- `AQL Query`: the AQL equivalent of the `User Input`, translated by another AI Model. Should you deem it to be incorrect, suggest a different AQL Query.
- `AQL Result`: the JSON output returned by executing the `AQL Query` within the ArangoDB Database.
Remember to think step by step.
Your `Summary` should sound like it is a response to the `User Input`.
Your `Summary` should not include any mention of the `AQL Query` or the `AQL Result`.
ArangoDB Schema:
{adb_schema}
User Input:
{user_input}
AQL Query:
{aql_query}
AQL Result:
{aql_result}
"""
AQL_QA_PROMPT = PromptTemplate(
input_variables=["adb_schema", "user_input", "aql_query", "aql_result"],
template=AQL_QA_TEMPLATE,
)
NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS = """
Instructions:
Generate the query in openCypher format and follow these rules:
Do not use `NONE`, `ALL` or `ANY` predicate functions, rather use list comprehensions.
Do not use `REDUCE` function. Rather use a combination of list comprehension and the `UNWIND` clause to achieve similar results.
Do not use `FOREACH` clause. Rather use a combination of `WITH` and `UNWIND` clauses to achieve similar results.{extra_instructions}
\n"""
NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
"Instructions:", NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS
)
NEPTUNE_OPENCYPHER_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "question", "extra_instructions"],
template=NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE,
)
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE = """
Write an openCypher query to answer the following question. Do not explain the answer. Only return the query.{extra_instructions}
Question: "{question}".
Here is the property graph schema:
{schema}
\n"""
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT = PromptTemplate(
input_variables=["schema", "question", "extra_instructions"],
template=NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE,
)

@ -0,0 +1,152 @@
"""
Question answering over an RDF or OWL graph using SPARQL.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_community.chains.graph_qa.prompts import (
SPARQL_GENERATION_SELECT_PROMPT,
SPARQL_GENERATION_UPDATE_PROMPT,
SPARQL_INTENT_PROMPT,
SPARQL_QA_PROMPT,
)
from langchain_community.graphs.rdf_graph import RdfGraph
class GraphSparqlQAChain(Chain):
"""Question-answering against an RDF or OWL graph by generating SPARQL statements.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
graph: RdfGraph = Field(exclude=True)
sparql_generation_select_chain: LLMChain
sparql_generation_update_chain: LLMChain
sparql_intent_chain: LLMChain
qa_chain: LLMChain
return_sparql_query: bool = False
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
sparql_query_key: str = "sparql_query" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT,
sparql_select_prompt: BasePromptTemplate = SPARQL_GENERATION_SELECT_PROMPT,
sparql_update_prompt: BasePromptTemplate = SPARQL_GENERATION_UPDATE_PROMPT,
sparql_intent_prompt: BasePromptTemplate = SPARQL_INTENT_PROMPT,
**kwargs: Any,
) -> GraphSparqlQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
sparql_generation_select_chain = LLMChain(llm=llm, prompt=sparql_select_prompt)
sparql_generation_update_chain = LLMChain(llm=llm, prompt=sparql_update_prompt)
sparql_intent_chain = LLMChain(llm=llm, prompt=sparql_intent_prompt)
return cls(
qa_chain=qa_chain,
sparql_generation_select_chain=sparql_generation_select_chain,
sparql_generation_update_chain=sparql_generation_update_chain,
sparql_intent_chain=sparql_intent_chain,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""
Generate SPARQL query, use it to retrieve a response from the gdb and answer
the question.
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
prompt = inputs[self.input_key]
_intent = self.sparql_intent_chain.run({"prompt": prompt}, callbacks=callbacks)
intent = _intent.strip()
if "SELECT" in intent and "UPDATE" not in intent:
sparql_generation_chain = self.sparql_generation_select_chain
intent = "SELECT"
elif "UPDATE" in intent and "SELECT" not in intent:
sparql_generation_chain = self.sparql_generation_update_chain
intent = "UPDATE"
else:
raise ValueError(
"I am sorry, but this prompt seems to fit none of the currently "
"supported SPARQL query types, i.e., SELECT and UPDATE."
)
_run_manager.on_text("Identified intent:", end="\n", verbose=self.verbose)
_run_manager.on_text(intent, color="green", end="\n", verbose=self.verbose)
generated_sparql = sparql_generation_chain.run(
{"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks
)
_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_sparql, color="green", end="\n", verbose=self.verbose
)
if intent == "SELECT":
context = self.graph.query(generated_sparql)
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
result = self.qa_chain(
{"prompt": prompt, "context": context},
callbacks=callbacks,
)
res = result[self.qa_chain.output_key]
elif intent == "UPDATE":
self.graph.update(generated_sparql)
res = "Successfully inserted triples into the graph."
else:
raise ValueError("Unsupported SPARQL query type.")
chain_result: Dict[str, Any] = {self.output_key: res}
if self.return_sparql_query:
chain_result[self.sparql_query_key] = generated_sparql
return chain_result

@ -0,0 +1,97 @@
"""Chain that hits a URL and then uses an LLM to parse results."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from langchain.chains import LLMChain
from langchain.chains.base import Chain
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_community.utilities.requests import TextRequestsWrapper
DEFAULT_HEADERS = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36" # noqa: E501
}
class LLMRequestsChain(Chain):
"""Chain that requests a URL and then uses an LLM to parse results.
**Security Note**: This chain can make GET requests to arbitrary URLs,
including internal URLs.
Control access to who can run this chain and what network access
this chain has.
See https://python.langchain.com/docs/security for more information.
"""
llm_chain: LLMChain # type: ignore[valid-type]
requests_wrapper: TextRequestsWrapper = Field(
default_factory=lambda: TextRequestsWrapper(headers=DEFAULT_HEADERS),
exclude=True,
)
text_length: int = 8000
requests_key: str = "requests_result" #: :meta private:
input_key: str = "url" #: :meta private:
output_key: str = "output" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
from bs4 import BeautifulSoup # noqa: F401
except ImportError:
raise ImportError(
"Could not import bs4 python package. "
"Please install it with `pip install bs4`."
)
return values
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
from bs4 import BeautifulSoup
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
# Other keys are assumed to be needed for LLM prediction
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
url = inputs[self.input_key]
res = self.requests_wrapper.get(url)
# extract the text from the html
soup = BeautifulSoup(res, "html.parser")
other_keys[self.requests_key] = soup.get_text()[: self.text_length]
result = self.llm_chain.predict( # type: ignore[attr-defined]
callbacks=_run_manager.get_child(), **other_keys
)
return {self.output_key: result}
@property
def _chain_type(self) -> str:
return "llm_requests_chain"

@ -0,0 +1,229 @@
"""Chain that makes API calls and summarizes the responses to answer a question."""
from __future__ import annotations
import json
from typing import Any, Dict, List, NamedTuple, Optional, cast
from langchain.chains.api.openapi.requests_chain import APIRequesterChain
from langchain.chains.api.openapi.response_chain import APIResponderChain
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import CallbackManagerForChainRun, Callbacks
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import BaseModel, Field
from requests import Response
from langchain_community.tools.openapi.utils.api_models import APIOperation
from langchain_community.utilities.requests import Requests
class _ParamMapping(NamedTuple):
"""Mapping from parameter name to parameter value."""
query_params: List[str]
body_params: List[str]
path_params: List[str]
class OpenAPIEndpointChain(Chain, BaseModel):
"""Chain interacts with an OpenAPI endpoint using natural language."""
api_request_chain: LLMChain
api_response_chain: Optional[LLMChain]
api_operation: APIOperation
requests: Requests = Field(exclude=True, default_factory=Requests)
param_mapping: _ParamMapping = Field(alias="param_mapping")
return_intermediate_steps: bool = False
instructions_key: str = "instructions" #: :meta private:
output_key: str = "output" #: :meta private:
max_text_length: Optional[int] = Field(ge=0) #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.instructions_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
if not self.return_intermediate_steps:
return [self.output_key]
else:
return [self.output_key, "intermediate_steps"]
def _construct_path(self, args: Dict[str, str]) -> str:
"""Construct the path from the deserialized input."""
path = self.api_operation.base_url + self.api_operation.path
for param in self.param_mapping.path_params:
path = path.replace(f"{{{param}}}", str(args.pop(param, "")))
return path
def _extract_query_params(self, args: Dict[str, str]) -> Dict[str, str]:
"""Extract the query params from the deserialized input."""
query_params = {}
for param in self.param_mapping.query_params:
if param in args:
query_params[param] = args.pop(param)
return query_params
def _extract_body_params(self, args: Dict[str, str]) -> Optional[Dict[str, str]]:
"""Extract the request body params from the deserialized input."""
body_params = None
if self.param_mapping.body_params:
body_params = {}
for param in self.param_mapping.body_params:
if param in args:
body_params[param] = args.pop(param)
return body_params
def deserialize_json_input(self, serialized_args: str) -> dict:
"""Use the serialized typescript dictionary.
Resolve the path, query params dict, and optional requestBody dict.
"""
args: dict = json.loads(serialized_args)
path = self._construct_path(args)
body_params = self._extract_body_params(args)
query_params = self._extract_query_params(args)
return {
"url": path,
"data": body_params,
"params": query_params,
}
def _get_output(self, output: str, intermediate_steps: dict) -> dict:
"""Return the output from the API call."""
if self.return_intermediate_steps:
return {
self.output_key: output,
"intermediate_steps": intermediate_steps,
}
else:
return {self.output_key: output}
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
intermediate_steps = {}
instructions = inputs[self.instructions_key]
instructions = instructions[: self.max_text_length]
_api_arguments = self.api_request_chain.predict_and_parse(
instructions=instructions, callbacks=_run_manager.get_child()
)
api_arguments = cast(str, _api_arguments)
intermediate_steps["request_args"] = api_arguments
_run_manager.on_text(
api_arguments, color="green", end="\n", verbose=self.verbose
)
if api_arguments.startswith("ERROR"):
return self._get_output(api_arguments, intermediate_steps)
elif api_arguments.startswith("MESSAGE:"):
return self._get_output(
api_arguments[len("MESSAGE:") :], intermediate_steps
)
try:
request_args = self.deserialize_json_input(api_arguments)
method = getattr(self.requests, self.api_operation.method.value)
api_response: Response = method(**request_args)
if api_response.status_code != 200:
method_str = str(self.api_operation.method.value)
response_text = (
f"{api_response.status_code}: {api_response.reason}"
+ f"\nFor {method_str.upper()} {request_args['url']}\n"
+ f"Called with args: {request_args['params']}"
)
else:
response_text = api_response.text
except Exception as e:
response_text = f"Error with message {str(e)}"
response_text = response_text[: self.max_text_length]
intermediate_steps["response_text"] = response_text
_run_manager.on_text(
response_text, color="blue", end="\n", verbose=self.verbose
)
if self.api_response_chain is not None:
_answer = self.api_response_chain.predict_and_parse(
response=response_text,
instructions=instructions,
callbacks=_run_manager.get_child(),
)
answer = cast(str, _answer)
_run_manager.on_text(answer, color="yellow", end="\n", verbose=self.verbose)
return self._get_output(answer, intermediate_steps)
else:
return self._get_output(response_text, intermediate_steps)
@classmethod
def from_url_and_method(
cls,
spec_url: str,
path: str,
method: str,
llm: BaseLanguageModel,
requests: Optional[Requests] = None,
return_intermediate_steps: bool = False,
**kwargs: Any,
# TODO: Handle async
) -> "OpenAPIEndpointChain":
"""Create an OpenAPIEndpoint from a spec at the specified url."""
operation = APIOperation.from_openapi_url(spec_url, path, method)
return cls.from_api_operation(
operation,
requests=requests,
llm=llm,
return_intermediate_steps=return_intermediate_steps,
**kwargs,
)
@classmethod
def from_api_operation(
cls,
operation: APIOperation,
llm: BaseLanguageModel,
requests: Optional[Requests] = None,
verbose: bool = False,
return_intermediate_steps: bool = False,
raw_response: bool = False,
callbacks: Callbacks = None,
**kwargs: Any,
# TODO: Handle async
) -> "OpenAPIEndpointChain":
"""Create an OpenAPIEndpointChain from an operation and a spec."""
param_mapping = _ParamMapping(
query_params=operation.query_params,
body_params=operation.body_params,
path_params=operation.path_params,
)
requests_chain = APIRequesterChain.from_llm_and_typescript(
llm,
typescript_definition=operation.to_typescript(),
verbose=verbose,
callbacks=callbacks,
)
if raw_response:
response_chain = None
else:
response_chain = APIResponderChain.from_llm(
llm, verbose=verbose, callbacks=callbacks
)
_requests = requests or Requests()
return cls(
api_request_chain=requests_chain,
api_response_chain=response_chain,
api_operation=operation,
requests=_requests,
param_mapping=param_mapping,
verbose=verbose,
return_intermediate_steps=return_intermediate_steps,
callbacks=callbacks,
**kwargs,
)

@ -0,0 +1,57 @@
# flake8: noqa
REQUEST_TEMPLATE = """You are a helpful AI Assistant. Please provide JSON arguments to agentFunc() based on the user's instructions.
API_SCHEMA: ```typescript
{schema}
```
USER_INSTRUCTIONS: "{instructions}"
Your arguments must be plain json provided in a markdown block:
ARGS: ```json
{{valid json conforming to API_SCHEMA}}
```
Example
-----
ARGS: ```json
{{"foo": "bar", "baz": {{"qux": "quux"}}}}
```
The block must be no more than 1 line long, and all arguments must be valid JSON. All string arguments must be wrapped in double quotes.
You MUST strictly comply to the types indicated by the provided schema, including all required args.
If you don't have sufficient information to call the function due to things like requiring specific uuid's, you can reply with the following message:
Message: ```text
Concise response requesting the additional information that would make calling the function successful.
```
Begin
-----
ARGS:
"""
RESPONSE_TEMPLATE = """You are a helpful AI assistant trained to answer user queries from API responses.
You attempted to call an API, which resulted in:
API_RESPONSE: {response}
USER_COMMENT: "{instructions}"
If the API_RESPONSE can answer the USER_COMMENT respond with the following markdown json block:
Response: ```json
{{"response": "Human-understandable synthesis of the API_RESPONSE"}}
```
Otherwise respond with the following markdown json block:
Response Error: ```json
{{"response": "What you did and a concise statement of the resulting error. If it can be easily fixed, provide a suggestion."}}
```
You MUST respond as a markdown json code block. The person you are responding to CANNOT see the API_RESPONSE, so if there is any relevant information there you must include it in your response.
Begin:
---
"""

@ -0,0 +1,62 @@
"""request parser."""
import json
import re
from typing import Any
from langchain.chains.api.openapi.prompts import REQUEST_TEMPLATE
from langchain.chains.llm import LLMChain
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts.prompt import PromptTemplate
class APIRequesterOutputParser(BaseOutputParser):
"""Parse the request and error tags."""
def _load_json_block(self, serialized_block: str) -> str:
try:
return json.dumps(json.loads(serialized_block, strict=False))
except json.JSONDecodeError:
return "ERROR serializing request."
def parse(self, llm_output: str) -> str:
"""Parse the request and error tags."""
json_match = re.search(r"```json(.*?)```", llm_output, re.DOTALL)
if json_match:
return self._load_json_block(json_match.group(1).strip())
message_match = re.search(r"```text(.*?)```", llm_output, re.DOTALL)
if message_match:
return f"MESSAGE: {message_match.group(1).strip()}"
return "ERROR making request"
@property
def _type(self) -> str:
return "api_requester"
class APIRequesterChain(LLMChain):
"""Get the request parser."""
@classmethod
def is_lc_serializable(cls) -> bool:
return False
@classmethod
def from_llm_and_typescript(
cls,
llm: BaseLanguageModel,
typescript_definition: str,
verbose: bool = True,
**kwargs: Any,
) -> LLMChain:
"""Get the request parser."""
output_parser = APIRequesterOutputParser()
prompt = PromptTemplate(
template=REQUEST_TEMPLATE,
output_parser=output_parser,
partial_variables={"schema": typescript_definition},
input_variables=["instructions"],
)
return cls(prompt=prompt, llm=llm, verbose=verbose, **kwargs)

@ -0,0 +1,57 @@
"""Response parser."""
import json
import re
from typing import Any
from langchain.chains.api.openapi.prompts import RESPONSE_TEMPLATE
from langchain.chains.llm import LLMChain
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts.prompt import PromptTemplate
class APIResponderOutputParser(BaseOutputParser):
"""Parse the response and error tags."""
def _load_json_block(self, serialized_block: str) -> str:
try:
response_content = json.loads(serialized_block, strict=False)
return response_content.get("response", "ERROR parsing response.")
except json.JSONDecodeError:
return "ERROR parsing response."
except:
raise
def parse(self, llm_output: str) -> str:
"""Parse the response and error tags."""
json_match = re.search(r"```json(.*?)```", llm_output, re.DOTALL)
if json_match:
return self._load_json_block(json_match.group(1).strip())
else:
raise ValueError(f"No response found in output: {llm_output}.")
@property
def _type(self) -> str:
return "api_responder"
class APIResponderChain(LLMChain):
"""Get the response parser."""
@classmethod
def is_lc_serializable(cls) -> bool:
return False
@classmethod
def from_llm(
cls, llm: BaseLanguageModel, verbose: bool = True, **kwargs: Any
) -> LLMChain:
"""Get the response parser."""
output_parser = APIResponderOutputParser()
prompt = PromptTemplate(
template=RESPONSE_TEMPLATE,
output_parser=output_parser,
input_variables=["response", "instructions"],
)
return cls(prompt=prompt, llm=llm, verbose=verbose, **kwargs)

@ -1,17 +1,3 @@
from abc import ABC, abstractmethod
from typing import List, Tuple
from langchain.retrievers.document_compressors.cross_encoder import BaseCrossEncoder
class BaseCrossEncoder(ABC):
"""Interface for cross encoder models."""
@abstractmethod
def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]:
"""Score pairs' similarity.
Args:
text_pairs: List of pairs of texts.
Returns:
List of scores.
"""
__all__ = ["BaseCrossEncoder"]

@ -0,0 +1,70 @@
"""Logic for converting internal query language to a valid AstraDB query."""
from typing import Dict, Tuple, Union
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
MULTIPLE_ARITY_COMPARATORS = [Comparator.IN, Comparator.NIN]
class AstraDBTranslator(Visitor):
"""Translate AstraDB internal query language elements to valid filters."""
"""Subset of allowed logical comparators."""
allowed_comparators = [
Comparator.EQ,
Comparator.NE,
Comparator.GT,
Comparator.GTE,
Comparator.LT,
Comparator.LTE,
Comparator.IN,
Comparator.NIN,
]
"""Subset of allowed logical operators."""
allowed_operators = [Operator.AND, Operator.OR]
def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
map_dict = {
Operator.AND: "$and",
Operator.OR: "$or",
Comparator.EQ: "$eq",
Comparator.NE: "$ne",
Comparator.GTE: "$gte",
Comparator.LTE: "$lte",
Comparator.LT: "$lt",
Comparator.GT: "$gt",
Comparator.IN: "$in",
Comparator.NIN: "$nin",
}
return map_dict[func]
def visit_operation(self, operation: Operation) -> Dict:
args = [arg.accept(self) for arg in operation.arguments]
return {self._format_func(operation.operator): args}
def visit_comparison(self, comparison: Comparison) -> Dict:
if comparison.comparator in MULTIPLE_ARITY_COMPARATORS and not isinstance(
comparison.value, list
):
comparison.value = [comparison.value]
comparator = self._format_func(comparison.comparator)
return {comparison.attribute: {comparator: comparison.value}}
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"filter": structured_query.filter.accept(self)}
return structured_query.query, kwargs

@ -0,0 +1,50 @@
from typing import Dict, Tuple, Union
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
class ChromaTranslator(Visitor):
"""Translate `Chroma` internal query language elements to valid filters."""
allowed_operators = [Operator.AND, Operator.OR]
"""Subset of allowed logical operators."""
allowed_comparators = [
Comparator.EQ,
Comparator.NE,
Comparator.GT,
Comparator.GTE,
Comparator.LT,
Comparator.LTE,
]
"""Subset of allowed logical comparators."""
def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
return f"${func.value}"
def visit_operation(self, operation: Operation) -> Dict:
args = [arg.accept(self) for arg in operation.arguments]
return {self._format_func(operation.operator): args}
def visit_comparison(self, comparison: Comparison) -> Dict:
return {
comparison.attribute: {
self._format_func(comparison.comparator): comparison.value
}
}
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"filter": structured_query.filter.accept(self)}
return structured_query.query, kwargs

@ -0,0 +1,64 @@
"""Logic for converting internal query language to a valid DashVector query."""
from typing import Tuple, Union
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
class DashvectorTranslator(Visitor):
"""Logic for converting internal query language elements to valid filters."""
allowed_operators = [Operator.AND, Operator.OR]
allowed_comparators = [
Comparator.EQ,
Comparator.GT,
Comparator.GTE,
Comparator.LT,
Comparator.LTE,
Comparator.LIKE,
]
map_dict = {
Operator.AND: " AND ",
Operator.OR: " OR ",
Comparator.EQ: " = ",
Comparator.GT: " > ",
Comparator.GTE: " >= ",
Comparator.LT: " < ",
Comparator.LTE: " <= ",
Comparator.LIKE: " LIKE ",
}
def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
return self.map_dict[func]
def visit_operation(self, operation: Operation) -> str:
args = [arg.accept(self) for arg in operation.arguments]
return self._format_func(operation.operator).join(args)
def visit_comparison(self, comparison: Comparison) -> str:
value = comparison.value
if isinstance(value, str):
if comparison.comparator == Comparator.LIKE:
value = f"'%{value}%'"
else:
value = f"'{value}'"
return (
f"{comparison.attribute}{self._format_func(comparison.comparator)}{value}"
)
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"filter": structured_query.filter.accept(self)}
return structured_query.query, kwargs

@ -0,0 +1,94 @@
from collections import ChainMap
from itertools import chain
from typing import Dict, Tuple
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
_COMPARATOR_TO_SYMBOL = {
Comparator.EQ: "",
Comparator.GT: " >",
Comparator.GTE: " >=",
Comparator.LT: " <",
Comparator.LTE: " <=",
Comparator.IN: "",
Comparator.LIKE: " LIKE",
}
class DatabricksVectorSearchTranslator(Visitor):
"""Translate `Databricks vector search` internal query language elements to
valid filters."""
"""Subset of allowed logical operators."""
allowed_operators = [Operator.AND, Operator.NOT, Operator.OR]
"""Subset of allowed logical comparators."""
allowed_comparators = [
Comparator.EQ,
Comparator.GT,
Comparator.GTE,
Comparator.LT,
Comparator.LTE,
Comparator.IN,
Comparator.LIKE,
]
def _visit_and_operation(self, operation: Operation) -> Dict:
return dict(ChainMap(*[arg.accept(self) for arg in operation.arguments]))
def _visit_or_operation(self, operation: Operation) -> Dict:
filter_args = [arg.accept(self) for arg in operation.arguments]
flattened_args = list(
chain.from_iterable(filter_arg.items() for filter_arg in filter_args)
)
return {
" OR ".join(key for key, _ in flattened_args): [
value for _, value in flattened_args
]
}
def _visit_not_operation(self, operation: Operation) -> Dict:
if len(operation.arguments) > 1:
raise ValueError(
f'"{operation.operator.value}" can have only one argument '
f"in Databricks vector search"
)
filter_arg = operation.arguments[0].accept(self)
return {
f"{colum_with_bool_expression} NOT": value
for colum_with_bool_expression, value in filter_arg.items()
}
def visit_operation(self, operation: Operation) -> Dict:
self._validate_func(operation.operator)
if operation.operator == Operator.AND:
return self._visit_and_operation(operation)
elif operation.operator == Operator.OR:
return self._visit_or_operation(operation)
elif operation.operator == Operator.NOT:
return self._visit_not_operation(operation)
else:
raise NotImplementedError(
f'Operator "{operation.operator}" is not supported'
)
def visit_comparison(self, comparison: Comparison) -> Dict:
self._validate_func(comparison.comparator)
comparator_symbol = _COMPARATOR_TO_SYMBOL[comparison.comparator]
return {f"{comparison.attribute}{comparator_symbol}": comparison.value}
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"filters": structured_query.filter.accept(self)}
return structured_query.query, kwargs

@ -0,0 +1,88 @@
"""Logic for converting internal query language to a valid Chroma query."""
from typing import Tuple, Union
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
COMPARATOR_TO_TQL = {
Comparator.EQ: "==",
Comparator.GT: ">",
Comparator.GTE: ">=",
Comparator.LT: "<",
Comparator.LTE: "<=",
}
OPERATOR_TO_TQL = {
Operator.AND: "and",
Operator.OR: "or",
Operator.NOT: "NOT",
}
def can_cast_to_float(string: str) -> bool:
"""Check if a string can be cast to a float."""
try:
float(string)
return True
except ValueError:
return False
class DeepLakeTranslator(Visitor):
"""Translate `DeepLake` internal query language elements to valid filters."""
allowed_operators = [Operator.AND, Operator.OR, Operator.NOT]
"""Subset of allowed logical operators."""
allowed_comparators = [
Comparator.EQ,
Comparator.GT,
Comparator.GTE,
Comparator.LT,
Comparator.LTE,
]
"""Subset of allowed logical comparators."""
def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
if isinstance(func, Operator):
value = OPERATOR_TO_TQL[func.value] # type: ignore
elif isinstance(func, Comparator):
value = COMPARATOR_TO_TQL[func.value] # type: ignore
return f"{value}"
def visit_operation(self, operation: Operation) -> str:
args = [arg.accept(self) for arg in operation.arguments]
operator = self._format_func(operation.operator)
return "(" + (" " + operator + " ").join(args) + ")"
def visit_comparison(self, comparison: Comparison) -> str:
comparator = self._format_func(comparison.comparator)
values = comparison.value
if isinstance(values, list):
tql = []
for value in values:
comparison.value = value
tql.append(self.visit_comparison(comparison))
return "(" + (" or ").join(tql) + ")"
if not can_cast_to_float(comparison.value):
values = f"'{values}'"
return f"metadata['{comparison.attribute}'] {comparator} {values}"
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
tqL = f"SELECT * WHERE {structured_query.filter.accept(self)}"
kwargs = {"tql": tqL}
return structured_query.query, kwargs

@ -0,0 +1,49 @@
from typing import Tuple, Union
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
class DingoDBTranslator(Visitor):
"""Translate `DingoDB` internal query language elements to valid filters."""
allowed_comparators = (
Comparator.EQ,
Comparator.NE,
Comparator.LT,
Comparator.LTE,
Comparator.GT,
Comparator.GTE,
)
"""Subset of allowed logical comparators."""
allowed_operators = (Operator.AND, Operator.OR)
"""Subset of allowed logical operators."""
def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
return f"${func.value}"
def visit_operation(self, operation: Operation) -> Operation:
return operation
def visit_comparison(self, comparison: Comparison) -> Comparison:
return comparison
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {
"search_params": {
"langchain_expr": structured_query.filter.accept(self)
}
}
return structured_query.query, kwargs

@ -0,0 +1,100 @@
from typing import Dict, Tuple, Union
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
class ElasticsearchTranslator(Visitor):
"""Translate `Elasticsearch` internal query language elements to valid filters."""
allowed_comparators = [
Comparator.EQ,
Comparator.GT,
Comparator.GTE,
Comparator.LT,
Comparator.LTE,
Comparator.CONTAIN,
Comparator.LIKE,
]
"""Subset of allowed logical comparators."""
allowed_operators = [Operator.AND, Operator.OR, Operator.NOT]
"""Subset of allowed logical operators."""
def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
map_dict = {
Operator.OR: "should",
Operator.NOT: "must_not",
Operator.AND: "must",
Comparator.EQ: "term",
Comparator.GT: "gt",
Comparator.GTE: "gte",
Comparator.LT: "lt",
Comparator.LTE: "lte",
Comparator.CONTAIN: "match",
Comparator.LIKE: "match",
}
return map_dict[func]
def visit_operation(self, operation: Operation) -> Dict:
args = [arg.accept(self) for arg in operation.arguments]
return {"bool": {self._format_func(operation.operator): args}}
def visit_comparison(self, comparison: Comparison) -> Dict:
# ElasticsearchStore filters require to target
# the metadata object field
field = f"metadata.{comparison.attribute}"
is_range_comparator = comparison.comparator in [
Comparator.GT,
Comparator.GTE,
Comparator.LT,
Comparator.LTE,
]
if is_range_comparator:
value = comparison.value
if isinstance(comparison.value, dict) and "date" in comparison.value:
value = comparison.value["date"]
return {"range": {field: {self._format_func(comparison.comparator): value}}}
if comparison.comparator == Comparator.CONTAIN:
return {
self._format_func(comparison.comparator): {
field: {"query": comparison.value}
}
}
if comparison.comparator == Comparator.LIKE:
return {
self._format_func(comparison.comparator): {
field: {"query": comparison.value, "fuzziness": "AUTO"}
}
}
# we assume that if the value is a string,
# we want to use the keyword field
field = f"{field}.keyword" if isinstance(comparison.value, str) else field
if isinstance(comparison.value, dict):
if "date" in comparison.value:
comparison.value = comparison.value["date"]
return {self._format_func(comparison.comparator): {field: comparison.value}}
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"filter": [structured_query.filter.accept(self)]}
return structured_query.query, kwargs

@ -0,0 +1,103 @@
"""Logic for converting internal query language to a valid Milvus query."""
from typing import Tuple, Union
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
COMPARATOR_TO_BER = {
Comparator.EQ: "==",
Comparator.GT: ">",
Comparator.GTE: ">=",
Comparator.LT: "<",
Comparator.LTE: "<=",
Comparator.IN: "in",
Comparator.LIKE: "like",
}
UNARY_OPERATORS = [Operator.NOT]
def process_value(value: Union[int, float, str], comparator: Comparator) -> str:
"""Convert a value to a string and add double quotes if it is a string.
It required for comparators involving strings.
Args:
value: The value to convert.
comparator: The comparator.
Returns:
The converted value as a string.
"""
#
if isinstance(value, str):
if comparator is Comparator.LIKE:
# If the comparator is LIKE, add a percent sign after it for prefix matching
# and add double quotes
return f'"{value}%"'
else:
# If the value is already a string, add double quotes
return f'"{value}"'
else:
# If the value is not a string, convert it to a string without double quotes
return str(value)
class MilvusTranslator(Visitor):
"""Translate Milvus internal query language elements to valid filters."""
"""Subset of allowed logical operators."""
allowed_operators = [Operator.AND, Operator.NOT, Operator.OR]
"""Subset of allowed logical comparators."""
allowed_comparators = [
Comparator.EQ,
Comparator.GT,
Comparator.GTE,
Comparator.LT,
Comparator.LTE,
Comparator.IN,
Comparator.LIKE,
]
def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
value = func.value
if isinstance(func, Comparator):
value = COMPARATOR_TO_BER[func]
return f"{value}"
def visit_operation(self, operation: Operation) -> str:
if operation.operator in UNARY_OPERATORS and len(operation.arguments) == 1:
operator = self._format_func(operation.operator)
return operator + "(" + operation.arguments[0].accept(self) + ")"
elif operation.operator in UNARY_OPERATORS:
raise ValueError(
f'"{operation.operator.value}" can have only one argument in Milvus'
)
else:
args = [arg.accept(self) for arg in operation.arguments]
operator = self._format_func(operation.operator)
return "(" + (" " + operator + " ").join(args) + ")"
def visit_comparison(self, comparison: Comparison) -> str:
comparator = self._format_func(comparison.comparator)
processed_value = process_value(comparison.value, comparison.comparator)
attribute = comparison.attribute
return "( " + attribute + " " + comparator + " " + processed_value + " )"
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"expr": structured_query.filter.accept(self)}
return structured_query.query, kwargs

@ -0,0 +1,74 @@
"""Logic for converting internal query language to a valid MongoDB Atlas query."""
from typing import Dict, Tuple, Union
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
MULTIPLE_ARITY_COMPARATORS = [Comparator.IN, Comparator.NIN]
class MongoDBAtlasTranslator(Visitor):
"""Translate Mongo internal query language elements to valid filters."""
"""Subset of allowed logical comparators."""
allowed_comparators = [
Comparator.EQ,
Comparator.NE,
Comparator.GT,
Comparator.GTE,
Comparator.LT,
Comparator.LTE,
Comparator.IN,
Comparator.NIN,
]
"""Subset of allowed logical operators."""
allowed_operators = [Operator.AND, Operator.OR]
## Convert a operator or a comparator to Mongo Query Format
def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
map_dict = {
Operator.AND: "$and",
Operator.OR: "$or",
Comparator.EQ: "$eq",
Comparator.NE: "$ne",
Comparator.GTE: "$gte",
Comparator.LTE: "$lte",
Comparator.LT: "$lt",
Comparator.GT: "$gt",
Comparator.IN: "$in",
Comparator.NIN: "$nin",
}
return map_dict[func]
def visit_operation(self, operation: Operation) -> Dict:
args = [arg.accept(self) for arg in operation.arguments]
return {self._format_func(operation.operator): args}
def visit_comparison(self, comparison: Comparison) -> Dict:
if comparison.comparator in MULTIPLE_ARITY_COMPARATORS and not isinstance(
comparison.value, list
):
comparison.value = [comparison.value]
comparator = self._format_func(comparison.comparator)
attribute = comparison.attribute
return {attribute: {comparator: comparison.value}}
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"pre_filter": structured_query.filter.accept(self)}
return structured_query.query, kwargs

@ -0,0 +1,125 @@
import re
from typing import Any, Callable, Dict, Tuple
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
def _DEFAULT_COMPOSER(op_name: str) -> Callable:
"""
Default composer for logical operators.
Args:
op_name: Name of the operator.
Returns:
Callable that takes a list of arguments and returns a string.
"""
def f(*args: Any) -> str:
args_: map[str] = map(str, args)
return f" {op_name} ".join(args_)
return f
def _FUNCTION_COMPOSER(op_name: str) -> Callable:
"""
Composer for functions.
Args:
op_name: Name of the function.
Returns:
Callable that takes a list of arguments and returns a string.
"""
def f(*args: Any) -> str:
args_: map[str] = map(str, args)
return f"{op_name}({','.join(args_)})"
return f
class MyScaleTranslator(Visitor):
"""Translate `MyScale` internal query language elements to valid filters."""
allowed_operators = [Operator.AND, Operator.OR, Operator.NOT]
"""Subset of allowed logical operators."""
allowed_comparators = [
Comparator.EQ,
Comparator.GT,
Comparator.GTE,
Comparator.LT,
Comparator.LTE,
Comparator.CONTAIN,
Comparator.LIKE,
]
map_dict = {
Operator.AND: _DEFAULT_COMPOSER("AND"),
Operator.OR: _DEFAULT_COMPOSER("OR"),
Operator.NOT: _DEFAULT_COMPOSER("NOT"),
Comparator.EQ: _DEFAULT_COMPOSER("="),
Comparator.GT: _DEFAULT_COMPOSER(">"),
Comparator.GTE: _DEFAULT_COMPOSER(">="),
Comparator.LT: _DEFAULT_COMPOSER("<"),
Comparator.LTE: _DEFAULT_COMPOSER("<="),
Comparator.CONTAIN: _FUNCTION_COMPOSER("has"),
Comparator.LIKE: _DEFAULT_COMPOSER("ILIKE"),
}
def __init__(self, metadata_key: str = "metadata") -> None:
super().__init__()
self.metadata_key = metadata_key
def visit_operation(self, operation: Operation) -> Dict:
args = [arg.accept(self) for arg in operation.arguments]
func = operation.operator
self._validate_func(func)
return self.map_dict[func](*args)
def visit_comparison(self, comparison: Comparison) -> Dict:
regex = r"\((.*?)\)"
matched = re.search(r"\(\w+\)", comparison.attribute)
# If arbitrary function is applied to an attribute
if matched:
attr = re.sub(
regex,
f"({self.metadata_key}.{matched.group(0)[1:-1]})",
comparison.attribute,
)
else:
attr = f"{self.metadata_key}.{comparison.attribute}"
value = comparison.value
comp = comparison.comparator
value = f"'{value}'" if isinstance(value, str) else value
# convert timestamp for datetime objects
if isinstance(value, dict) and value.get("type") == "date":
attr = f"parseDateTime32BestEffort({attr})"
value = f"parseDateTime32BestEffort('{value['date']}')"
# string pattern match
if comp is Comparator.LIKE:
value = f"'%{value[1:-1]}%'"
return self.map_dict[comp](attr, value)
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
print(structured_query) # noqa: T201
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"where_str": structured_query.filter.accept(self)}
return structured_query.query, kwargs

@ -0,0 +1,104 @@
from typing import Dict, Tuple, Union
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
class OpenSearchTranslator(Visitor):
"""Translate `OpenSearch` internal query domain-specific
language elements to valid filters."""
allowed_comparators = [
Comparator.EQ,
Comparator.LT,
Comparator.LTE,
Comparator.GT,
Comparator.GTE,
Comparator.CONTAIN,
Comparator.LIKE,
]
"""Subset of allowed logical comparators."""
allowed_operators = [Operator.AND, Operator.OR, Operator.NOT]
"""Subset of allowed logical operators."""
def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
comp_operator_map = {
Comparator.EQ: "term",
Comparator.LT: "lt",
Comparator.LTE: "lte",
Comparator.GT: "gt",
Comparator.GTE: "gte",
Comparator.CONTAIN: "match",
Comparator.LIKE: "fuzzy",
Operator.AND: "must",
Operator.OR: "should",
Operator.NOT: "must_not",
}
return comp_operator_map[func]
def visit_operation(self, operation: Operation) -> Dict:
args = [arg.accept(self) for arg in operation.arguments]
return {"bool": {self._format_func(operation.operator): args}}
def visit_comparison(self, comparison: Comparison) -> Dict:
field = f"metadata.{comparison.attribute}"
if comparison.comparator in [
Comparator.LT,
Comparator.LTE,
Comparator.GT,
Comparator.GTE,
]:
if isinstance(comparison.value, dict):
if "date" in comparison.value:
return {
"range": {
field: {
self._format_func(
comparison.comparator
): comparison.value["date"]
}
}
}
else:
return {
"range": {
field: {
self._format_func(comparison.comparator): comparison.value
}
}
}
if comparison.comparator == Comparator.LIKE:
return {
self._format_func(comparison.comparator): {
field: {"value": comparison.value}
}
}
field = f"{field}.keyword" if isinstance(comparison.value, str) else field
if isinstance(comparison.value, dict):
if "date" in comparison.value:
comparison.value = comparison.value["date"]
return {self._format_func(comparison.comparator): {field: comparison.value}}
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"filter": structured_query.filter.accept(self)}
return structured_query.query, kwargs

@ -0,0 +1,52 @@
from typing import Dict, Tuple, Union
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
class PGVectorTranslator(Visitor):
"""Translate `PGVector` internal query language elements to valid filters."""
allowed_operators = [Operator.AND, Operator.OR]
"""Subset of allowed logical operators."""
allowed_comparators = [
Comparator.EQ,
Comparator.NE,
Comparator.GT,
Comparator.LT,
Comparator.IN,
Comparator.NIN,
Comparator.CONTAIN,
Comparator.LIKE,
]
"""Subset of allowed logical comparators."""
def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
return f"{func.value}"
def visit_operation(self, operation: Operation) -> Dict:
args = [arg.accept(self) for arg in operation.arguments]
return {self._format_func(operation.operator): args}
def visit_comparison(self, comparison: Comparison) -> Dict:
return {
comparison.attribute: {
self._format_func(comparison.comparator): comparison.value
}
}
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"filter": structured_query.filter.accept(self)}
return structured_query.query, kwargs

@ -0,0 +1,57 @@
from typing import Dict, Tuple, Union
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
class PineconeTranslator(Visitor):
"""Translate `Pinecone` internal query language elements to valid filters."""
allowed_comparators = (
Comparator.EQ,
Comparator.NE,
Comparator.LT,
Comparator.LTE,
Comparator.GT,
Comparator.GTE,
Comparator.IN,
Comparator.NIN,
)
"""Subset of allowed logical comparators."""
allowed_operators = (Operator.AND, Operator.OR)
"""Subset of allowed logical operators."""
def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
return f"${func.value}"
def visit_operation(self, operation: Operation) -> Dict:
args = [arg.accept(self) for arg in operation.arguments]
return {self._format_func(operation.operator): args}
def visit_comparison(self, comparison: Comparison) -> Dict:
if comparison.comparator in (Comparator.IN, Comparator.NIN) and not isinstance(
comparison.value, list
):
comparison.value = [comparison.value]
return {
comparison.attribute: {
self._format_func(comparison.comparator): comparison.value
}
}
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"filter": structured_query.filter.accept(self)}
return structured_query.query, kwargs

@ -0,0 +1,98 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Tuple
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
if TYPE_CHECKING:
from qdrant_client.http import models as rest
class QdrantTranslator(Visitor):
"""Translate `Qdrant` internal query language elements to valid filters."""
allowed_operators = (
Operator.AND,
Operator.OR,
Operator.NOT,
)
"""Subset of allowed logical operators."""
allowed_comparators = (
Comparator.EQ,
Comparator.LT,
Comparator.LTE,
Comparator.GT,
Comparator.GTE,
Comparator.LIKE,
)
"""Subset of allowed logical comparators."""
def __init__(self, metadata_key: str):
self.metadata_key = metadata_key
def visit_operation(self, operation: Operation) -> rest.Filter:
try:
from qdrant_client.http import models as rest
except ImportError as e:
raise ImportError(
"Cannot import qdrant_client. Please install with `pip install "
"qdrant-client`."
) from e
args = [arg.accept(self) for arg in operation.arguments]
operator = {
Operator.AND: "must",
Operator.OR: "should",
Operator.NOT: "must_not",
}[operation.operator]
return rest.Filter(**{operator: args})
def visit_comparison(self, comparison: Comparison) -> rest.FieldCondition:
try:
from qdrant_client.http import models as rest
except ImportError as e:
raise ImportError(
"Cannot import qdrant_client. Please install with `pip install "
"qdrant-client`."
) from e
self._validate_func(comparison.comparator)
attribute = self.metadata_key + "." + comparison.attribute
if comparison.comparator == Comparator.EQ:
return rest.FieldCondition(
key=attribute, match=rest.MatchValue(value=comparison.value)
)
if comparison.comparator == Comparator.LIKE:
return rest.FieldCondition(
key=attribute, match=rest.MatchText(text=comparison.value)
)
kwargs = {comparison.comparator.value: comparison.value}
return rest.FieldCondition(key=attribute, range=rest.Range(**kwargs))
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
try:
from qdrant_client.http import models as rest
except ImportError as e:
raise ImportError(
"Cannot import qdrant_client. Please install with `pip install "
"qdrant-client`."
) from e
if structured_query.filter is None:
kwargs = {}
else:
filter = structured_query.filter.accept(self)
if isinstance(filter, rest.FieldCondition):
filter = rest.Filter(must=[filter])
kwargs = {"filter": filter}
return structured_query.query, kwargs

@ -0,0 +1,103 @@
from __future__ import annotations
from typing import Any, Tuple
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
from langchain_community.vectorstores.redis import Redis
from langchain_community.vectorstores.redis.filters import (
RedisFilterExpression,
RedisFilterField,
RedisFilterOperator,
RedisNum,
RedisTag,
RedisText,
)
from langchain_community.vectorstores.redis.schema import RedisModel
_COMPARATOR_TO_BUILTIN_METHOD = {
Comparator.EQ: "__eq__",
Comparator.NE: "__ne__",
Comparator.LT: "__lt__",
Comparator.GT: "__gt__",
Comparator.LTE: "__le__",
Comparator.GTE: "__ge__",
Comparator.CONTAIN: "__eq__",
Comparator.LIKE: "__mod__",
}
class RedisTranslator(Visitor):
"""Visitor for translating structured queries to Redis filter expressions."""
allowed_comparators = (
Comparator.EQ,
Comparator.NE,
Comparator.LT,
Comparator.LTE,
Comparator.GT,
Comparator.GTE,
Comparator.CONTAIN,
Comparator.LIKE,
)
"""Subset of allowed logical comparators."""
allowed_operators = (Operator.AND, Operator.OR)
"""Subset of allowed logical operators."""
def __init__(self, schema: RedisModel) -> None:
self._schema = schema
def _attribute_to_filter_field(self, attribute: str) -> RedisFilterField:
if attribute in [tf.name for tf in self._schema.text]:
return RedisText(attribute)
elif attribute in [tf.name for tf in self._schema.tag or []]:
return RedisTag(attribute)
elif attribute in [tf.name for tf in self._schema.numeric or []]:
return RedisNum(attribute)
else:
raise ValueError(
f"Invalid attribute {attribute} not in vector store schema. Schema is:"
f"\n{self._schema.as_dict()}"
)
def visit_comparison(self, comparison: Comparison) -> RedisFilterExpression:
filter_field = self._attribute_to_filter_field(comparison.attribute)
comparison_method = _COMPARATOR_TO_BUILTIN_METHOD[comparison.comparator]
return getattr(filter_field, comparison_method)(comparison.value)
def visit_operation(self, operation: Operation) -> Any:
left = operation.arguments[0].accept(self)
if len(operation.arguments) > 2:
right = self.visit_operation(
Operation(
operator=operation.operator, arguments=operation.arguments[1:]
)
)
else:
right = operation.arguments[1].accept(self)
redis_operator = (
RedisFilterOperator.OR
if operation.operator == Operator.OR
else RedisFilterOperator.AND
)
return RedisFilterExpression(operator=redis_operator, left=left, right=right)
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"filter": structured_query.filter.accept(self)}
return structured_query.query, kwargs
@classmethod
def from_vectorstore(cls, vectorstore: Redis) -> RedisTranslator:
return cls(vectorstore._schema)

@ -0,0 +1,97 @@
from typing import Any, Dict, Tuple
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
class SupabaseVectorTranslator(Visitor):
"""Translate Langchain filters to Supabase PostgREST filters."""
allowed_operators = [Operator.AND, Operator.OR]
"""Subset of allowed logical operators."""
allowed_comparators = [
Comparator.EQ,
Comparator.NE,
Comparator.GT,
Comparator.GTE,
Comparator.LT,
Comparator.LTE,
Comparator.LIKE,
]
"""Subset of allowed logical comparators."""
metadata_column = "metadata"
def _map_comparator(self, comparator: Comparator) -> str:
"""
Maps Langchain comparator to PostgREST comparator:
https://postgrest.org/en/stable/references/api/tables_views.html#operators
"""
postgrest_comparator = {
Comparator.EQ: "eq",
Comparator.NE: "neq",
Comparator.GT: "gt",
Comparator.GTE: "gte",
Comparator.LT: "lt",
Comparator.LTE: "lte",
Comparator.LIKE: "like",
}.get(comparator)
if postgrest_comparator is None:
raise Exception(
f"Comparator '{comparator}' is not currently "
"supported in Supabase Vector"
)
return postgrest_comparator
def _get_json_operator(self, value: Any) -> str:
if isinstance(value, str):
return "->>"
else:
return "->"
def visit_operation(self, operation: Operation) -> str:
args = [arg.accept(self) for arg in operation.arguments]
return f"{operation.operator.value}({','.join(args)})"
def visit_comparison(self, comparison: Comparison) -> str:
if isinstance(comparison.value, list):
return self.visit_operation(
Operation(
operator=Operator.AND,
arguments=[
Comparison(
comparator=comparison.comparator,
attribute=comparison.attribute,
value=value,
)
for value in comparison.value
],
)
)
return ".".join(
[
f"{self.metadata_column}{self._get_json_operator(comparison.value)}{comparison.attribute}",
f"{self._map_comparator(comparison.comparator)}",
f"{comparison.value}",
]
)
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, Dict[str, str]]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"postgrest_filter": structured_query.filter.accept(self)}
return structured_query.query, kwargs

@ -0,0 +1,116 @@
from __future__ import annotations
from typing import Optional, Sequence, Tuple
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
class TencentVectorDBTranslator(Visitor):
"""Translate StructuredQuery to Tencent VectorDB query."""
COMPARATOR_MAP = {
Comparator.EQ: "=",
Comparator.NE: "!=",
Comparator.GT: ">",
Comparator.GTE: ">=",
Comparator.LT: "<",
Comparator.LTE: "<=",
Comparator.IN: "in",
Comparator.NIN: "not in",
}
allowed_comparators: Optional[Sequence[Comparator]] = list(COMPARATOR_MAP.keys())
allowed_operators: Optional[Sequence[Operator]] = [
Operator.AND,
Operator.OR,
Operator.NOT,
]
def __init__(self, meta_keys: Optional[Sequence[str]] = None):
"""Initialize the translator.
Args:
meta_keys: List of meta keys to be used in the query. Default: [].
"""
self.meta_keys = meta_keys or []
def visit_operation(self, operation: Operation) -> str:
"""Visit an operation node and return the translated query.
Args:
operation: Operation node to be visited.
Returns:
Translated query.
"""
if operation.operator in (Operator.AND, Operator.OR):
ret = f" {operation.operator.value} ".join(
[arg.accept(self) for arg in operation.arguments]
)
if operation.operator == Operator.OR:
ret = f"({ret})"
return ret
else:
return f"not ({operation.arguments[0].accept(self)})"
def visit_comparison(self, comparison: Comparison) -> str:
"""Visit a comparison node and return the translated query.
Args:
comparison: Comparison node to be visited.
Returns:
Translated query.
"""
if self.meta_keys and comparison.attribute not in self.meta_keys:
raise ValueError(
f"Expr Filtering found Unsupported attribute: {comparison.attribute}"
)
if comparison.comparator in self.COMPARATOR_MAP:
if comparison.comparator in [Comparator.IN, Comparator.NIN]:
value = map(
lambda x: f'"{x}"' if isinstance(x, str) else x, comparison.value
)
return (
f"{comparison.attribute}"
f" {self.COMPARATOR_MAP[comparison.comparator]} "
f"({', '.join(value)})"
)
if isinstance(comparison.value, str):
return (
f"{comparison.attribute} "
f"{self.COMPARATOR_MAP[comparison.comparator]}"
f' "{comparison.value}"'
)
return (
f"{comparison.attribute}"
f" {self.COMPARATOR_MAP[comparison.comparator]} "
f"{comparison.value}"
)
else:
raise ValueError(f"Unsupported comparator {comparison.comparator}")
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
"""Visit a structured query node and return the translated query.
Args:
structured_query: StructuredQuery node to be visited.
Returns:
Translated query and query kwargs.
"""
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"expr": structured_query.filter.accept(self)}
return structured_query.query, kwargs

@ -0,0 +1,84 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Tuple, Union
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
if TYPE_CHECKING:
from timescale_vector import client
class TimescaleVectorTranslator(Visitor):
"""Translate the internal query language elements to valid filters."""
allowed_operators = [Operator.AND, Operator.OR, Operator.NOT]
"""Subset of allowed logical operators."""
allowed_comparators = [
Comparator.EQ,
Comparator.GT,
Comparator.GTE,
Comparator.LT,
Comparator.LTE,
]
COMPARATOR_MAP = {
Comparator.EQ: "==",
Comparator.GT: ">",
Comparator.GTE: ">=",
Comparator.LT: "<",
Comparator.LTE: "<=",
}
OPERATOR_MAP = {Operator.AND: "AND", Operator.OR: "OR", Operator.NOT: "NOT"}
def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
if isinstance(func, Operator):
value = self.OPERATOR_MAP[func.value] # type: ignore
elif isinstance(func, Comparator):
value = self.COMPARATOR_MAP[func.value] # type: ignore
return f"{value}"
def visit_operation(self, operation: Operation) -> client.Predicates:
try:
from timescale_vector import client
except ImportError as e:
raise ImportError(
"Cannot import timescale-vector. Please install with `pip install "
"timescale-vector`."
) from e
args = [arg.accept(self) for arg in operation.arguments]
return client.Predicates(*args, operator=self._format_func(operation.operator))
def visit_comparison(self, comparison: Comparison) -> client.Predicates:
try:
from timescale_vector import client
except ImportError as e:
raise ImportError(
"Cannot import timescale-vector. Please install with `pip install "
"timescale-vector`."
) from e
return client.Predicates(
(
comparison.attribute,
self._format_func(comparison.comparator),
comparison.value,
)
)
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"predicates": structured_query.filter.accept(self)}
return structured_query.query, kwargs

@ -0,0 +1,70 @@
from typing import Tuple, Union
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
def process_value(value: Union[int, float, str]) -> str:
"""Convert a value to a string and add single quotes if it is a string."""
if isinstance(value, str):
return f"'{value}'"
else:
return str(value)
class VectaraTranslator(Visitor):
"""Translate `Vectara` internal query language elements to valid filters."""
allowed_operators = [Operator.AND, Operator.OR]
"""Subset of allowed logical operators."""
allowed_comparators = [
Comparator.EQ,
Comparator.NE,
Comparator.GT,
Comparator.GTE,
Comparator.LT,
Comparator.LTE,
]
"""Subset of allowed logical comparators."""
def _format_func(self, func: Union[Operator, Comparator]) -> str:
map_dict = {
Operator.AND: " and ",
Operator.OR: " or ",
Comparator.EQ: "=",
Comparator.NE: "!=",
Comparator.GT: ">",
Comparator.GTE: ">=",
Comparator.LT: "<",
Comparator.LTE: "<=",
}
self._validate_func(func)
return map_dict[func]
def visit_operation(self, operation: Operation) -> str:
args = [arg.accept(self) for arg in operation.arguments]
operator = self._format_func(operation.operator)
return "( " + operator.join(args) + " )"
def visit_comparison(self, comparison: Comparison) -> str:
comparator = self._format_func(comparison.comparator)
processed_value = process_value(comparison.value)
attribute = comparison.attribute
return (
"( " + "doc." + attribute + " " + comparator + " " + processed_value + " )"
)
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"filter": structured_query.filter.accept(self)}
return structured_query.query, kwargs

@ -0,0 +1,79 @@
from datetime import datetime
from typing import Dict, Tuple, Union
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
class WeaviateTranslator(Visitor):
"""Translate `Weaviate` internal query language elements to valid filters."""
allowed_operators = [Operator.AND, Operator.OR]
"""Subset of allowed logical operators."""
allowed_comparators = [
Comparator.EQ,
Comparator.NE,
Comparator.GTE,
Comparator.LTE,
Comparator.LT,
Comparator.GT,
]
def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
# https://weaviate.io/developers/weaviate/api/graphql/filters
map_dict = {
Operator.AND: "And",
Operator.OR: "Or",
Comparator.EQ: "Equal",
Comparator.NE: "NotEqual",
Comparator.GTE: "GreaterThanEqual",
Comparator.LTE: "LessThanEqual",
Comparator.LT: "LessThan",
Comparator.GT: "GreaterThan",
}
return map_dict[func]
def visit_operation(self, operation: Operation) -> Dict:
args = [arg.accept(self) for arg in operation.arguments]
return {"operator": self._format_func(operation.operator), "operands": args}
def visit_comparison(self, comparison: Comparison) -> Dict:
value_type = "valueText"
value = comparison.value
if isinstance(comparison.value, bool):
value_type = "valueBoolean"
elif isinstance(comparison.value, float):
value_type = "valueNumber"
elif isinstance(comparison.value, int):
value_type = "valueInt"
elif (
isinstance(comparison.value, dict)
and comparison.value.get("type") == "date"
):
value_type = "valueDate"
# ISO 8601 timestamp, formatted as RFC3339
date = datetime.strptime(comparison.value["date"], "%Y-%m-%d")
value = date.strftime("%Y-%m-%dT%H:%M:%SZ")
filter = {
"path": [comparison.attribute],
"operator": self._format_func(comparison.comparator),
value_type: value,
}
return filter
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"where_filter": structured_query.filter.accept(self)}
return structured_query.query, kwargs

@ -123,6 +123,7 @@ if TYPE_CHECKING:
from langchain_community.retrievers.weaviate_hybrid_search import (
WeaviateHybridSearchRetriever,
)
from langchain_community.retrievers.web_research import WebResearchRetriever
from langchain_community.retrievers.wikipedia import (
WikipediaRetriever,
)
@ -174,6 +175,7 @@ _module_lookup = {
"TavilySearchAPIRetriever": "langchain_community.retrievers.tavily_search_api",
"VespaRetriever": "langchain_community.retrievers.vespa_retriever",
"WeaviateHybridSearchRetriever": "langchain_community.retrievers.weaviate_hybrid_search", # noqa: E501
"WebResearchRetriever": "langchain_community.retrievers.web_research",
"WikipediaRetriever": "langchain_community.retrievers.wikipedia",
"YouRetriever": "langchain_community.retrievers.you",
"ZepRetriever": "langchain_community.retrievers.zep",
@ -194,8 +196,8 @@ __all__ = [
"AmazonKnowledgeBasesRetriever",
"ArceeRetriever",
"ArxivRetriever",
"AzureCognitiveSearchRetriever",
"AzureAISearchRetriever",
"AzureCognitiveSearchRetriever",
"BM25Retriever",
"BreebsRetriever",
"ChaindeskRetriever",
@ -209,8 +211,8 @@ __all__ = [
"GoogleDocumentAIWarehouseRetriever",
"GoogleVertexAIMultiTurnSearchRetriever",
"GoogleVertexAISearchRetriever",
"KNNRetriever",
"KayAiRetriever",
"KNNRetriever",
"LlamaIndexGraphRetriever",
"LlamaIndexRetriever",
"MetalRetriever",
@ -223,10 +225,11 @@ __all__ = [
"RememberizerRetriever",
"RemoteLangChainRetriever",
"SVMRetriever",
"TFIDFRetriever",
"TavilySearchAPIRetriever",
"TFIDFRetriever",
"VespaRetriever",
"WeaviateHybridSearchRetriever",
"WebResearchRetriever",
"WikipediaRetriever",
"YouRetriever",
"ZepRetriever",

@ -0,0 +1,223 @@
import logging
import re
from typing import List, Optional
from langchain.chains import LLMChain
from langchain.chains.prompt_selector import ConditionalPromptSelector
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.language_models import BaseLLM
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_community.document_loaders import AsyncHtmlLoader
from langchain_community.document_transformers import Html2TextTransformer
from langchain_community.llms import LlamaCpp
from langchain_community.utilities import GoogleSearchAPIWrapper
logger = logging.getLogger(__name__)
class SearchQueries(BaseModel):
"""Search queries to research for the user's goal."""
queries: List[str] = Field(
..., description="List of search queries to look up on Google"
)
DEFAULT_LLAMA_SEARCH_PROMPT = PromptTemplate(
input_variables=["question"],
template="""<<SYS>> \n You are an assistant tasked with improving Google search \
results. \n <</SYS>> \n\n [INST] Generate THREE Google search queries that \
are similar to this question. The output should be a numbered list of questions \
and each should have a question mark at the end: \n\n {question} [/INST]""",
)
DEFAULT_SEARCH_PROMPT = PromptTemplate(
input_variables=["question"],
template="""You are an assistant tasked with improving Google search \
results. Generate THREE Google search queries that are similar to \
this question. The output should be a numbered list of questions and each \
should have a question mark at the end: {question}""",
)
class QuestionListOutputParser(BaseOutputParser[List[str]]):
"""Output parser for a list of numbered questions."""
def parse(self, text: str) -> List[str]:
lines = re.findall(r"\d+\..*?(?:\n|$)", text)
return lines
class WebResearchRetriever(BaseRetriever):
"""`Google Search API` retriever."""
# Inputs
vectorstore: VectorStore = Field(
..., description="Vector store for storing web pages"
)
llm_chain: LLMChain
search: GoogleSearchAPIWrapper = Field(..., description="Google Search API Wrapper")
num_search_results: int = Field(1, description="Number of pages per Google search")
text_splitter: TextSplitter = Field(
RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=50),
description="Text splitter for splitting web pages into chunks",
)
url_database: List[str] = Field(
default_factory=list, description="List of processed URLs"
)
@classmethod
def from_llm(
cls,
vectorstore: VectorStore,
llm: BaseLLM,
search: GoogleSearchAPIWrapper,
prompt: Optional[BasePromptTemplate] = None,
num_search_results: int = 1,
text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter(
chunk_size=1500, chunk_overlap=150
),
) -> "WebResearchRetriever":
"""Initialize from llm using default template.
Args:
vectorstore: Vector store for storing web pages
llm: llm for search question generation
search: GoogleSearchAPIWrapper
prompt: prompt to generating search questions
num_search_results: Number of pages per Google search
text_splitter: Text splitter for splitting web pages into chunks
Returns:
WebResearchRetriever
"""
if not prompt:
QUESTION_PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=DEFAULT_SEARCH_PROMPT,
conditionals=[
(lambda llm: isinstance(llm, LlamaCpp), DEFAULT_LLAMA_SEARCH_PROMPT)
],
)
prompt = QUESTION_PROMPT_SELECTOR.get_prompt(llm)
# Use chat model prompt
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
output_parser=QuestionListOutputParser(),
)
return cls(
vectorstore=vectorstore,
llm_chain=llm_chain,
search=search,
num_search_results=num_search_results,
text_splitter=text_splitter,
)
def clean_search_query(self, query: str) -> str:
# Some search tools (e.g., Google) will
# fail to return results if query has a
# leading digit: 1. "LangCh..."
# Check if the first character is a digit
if query[0].isdigit():
# Find the position of the first quote
first_quote_pos = query.find('"')
if first_quote_pos != -1:
# Extract the part of the string after the quote
query = query[first_quote_pos + 1 :]
# Remove the trailing quote if present
if query.endswith('"'):
query = query[:-1]
return query.strip()
def search_tool(self, query: str, num_search_results: int = 1) -> List[dict]:
"""Returns num_search_results pages per Google search."""
query_clean = self.clean_search_query(query)
result = self.search.results(query_clean, num_search_results)
return result
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""Search Google for documents related to the query input.
Args:
query: user query
Returns:
Relevant documents from all various urls.
"""
# Get search questions
logger.info("Generating questions for Google Search ...")
result = self.llm_chain({"question": query})
logger.info(f"Questions for Google Search (raw): {result}")
questions = result["text"]
logger.info(f"Questions for Google Search: {questions}")
# Get urls
logger.info("Searching for relevant urls...")
urls_to_look = []
for query in questions:
# Google search
search_results = self.search_tool(query, self.num_search_results)
logger.info("Searching for relevant urls...")
logger.info(f"Search results: {search_results}")
for res in search_results:
if res.get("link", None):
urls_to_look.append(res["link"])
# Relevant urls
urls = set(urls_to_look)
# Check for any new urls that we have not processed
new_urls = list(urls.difference(self.url_database))
logger.info(f"New URLs to load: {new_urls}")
# Load, split, and add new urls to vectorstore
if new_urls:
loader = AsyncHtmlLoader(new_urls, ignore_load_errors=True)
html2text = Html2TextTransformer()
logger.info("Indexing new urls...")
docs = loader.load()
docs = list(html2text.transform_documents(docs))
docs = self.text_splitter.split_documents(docs)
self.vectorstore.add_documents(docs)
self.url_database.extend(new_urls)
# Search for relevant splits
# TODO: make this async
logger.info("Grabbing most relevant splits from urls...")
docs = []
for query in questions:
docs.extend(self.vectorstore.similarity_search(query))
# Get unique docs
unique_documents_dict = {
(doc.page_content, tuple(sorted(doc.metadata.items()))): doc for doc in docs
}
unique_documents = list(unique_documents_dict.values())
return unique_documents
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> List[Document]:
raise NotImplementedError

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
[[package]]
name = "aenum"
@ -3454,6 +3454,7 @@ files = [
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"},
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"},
{file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"},
{file = "jq-1.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:49e44ed677713f4115bd5bf2dbae23baa4cd503be350e12a1c1f506b0687848f"},
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"},
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"},
{file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"},
@ -3962,9 +3963,51 @@ files = [
{file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"},
]
[[package]]
name = "langchain"
version = "0.2.0rc1"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
files = []
develop = true
[package.dependencies]
aiohttp = "^3.8.3"
async-timeout = {version = "^4.0.0", markers = "python_version < \"3.11\""}
dataclasses-json = ">= 0.5.7, < 0.7"
langchain-core = "^0.1.48"
langchain-text-splitters = ">=0.0.1,<0.1"
langsmith = "^0.1.17"
numpy = "^1"
pydantic = ">=1,<3"
PyYAML = ">=5.3"
requests = "^2"
SQLAlchemy = ">=1.4,<3"
tenacity = "^8.1.0"
[package.extras]
all = []
azure = ["azure-ai-formrecognizer (>=3.2.1,<4.0.0)", "azure-ai-textanalytics (>=5.3.0,<6.0.0)", "azure-cognitiveservices-speech (>=1.28.0,<2.0.0)", "azure-core (>=1.26.4,<2.0.0)", "azure-cosmos (>=4.4.0b1,<5.0.0)", "azure-identity (>=1.12.0,<2.0.0)", "azure-search-documents (==11.4.0b8)", "openai (<2)"]
clarifai = ["clarifai (>=9.1.0)"]
cli = ["typer (>=0.9.0,<0.10.0)"]
cohere = ["cohere (>=4,<6)"]
docarray = ["docarray[hnswlib] (>=0.32.0,<0.33.0)"]
embeddings = ["sentence-transformers (>=2,<3)"]
extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cohere (>=4,<6)", "couchbase (>=4.1.9,<5.0.0)", "dashvector (>=1.0.1,<2.0.0)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "langchain-openai (>=0.0.2,<0.1)", "lxml (>=4.9.3,<6.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "openai (<2)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "rdflib (==7.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)"]
javascript = ["esprima (>=4.0.1,<5.0.0)"]
llms = ["clarifai (>=9.1.0)", "cohere (>=4,<6)", "huggingface_hub (>=0,<1)", "manifest-ml (>=0.0.1,<0.0.2)", "nlpcloud (>=1,<2)", "openai (<2)", "openlm (>=0.0.5,<0.0.6)", "torch (>=1,<3)", "transformers (>=4,<5)"]
openai = ["openai (<2)", "tiktoken (>=0.3.2,<0.6.0)"]
qdrant = ["qdrant-client (>=1.3.1,<2.0.0)"]
text-helpers = ["chardet (>=5.1.0,<6.0.0)"]
[package.source]
type = "directory"
url = "../langchain"
[[package]]
name = "langchain-core"
version = "0.1.51"
version = "0.1.52"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -6064,8 +6107,6 @@ files = [
{file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"},
{file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"},
{file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"},
{file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"},
{file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"},
{file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"},
{file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"},
{file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"},
@ -6108,7 +6149,6 @@ files = [
{file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"},
{file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"},
@ -6117,8 +6157,6 @@ files = [
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"},
{file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"},
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"},
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"},
@ -6651,26 +6689,31 @@ python-versions = ">=3.8"
files = [
{file = "PyMuPDF-1.23.26-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:645a05321aecc8c45739f71f0eb574ce33138d19189582ffa5241fea3a8e2549"},
{file = "PyMuPDF-1.23.26-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:2dfc9e010669ae92fade6fb72aaea49ebe3b8dcd7ee4dcbbe50115abcaa4d3fe"},
{file = "PyMuPDF-1.23.26-cp310-none-manylinux2014_aarch64.whl", hash = "sha256:734ee380b3abd038602be79114194a3cb74ac102b7c943bcb333104575922c50"},
{file = "PyMuPDF-1.23.26-cp310-none-manylinux2014_x86_64.whl", hash = "sha256:b22f8d854f8196ad5b20308c1cebad3d5189ed9f0988acbafa043947ea7e6c55"},
{file = "PyMuPDF-1.23.26-cp310-none-win32.whl", hash = "sha256:cc0f794e3466bc96b5bf79d42fbc1551428751e3fef38ebc10ac70396b676144"},
{file = "PyMuPDF-1.23.26-cp310-none-win_amd64.whl", hash = "sha256:2eb701247d8e685a24e45899d1175f01a3ce5fc792a4431c91fbb68633b29298"},
{file = "PyMuPDF-1.23.26-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:e2804a64bb57da414781e312fb0561f6be67658ad57ed4a73dce008b23fc70a6"},
{file = "PyMuPDF-1.23.26-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:97b40bb22e3056874634617a90e0ed24a5172cf71791b9e25d1d91c6743bc567"},
{file = "PyMuPDF-1.23.26-cp311-none-manylinux2014_aarch64.whl", hash = "sha256:fab8833559bc47ab26ce736f915b8fc1dd37c108049b90396f7cd5e1004d7593"},
{file = "PyMuPDF-1.23.26-cp311-none-manylinux2014_x86_64.whl", hash = "sha256:f25aafd3e7fb9d7761a22acf2b67d704f04cc36d4dc33a3773f0eb3f4ec3606f"},
{file = "PyMuPDF-1.23.26-cp311-none-win32.whl", hash = "sha256:05e672ed3e82caca7ef02a88ace30130b1dd392a1190f03b2b58ffe7aa331400"},
{file = "PyMuPDF-1.23.26-cp311-none-win_amd64.whl", hash = "sha256:92b3c4dd4d0491d495f333be2d41f4e1c155a409bc9d04b5ff29655dccbf4655"},
{file = "PyMuPDF-1.23.26-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:a217689ede18cc6991b4e6a78afee8a440b3075d53b9dec4ba5ef7487d4547e9"},
{file = "PyMuPDF-1.23.26-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:42ad2b819b90ce1947e11b90ec5085889df0a2e3aa0207bc97ecacfc6157cabc"},
{file = "PyMuPDF-1.23.26-cp312-none-manylinux2014_aarch64.whl", hash = "sha256:99607649f89a02bba7d8ebe96e2410664316adc95e9337f7dfeff6a154f93049"},
{file = "PyMuPDF-1.23.26-cp312-none-manylinux2014_x86_64.whl", hash = "sha256:bb42d4b8407b4de7cb58c28f01449f16f32a6daed88afb41108f1aeb3552bdd4"},
{file = "PyMuPDF-1.23.26-cp312-none-win32.whl", hash = "sha256:c40d044411615e6f0baa7d3d933b3032cf97e168c7fa77d1be8a46008c109aee"},
{file = "PyMuPDF-1.23.26-cp312-none-win_amd64.whl", hash = "sha256:3f876533aa7f9a94bcd9a0225ce72571b7808260903fec1d95c120bc842fb52d"},
{file = "PyMuPDF-1.23.26-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:52df831d46beb9ff494f5fba3e5d069af6d81f49abf6b6e799ee01f4f8fa6799"},
{file = "PyMuPDF-1.23.26-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:0bbb0cf6593e53524f3fc26fb5e6ead17c02c64791caec7c4afe61b677dedf80"},
{file = "PyMuPDF-1.23.26-cp38-none-manylinux2014_aarch64.whl", hash = "sha256:5ef4360f20015673c20cf59b7e19afc97168795188c584254ed3778cde43ce77"},
{file = "PyMuPDF-1.23.26-cp38-none-manylinux2014_x86_64.whl", hash = "sha256:d7cd88842b2e7f4c71eef4d87c98c35646b80b60e6375392d7ce40e519261f59"},
{file = "PyMuPDF-1.23.26-cp38-none-win32.whl", hash = "sha256:6577e2f473625e2d0df5f5a3bf1e4519e94ae749733cc9937994d1b256687bfa"},
{file = "PyMuPDF-1.23.26-cp38-none-win_amd64.whl", hash = "sha256:fbe1a3255b2cd0d769b2da2c4efdd0c0f30d4961a1aac02c0f75cf951b337aa4"},
{file = "PyMuPDF-1.23.26-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:73fce034f2afea886a59ead2d0caedf27e2b2a8558b5da16d0286882e0b1eb82"},
{file = "PyMuPDF-1.23.26-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:b3de8618b7cb5b36db611083840b3bcf09b11a893e2d8262f4e042102c7e65de"},
{file = "PyMuPDF-1.23.26-cp39-none-manylinux2014_aarch64.whl", hash = "sha256:879e7f5ad35709d8760ab6103c3d5dac8ab8043a856ab3653fd324af7358ee87"},
{file = "PyMuPDF-1.23.26-cp39-none-manylinux2014_x86_64.whl", hash = "sha256:deee96c2fd415ded7b5070d8d5b2c60679aee6ed0e28ac0d2cb998060d835c2c"},
{file = "PyMuPDF-1.23.26-cp39-none-win32.whl", hash = "sha256:9f7f4ef99dd8ac97fb0b852efa3dcbee515798078b6c79a6a13c7b1e7c5d41a4"},
{file = "PyMuPDF-1.23.26-cp39-none-win_amd64.whl", hash = "sha256:ba9a54552c7afb9ec85432c765e2fa9a81413acfaa7d70db7c9b528297749e5b"},
@ -7111,7 +7154,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -10044,4 +10086,4 @@ extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "as
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "ca64e52a60e8ee6f2f4ea303e1779a4508f401e283f63861161cb6a9560e2178"
content-hash = "9ad37aae2905701ec099c1f9cdec59692de43e8d047ceb2ce25898b4c873b190"

@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-community"
version = "0.0.37"
version = "0.0.38rc1"
description = "Community contributed LangChain integrations."
authors = []
license = "MIT"
@ -10,6 +10,7 @@ repository = "https://github.com/langchain-ai/langchain"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.1.51"
langchain = "~0.2.0rc1"
SQLAlchemy = ">=1.4,<3"
requests = "^2"
PyYAML = ">=5.3"
@ -126,6 +127,7 @@ pytest-socket = "^0.6.0"
syrupy = "^4.0.2"
requests-mock = "^1.11.0"
langchain-core = {path = "../core", develop = true}
langchain = {path = "../langchain", develop = true}
[tool.poetry.group.codespell]
optional = true
@ -160,6 +162,7 @@ cassio = "^0.1.6"
tiktoken = ">=0.3.2,<0.6.0"
anthropic = "^0.3.11"
langchain-core = { path = "../core", develop = true }
langchain = {path = "../langchain", develop = true}
fireworks-ai = "^0.9.0"
vdms = "^0.0.20"
exllamav2 = "^0.0.18"
@ -181,6 +184,7 @@ types-redis = "^4.3.21.6"
mypy-protobuf = "^3.0.0"
langchain-core = {path = "../core", develop = true}
langchain-text-splitters = {path = "../text-splitters", develop = true}
langchain = {path = "../langchain", develop = true}
[tool.poetry.group.dev]
optional = true

@ -8,12 +8,12 @@ from typing import Any
from urllib.error import HTTPError
import pytest
from langchain.agents import AgentType, initialize_agent
from langchain_community.agent_toolkits.ainetwork.toolkit import AINetworkToolkit
from langchain_community.chat_models import ChatOpenAI
from langchain_community.tools.ainetwork.utils import authenticate
from langchain.agents import AgentType, initialize_agent
class Match(Enum):
__test__ = False

@ -1,10 +1,10 @@
import pytest
from langchain_core.utils import get_from_env
from langchain_community.agent_toolkits import PowerBIToolkit, create_pbi_agent
from langchain_community.chat_models import ChatOpenAI
from langchain_community.utilities.powerbi import PowerBIDataset
from langchain.utils import get_from_env
def azure_installed() -> bool:
try:

@ -0,0 +1,81 @@
"""Fake Embedding class for testing purposes."""
import math
from typing import List
from langchain_core.embeddings import Embeddings
fake_texts = ["foo", "bar", "baz"]
class FakeEmbeddings(Embeddings):
"""Fake embeddings functionality for testing."""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return simple embeddings.
Embeddings encode each text as its index."""
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
return self.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
"""Return constant query embeddings.
Embeddings are identical to embed_documents(texts)[0].
Distance to each text will be that text's index,
as it was passed to embed_documents."""
return [float(1.0)] * 9 + [float(0.0)]
async def aembed_query(self, text: str) -> List[float]:
return self.embed_query(text)
class ConsistentFakeEmbeddings(FakeEmbeddings):
"""Fake embeddings which remember all the texts seen so far to return consistent
vectors for the same texts."""
def __init__(self, dimensionality: int = 10) -> None:
self.known_texts: List[str] = []
self.dimensionality = dimensionality
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return consistent embeddings for each text seen so far."""
out_vectors = []
for text in texts:
if text not in self.known_texts:
self.known_texts.append(text)
vector = [float(1.0)] * (self.dimensionality - 1) + [
float(self.known_texts.index(text))
]
out_vectors.append(vector)
return out_vectors
def embed_query(self, text: str) -> List[float]:
"""Return consistent embeddings for the text, if seen before, or a constant
one if the text is unknown."""
return self.embed_documents([text])[0]
class AngularTwoDimensionalEmbeddings(Embeddings):
"""
From angles (as strings in units of pi) to unit embedding vectors on a circle.
"""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Make a list of texts into a list of embedding vectors.
"""
return [self.embed_query(text) for text in texts]
def embed_query(self, text: str) -> List[float]:
"""
Convert input text to a 'vector' (list of floats).
If the text is a number, use it as the angle for the
unit vector in units of pi.
Any other input text becomes the singular result [0, 0] !
"""
try:
angle = float(text)
return [math.cos(angle * math.pi), math.sin(angle * math.pi)]
except ValueError:
# Assume: just test string, no attention is paid to values.
return [0.0, 0.0]

@ -15,13 +15,13 @@ import os
from typing import AsyncIterator, Iterator
import pytest
from langchain_community.utilities.astradb import SetupMode
from langchain.globals import get_llm_cache, set_llm_cache
from langchain_core.caches import BaseCache
from langchain_core.language_models import LLM
from langchain_core.outputs import Generation, LLMResult
from langchain.cache import AstraDBCache, AstraDBSemanticCache
from langchain.globals import get_llm_cache, set_llm_cache
from langchain_community.cache import AstraDBCache, AstraDBSemanticCache
from langchain_community.utilities.astradb import SetupMode
from tests.integration_tests.cache.fake_embeddings import FakeEmbeddings
from tests.unit_tests.llms.fake_llm import FakeLLM

@ -10,14 +10,14 @@ import os
import uuid
import pytest
from langchain.globals import get_llm_cache, set_llm_cache
from langchain_core.outputs import Generation
from langchain_community.cache import AzureCosmosDBSemanticCache
from langchain_community.vectorstores.azure_cosmos_db import (
CosmosDBSimilarityType,
CosmosDBVectorSearchType,
)
from langchain_core.outputs import Generation
from langchain.globals import get_llm_cache, set_llm_cache
from tests.integration_tests.cache.fake_embeddings import (
FakeEmbeddings,
)

@ -5,11 +5,11 @@ import time
from typing import Any, Iterator, Tuple
import pytest
from langchain_community.utilities.cassandra import SetupMode
from langchain.globals import get_llm_cache, set_llm_cache
from langchain_core.outputs import Generation, LLMResult
from langchain.cache import CassandraCache, CassandraSemanticCache
from langchain.globals import get_llm_cache, set_llm_cache
from langchain_community.cache import CassandraCache, CassandraSemanticCache
from langchain_community.utilities.cassandra import SetupMode
from tests.integration_tests.cache.fake_embeddings import FakeEmbeddings
from tests.unit_tests.llms.fake_llm import FakeLLM

@ -2,10 +2,10 @@ import os
from typing import Any, Callable, Union
import pytest
from langchain.globals import get_llm_cache, set_llm_cache
from langchain_core.outputs import Generation
from langchain.cache import GPTCache
from langchain.globals import get_llm_cache, set_llm_cache
from langchain_community.cache import GPTCache
from tests.unit_tests.llms.fake_llm import FakeLLM
try:

@ -11,10 +11,10 @@ from datetime import timedelta
from typing import Iterator
import pytest
from langchain.globals import set_llm_cache
from langchain_core.outputs import Generation, LLMResult
from langchain.cache import MomentoCache
from langchain.globals import set_llm_cache
from langchain_community.cache import MomentoCache
from tests.unit_tests.llms.fake_llm import FakeLLM

@ -1,7 +1,7 @@
from langchain_community.cache import OpenSearchSemanticCache
from langchain.globals import get_llm_cache, set_llm_cache
from langchain_core.outputs import Generation
from langchain.globals import get_llm_cache, set_llm_cache
from langchain_community.cache import OpenSearchSemanticCache
from tests.integration_tests.cache.fake_embeddings import (
FakeEmbeddings,
)

@ -5,13 +5,13 @@ from contextlib import asynccontextmanager, contextmanager
from typing import AsyncGenerator, Generator, List, Optional, cast
import pytest
from langchain_community.cache import AsyncRedisCache, RedisCache, RedisSemanticCache
from langchain.globals import get_llm_cache, set_llm_cache
from langchain_core.embeddings import Embeddings
from langchain_core.load.dump import dumps
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, Generation, LLMResult
from langchain.globals import get_llm_cache, set_llm_cache
from langchain_community.cache import AsyncRedisCache, RedisCache, RedisSemanticCache
from tests.integration_tests.cache.fake_embeddings import (
ConsistentFakeEmbeddings,
FakeEmbeddings,

@ -1,11 +1,11 @@
"""Test Upstash Redis cache functionality."""
import uuid
import langchain
import pytest
from langchain_core.outputs import Generation, LLMResult
import langchain
from langchain.cache import UpstashRedisCache
from langchain_community.cache import UpstashRedisCache
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM

@ -1,7 +1,8 @@
"""Integration test for Dall-E image generator agent."""
from langchain_community.llms import OpenAI
from langchain.agents import AgentType, initialize_agent
from langchain.agents import AgentType, initialize_agent, load_tools
from langchain_community.agent_toolkits.load_tools import load_tools
from langchain_community.llms import OpenAI
def test_call() -> None:

@ -1,12 +1,12 @@
"""Test Graph Database Chain."""
import os
from langchain.chains.loading import load_chain
from langchain_community.chains.graph_qa.cypher import GraphCypherQAChain
from langchain_community.graphs import Neo4jGraph
from langchain_community.llms.openai import OpenAI
from langchain.chains.graph_qa.cypher import GraphCypherQAChain
from langchain.chains.loading import load_chain
def test_connect_neo4j() -> None:
"""Test that Neo4j database is correctly instantiated and connected."""

@ -1,12 +1,11 @@
"""Test Graph Database Chain."""
from typing import Any
from langchain_community.chains.graph_qa.arangodb import ArangoGraphQAChain
from langchain_community.graphs import ArangoGraph
from langchain_community.graphs.arangodb_graph import get_arangodb_client
from langchain_community.llms.openai import OpenAI
from langchain.chains.graph_qa.arangodb import ArangoGraphQAChain
def populate_arangodb_database(db: Any) -> None:
if db.has_graph("GameOfThrones"):

@ -3,10 +3,10 @@ import pathlib
import re
from unittest.mock import MagicMock, Mock
from langchain_community.graphs import RdfGraph
from langchain.chains import LLMChain
from langchain.chains.graph_qa.sparql import GraphSparqlQAChain
from langchain_community.chains.graph_qa.sparql import GraphSparqlQAChain
from langchain_community.graphs import RdfGraph
"""
cd libs/langchain/tests/integration_tests/chains/docker-compose-ontotext-graphdb

@ -1,9 +1,10 @@
from unittest.mock import MagicMock, Mock
import pytest
from langchain_community.graphs import OntotextGraphDBGraph
from langchain.chains import LLMChain
from langchain.chains import LLMChain, OntotextGraphDBQAChain
from langchain_community.chains.graph_qa.ontotext_graphdb import OntotextGraphDBQAChain
from langchain_community.graphs import OntotextGraphDBGraph
"""
cd libs/langchain/tests/integration_tests/chains/docker-compose-ontotext-graphdb

@ -1,9 +1,9 @@
"""Integration test for self ask with search."""
from langchain_community.llms.openai import OpenAI
from langchain.agents.react.base import ReActChain
from langchain.docstore.wikipedia import Wikipedia
from langchain_community.docstore import Wikipedia
from langchain_community.llms.openai import OpenAI
def test_react() -> None:

@ -1,14 +1,14 @@
"""Test RetrievalQA functionality."""
from pathlib import Path
from langchain.chains import RetrievalQA
from langchain.chains.loading import load_chain
from langchain_text_splitters.character import CharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from langchain_community.embeddings.openai import OpenAIEmbeddings
from langchain_community.llms import OpenAI
from langchain_community.vectorstores import FAISS
from langchain_text_splitters.character import CharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain.chains.loading import load_chain
def test_retrieval_qa_saving_loading(tmp_path: Path) -> None:

@ -1,12 +1,12 @@
"""Test RetrievalQA functionality."""
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains.loading import load_chain
from langchain_text_splitters.character import CharacterTextSplitter
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.embeddings.openai import OpenAIEmbeddings
from langchain_community.llms import OpenAI
from langchain_community.vectorstores import FAISS
from langchain_text_splitters.character import CharacterTextSplitter
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains.loading import load_chain
def test_retrieval_qa_with_sources_chain_saving_loading(tmp_path: str) -> None:

@ -1,9 +1,9 @@
"""Integration test for self ask with search."""
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain
from langchain_community.llms.openai import OpenAI
from langchain_community.utilities.searchapi import SearchApiAPIWrapper
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain
def test_self_ask_with_search() -> None:
"""Test functionality on a prompt."""

@ -1,11 +1,12 @@
"""Integration test for embedding-based redundant doc filtering."""
from langchain_core.documents import Document
from langchain_community.document_transformers.embeddings_redundant_filter import (
EmbeddingsClusteringFilter,
EmbeddingsRedundantFilter,
_DocumentWithState,
)
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_core.documents import Document
def test_embeddings_redundant_filter() -> None:

@ -1,10 +1,10 @@
import json
import os
from langchain_community.chat_message_histories import CosmosDBChatMessageHistory
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory
from langchain_community.chat_message_histories import CosmosDBChatMessageHistory
# Replace these with your Azure Cosmos DB endpoint and key
endpoint = os.environ.get("COSMOS_DB_ENDPOINT", "")

@ -4,10 +4,10 @@ import uuid
from typing import Generator, Union
import pytest
from langchain_community.chat_message_histories import ElasticsearchChatMessageHistory
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory
from langchain_community.chat_message_histories import ElasticsearchChatMessageHistory
"""
cd tests/integration_tests/memory/docker-compose

@ -1,9 +1,9 @@
import json
from langchain_community.chat_message_histories import FirestoreChatMessageHistory
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory
from langchain_community.chat_message_histories import FirestoreChatMessageHistory
def test_memory_with_message_store() -> None:

@ -2,13 +2,13 @@ import os
from typing import AsyncIterable, Iterable
import pytest
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import AIMessage, HumanMessage
from langchain_community.chat_message_histories.astradb import (
AstraDBChatMessageHistory,
)
from langchain_community.utilities.astradb import SetupMode
from langchain_core.messages import AIMessage, HumanMessage
from langchain.memory import ConversationBufferMemory
def _has_env_vars() -> bool:

@ -2,12 +2,12 @@ import os
import time
from typing import Optional
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import AIMessage, HumanMessage
from langchain_community.chat_message_histories.cassandra import (
CassandraChatMessageHistory,
)
from langchain_core.messages import AIMessage, HumanMessage
from langchain.memory import ConversationBufferMemory
def _chat_message_history(

@ -10,10 +10,10 @@ from datetime import timedelta
from typing import Iterator
import pytest
from langchain_community.chat_message_histories import MomentoChatMessageHistory
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory
from langchain_community.chat_message_histories import MomentoChatMessageHistory
def random_string() -> str:

@ -1,10 +1,10 @@
import json
import os
from langchain_community.chat_message_histories import MongoDBChatMessageHistory
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory
from langchain_community.chat_message_histories import MongoDBChatMessageHistory
# Replace these with your mongodb connection string
connection_string = os.environ.get("MONGODB_CONNECTION_STRING", "")

@ -1,9 +1,9 @@
import json
from langchain_community.chat_message_histories import Neo4jChatMessageHistory
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory
from langchain_community.chat_message_histories import Neo4jChatMessageHistory
def test_memory_with_message_store() -> None:

@ -1,9 +1,9 @@
import json
from langchain_community.chat_message_histories import RedisChatMessageHistory
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory
from langchain_community.chat_message_histories import RedisChatMessageHistory
def test_memory_with_message_store() -> None:

@ -8,10 +8,10 @@ and ROCKSET_REGION environment variables set.
import json
import os
from langchain_community.chat_message_histories import RocksetChatMessageHistory
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory
from langchain_community.chat_message_histories import RocksetChatMessageHistory
collection_name = "langchain_demo"
session_id = "MySession"

@ -1,8 +1,9 @@
import json
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory, SingleStoreDBChatMessageHistory
from langchain_community.chat_message_histories import SingleStoreDBChatMessageHistory
# Replace these with your mongodb connection string
TEST_SINGLESTOREDB_URL = "root:pass@localhost:3306/db"

@ -1,12 +1,12 @@
import json
import pytest
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import message_to_dict
from langchain_community.chat_message_histories.upstash_redis import (
UpstashRedisChatMessageHistory,
)
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory
URL = "<UPSTASH_REDIS_REST_URL>"
TOKEN = "<UPSTASH_REDIS_REST_TOKEN>"

@ -6,10 +6,10 @@ Before running this test, please create a Xata database.
import json
import os
from langchain_community.chat_message_histories import XataChatMessageHistory
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory
from langchain_community.chat_message_histories import XataChatMessageHistory
class TestXata:

@ -3,7 +3,7 @@
import pytest
from langchain_core.prompts import PromptTemplate
from langchain.prompts.example_selector.ngram_overlap import (
from langchain_community.example_selectors import (
NGramOverlapExampleSelector,
ngram_overlap_score,
)

@ -1,13 +1,13 @@
"""Integration test for compression pipelines."""
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_core.documents import Document
from langchain_text_splitters.character import CharacterTextSplitter
from langchain.retrievers.document_compressors import (
DocumentCompressorPipeline,
EmbeddingsFilter,
)
from langchain_core.documents import Document
from langchain_text_splitters.character import CharacterTextSplitter
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain_community.embeddings import OpenAIEmbeddings
def test_document_compressor_pipeline() -> None:

@ -1,8 +1,8 @@
"""Integration test for LLMChainExtractor."""
from langchain_community.chat_models import ChatOpenAI
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain_core.documents import Document
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain_community.chat_models import ChatOpenAI
def test_llm_construction_with_kwargs() -> None:

@ -1,8 +1,8 @@
"""Integration test for llm-based relevant doc filtering."""
from langchain_community.chat_models import ChatOpenAI
from langchain.retrievers.document_compressors import LLMChainFilter
from langchain_core.documents import Document
from langchain.retrievers.document_compressors import LLMChainFilter
from langchain_community.chat_models import ChatOpenAI
def test_llm_chain_filter() -> None:

@ -1,12 +1,12 @@
"""Integration test for embedding-based relevant doc filtering."""
import numpy as np
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain_core.documents import Document
from langchain_community.document_transformers.embeddings_redundant_filter import (
_DocumentWithState,
)
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_core.documents import Document
from langchain.retrievers.document_compressors import EmbeddingsFilter
def test_embeddings_filter() -> None:

@ -1,9 +1,9 @@
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
def test_contextual_compression_retriever_get_relevant_docs() -> None:
"""Test get_relevant_docs."""

@ -1,8 +1,8 @@
from langchain.retrievers.merger_retriever import MergerRetriever
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.retrievers.merger_retriever import MergerRetriever
def test_merger_retriever_get_relevant_docs() -> None:
"""Test get_relevant_docs."""

@ -2,19 +2,19 @@ from typing import Iterator, List, Optional
from uuid import uuid4
import pytest
from langchain_community.chat_models import ChatOpenAI
from langchain_community.llms.openai import OpenAI
from langchain.chains.llm import LLMChain
from langchain.evaluation import EvaluatorType
from langchain.smith import RunEvalConfig, run_on_dataset
from langchain.smith.evaluation import InputFormatError
from langchain.smith.evaluation.runner_utils import arun_on_dataset
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.prompts.chat import ChatPromptTemplate
from langsmith import Client as Client
from langsmith.evaluation import run_evaluator
from langsmith.schemas import DataType, Example, Run
from langchain.chains.llm import LLMChain
from langchain.evaluation import EvaluatorType
from langchain.smith import RunEvalConfig, run_on_dataset
from langchain.smith.evaluation import InputFormatError
from langchain.smith.evaluation.runner_utils import arun_on_dataset
from langchain_community.chat_models import ChatOpenAI
from langchain_community.llms.openai import OpenAI
def _check_all_feedback_passed(_project_name: str, client: Client) -> None:

@ -0,0 +1,73 @@
"""Integration test for embedding-based redundant doc filtering."""
from langchain_core.documents import Document
from langchain_community.document_transformers.embeddings_redundant_filter import (
EmbeddingsClusteringFilter,
EmbeddingsRedundantFilter,
_DocumentWithState,
)
from langchain_community.embeddings import OpenAIEmbeddings
def test_embeddings_redundant_filter() -> None:
texts = [
"What happened to all of my cookies?",
"Where did all of my cookies go?",
"I wish there were better Italian restaurants in my neighborhood.",
]
docs = [Document(page_content=t) for t in texts]
embeddings = OpenAIEmbeddings()
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
actual = redundant_filter.transform_documents(docs)
assert len(actual) == 2
assert set(texts[:2]).intersection([d.page_content for d in actual])
def test_embeddings_redundant_filter_with_state() -> None:
texts = ["What happened to all of my cookies?", "foo bar baz"]
state = {"embedded_doc": [0.5] * 10}
docs = [_DocumentWithState(page_content=t, state=state) for t in texts]
embeddings = OpenAIEmbeddings()
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
actual = redundant_filter.transform_documents(docs)
assert len(actual) == 1
def test_embeddings_clustering_filter() -> None:
texts = [
"What happened to all of my cookies?",
"A cookie is a small, baked sweet treat and you can find it in the cookie",
"monsters' jar.",
"Cookies are good.",
"I have nightmares about the cookie monster.",
"The most popular pizza styles are: Neapolitan, New York-style and",
"Chicago-style. You can find them on iconic restaurants in major cities.",
"Neapolitan pizza: This is the original pizza style,hailing from Naples,",
"Italy.",
"I wish there were better Italian Pizza restaurants in my neighborhood.",
"New York-style pizza: This is characterized by its large, thin crust, and",
"generous toppings.",
"The first movie to feature a robot was 'A Trip to the Moon' (1902).",
"The first movie to feature a robot that could pass for a human was",
"'Blade Runner' (1982)",
"The first movie to feature a robot that could fall in love with a human",
"was 'Her' (2013)",
"A robot is a machine capable of carrying out complex actions automatically.",
"There are certainly hundreds, if not thousands movies about robots like:",
"'Blade Runner', 'Her' and 'A Trip to the Moon'",
]
docs = [Document(page_content=t) for t in texts]
embeddings = OpenAIEmbeddings()
redundant_filter = EmbeddingsClusteringFilter(
embeddings=embeddings,
num_clusters=3,
num_closest=1,
sorted=True,
)
actual = redundant_filter.transform_documents(docs)
assert len(actual) == 3
assert texts[1] in [d.page_content for d in actual]
assert texts[4] in [d.page_content for d in actual]
assert texts[11] in [d.page_content for d in actual]

@ -3,11 +3,12 @@ import json
from typing import Any
from unittest import mock
from langchain_core.documents import Document
from langchain_community.document_transformers.nuclia_text_transform import (
NucliaTextTransformer,
)
from langchain_community.tools.nuclia.tool import NucliaUnderstandingAPI
from langchain_core.documents import Document
def fakerun(**args: Any) -> Any:

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save