From f92006de3ce2ef6795c9a3f5bc798a8d2fa02bb7 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 8 May 2024 16:46:52 -0400 Subject: [PATCH] multiple: langchain 0.2 in master (#21191) 0.2rc migrations - [x] Move memory - [x] Move remaining retrievers - [x] graph_qa chains - [x] some dependency from evaluation code potentially on math utils - [x] Move openapi chain from `langchain.chains.api.openapi` to `langchain_community.chains.openapi` - [x] Migrate `langchain.chains.ernie_functions` to `langchain_community.chains.ernie_functions` - [x] migrate `langchain/chains/llm_requests.py` to `langchain_community.chains.llm_requests` - [x] Moving `langchain_community.cross_enoders.base:BaseCrossEncoder` -> `langchain_community.retrievers.document_compressors.cross_encoder:BaseCrossEncoder` (namespace not ideal, but it needs to be moved to `langchain` to avoid circular deps) - [x] unit tests langchain -- add pytest.mark.community to some unit tests that will stay in langchain - [x] unit tests community -- move unit tests that depend on community to community - [x] mv integration tests that depend on community to community - [x] mypy checks Other todo - [x] Make deprecation warnings not noisy (need to use warn deprecated and check that things are implemented properly) - [x] Update deprecation messages with timeline for code removal (likely we actually won't be removing things until 0.4 release) -- will give people more time to transition their code. - [ ] Add information to deprecation warning to show users how to migrate their code base using langchain-cli - [ ] Remove any unnecessary requirements in langchain (e.g., is SQLALchemy required?) --------- Co-authored-by: Erick Friis --- .../agent_toolkits/sql/base.py | 8 +- .../langchain_community/chains/__init__.py | 0 .../chains/ernie_functions/__init__.py | 17 + .../chains/ernie_functions/base.py | 551 ++++++++++++++++ .../chains/graph_qa/__init__.py | 1 + .../chains/graph_qa/arangodb.py | 241 +++++++ .../chains/graph_qa/base.py | 103 +++ .../chains/graph_qa/cypher.py | 298 +++++++++ .../chains/graph_qa/cypher_utils.py | 260 ++++++++ .../chains/graph_qa/falkordb.py | 157 +++++ .../chains/graph_qa/gremlin.py | 221 +++++++ .../chains/graph_qa/hugegraph.py | 106 ++++ .../chains/graph_qa/kuzu.py | 143 +++++ .../chains/graph_qa/nebulagraph.py | 106 ++++ .../chains/graph_qa/neptune_cypher.py | 217 +++++++ .../chains/graph_qa/neptune_sparql.py | 204 ++++++ .../chains/graph_qa/ontotext_graphdb.py | 190 ++++++ .../chains/graph_qa/prompts.py | 415 +++++++++++++ .../chains/graph_qa/sparql.py | 152 +++++ .../chains/llm_requests.py | 97 +++ .../chains/openapi/__init__.py | 0 .../chains/openapi/chain.py | 229 +++++++ .../chains/openapi/prompts.py | 57 ++ .../chains/openapi/requests_chain.py | 62 ++ .../chains/openapi/response_chain.py | 57 ++ .../cross_encoders/base.py | 18 +- .../query_constructors/__init__.py | 0 .../query_constructors/astradb.py | 70 +++ .../query_constructors/chroma.py | 50 ++ .../query_constructors/dashvector.py | 64 ++ .../databricks_vector_search.py | 94 +++ .../query_constructors/deeplake.py | 88 +++ .../query_constructors/dingo.py | 49 ++ .../query_constructors/elasticsearch.py | 100 +++ .../query_constructors/milvus.py | 103 +++ .../query_constructors/mongodb_atlas.py | 74 +++ .../query_constructors/myscale.py | 125 ++++ .../query_constructors/opensearch.py | 104 ++++ .../query_constructors/pgvector.py | 52 ++ .../query_constructors/pinecone.py | 57 ++ .../query_constructors/qdrant.py | 98 +++ .../query_constructors/redis.py | 103 +++ .../query_constructors/supabase.py | 97 +++ .../query_constructors/tencentvectordb.py | 116 ++++ .../query_constructors/timescalevector.py | 84 +++ .../query_constructors/vectara.py | 70 +++ .../query_constructors/weaviate.py | 79 +++ .../retrievers/__init__.py | 9 +- .../retrievers/web_research.py | 223 +++++++ libs/community/poetry.lock | 60 +- libs/community/pyproject.toml | 6 +- .../agent/test_ainetwork_agent.py | 4 +- .../agent/test_powerbi_agent.py | 4 +- .../cache/fake_embeddings.py | 81 +++ .../integration_tests/cache/test_astradb.py | 6 +- .../cache/test_azure_cosmosdb_cache.py | 6 +- .../integration_tests/cache/test_cassandra.py | 6 +- .../integration_tests/cache/test_gptcache.py | 4 +- .../cache/test_momento_cache.py | 4 +- .../cache/test_opensearch_cache.py | 4 +- .../cache/test_redis_cache.py | 4 +- .../cache/test_upstash_redis_cache.py | 4 +- .../chains/test_dalle_agent.py | 5 +- .../chains/test_graph_database.py | 6 +- .../chains/test_graph_database_arangodb.py | 3 +- .../chains/test_graph_database_sparql.py | 6 +- .../chains/test_ontotext_graphdb_qa.py | 5 +- .../integration_tests/chains/test_react.py | 6 +- .../chains/test_retrieval_qa.py | 8 +- .../chains/test_retrieval_qa_with_sources.py | 8 +- .../chains/test_self_ask_with_search.py | 4 +- .../document_transformers/__init__.py | 0 .../test_embeddings_filter.py} | 3 +- .../memory/test_cosmos_db.py | 4 +- .../memory/test_elasticsearch.py | 4 +- .../memory/test_firestore.py | 4 +- .../memory/test_memory_astradb.py} | 6 +- .../memory/test_memory_cassandra.py} | 6 +- .../integration_tests/memory/test_momento.py | 4 +- .../integration_tests/memory/test_mongodb.py | 4 +- .../integration_tests/memory/test_neo4j.py | 4 +- .../integration_tests/memory/test_redis.py | 4 +- .../integration_tests/memory/test_rockset.py | 4 +- .../memory/test_singlestoredb.py | 3 +- .../memory/test_upstash_redis.py | 6 +- .../integration_tests/memory/test_xata.py | 4 +- .../test_ngram_overlap_example_selector.py | 2 +- .../document_compressors/test_base.py | 10 +- .../test_chain_extract.py | 4 +- .../document_compressors/test_chain_filter.py | 4 +- .../test_embeddings_filter.py | 6 +- .../retrievers/test_contextual_compression.py | 6 +- .../retrievers/test_merger_retriever.py | 4 +- .../smith/evaluation/test_runner_utils.py | 14 +- .../tests/integration_tests/test_dalle.py | 0 .../test_document_transformers.py | 73 +++ .../test_long_context_reorder.py | 0 .../test_nuclia_transformer.py | 3 +- .../test_pdf_pagesplitter.py | 0 .../tests/unit_tests/agents/__init__.py | 0 .../tests/unit_tests/agents/test_react.py | 4 +- .../tests/unit_tests/agents/test_sql.py | 1 - .../tests/unit_tests/agents/test_tools.py | 6 +- .../tests/unit_tests/chains/__init__.py | 0 .../tests/unit_tests/chains/test_api.py | 4 +- .../tests/unit_tests/chains/test_graph_qa.py | 23 +- .../tests/unit_tests/chains/test_llm.py | 2 +- .../unit_tests/data/cypher_corrector.csv | 0 .../unit_tests/query_constructors/__init__.py | 0 .../query_constructors}/test_astradb.py | 2 +- .../query_constructors}/test_chroma.py | 2 +- .../query_constructors}/test_dashvector.py | 2 +- .../test_databricks_vector_search.py | 2 +- .../query_constructors}/test_deeplake.py | 2 +- .../query_constructors}/test_dingo.py | 2 +- .../query_constructors}/test_elasticsearch.py | 2 +- .../query_constructors}/test_milvus.py | 2 +- .../query_constructors}/test_mongodb_atlas.py | 2 +- .../query_constructors}/test_myscale.py | 2 +- .../query_constructors}/test_opensearch.py | 2 +- .../query_constructors}/test_pgvector.py | 2 +- .../query_constructors}/test_pinecone.py | 2 +- .../query_constructors}/test_redis.py | 18 +- .../query_constructors}/test_supabase.py | 2 +- .../test_tencentvectordb.py | 4 +- .../test_timescalevector.py | 4 +- .../query_constructors}/test_vectara.py | 2 +- .../query_constructors}/test_weaviate.py | 2 +- .../unit_tests/retrievers/test_imports.py | 1 + .../retrievers/test_web_research.py | 2 +- .../tests/unit_tests/test_cache.py | 6 +- .../tests/unit_tests/test_dependencies.py | 2 + .../unit_tests/test_document_transformers.py | 9 +- libs/experimental/poetry.lock | 22 +- libs/experimental/pyproject.toml | 2 +- .../agent_toolkits/vectorstore/toolkit.py | 30 +- .../langchain/callbacks/streamlit/__init__.py | 25 +- .../langchain/callbacks/tracers/__init__.py | 1 - .../langchain/chains/api/openapi/chain.py | 236 +------ .../langchain/chains/api/openapi/prompts.py | 68 +- .../chains/api/openapi/requests_chain.py | 92 +-- .../chains/api/openapi/response_chain.py | 87 +-- .../langchain/chains/conversation/memory.py | 22 +- .../chains/ernie_functions/__init__.py | 43 +- .../langchain/chains/ernie_functions/base.py | 586 ++---------------- libs/langchain/langchain/chains/flare/base.py | 17 +- .../langchain/chains/graph_qa/__init__.py | 1 - .../langchain/chains/graph_qa/arangodb.py | 248 +------- .../langchain/chains/graph_qa/base.py | 107 +--- .../langchain/chains/graph_qa/cypher.py | 313 +--------- .../langchain/chains/graph_qa/cypher_utils.py | 269 +------- .../langchain/chains/graph_qa/falkordb.py | 167 +---- .../langchain/chains/graph_qa/gremlin.py | 257 ++------ .../langchain/chains/graph_qa/hugegraph.py | 113 +--- .../langchain/chains/graph_qa/kuzu.py | 153 +---- .../langchain/chains/graph_qa/nebulagraph.py | 110 +--- .../chains/graph_qa/neptune_cypher.py | 238 +------ .../chains/graph_qa/neptune_sparql.py | 240 ++----- .../chains/graph_qa/ontotext_graphdb.py | 197 +----- .../langchain/chains/graph_qa/prompts.py | 511 +++------------ .../langchain/chains/graph_qa/sparql.py | 159 +---- .../langchain/chains/llm_requests.py | 102 +-- libs/langchain/langchain/chains/loading.py | 27 +- .../langchain/langchain/chains/natbot/base.py | 8 +- .../chains/openai_functions/openapi.py | 28 +- .../chains/router/multi_retrieval_qa.py | 17 +- .../document_loaders/blob_loaders/schema.py | 2 +- .../evaluation/embedding_distance/base.py | 9 +- .../langchain/indexes/vectorstore.py | 65 +- .../langchain/retrievers/__init__.py | 10 +- .../document_compressors/cross_encoder.py | 17 + .../cross_encoder_rerank.py | 3 +- .../document_compressors/embeddings_filter.py | 30 +- libs/langchain/langchain/retrievers/pupmed.py | 1 - .../retrievers/self_query/astradb.py | 77 +-- .../langchain/retrievers/self_query/base.py | 113 ++-- .../langchain/retrievers/self_query/chroma.py | 73 +-- .../retrievers/self_query/dashvector.py | 71 +-- .../self_query/databricks_vector_search.py | 105 +--- .../retrievers/self_query/deeplake.py | 103 +-- .../langchain/retrievers/self_query/dingo.py | 72 +-- .../retrievers/self_query/elasticsearch.py | 109 +--- .../langchain/retrievers/self_query/milvus.py | 118 +--- .../retrievers/self_query/mongodb_atlas.py | 83 +-- .../retrievers/self_query/myscale.py | 132 +--- .../retrievers/self_query/opensearch.py | 111 +--- .../retrievers/self_query/pgvector.py | 75 +-- .../retrievers/self_query/pinecone.py | 80 +-- .../langchain/retrievers/self_query/qdrant.py | 103 +-- .../langchain/retrievers/self_query/redis.py | 107 +--- .../retrievers/self_query/supabase.py | 104 +--- .../retrievers/self_query/tencentvectordb.py | 127 +--- .../retrievers/self_query/timescalevector.py | 93 +-- .../retrievers/self_query/vectara.py | 81 +-- .../retrievers/self_query/weaviate.py | 86 +-- .../langchain/retrievers/web_research.py | 234 +------ libs/langchain/poetry.lock | 47 +- libs/langchain/pyproject.toml | 7 +- .../agents/agent_toolkits/test_imports.py | 2 - .../tests/unit_tests/agents/test_agent.py | 3 +- .../tests/unit_tests/agents/test_imports.py | 2 - .../unit_tests/agents/test_serialization.py | 3 + .../unit_tests/callbacks/test_imports.py | 2 - .../tests/unit_tests/chains/test_imports.py | 2 - .../chains/test_neptune_cypher_qa.py | 2 - .../chains/test_ontotext_graphdb_qa.py | 2 - .../unit_tests/chat_models/test_imports.py | 2 - libs/langchain/tests/unit_tests/conftest.py | 14 + .../tests/unit_tests/docstore/test_imports.py | 2 - .../unit_tests/document_loaders/test_base.py | 3 +- .../document_loaders/test_imports.py | 2 - .../document_transformers/test_imports.py | 2 - .../unit_tests/embeddings/test_imports.py | 2 - .../unit_tests/evaluation/test_loading.py | 1 + .../tests/unit_tests/graphs/test_imports.py | 2 - .../tests/unit_tests/indexes/test_indexing.py | 2 +- .../tests/unit_tests/llms/test_base.py | 11 +- .../tests/unit_tests/llms/test_imports.py | 11 +- .../tests/unit_tests/load/test_dump.py | 17 +- .../tests/unit_tests/load/test_load.py | 8 +- .../unit_tests/load/test_serializable.py | 3 + .../chat_message_histories/test_imports.py | 2 - .../tests/unit_tests/memory/test_imports.py | 2 - .../unit_tests/output_parsers/test_imports.py | 2 - .../test_pandas_dataframe_parser.py | 2 +- .../output_parsers/test_pydantic_parser.py | 3 +- .../tests/unit_tests/prompts/test_imports.py | 2 - .../unit_tests/retrievers/test_ensemble.py | 16 +- .../unit_tests/retrievers/test_imports.py | 2 - .../tests/unit_tests/smith/test_imports.py | 2 - .../tests/unit_tests/storage/test_imports.py | 2 - .../tests/unit_tests/test_dependencies.py | 6 +- .../tests/unit_tests/test_imports.py | 28 + .../tests/unit_tests/tools/test_imports.py | 2 - .../unit_tests/utilities/test_imports.py | 2 - .../tests/unit_tests/utils/test_imports.py | 2 - .../unit_tests/utils/test_openai_functions.py | 3 +- .../unit_tests/vectorstores/test_imports.py | 2 + 238 files changed, 7572 insertions(+), 5919 deletions(-) create mode 100644 libs/community/langchain_community/chains/__init__.py create mode 100644 libs/community/langchain_community/chains/ernie_functions/__init__.py create mode 100644 libs/community/langchain_community/chains/ernie_functions/base.py create mode 100644 libs/community/langchain_community/chains/graph_qa/__init__.py create mode 100644 libs/community/langchain_community/chains/graph_qa/arangodb.py create mode 100644 libs/community/langchain_community/chains/graph_qa/base.py create mode 100644 libs/community/langchain_community/chains/graph_qa/cypher.py create mode 100644 libs/community/langchain_community/chains/graph_qa/cypher_utils.py create mode 100644 libs/community/langchain_community/chains/graph_qa/falkordb.py create mode 100644 libs/community/langchain_community/chains/graph_qa/gremlin.py create mode 100644 libs/community/langchain_community/chains/graph_qa/hugegraph.py create mode 100644 libs/community/langchain_community/chains/graph_qa/kuzu.py create mode 100644 libs/community/langchain_community/chains/graph_qa/nebulagraph.py create mode 100644 libs/community/langchain_community/chains/graph_qa/neptune_cypher.py create mode 100644 libs/community/langchain_community/chains/graph_qa/neptune_sparql.py create mode 100644 libs/community/langchain_community/chains/graph_qa/ontotext_graphdb.py create mode 100644 libs/community/langchain_community/chains/graph_qa/prompts.py create mode 100644 libs/community/langchain_community/chains/graph_qa/sparql.py create mode 100644 libs/community/langchain_community/chains/llm_requests.py create mode 100644 libs/community/langchain_community/chains/openapi/__init__.py create mode 100644 libs/community/langchain_community/chains/openapi/chain.py create mode 100644 libs/community/langchain_community/chains/openapi/prompts.py create mode 100644 libs/community/langchain_community/chains/openapi/requests_chain.py create mode 100644 libs/community/langchain_community/chains/openapi/response_chain.py create mode 100644 libs/community/langchain_community/query_constructors/__init__.py create mode 100644 libs/community/langchain_community/query_constructors/astradb.py create mode 100644 libs/community/langchain_community/query_constructors/chroma.py create mode 100644 libs/community/langchain_community/query_constructors/dashvector.py create mode 100644 libs/community/langchain_community/query_constructors/databricks_vector_search.py create mode 100644 libs/community/langchain_community/query_constructors/deeplake.py create mode 100644 libs/community/langchain_community/query_constructors/dingo.py create mode 100644 libs/community/langchain_community/query_constructors/elasticsearch.py create mode 100644 libs/community/langchain_community/query_constructors/milvus.py create mode 100644 libs/community/langchain_community/query_constructors/mongodb_atlas.py create mode 100644 libs/community/langchain_community/query_constructors/myscale.py create mode 100644 libs/community/langchain_community/query_constructors/opensearch.py create mode 100644 libs/community/langchain_community/query_constructors/pgvector.py create mode 100644 libs/community/langchain_community/query_constructors/pinecone.py create mode 100644 libs/community/langchain_community/query_constructors/qdrant.py create mode 100644 libs/community/langchain_community/query_constructors/redis.py create mode 100644 libs/community/langchain_community/query_constructors/supabase.py create mode 100644 libs/community/langchain_community/query_constructors/tencentvectordb.py create mode 100644 libs/community/langchain_community/query_constructors/timescalevector.py create mode 100644 libs/community/langchain_community/query_constructors/vectara.py create mode 100644 libs/community/langchain_community/query_constructors/weaviate.py create mode 100644 libs/community/langchain_community/retrievers/web_research.py rename libs/{langchain => community}/tests/integration_tests/agent/test_ainetwork_agent.py (100%) rename libs/{langchain => community}/tests/integration_tests/agent/test_powerbi_agent.py (97%) create mode 100644 libs/community/tests/integration_tests/cache/fake_embeddings.py rename libs/{langchain => community}/tests/integration_tests/cache/test_astradb.py (98%) rename libs/{langchain => community}/tests/integration_tests/cache/test_azure_cosmosdb_cache.py (100%) rename libs/{langchain => community}/tests/integration_tests/cache/test_cassandra.py (98%) rename libs/{langchain => community}/tests/integration_tests/cache/test_gptcache.py (97%) rename libs/{langchain => community}/tests/integration_tests/cache/test_momento_cache.py (98%) rename libs/{langchain => community}/tests/integration_tests/cache/test_opensearch_cache.py (100%) rename libs/{langchain => community}/tests/integration_tests/cache/test_redis_cache.py (100%) rename libs/{langchain => community}/tests/integration_tests/cache/test_upstash_redis_cache.py (98%) rename libs/{langchain => community}/tests/integration_tests/chains/test_dalle_agent.py (79%) rename libs/{langchain => community}/tests/integration_tests/chains/test_graph_database.py (99%) rename libs/{langchain => community}/tests/integration_tests/chains/test_graph_database_arangodb.py (97%) rename libs/{langchain => community}/tests/integration_tests/chains/test_graph_database_sparql.py (99%) rename libs/{langchain => community}/tests/integration_tests/chains/test_ontotext_graphdb_qa.py (99%) rename libs/{langchain => community}/tests/integration_tests/chains/test_react.py (92%) rename libs/{langchain => community}/tests/integration_tests/chains/test_retrieval_qa.py (100%) rename libs/{langchain => community}/tests/integration_tests/chains/test_retrieval_qa_with_sources.py (100%) rename libs/{langchain => community}/tests/integration_tests/chains/test_self_ask_with_search.py (100%) create mode 100644 libs/community/tests/integration_tests/document_transformers/__init__.py rename libs/{langchain/tests/integration_tests/test_document_transformers.py => community/tests/integration_tests/document_transformers/test_embeddings_filter.py} (99%) rename libs/{langchain => community}/tests/integration_tests/memory/test_cosmos_db.py (100%) rename libs/{langchain => community}/tests/integration_tests/memory/test_elasticsearch.py (100%) rename libs/{langchain => community}/tests/integration_tests/memory/test_firestore.py (100%) rename libs/{langchain/tests/integration_tests/memory/test_astradb.py => community/tests/integration_tests/memory/test_memory_astradb.py} (100%) rename libs/{langchain/tests/integration_tests/memory/test_cassandra.py => community/tests/integration_tests/memory/test_memory_cassandra.py} (100%) rename libs/{langchain => community}/tests/integration_tests/memory/test_momento.py (100%) rename libs/{langchain => community}/tests/integration_tests/memory/test_mongodb.py (100%) rename libs/{langchain => community}/tests/integration_tests/memory/test_neo4j.py (100%) rename libs/{langchain => community}/tests/integration_tests/memory/test_redis.py (100%) rename libs/{langchain => community}/tests/integration_tests/memory/test_rockset.py (100%) rename libs/{langchain => community}/tests/integration_tests/memory/test_singlestoredb.py (91%) rename libs/{langchain => community}/tests/integration_tests/memory/test_upstash_redis.py (100%) rename libs/{langchain => community}/tests/integration_tests/memory/test_xata.py (100%) rename libs/{langchain => community}/tests/integration_tests/prompts/test_ngram_overlap_example_selector.py (97%) rename libs/{langchain => community}/tests/integration_tests/retrievers/document_compressors/test_base.py (100%) rename libs/{langchain => community}/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py (100%) rename libs/{langchain => community}/tests/integration_tests/retrievers/document_compressors/test_chain_filter.py (100%) rename libs/{langchain => community}/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py (100%) rename libs/{langchain => community}/tests/integration_tests/retrievers/test_contextual_compression.py (100%) rename libs/{langchain => community}/tests/integration_tests/retrievers/test_merger_retriever.py (100%) rename libs/{langchain => community}/tests/integration_tests/smith/evaluation/test_runner_utils.py (100%) rename libs/{langchain => community}/tests/integration_tests/test_dalle.py (100%) create mode 100644 libs/community/tests/integration_tests/test_document_transformers.py rename libs/{langchain => community}/tests/integration_tests/test_long_context_reorder.py (100%) rename libs/{langchain => community}/tests/integration_tests/test_nuclia_transformer.py (99%) rename libs/{langchain => community}/tests/integration_tests/test_pdf_pagesplitter.py (100%) create mode 100644 libs/community/tests/unit_tests/agents/__init__.py rename libs/{langchain => community}/tests/unit_tests/agents/test_react.py (97%) rename libs/{langchain => community}/tests/unit_tests/agents/test_sql.py (99%) rename libs/{langchain => community}/tests/unit_tests/agents/test_tools.py (98%) create mode 100644 libs/community/tests/unit_tests/chains/__init__.py rename libs/{langchain => community}/tests/unit_tests/chains/test_api.py (100%) rename libs/{langchain => community}/tests/unit_tests/chains/test_graph_qa.py (97%) rename libs/{langchain => community}/tests/unit_tests/chains/test_llm.py (100%) rename libs/{langchain => community}/tests/unit_tests/data/cypher_corrector.csv (100%) create mode 100644 libs/community/tests/unit_tests/query_constructors/__init__.py rename libs/{langchain/tests/unit_tests/retrievers/self_query => community/tests/unit_tests/query_constructors}/test_astradb.py (98%) rename libs/{langchain/tests/unit_tests/retrievers/self_query => community/tests/unit_tests/query_constructors}/test_chroma.py (97%) rename libs/{langchain/tests/unit_tests/retrievers/self_query => community/tests/unit_tests/query_constructors}/test_dashvector.py (94%) rename libs/{langchain/tests/unit_tests/retrievers/self_query => community/tests/unit_tests/query_constructors}/test_databricks_vector_search.py (98%) rename libs/{langchain/tests/unit_tests/retrievers/self_query => community/tests/unit_tests/query_constructors}/test_deeplake.py (96%) rename libs/{langchain/tests/unit_tests/retrievers/self_query => community/tests/unit_tests/query_constructors}/test_dingo.py (97%) rename libs/{langchain/tests/unit_tests/retrievers/self_query => community/tests/unit_tests/query_constructors}/test_elasticsearch.py (99%) rename libs/{langchain/tests/unit_tests/retrievers/self_query => community/tests/unit_tests/query_constructors}/test_milvus.py (97%) rename libs/{langchain/tests/unit_tests/retrievers/self_query => community/tests/unit_tests/query_constructors}/test_mongodb_atlas.py (97%) rename libs/{langchain/tests/unit_tests/retrievers/self_query => community/tests/unit_tests/query_constructors}/test_myscale.py (97%) rename libs/{langchain/tests/unit_tests/retrievers/self_query => community/tests/unit_tests/query_constructors}/test_opensearch.py (98%) rename libs/{langchain/tests/unit_tests/retrievers/self_query => community/tests/unit_tests/query_constructors}/test_pgvector.py (96%) rename libs/{langchain/tests/unit_tests/retrievers/self_query => community/tests/unit_tests/query_constructors}/test_pinecone.py (96%) rename libs/{langchain/tests/unit_tests/retrievers/self_query => community/tests/unit_tests/query_constructors}/test_redis.py (98%) rename libs/{langchain/tests/unit_tests/retrievers/self_query => community/tests/unit_tests/query_constructors}/test_supabase.py (96%) rename libs/{langchain/tests/unit_tests/retrievers/self_query => community/tests/unit_tests/query_constructors}/test_tencentvectordb.py (96%) rename libs/{langchain/tests/unit_tests/retrievers/self_query => community/tests/unit_tests/query_constructors}/test_timescalevector.py (96%) rename libs/{langchain/tests/unit_tests/retrievers/self_query => community/tests/unit_tests/query_constructors}/test_vectara.py (96%) rename libs/{langchain/tests/unit_tests/retrievers/self_query => community/tests/unit_tests/query_constructors}/test_weaviate.py (98%) rename libs/{langchain => community}/tests/unit_tests/retrievers/test_web_research.py (90%) rename libs/{langchain => community}/tests/unit_tests/test_cache.py (98%) rename libs/{langchain => community}/tests/unit_tests/test_document_transformers.py (76%) create mode 100644 libs/langchain/langchain/retrievers/document_compressors/cross_encoder.py delete mode 100644 libs/langchain/tests/unit_tests/chains/test_neptune_cypher_qa.py delete mode 100644 libs/langchain/tests/unit_tests/chains/test_ontotext_graphdb_qa.py diff --git a/libs/community/langchain_community/agent_toolkits/sql/base.py b/libs/community/langchain_community/agent_toolkits/sql/base.py index 2e75f78ac0..75f3cb97b4 100644 --- a/libs/community/langchain_community/agent_toolkits/sql/base.py +++ b/libs/community/langchain_community/agent_toolkits/sql/base.py @@ -193,7 +193,7 @@ def create_sql_agent( ] prompt = ChatPromptTemplate.from_messages(messages) agent = RunnableAgent( - runnable=create_openai_functions_agent(llm, tools, prompt), + runnable=create_openai_functions_agent(llm, tools, prompt), # type: ignore input_keys_arg=["input"], return_keys_arg=["output"], **kwargs, @@ -208,10 +208,10 @@ def create_sql_agent( ] prompt = ChatPromptTemplate.from_messages(messages) if agent_type == "openai-tools": - runnable = create_openai_tools_agent(llm, tools, prompt) + runnable = create_openai_tools_agent(llm, tools, prompt) # type: ignore else: - runnable = create_tool_calling_agent(llm, tools, prompt) - agent = RunnableMultiActionAgent( + runnable = create_tool_calling_agent(llm, tools, prompt) # type: ignore + agent = RunnableMultiActionAgent( # type: ignore[assignment] runnable=runnable, input_keys_arg=["input"], return_keys_arg=["output"], diff --git a/libs/community/langchain_community/chains/__init__.py b/libs/community/langchain_community/chains/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/community/langchain_community/chains/ernie_functions/__init__.py b/libs/community/langchain_community/chains/ernie_functions/__init__.py new file mode 100644 index 0000000000..3efc22d419 --- /dev/null +++ b/libs/community/langchain_community/chains/ernie_functions/__init__.py @@ -0,0 +1,17 @@ +from langchain.chains.ernie_functions.base import ( + convert_to_ernie_function, + create_ernie_fn_chain, + create_ernie_fn_runnable, + create_structured_output_chain, + create_structured_output_runnable, + get_ernie_output_parser, +) + +__all__ = [ + "convert_to_ernie_function", + "create_structured_output_chain", + "create_ernie_fn_chain", + "create_structured_output_runnable", + "create_ernie_fn_runnable", + "get_ernie_output_parser", +] diff --git a/libs/community/langchain_community/chains/ernie_functions/base.py b/libs/community/langchain_community/chains/ernie_functions/base.py new file mode 100644 index 0000000000..3749d66776 --- /dev/null +++ b/libs/community/langchain_community/chains/ernie_functions/base.py @@ -0,0 +1,551 @@ +"""Methods for creating chains that use Ernie function-calling APIs.""" +import inspect +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) + +from langchain.chains import LLMChain +from langchain_core.language_models import BaseLanguageModel +from langchain_core.output_parsers import ( + BaseGenerationOutputParser, + BaseLLMOutputParser, + BaseOutputParser, +) +from langchain_core.prompts import BasePromptTemplate +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables import Runnable + +from langchain_community.output_parsers.ernie_functions import ( + JsonOutputFunctionsParser, + PydanticAttrOutputFunctionsParser, + PydanticOutputFunctionsParser, +) +from langchain_community.utils.ernie_functions import convert_pydantic_to_ernie_function + +PYTHON_TO_JSON_TYPES = { + "str": "string", + "int": "number", + "float": "number", + "bool": "boolean", +} + + +def _get_python_function_name(function: Callable) -> str: + """Get the name of a Python function.""" + return function.__name__ + + +def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]: + """Parse the function and argument descriptions from the docstring of a function. + + Assumes the function docstring follows Google Python style guide. + """ + docstring = inspect.getdoc(function) + if docstring: + docstring_blocks = docstring.split("\n\n") + descriptors = [] + args_block = None + past_descriptors = False + for block in docstring_blocks: + if block.startswith("Args:"): + args_block = block + break + elif block.startswith("Returns:") or block.startswith("Example:"): + # Don't break in case Args come after + past_descriptors = True + elif not past_descriptors: + descriptors.append(block) + else: + continue + description = " ".join(descriptors) + else: + description = "" + args_block = None + arg_descriptions = {} + if args_block: + arg = None + for line in args_block.split("\n")[1:]: + if ":" in line: + arg, desc = line.split(":") + arg_descriptions[arg.strip()] = desc.strip() + elif arg: + arg_descriptions[arg.strip()] += " " + line.strip() + return description, arg_descriptions + + +def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -> dict: + """Get JsonSchema describing a Python functions arguments. + + Assumes all function arguments are of primitive types (int, float, str, bool) or + are subclasses of pydantic.BaseModel. + """ + properties = {} + annotations = inspect.getfullargspec(function).annotations + for arg, arg_type in annotations.items(): + if arg == "return": + continue + if isinstance(arg_type, type) and issubclass(arg_type, BaseModel): + # Mypy error: + # "type" has no attribute "schema" + properties[arg] = arg_type.schema() # type: ignore[attr-defined] + elif arg_type.__name__ in PYTHON_TO_JSON_TYPES: + properties[arg] = {"type": PYTHON_TO_JSON_TYPES[arg_type.__name__]} + if arg in arg_descriptions: + if arg not in properties: + properties[arg] = {} + properties[arg]["description"] = arg_descriptions[arg] + return properties + + +def _get_python_function_required_args(function: Callable) -> List[str]: + """Get the required arguments for a Python function.""" + spec = inspect.getfullargspec(function) + required = spec.args[: -len(spec.defaults)] if spec.defaults else spec.args + required += [k for k in spec.kwonlyargs if k not in (spec.kwonlydefaults or {})] + + is_class = type(function) is type + if is_class and required[0] == "self": + required = required[1:] + return required + + +def convert_python_function_to_ernie_function( + function: Callable, +) -> Dict[str, Any]: + """Convert a Python function to an Ernie function-calling API compatible dict. + + Assumes the Python function has type hints and a docstring with a description. If + the docstring has Google Python style argument descriptions, these will be + included as well. + """ + description, arg_descriptions = _parse_python_function_docstring(function) + return { + "name": _get_python_function_name(function), + "description": description, + "parameters": { + "type": "object", + "properties": _get_python_function_arguments(function, arg_descriptions), + "required": _get_python_function_required_args(function), + }, + } + + +def convert_to_ernie_function( + function: Union[Dict[str, Any], Type[BaseModel], Callable], +) -> Dict[str, Any]: + """Convert a raw function/class to an Ernie function. + + Args: + function: Either a dictionary, a pydantic.BaseModel class, or a Python function. + If a dictionary is passed in, it is assumed to already be a valid Ernie + function. + + Returns: + A dict version of the passed in function which is compatible with the + Ernie function-calling API. + """ + if isinstance(function, dict): + return function + elif isinstance(function, type) and issubclass(function, BaseModel): + return cast(Dict, convert_pydantic_to_ernie_function(function)) + elif callable(function): + return convert_python_function_to_ernie_function(function) + + else: + raise ValueError( + f"Unsupported function type {type(function)}. Functions must be passed in" + f" as Dict, pydantic.BaseModel, or Callable." + ) + + +def get_ernie_output_parser( + functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], +) -> Union[BaseOutputParser, BaseGenerationOutputParser]: + """Get the appropriate function output parser given the user functions. + + Args: + functions: Sequence where element is a dictionary, a pydantic.BaseModel class, + or a Python function. If a dictionary is passed in, it is assumed to + already be a valid Ernie function. + + Returns: + A PydanticOutputFunctionsParser if functions are Pydantic classes, otherwise + a JsonOutputFunctionsParser. If there's only one function and it is + not a Pydantic class, then the output parser will automatically extract + only the function arguments and not the function name. + """ + function_names = [convert_to_ernie_function(f)["name"] for f in functions] + if isinstance(functions[0], type) and issubclass(functions[0], BaseModel): + if len(functions) > 1: + pydantic_schema: Union[Dict, Type[BaseModel]] = { + name: fn for name, fn in zip(function_names, functions) + } + else: + pydantic_schema = functions[0] + output_parser: Union[ + BaseOutputParser, BaseGenerationOutputParser + ] = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema) + else: + output_parser = JsonOutputFunctionsParser(args_only=len(functions) <= 1) + return output_parser + + +def create_ernie_fn_runnable( + functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], + llm: Runnable, + prompt: BasePromptTemplate, + *, + output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None, + **kwargs: Any, +) -> Runnable: + """Create a runnable sequence that uses Ernie functions. + + Args: + functions: A sequence of either dictionaries, pydantic.BaseModels classes, or + Python functions. If dictionaries are passed in, they are assumed to + already be a valid Ernie functions. If only a single + function is passed in, then it will be enforced that the model use that + function. pydantic.BaseModels and Python functions should have docstrings + describing what the function does. For best results, pydantic.BaseModels + should have descriptions of the parameters and Python functions should have + Google Python style args descriptions in the docstring. Additionally, + Python functions should only use primitive types (str, int, float, bool) or + pydantic.BaseModels for arguments. + llm: Language model to use, assumed to support the Ernie function-calling API. + prompt: BasePromptTemplate to pass to the model. + output_parser: BaseLLMOutputParser to use for parsing model outputs. By default + will be inferred from the function types. If pydantic.BaseModels are passed + in, then the OutputParser will try to parse outputs using those. Otherwise + model outputs will simply be parsed as JSON. If multiple functions are + passed in and they are not pydantic.BaseModels, the chain output will + include both the name of the function that was returned and the arguments + to pass to the function. + + Returns: + A runnable sequence that will pass in the given functions to the model when run. + + Example: + .. code-block:: python + + from typing import Optional + + from langchain.chains.ernie_functions import create_ernie_fn_chain + from langchain_community.chat_models import ErnieBotChat + from langchain_core.prompts import ChatPromptTemplate + from langchain.pydantic_v1 import BaseModel, Field + + + class RecordPerson(BaseModel): + \"\"\"Record some identifying information about a person.\"\"\" + + name: str = Field(..., description="The person's name") + age: int = Field(..., description="The person's age") + fav_food: Optional[str] = Field(None, description="The person's favorite food") + + + class RecordDog(BaseModel): + \"\"\"Record some identifying information about a dog.\"\"\" + + name: str = Field(..., description="The dog's name") + color: str = Field(..., description="The dog's color") + fav_food: Optional[str] = Field(None, description="The dog's favorite food") + + + llm = ErnieBotChat(model_name="ERNIE-Bot-4") + prompt = ChatPromptTemplate.from_messages( + [ + ("user", "Make calls to the relevant function to record the entities in the following input: {input}"), + ("assistant", "OK!"), + ("user", "Tip: Make sure to answer in the correct format"), + ] + ) + chain = create_ernie_fn_runnable([RecordPerson, RecordDog], llm, prompt) + chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"}) + # -> RecordDog(name="Harry", color="brown", fav_food="chicken") + """ # noqa: E501 + if not functions: + raise ValueError("Need to pass in at least one function. Received zero.") + ernie_functions = [convert_to_ernie_function(f) for f in functions] + llm_kwargs: Dict[str, Any] = {"functions": ernie_functions, **kwargs} + if len(ernie_functions) == 1: + llm_kwargs["function_call"] = {"name": ernie_functions[0]["name"]} + output_parser = output_parser or get_ernie_output_parser(functions) + return prompt | llm.bind(**llm_kwargs) | output_parser + + +def create_structured_output_runnable( + output_schema: Union[Dict[str, Any], Type[BaseModel]], + llm: Runnable, + prompt: BasePromptTemplate, + *, + output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None, + **kwargs: Any, +) -> Runnable: + """Create a runnable that uses an Ernie function to get a structured output. + + Args: + output_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary + is passed in, it's assumed to already be a valid JsonSchema. + For best results, pydantic.BaseModels should have docstrings describing what + the schema represents and descriptions for the parameters. + llm: Language model to use, assumed to support the Ernie function-calling API. + prompt: BasePromptTemplate to pass to the model. + output_parser: BaseLLMOutputParser to use for parsing model outputs. By default + will be inferred from the function types. If pydantic.BaseModels are passed + in, then the OutputParser will try to parse outputs using those. Otherwise + model outputs will simply be parsed as JSON. + + Returns: + A runnable sequence that will pass the given function to the model when run. + + Example: + .. code-block:: python + + from typing import Optional + + from langchain.chains.ernie_functions import create_structured_output_chain + from langchain_community.chat_models import ErnieBotChat + from langchain_core.prompts import ChatPromptTemplate + from langchain.pydantic_v1 import BaseModel, Field + + class Dog(BaseModel): + \"\"\"Identifying information about a dog.\"\"\" + + name: str = Field(..., description="The dog's name") + color: str = Field(..., description="The dog's color") + fav_food: Optional[str] = Field(None, description="The dog's favorite food") + + llm = ErnieBotChat(model_name="ERNIE-Bot-4") + prompt = ChatPromptTemplate.from_messages( + [ + ("user", "Use the given format to extract information from the following input: {input}"), + ("assistant", "OK!"), + ("user", "Tip: Make sure to answer in the correct format"), + ] + ) + chain = create_structured_output_chain(Dog, llm, prompt) + chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"}) + # -> Dog(name="Harry", color="brown", fav_food="chicken") + """ # noqa: E501 + if isinstance(output_schema, dict): + function: Any = { + "name": "output_formatter", + "description": ( + "Output formatter. Should always be used to format your response to the" + " user." + ), + "parameters": output_schema, + } + else: + + class _OutputFormatter(BaseModel): + """Output formatter. Should always be used to format your response to the user.""" # noqa: E501 + + output: output_schema # type: ignore + + function = _OutputFormatter + output_parser = output_parser or PydanticAttrOutputFunctionsParser( + pydantic_schema=_OutputFormatter, attr_name="output" + ) + return create_ernie_fn_runnable( + [function], + llm, + prompt, + output_parser=output_parser, + **kwargs, + ) + + +""" --- Legacy --- """ + + +def create_ernie_fn_chain( + functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], + llm: BaseLanguageModel, + prompt: BasePromptTemplate, + *, + output_key: str = "function", + output_parser: Optional[BaseLLMOutputParser] = None, + **kwargs: Any, +) -> LLMChain: + """[Legacy] Create an LLM chain that uses Ernie functions. + + Args: + functions: A sequence of either dictionaries, pydantic.BaseModels classes, or + Python functions. If dictionaries are passed in, they are assumed to + already be a valid Ernie functions. If only a single + function is passed in, then it will be enforced that the model use that + function. pydantic.BaseModels and Python functions should have docstrings + describing what the function does. For best results, pydantic.BaseModels + should have descriptions of the parameters and Python functions should have + Google Python style args descriptions in the docstring. Additionally, + Python functions should only use primitive types (str, int, float, bool) or + pydantic.BaseModels for arguments. + llm: Language model to use, assumed to support the Ernie function-calling API. + prompt: BasePromptTemplate to pass to the model. + output_key: The key to use when returning the output in LLMChain.__call__. + output_parser: BaseLLMOutputParser to use for parsing model outputs. By default + will be inferred from the function types. If pydantic.BaseModels are passed + in, then the OutputParser will try to parse outputs using those. Otherwise + model outputs will simply be parsed as JSON. If multiple functions are + passed in and they are not pydantic.BaseModels, the chain output will + include both the name of the function that was returned and the arguments + to pass to the function. + + Returns: + An LLMChain that will pass in the given functions to the model when run. + + Example: + .. code-block:: python + + from typing import Optional + + from langchain.chains.ernie_functions import create_ernie_fn_chain + from langchain_community.chat_models import ErnieBotChat + from langchain_core.prompts import ChatPromptTemplate + + from langchain.pydantic_v1 import BaseModel, Field + + + class RecordPerson(BaseModel): + \"\"\"Record some identifying information about a person.\"\"\" + + name: str = Field(..., description="The person's name") + age: int = Field(..., description="The person's age") + fav_food: Optional[str] = Field(None, description="The person's favorite food") + + + class RecordDog(BaseModel): + \"\"\"Record some identifying information about a dog.\"\"\" + + name: str = Field(..., description="The dog's name") + color: str = Field(..., description="The dog's color") + fav_food: Optional[str] = Field(None, description="The dog's favorite food") + + + llm = ErnieBotChat(model_name="ERNIE-Bot-4") + prompt = ChatPromptTemplate.from_messages( + [ + ("user", "Make calls to the relevant function to record the entities in the following input: {input}"), + ("assistant", "OK!"), + ("user", "Tip: Make sure to answer in the correct format"), + ] + ) + chain = create_ernie_fn_chain([RecordPerson, RecordDog], llm, prompt) + chain.run("Harry was a chubby brown beagle who loved chicken") + # -> RecordDog(name="Harry", color="brown", fav_food="chicken") + """ # noqa: E501 + if not functions: + raise ValueError("Need to pass in at least one function. Received zero.") + ernie_functions = [convert_to_ernie_function(f) for f in functions] + output_parser = output_parser or get_ernie_output_parser(functions) + llm_kwargs: Dict[str, Any] = { + "functions": ernie_functions, + } + if len(ernie_functions) == 1: + llm_kwargs["function_call"] = {"name": ernie_functions[0]["name"]} + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + output_parser=output_parser, + llm_kwargs=llm_kwargs, + output_key=output_key, + **kwargs, + ) + return llm_chain + + +def create_structured_output_chain( + output_schema: Union[Dict[str, Any], Type[BaseModel]], + llm: BaseLanguageModel, + prompt: BasePromptTemplate, + *, + output_key: str = "function", + output_parser: Optional[BaseLLMOutputParser] = None, + **kwargs: Any, +) -> LLMChain: + """[Legacy] Create an LLMChain that uses an Ernie function to get a structured output. + + Args: + output_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary + is passed in, it's assumed to already be a valid JsonSchema. + For best results, pydantic.BaseModels should have docstrings describing what + the schema represents and descriptions for the parameters. + llm: Language model to use, assumed to support the Ernie function-calling API. + prompt: BasePromptTemplate to pass to the model. + output_key: The key to use when returning the output in LLMChain.__call__. + output_parser: BaseLLMOutputParser to use for parsing model outputs. By default + will be inferred from the function types. If pydantic.BaseModels are passed + in, then the OutputParser will try to parse outputs using those. Otherwise + model outputs will simply be parsed as JSON. + + Returns: + An LLMChain that will pass the given function to the model. + + Example: + .. code-block:: python + + from typing import Optional + + from langchain.chains.ernie_functions import create_structured_output_chain + from langchain_community.chat_models import ErnieBotChat + from langchain_core.prompts import ChatPromptTemplate + + from langchain.pydantic_v1 import BaseModel, Field + + class Dog(BaseModel): + \"\"\"Identifying information about a dog.\"\"\" + + name: str = Field(..., description="The dog's name") + color: str = Field(..., description="The dog's color") + fav_food: Optional[str] = Field(None, description="The dog's favorite food") + + llm = ErnieBotChat(model_name="ERNIE-Bot-4") + prompt = ChatPromptTemplate.from_messages( + [ + ("user", "Use the given format to extract information from the following input: {input}"), + ("assistant", "OK!"), + ("user", "Tip: Make sure to answer in the correct format"), + ] + ) + chain = create_structured_output_chain(Dog, llm, prompt) + chain.run("Harry was a chubby brown beagle who loved chicken") + # -> Dog(name="Harry", color="brown", fav_food="chicken") + """ # noqa: E501 + if isinstance(output_schema, dict): + function: Any = { + "name": "output_formatter", + "description": ( + "Output formatter. Should always be used to format your response to the" + " user." + ), + "parameters": output_schema, + } + else: + + class _OutputFormatter(BaseModel): + """Output formatter. Should always be used to format your response to the user.""" # noqa: E501 + + output: output_schema # type: ignore + + function = _OutputFormatter + output_parser = output_parser or PydanticAttrOutputFunctionsParser( + pydantic_schema=_OutputFormatter, attr_name="output" + ) + return create_ernie_fn_chain( + [function], + llm, + prompt, + output_key=output_key, + output_parser=output_parser, + **kwargs, + ) diff --git a/libs/community/langchain_community/chains/graph_qa/__init__.py b/libs/community/langchain_community/chains/graph_qa/__init__.py new file mode 100644 index 0000000000..f3bc55efbc --- /dev/null +++ b/libs/community/langchain_community/chains/graph_qa/__init__.py @@ -0,0 +1 @@ +"""Question answering over a knowledge graph.""" diff --git a/libs/community/langchain_community/chains/graph_qa/arangodb.py b/libs/community/langchain_community/chains/graph_qa/arangodb.py new file mode 100644 index 0000000000..e875cf2f07 --- /dev/null +++ b/libs/community/langchain_community/chains/graph_qa/arangodb.py @@ -0,0 +1,241 @@ +"""Question answering over a graph.""" +from __future__ import annotations + +import re +from typing import Any, Dict, List, Optional + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain_core.callbacks import CallbackManagerForChainRun +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts import BasePromptTemplate +from langchain_core.pydantic_v1 import Field + +from langchain_community.chains.graph_qa.prompts import ( + AQL_FIX_PROMPT, + AQL_GENERATION_PROMPT, + AQL_QA_PROMPT, +) +from langchain_community.graphs.arangodb_graph import ArangoGraph + + +class ArangoGraphQAChain(Chain): + """Chain for question-answering against a graph by generating AQL statements. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + """ + + graph: ArangoGraph = Field(exclude=True) + aql_generation_chain: LLMChain + aql_fix_chain: LLMChain + qa_chain: LLMChain + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + + # Specifies the maximum number of AQL Query Results to return + top_k: int = 10 + + # Specifies the set of AQL Query Examples that promote few-shot-learning + aql_examples: str = "" + + # Specify whether to return the AQL Query in the output dictionary + return_aql_query: bool = False + + # Specify whether to return the AQL JSON Result in the output dictionary + return_aql_result: bool = False + + # Specify the maximum amount of AQL Generation attempts that should be made + max_aql_generation_attempts: int = 3 + + @property + def input_keys(self) -> List[str]: + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + return [self.output_key] + + @property + def _chain_type(self) -> str: + return "graph_aql_chain" + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + *, + qa_prompt: BasePromptTemplate = AQL_QA_PROMPT, + aql_generation_prompt: BasePromptTemplate = AQL_GENERATION_PROMPT, + aql_fix_prompt: BasePromptTemplate = AQL_FIX_PROMPT, + **kwargs: Any, + ) -> ArangoGraphQAChain: + """Initialize from LLM.""" + qa_chain = LLMChain(llm=llm, prompt=qa_prompt) + aql_generation_chain = LLMChain(llm=llm, prompt=aql_generation_prompt) + aql_fix_chain = LLMChain(llm=llm, prompt=aql_fix_prompt) + + return cls( + qa_chain=qa_chain, + aql_generation_chain=aql_generation_chain, + aql_fix_chain=aql_fix_chain, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + """ + Generate an AQL statement from user input, use it retrieve a response + from an ArangoDB Database instance, and respond to the user input + in natural language. + + Users can modify the following ArangoGraphQAChain Class Variables: + + :var top_k: The maximum number of AQL Query Results to return + :type top_k: int + + :var aql_examples: A set of AQL Query Examples that are passed to + the AQL Generation Prompt Template to promote few-shot-learning. + Defaults to an empty string. + :type aql_examples: str + + :var return_aql_query: Whether to return the AQL Query in the + output dictionary. Defaults to False. + :type return_aql_query: bool + + :var return_aql_result: Whether to return the AQL Query in the + output dictionary. Defaults to False + :type return_aql_result: bool + + :var max_aql_generation_attempts: The maximum amount of AQL + Generation attempts to be made prior to raising the last + AQL Query Execution Error. Defaults to 3. + :type max_aql_generation_attempts: int + """ + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + user_input = inputs[self.input_key] + + ######################### + # Generate AQL Query # + aql_generation_output = self.aql_generation_chain.run( + { + "adb_schema": self.graph.schema, + "aql_examples": self.aql_examples, + "user_input": user_input, + }, + callbacks=callbacks, + ) + ######################### + + aql_query = "" + aql_error = "" + aql_result = None + aql_generation_attempt = 1 + + while ( + aql_result is None + and aql_generation_attempt < self.max_aql_generation_attempts + 1 + ): + ##################### + # Extract AQL Query # + pattern = r"```(?i:aql)?(.*?)```" + matches = re.findall(pattern, aql_generation_output, re.DOTALL) + if not matches: + _run_manager.on_text( + "Invalid Response: ", end="\n", verbose=self.verbose + ) + _run_manager.on_text( + aql_generation_output, color="red", end="\n", verbose=self.verbose + ) + raise ValueError(f"Response is Invalid: {aql_generation_output}") + + aql_query = matches[0] + ##################### + + _run_manager.on_text( + f"AQL Query ({aql_generation_attempt}):", verbose=self.verbose + ) + _run_manager.on_text( + aql_query, color="green", end="\n", verbose=self.verbose + ) + + ##################### + # Execute AQL Query # + from arango import AQLQueryExecuteError + + try: + aql_result = self.graph.query(aql_query, self.top_k) + except AQLQueryExecuteError as e: + aql_error = e.error_message + + _run_manager.on_text( + "AQL Query Execution Error: ", end="\n", verbose=self.verbose + ) + _run_manager.on_text( + aql_error, color="yellow", end="\n\n", verbose=self.verbose + ) + + ######################## + # Retry AQL Generation # + aql_generation_output = self.aql_fix_chain.run( + { + "adb_schema": self.graph.schema, + "aql_query": aql_query, + "aql_error": aql_error, + }, + callbacks=callbacks, + ) + ######################## + + ##################### + + aql_generation_attempt += 1 + + if aql_result is None: + m = f""" + Maximum amount of AQL Query Generation attempts reached. + Unable to execute the AQL Query due to the following error: + {aql_error} + """ + raise ValueError(m) + + _run_manager.on_text("AQL Result:", end="\n", verbose=self.verbose) + _run_manager.on_text( + str(aql_result), color="green", end="\n", verbose=self.verbose + ) + + ######################## + # Interpret AQL Result # + result = self.qa_chain( + { + "adb_schema": self.graph.schema, + "user_input": user_input, + "aql_query": aql_query, + "aql_result": aql_result, + }, + callbacks=callbacks, + ) + ######################## + + # Return results # + result = {self.output_key: result[self.qa_chain.output_key]} + + if self.return_aql_query: + result["aql_query"] = aql_query + + if self.return_aql_result: + result["aql_result"] = aql_result + + return result diff --git a/libs/community/langchain_community/chains/graph_qa/base.py b/libs/community/langchain_community/chains/graph_qa/base.py new file mode 100644 index 0000000000..e315c3340d --- /dev/null +++ b/libs/community/langchain_community/chains/graph_qa/base.py @@ -0,0 +1,103 @@ +"""Question answering over a graph.""" +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain_core.callbacks.manager import CallbackManagerForChainRun +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts import BasePromptTemplate +from langchain_core.pydantic_v1 import Field + +from langchain_community.chains.graph_qa.prompts import ( + ENTITY_EXTRACTION_PROMPT, + GRAPH_QA_PROMPT, +) +from langchain_community.graphs.networkx_graph import NetworkxEntityGraph, get_entities + + +class GraphQAChain(Chain): + """Chain for question-answering against a graph. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + """ + + graph: NetworkxEntityGraph = Field(exclude=True) + entity_extraction_chain: LLMChain + qa_chain: LLMChain + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + + @property + def input_keys(self) -> List[str]: + """Input keys. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Output keys. + + :meta private: + """ + _output_keys = [self.output_key] + return _output_keys + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + qa_prompt: BasePromptTemplate = GRAPH_QA_PROMPT, + entity_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT, + **kwargs: Any, + ) -> GraphQAChain: + """Initialize from LLM.""" + qa_chain = LLMChain(llm=llm, prompt=qa_prompt) + entity_chain = LLMChain(llm=llm, prompt=entity_prompt) + + return cls( + qa_chain=qa_chain, + entity_extraction_chain=entity_chain, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + """Extract entities, look up info and answer question.""" + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + question = inputs[self.input_key] + + entity_string = self.entity_extraction_chain.run(question) + + _run_manager.on_text("Entities Extracted:", end="\n", verbose=self.verbose) + _run_manager.on_text( + entity_string, color="green", end="\n", verbose=self.verbose + ) + entities = get_entities(entity_string) + context = "" + all_triplets = [] + for entity in entities: + all_triplets.extend(self.graph.get_entity_knowledge(entity)) + context = "\n".join(all_triplets) + _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) + _run_manager.on_text(context, color="green", end="\n", verbose=self.verbose) + result = self.qa_chain( + {"question": question, "context": context}, + callbacks=_run_manager.get_child(), + ) + return {self.output_key: result[self.qa_chain.output_key]} diff --git a/libs/community/langchain_community/chains/graph_qa/cypher.py b/libs/community/langchain_community/chains/graph_qa/cypher.py new file mode 100644 index 0000000000..e43f24037b --- /dev/null +++ b/libs/community/langchain_community/chains/graph_qa/cypher.py @@ -0,0 +1,298 @@ +"""Question answering over a graph.""" +from __future__ import annotations + +import re +from typing import Any, Dict, List, Optional + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain_core.callbacks import CallbackManagerForChainRun +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts import BasePromptTemplate +from langchain_core.pydantic_v1 import Field + +from langchain_community.chains.graph_qa.cypher_utils import ( + CypherQueryCorrector, + Schema, +) +from langchain_community.chains.graph_qa.prompts import ( + CYPHER_GENERATION_PROMPT, + CYPHER_QA_PROMPT, +) +from langchain_community.graphs.graph_store import GraphStore + +INTERMEDIATE_STEPS_KEY = "intermediate_steps" + + +def extract_cypher(text: str) -> str: + """Extract Cypher code from a text. + + Args: + text: Text to extract Cypher code from. + + Returns: + Cypher code extracted from the text. + """ + # The pattern to find Cypher code enclosed in triple backticks + pattern = r"```(.*?)```" + + # Find all matches in the input text + matches = re.findall(pattern, text, re.DOTALL) + + return matches[0] if matches else text + + +def construct_schema( + structured_schema: Dict[str, Any], + include_types: List[str], + exclude_types: List[str], +) -> str: + """Filter the schema based on included or excluded types""" + + def filter_func(x: str) -> bool: + return x in include_types if include_types else x not in exclude_types + + filtered_schema: Dict[str, Any] = { + "node_props": { + k: v + for k, v in structured_schema.get("node_props", {}).items() + if filter_func(k) + }, + "rel_props": { + k: v + for k, v in structured_schema.get("rel_props", {}).items() + if filter_func(k) + }, + "relationships": [ + r + for r in structured_schema.get("relationships", []) + if all(filter_func(r[t]) for t in ["start", "end", "type"]) + ], + } + + # Format node properties + formatted_node_props = [] + for label, properties in filtered_schema["node_props"].items(): + props_str = ", ".join( + [f"{prop['property']}: {prop['type']}" for prop in properties] + ) + formatted_node_props.append(f"{label} {{{props_str}}}") + + # Format relationship properties + formatted_rel_props = [] + for rel_type, properties in filtered_schema["rel_props"].items(): + props_str = ", ".join( + [f"{prop['property']}: {prop['type']}" for prop in properties] + ) + formatted_rel_props.append(f"{rel_type} {{{props_str}}}") + + # Format relationships + formatted_rels = [ + f"(:{el['start']})-[:{el['type']}]->(:{el['end']})" + for el in filtered_schema["relationships"] + ] + + return "\n".join( + [ + "Node properties are the following:", + ",".join(formatted_node_props), + "Relationship properties are the following:", + ",".join(formatted_rel_props), + "The relationships are the following:", + ",".join(formatted_rels), + ] + ) + + +class GraphCypherQAChain(Chain): + """Chain for question-answering against a graph by generating Cypher statements. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + """ + + graph: GraphStore = Field(exclude=True) + cypher_generation_chain: LLMChain + qa_chain: LLMChain + graph_schema: str + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + top_k: int = 10 + """Number of results to return from the query""" + return_intermediate_steps: bool = False + """Whether or not to return the intermediate steps along with the final answer.""" + return_direct: bool = False + """Whether or not to return the result of querying the graph directly.""" + cypher_query_corrector: Optional[CypherQueryCorrector] = None + """Optional cypher validation tool""" + + @property + def input_keys(self) -> List[str]: + """Return the input keys. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return the output keys. + + :meta private: + """ + _output_keys = [self.output_key] + return _output_keys + + @property + def _chain_type(self) -> str: + return "graph_cypher_chain" + + @classmethod + def from_llm( + cls, + llm: Optional[BaseLanguageModel] = None, + *, + qa_prompt: Optional[BasePromptTemplate] = None, + cypher_prompt: Optional[BasePromptTemplate] = None, + cypher_llm: Optional[BaseLanguageModel] = None, + qa_llm: Optional[BaseLanguageModel] = None, + exclude_types: List[str] = [], + include_types: List[str] = [], + validate_cypher: bool = False, + qa_llm_kwargs: Optional[Dict[str, Any]] = None, + cypher_llm_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> GraphCypherQAChain: + """Initialize from LLM.""" + + if not cypher_llm and not llm: + raise ValueError("Either `llm` or `cypher_llm` parameters must be provided") + if not qa_llm and not llm: + raise ValueError("Either `llm` or `qa_llm` parameters must be provided") + if cypher_llm and qa_llm and llm: + raise ValueError( + "You can specify up to two of 'cypher_llm', 'qa_llm'" + ", and 'llm', but not all three simultaneously." + ) + if cypher_prompt and cypher_llm_kwargs: + raise ValueError( + "Specifying cypher_prompt and cypher_llm_kwargs together is" + " not allowed. Please pass prompt via cypher_llm_kwargs." + ) + if qa_prompt and qa_llm_kwargs: + raise ValueError( + "Specifying qa_prompt and qa_llm_kwargs together is" + " not allowed. Please pass prompt via qa_llm_kwargs." + ) + use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {} + use_cypher_llm_kwargs = ( + cypher_llm_kwargs if cypher_llm_kwargs is not None else {} + ) + if "prompt" not in use_qa_llm_kwargs: + use_qa_llm_kwargs["prompt"] = ( + qa_prompt if qa_prompt is not None else CYPHER_QA_PROMPT + ) + if "prompt" not in use_cypher_llm_kwargs: + use_cypher_llm_kwargs["prompt"] = ( + cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT + ) + + qa_chain = LLMChain(llm=qa_llm or llm, **use_qa_llm_kwargs) # type: ignore[arg-type] + + cypher_generation_chain = LLMChain( + llm=cypher_llm or llm, # type: ignore[arg-type] + **use_cypher_llm_kwargs, # type: ignore[arg-type] + ) + + if exclude_types and include_types: + raise ValueError( + "Either `exclude_types` or `include_types` " + "can be provided, but not both" + ) + + graph_schema = construct_schema( + kwargs["graph"].get_structured_schema, include_types, exclude_types + ) + + cypher_query_corrector = None + if validate_cypher: + corrector_schema = [ + Schema(el["start"], el["type"], el["end"]) + for el in kwargs["graph"].structured_schema.get("relationships") + ] + cypher_query_corrector = CypherQueryCorrector(corrector_schema) + + return cls( + graph_schema=graph_schema, + qa_chain=qa_chain, + cypher_generation_chain=cypher_generation_chain, + cypher_query_corrector=cypher_query_corrector, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + """Generate Cypher statement, use it to look up in db and answer question.""" + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + question = inputs[self.input_key] + + intermediate_steps: List = [] + + generated_cypher = self.cypher_generation_chain.run( + {"question": question, "schema": self.graph_schema}, callbacks=callbacks + ) + + # Extract Cypher code if it is wrapped in backticks + generated_cypher = extract_cypher(generated_cypher) + + # Correct Cypher query if enabled + if self.cypher_query_corrector: + generated_cypher = self.cypher_query_corrector(generated_cypher) + + _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) + _run_manager.on_text( + generated_cypher, color="green", end="\n", verbose=self.verbose + ) + + intermediate_steps.append({"query": generated_cypher}) + + # Retrieve and limit the number of results + # Generated Cypher be null if query corrector identifies invalid schema + if generated_cypher: + context = self.graph.query(generated_cypher)[: self.top_k] + else: + context = [] + + if self.return_direct: + final_result = context + else: + _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) + _run_manager.on_text( + str(context), color="green", end="\n", verbose=self.verbose + ) + + intermediate_steps.append({"context": context}) + + result = self.qa_chain( + {"question": question, "context": context}, + callbacks=callbacks, + ) + final_result = result[self.qa_chain.output_key] + + chain_result: Dict[str, Any] = {self.output_key: final_result} + if self.return_intermediate_steps: + chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps + + return chain_result diff --git a/libs/community/langchain_community/chains/graph_qa/cypher_utils.py b/libs/community/langchain_community/chains/graph_qa/cypher_utils.py new file mode 100644 index 0000000000..c123cac9b5 --- /dev/null +++ b/libs/community/langchain_community/chains/graph_qa/cypher_utils.py @@ -0,0 +1,260 @@ +import re +from collections import namedtuple +from typing import Any, Dict, List, Optional, Tuple + +Schema = namedtuple("Schema", ["left_node", "relation", "right_node"]) + + +class CypherQueryCorrector: + """ + Used to correct relationship direction in generated Cypher statements. + This code is copied from the winner's submission to the Cypher competition: + https://github.com/sakusaku-rich/cypher-direction-competition + """ + + property_pattern = re.compile(r"\{.+?\}") + node_pattern = re.compile(r"\(.+?\)") + path_pattern = re.compile( + r"(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))(?)(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))" + ) + node_relation_node_pattern = re.compile( + r"(\()+(?P[^()]*?)\)(?P.*?)\((?P[^()]*?)(\))+" + ) + relation_type_pattern = re.compile(r":(?P.+?)?(\{.+\})?]") + + def __init__(self, schemas: List[Schema]): + """ + Args: + schemas: list of schemas + """ + self.schemas = schemas + + def clean_node(self, node: str) -> str: + """ + Args: + node: node in string format + + """ + node = re.sub(self.property_pattern, "", node) + node = node.replace("(", "") + node = node.replace(")", "") + node = node.strip() + return node + + def detect_node_variables(self, query: str) -> Dict[str, List[str]]: + """ + Args: + query: cypher query + """ + nodes = re.findall(self.node_pattern, query) + nodes = [self.clean_node(node) for node in nodes] + res: Dict[str, Any] = {} + for node in nodes: + parts = node.split(":") + if parts == "": + continue + variable = parts[0] + if variable not in res: + res[variable] = [] + res[variable] += parts[1:] + return res + + def extract_paths(self, query: str) -> "List[str]": + """ + Args: + query: cypher query + """ + paths = [] + idx = 0 + while matched := self.path_pattern.findall(query[idx:]): + matched = matched[0] + matched = [ + m for i, m in enumerate(matched) if i not in [1, len(matched) - 1] + ] + path = "".join(matched) + idx = query.find(path) + len(path) - len(matched[-1]) + paths.append(path) + return paths + + def judge_direction(self, relation: str) -> str: + """ + Args: + relation: relation in string format + """ + direction = "BIDIRECTIONAL" + if relation[0] == "<": + direction = "INCOMING" + if relation[-1] == ">": + direction = "OUTGOING" + return direction + + def extract_node_variable(self, part: str) -> Optional[str]: + """ + Args: + part: node in string format + """ + part = part.lstrip("(").rstrip(")") + idx = part.find(":") + if idx != -1: + part = part[:idx] + return None if part == "" else part + + def detect_labels( + self, str_node: str, node_variable_dict: Dict[str, Any] + ) -> List[str]: + """ + Args: + str_node: node in string format + node_variable_dict: dictionary of node variables + """ + splitted_node = str_node.split(":") + variable = splitted_node[0] + labels = [] + if variable in node_variable_dict: + labels = node_variable_dict[variable] + elif variable == "" and len(splitted_node) > 1: + labels = splitted_node[1:] + return labels + + def verify_schema( + self, + from_node_labels: List[str], + relation_types: List[str], + to_node_labels: List[str], + ) -> bool: + """ + Args: + from_node_labels: labels of the from node + relation_type: type of the relation + to_node_labels: labels of the to node + """ + valid_schemas = self.schemas + if from_node_labels != []: + from_node_labels = [label.strip("`") for label in from_node_labels] + valid_schemas = [ + schema for schema in valid_schemas if schema[0] in from_node_labels + ] + if to_node_labels != []: + to_node_labels = [label.strip("`") for label in to_node_labels] + valid_schemas = [ + schema for schema in valid_schemas if schema[2] in to_node_labels + ] + if relation_types != []: + relation_types = [type.strip("`") for type in relation_types] + valid_schemas = [ + schema for schema in valid_schemas if schema[1] in relation_types + ] + return valid_schemas != [] + + def detect_relation_types(self, str_relation: str) -> Tuple[str, List[str]]: + """ + Args: + str_relation: relation in string format + """ + relation_direction = self.judge_direction(str_relation) + relation_type = self.relation_type_pattern.search(str_relation) + if relation_type is None or relation_type.group("relation_type") is None: + return relation_direction, [] + relation_types = [ + t.strip().strip("!") + for t in relation_type.group("relation_type").split("|") + ] + return relation_direction, relation_types + + def correct_query(self, query: str) -> str: + """ + Args: + query: cypher query + """ + node_variable_dict = self.detect_node_variables(query) + paths = self.extract_paths(query) + for path in paths: + original_path = path + start_idx = 0 + while start_idx < len(path): + match_res = re.match(self.node_relation_node_pattern, path[start_idx:]) + if match_res is None: + break + start_idx += match_res.start() + match_dict = match_res.groupdict() + left_node_labels = self.detect_labels( + match_dict["left_node"], node_variable_dict + ) + right_node_labels = self.detect_labels( + match_dict["right_node"], node_variable_dict + ) + end_idx = ( + start_idx + + 4 + + len(match_dict["left_node"]) + + len(match_dict["relation"]) + + len(match_dict["right_node"]) + ) + original_partial_path = original_path[start_idx : end_idx + 1] + relation_direction, relation_types = self.detect_relation_types( + match_dict["relation"] + ) + + if relation_types != [] and "".join(relation_types).find("*") != -1: + start_idx += ( + len(match_dict["left_node"]) + len(match_dict["relation"]) + 2 + ) + continue + + if relation_direction == "OUTGOING": + is_legal = self.verify_schema( + left_node_labels, relation_types, right_node_labels + ) + if not is_legal: + is_legal = self.verify_schema( + right_node_labels, relation_types, left_node_labels + ) + if is_legal: + corrected_relation = "<" + match_dict["relation"][:-1] + corrected_partial_path = original_partial_path.replace( + match_dict["relation"], corrected_relation + ) + query = query.replace( + original_partial_path, corrected_partial_path + ) + else: + return "" + elif relation_direction == "INCOMING": + is_legal = self.verify_schema( + right_node_labels, relation_types, left_node_labels + ) + if not is_legal: + is_legal = self.verify_schema( + left_node_labels, relation_types, right_node_labels + ) + if is_legal: + corrected_relation = match_dict["relation"][1:] + ">" + corrected_partial_path = original_partial_path.replace( + match_dict["relation"], corrected_relation + ) + query = query.replace( + original_partial_path, corrected_partial_path + ) + else: + return "" + else: + is_legal = self.verify_schema( + left_node_labels, relation_types, right_node_labels + ) + is_legal |= self.verify_schema( + right_node_labels, relation_types, left_node_labels + ) + if not is_legal: + return "" + + start_idx += ( + len(match_dict["left_node"]) + len(match_dict["relation"]) + 2 + ) + return query + + def __call__(self, query: str) -> str: + """Correct the query to make it valid. If + Args: + query: cypher query + """ + return self.correct_query(query) diff --git a/libs/community/langchain_community/chains/graph_qa/falkordb.py b/libs/community/langchain_community/chains/graph_qa/falkordb.py new file mode 100644 index 0000000000..5d27adfa42 --- /dev/null +++ b/libs/community/langchain_community/chains/graph_qa/falkordb.py @@ -0,0 +1,157 @@ +"""Question answering over a graph.""" +from __future__ import annotations + +import re +from typing import Any, Dict, List, Optional + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain_core.callbacks import CallbackManagerForChainRun +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts import BasePromptTemplate +from langchain_core.pydantic_v1 import Field + +from langchain_community.chains.graph_qa.prompts import ( + CYPHER_GENERATION_PROMPT, + CYPHER_QA_PROMPT, +) +from langchain_community.graphs import FalkorDBGraph + +INTERMEDIATE_STEPS_KEY = "intermediate_steps" + + +def extract_cypher(text: str) -> str: + """ + Extract Cypher code from a text. + Args: + text: Text to extract Cypher code from. + + Returns: + Cypher code extracted from the text. + """ + # The pattern to find Cypher code enclosed in triple backticks + pattern = r"```(.*?)```" + + # Find all matches in the input text + matches = re.findall(pattern, text, re.DOTALL) + + return matches[0] if matches else text + + +class FalkorDBQAChain(Chain): + """Chain for question-answering against a graph by generating Cypher statements. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + """ + + graph: FalkorDBGraph = Field(exclude=True) + cypher_generation_chain: LLMChain + qa_chain: LLMChain + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + top_k: int = 10 + """Number of results to return from the query""" + return_intermediate_steps: bool = False + """Whether or not to return the intermediate steps along with the final answer.""" + return_direct: bool = False + """Whether or not to return the result of querying the graph directly.""" + + @property + def input_keys(self) -> List[str]: + """Return the input keys. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return the output keys. + + :meta private: + """ + _output_keys = [self.output_key] + return _output_keys + + @property + def _chain_type(self) -> str: + return "graph_cypher_chain" + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + *, + qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, + cypher_prompt: BasePromptTemplate = CYPHER_GENERATION_PROMPT, + **kwargs: Any, + ) -> FalkorDBQAChain: + """Initialize from LLM.""" + qa_chain = LLMChain(llm=llm, prompt=qa_prompt) + cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt) + + return cls( + qa_chain=qa_chain, + cypher_generation_chain=cypher_generation_chain, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + """Generate Cypher statement, use it to look up in db and answer question.""" + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + question = inputs[self.input_key] + + intermediate_steps: List = [] + + generated_cypher = self.cypher_generation_chain.run( + {"question": question, "schema": self.graph.schema}, callbacks=callbacks + ) + + # Extract Cypher code if it is wrapped in backticks + generated_cypher = extract_cypher(generated_cypher) + + _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) + _run_manager.on_text( + generated_cypher, color="green", end="\n", verbose=self.verbose + ) + + intermediate_steps.append({"query": generated_cypher}) + + # Retrieve and limit the number of results + context = self.graph.query(generated_cypher)[: self.top_k] + + if self.return_direct: + final_result = context + else: + _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) + _run_manager.on_text( + str(context), color="green", end="\n", verbose=self.verbose + ) + + intermediate_steps.append({"context": context}) + + result = self.qa_chain( + {"question": question, "context": context}, + callbacks=callbacks, + ) + final_result = result[self.qa_chain.output_key] + + chain_result: Dict[str, Any] = {self.output_key: final_result} + if self.return_intermediate_steps: + chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps + + return chain_result diff --git a/libs/community/langchain_community/chains/graph_qa/gremlin.py b/libs/community/langchain_community/chains/graph_qa/gremlin.py new file mode 100644 index 0000000000..c75157e4a5 --- /dev/null +++ b/libs/community/langchain_community/chains/graph_qa/gremlin.py @@ -0,0 +1,221 @@ +"""Question answering over a graph.""" +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain_core.callbacks.manager import CallbackManager, CallbackManagerForChainRun +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts import BasePromptTemplate +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.pydantic_v1 import Field + +from langchain_community.chains.graph_qa.prompts import ( + CYPHER_QA_PROMPT, + GRAPHDB_SPARQL_FIX_TEMPLATE, + GREMLIN_GENERATION_PROMPT, +) +from langchain_community.graphs import GremlinGraph + +INTERMEDIATE_STEPS_KEY = "intermediate_steps" + + +def extract_gremlin(text: str) -> str: + """Extract Gremlin code from a text. + + Args: + text: Text to extract Gremlin code from. + + Returns: + Gremlin code extracted from the text. + """ + text = text.replace("`", "") + if text.startswith("gremlin"): + text = text[len("gremlin") :] + return text.replace("\n", "") + + +class GremlinQAChain(Chain): + """Chain for question-answering against a graph by generating gremlin statements. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + """ + + graph: GremlinGraph = Field(exclude=True) + gremlin_generation_chain: LLMChain + qa_chain: LLMChain + gremlin_fix_chain: LLMChain + max_fix_retries: int = 3 + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + top_k: int = 100 + return_direct: bool = False + return_intermediate_steps: bool = False + + @property + def input_keys(self) -> List[str]: + """Input keys. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Output keys. + + :meta private: + """ + _output_keys = [self.output_key] + return _output_keys + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + *, + gremlin_fix_prompt: BasePromptTemplate = PromptTemplate( + input_variables=["error_message", "generated_sparql", "schema"], + template=GRAPHDB_SPARQL_FIX_TEMPLATE.replace("SPARQL", "Gremlin").replace( + "in Turtle format", "" + ), + ), + qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, + gremlin_prompt: BasePromptTemplate = GREMLIN_GENERATION_PROMPT, + **kwargs: Any, + ) -> GremlinQAChain: + """Initialize from LLM.""" + qa_chain = LLMChain(llm=llm, prompt=qa_prompt) + gremlin_generation_chain = LLMChain(llm=llm, prompt=gremlin_prompt) + gremlinl_fix_chain = LLMChain(llm=llm, prompt=gremlin_fix_prompt) + return cls( + qa_chain=qa_chain, + gremlin_generation_chain=gremlin_generation_chain, + gremlin_fix_chain=gremlinl_fix_chain, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + """Generate gremlin statement, use it to look up in db and answer question.""" + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + question = inputs[self.input_key] + + intermediate_steps: List = [] + + chain_response = self.gremlin_generation_chain.invoke( + {"question": question, "schema": self.graph.get_schema}, callbacks=callbacks + ) + + generated_gremlin = extract_gremlin( + chain_response[self.gremlin_generation_chain.output_key] + ) + + _run_manager.on_text("Generated gremlin:", end="\n", verbose=self.verbose) + _run_manager.on_text( + generated_gremlin, color="green", end="\n", verbose=self.verbose + ) + + intermediate_steps.append({"query": generated_gremlin}) + + if generated_gremlin: + context = self.execute_with_retry( + _run_manager, callbacks, generated_gremlin + )[: self.top_k] + else: + context = [] + + if self.return_direct: + final_result = context + else: + _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) + _run_manager.on_text( + str(context), color="green", end="\n", verbose=self.verbose + ) + + intermediate_steps.append({"context": context}) + + result = self.qa_chain.invoke( + {"question": question, "context": context}, + callbacks=callbacks, + ) + final_result = result[self.qa_chain.output_key] + + chain_result: Dict[str, Any] = {self.output_key: final_result} + if self.return_intermediate_steps: + chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps + + return chain_result + + def execute_query(self, query: str) -> List[Any]: + try: + return self.graph.query(query) + except Exception as e: + if hasattr(e, "status_message"): + raise ValueError(e.status_message) + else: + raise ValueError(str(e)) + + def execute_with_retry( + self, + _run_manager: CallbackManagerForChainRun, + callbacks: CallbackManager, + generated_gremlin: str, + ) -> List[Any]: + try: + return self.execute_query(generated_gremlin) + except Exception as e: + retries = 0 + error_message = str(e) + self.log_invalid_query(_run_manager, generated_gremlin, error_message) + + while retries < self.max_fix_retries: + try: + fix_chain_result = self.gremlin_fix_chain.invoke( + { + "error_message": error_message, + # we are borrowing template from sparql + "generated_sparql": generated_gremlin, + "schema": self.schema, + }, + callbacks=callbacks, + ) + fixed_gremlin = fix_chain_result[self.gremlin_fix_chain.output_key] + return self.execute_query(fixed_gremlin) + except Exception as e: + retries += 1 + parse_exception = str(e) + self.log_invalid_query(_run_manager, fixed_gremlin, parse_exception) + + raise ValueError("The generated Gremlin query is invalid.") + + def log_invalid_query( + self, + _run_manager: CallbackManagerForChainRun, + generated_query: str, + error_message: str, + ) -> None: + _run_manager.on_text("Invalid Gremlin query: ", end="\n", verbose=self.verbose) + _run_manager.on_text( + generated_query, color="red", end="\n", verbose=self.verbose + ) + _run_manager.on_text( + "Gremlin Query Parse Error: ", end="\n", verbose=self.verbose + ) + _run_manager.on_text( + error_message, color="red", end="\n\n", verbose=self.verbose + ) diff --git a/libs/community/langchain_community/chains/graph_qa/hugegraph.py b/libs/community/langchain_community/chains/graph_qa/hugegraph.py new file mode 100644 index 0000000000..a20ce9f3d5 --- /dev/null +++ b/libs/community/langchain_community/chains/graph_qa/hugegraph.py @@ -0,0 +1,106 @@ +"""Question answering over a graph.""" +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain_core.callbacks import CallbackManagerForChainRun +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts import BasePromptTemplate +from langchain_core.pydantic_v1 import Field + +from langchain_community.chains.graph_qa.prompts import ( + CYPHER_QA_PROMPT, + GREMLIN_GENERATION_PROMPT, +) +from langchain_community.graphs.hugegraph import HugeGraph + + +class HugeGraphQAChain(Chain): + """Chain for question-answering against a graph by generating gremlin statements. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + """ + + graph: HugeGraph = Field(exclude=True) + gremlin_generation_chain: LLMChain + qa_chain: LLMChain + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + + @property + def input_keys(self) -> List[str]: + """Input keys. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Output keys. + + :meta private: + """ + _output_keys = [self.output_key] + return _output_keys + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + *, + qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, + gremlin_prompt: BasePromptTemplate = GREMLIN_GENERATION_PROMPT, + **kwargs: Any, + ) -> HugeGraphQAChain: + """Initialize from LLM.""" + qa_chain = LLMChain(llm=llm, prompt=qa_prompt) + gremlin_generation_chain = LLMChain(llm=llm, prompt=gremlin_prompt) + + return cls( + qa_chain=qa_chain, + gremlin_generation_chain=gremlin_generation_chain, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + """Generate gremlin statement, use it to look up in db and answer question.""" + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + question = inputs[self.input_key] + + generated_gremlin = self.gremlin_generation_chain.run( + {"question": question, "schema": self.graph.get_schema}, callbacks=callbacks + ) + + _run_manager.on_text("Generated gremlin:", end="\n", verbose=self.verbose) + _run_manager.on_text( + generated_gremlin, color="green", end="\n", verbose=self.verbose + ) + context = self.graph.query(generated_gremlin) + + _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) + _run_manager.on_text( + str(context), color="green", end="\n", verbose=self.verbose + ) + + result = self.qa_chain( + {"question": question, "context": context}, + callbacks=callbacks, + ) + return {self.output_key: result[self.qa_chain.output_key]} diff --git a/libs/community/langchain_community/chains/graph_qa/kuzu.py b/libs/community/langchain_community/chains/graph_qa/kuzu.py new file mode 100644 index 0000000000..752b7c20df --- /dev/null +++ b/libs/community/langchain_community/chains/graph_qa/kuzu.py @@ -0,0 +1,143 @@ +"""Question answering over a graph.""" +from __future__ import annotations + +import re +from typing import Any, Dict, List, Optional + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain_core.callbacks import CallbackManagerForChainRun +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts import BasePromptTemplate +from langchain_core.pydantic_v1 import Field + +from langchain_community.chains.graph_qa.prompts import ( + CYPHER_QA_PROMPT, + KUZU_GENERATION_PROMPT, +) +from langchain_community.graphs.kuzu_graph import KuzuGraph + + +def remove_prefix(text: str, prefix: str) -> str: + """Remove a prefix from a text. + + Args: + text: Text to remove the prefix from. + prefix: Prefix to remove from the text. + + Returns: + Text with the prefix removed. + """ + if text.startswith(prefix): + return text[len(prefix) :] + return text + + +def extract_cypher(text: str) -> str: + """Extract Cypher code from a text. + + Args: + text: Text to extract Cypher code from. + + Returns: + Cypher code extracted from the text. + """ + # The pattern to find Cypher code enclosed in triple backticks + pattern = r"```(.*?)```" + + # Find all matches in the input text + matches = re.findall(pattern, text, re.DOTALL) + + return matches[0] if matches else text + + +class KuzuQAChain(Chain): + """Question-answering against a graph by generating Cypher statements for Kùzu. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + """ + + graph: KuzuGraph = Field(exclude=True) + cypher_generation_chain: LLMChain + qa_chain: LLMChain + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + + @property + def input_keys(self) -> List[str]: + """Return the input keys. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return the output keys. + + :meta private: + """ + _output_keys = [self.output_key] + return _output_keys + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + *, + qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, + cypher_prompt: BasePromptTemplate = KUZU_GENERATION_PROMPT, + **kwargs: Any, + ) -> KuzuQAChain: + """Initialize from LLM.""" + qa_chain = LLMChain(llm=llm, prompt=qa_prompt) + cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt) + + return cls( + qa_chain=qa_chain, + cypher_generation_chain=cypher_generation_chain, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + """Generate Cypher statement, use it to look up in db and answer question.""" + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + question = inputs[self.input_key] + + generated_cypher = self.cypher_generation_chain.run( + {"question": question, "schema": self.graph.get_schema}, callbacks=callbacks + ) + # Extract Cypher code if it is wrapped in triple backticks + # with the language marker "cypher" + generated_cypher = remove_prefix(extract_cypher(generated_cypher), "cypher") + + _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) + _run_manager.on_text( + generated_cypher, color="green", end="\n", verbose=self.verbose + ) + context = self.graph.query(generated_cypher) + + _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) + _run_manager.on_text( + str(context), color="green", end="\n", verbose=self.verbose + ) + + result = self.qa_chain( + {"question": question, "context": context}, + callbacks=callbacks, + ) + return {self.output_key: result[self.qa_chain.output_key]} diff --git a/libs/community/langchain_community/chains/graph_qa/nebulagraph.py b/libs/community/langchain_community/chains/graph_qa/nebulagraph.py new file mode 100644 index 0000000000..9429a8ea88 --- /dev/null +++ b/libs/community/langchain_community/chains/graph_qa/nebulagraph.py @@ -0,0 +1,106 @@ +"""Question answering over a graph.""" +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain_core.callbacks import CallbackManagerForChainRun +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts import BasePromptTemplate +from langchain_core.pydantic_v1 import Field + +from langchain_community.chains.graph_qa.prompts import ( + CYPHER_QA_PROMPT, + NGQL_GENERATION_PROMPT, +) +from langchain_community.graphs.nebula_graph import NebulaGraph + + +class NebulaGraphQAChain(Chain): + """Chain for question-answering against a graph by generating nGQL statements. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + """ + + graph: NebulaGraph = Field(exclude=True) + ngql_generation_chain: LLMChain + qa_chain: LLMChain + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + + @property + def input_keys(self) -> List[str]: + """Return the input keys. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return the output keys. + + :meta private: + """ + _output_keys = [self.output_key] + return _output_keys + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + *, + qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, + ngql_prompt: BasePromptTemplate = NGQL_GENERATION_PROMPT, + **kwargs: Any, + ) -> NebulaGraphQAChain: + """Initialize from LLM.""" + qa_chain = LLMChain(llm=llm, prompt=qa_prompt) + ngql_generation_chain = LLMChain(llm=llm, prompt=ngql_prompt) + + return cls( + qa_chain=qa_chain, + ngql_generation_chain=ngql_generation_chain, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + """Generate nGQL statement, use it to look up in db and answer question.""" + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + question = inputs[self.input_key] + + generated_ngql = self.ngql_generation_chain.run( + {"question": question, "schema": self.graph.get_schema}, callbacks=callbacks + ) + + _run_manager.on_text("Generated nGQL:", end="\n", verbose=self.verbose) + _run_manager.on_text( + generated_ngql, color="green", end="\n", verbose=self.verbose + ) + context = self.graph.query(generated_ngql) + + _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) + _run_manager.on_text( + str(context), color="green", end="\n", verbose=self.verbose + ) + + result = self.qa_chain( + {"question": question, "context": context}, + callbacks=callbacks, + ) + return {self.output_key: result[self.qa_chain.output_key]} diff --git a/libs/community/langchain_community/chains/graph_qa/neptune_cypher.py b/libs/community/langchain_community/chains/graph_qa/neptune_cypher.py new file mode 100644 index 0000000000..5b786f9f57 --- /dev/null +++ b/libs/community/langchain_community/chains/graph_qa/neptune_cypher.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import re +from typing import Any, Dict, List, Optional + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain.chains.prompt_selector import ConditionalPromptSelector +from langchain_core.callbacks import CallbackManagerForChainRun +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts.base import BasePromptTemplate +from langchain_core.pydantic_v1 import Field + +from langchain_community.chains.graph_qa.prompts import ( + CYPHER_QA_PROMPT, + NEPTUNE_OPENCYPHER_GENERATION_PROMPT, + NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT, +) +from langchain_community.graphs import BaseNeptuneGraph + +INTERMEDIATE_STEPS_KEY = "intermediate_steps" + + +def trim_query(query: str) -> str: + """Trim the query to only include Cypher keywords.""" + keywords = ( + "CALL", + "CREATE", + "DELETE", + "DETACH", + "LIMIT", + "MATCH", + "MERGE", + "OPTIONAL", + "ORDER", + "REMOVE", + "RETURN", + "SET", + "SKIP", + "UNWIND", + "WITH", + "WHERE", + "//", + ) + + lines = query.split("\n") + new_query = "" + + for line in lines: + if line.strip().upper().startswith(keywords): + new_query += line + "\n" + + return new_query + + +def extract_cypher(text: str) -> str: + """Extract Cypher code from text using Regex.""" + # The pattern to find Cypher code enclosed in triple backticks + pattern = r"```(.*?)```" + + # Find all matches in the input text + matches = re.findall(pattern, text, re.DOTALL) + + return matches[0] if matches else text + + +def use_simple_prompt(llm: BaseLanguageModel) -> bool: + """Decides whether to use the simple prompt""" + if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore + return True + + # Bedrock anthropic + if hasattr(llm, "model_id") and "anthropic" in llm.model_id: # type: ignore + return True + + return False + + +PROMPT_SELECTOR = ConditionalPromptSelector( + default_prompt=NEPTUNE_OPENCYPHER_GENERATION_PROMPT, + conditionals=[(use_simple_prompt, NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT)], +) + + +class NeptuneOpenCypherQAChain(Chain): + """Chain for question-answering against a Neptune graph + by generating openCypher statements. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + + Example: + .. code-block:: python + + chain = NeptuneOpenCypherQAChain.from_llm( + llm=llm, + graph=graph + ) + response = chain.run(query) + """ + + graph: BaseNeptuneGraph = Field(exclude=True) + cypher_generation_chain: LLMChain + qa_chain: LLMChain + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + top_k: int = 10 + return_intermediate_steps: bool = False + """Whether or not to return the intermediate steps along with the final answer.""" + return_direct: bool = False + """Whether or not to return the result of querying the graph directly.""" + extra_instructions: Optional[str] = None + """Extra instructions by the appended to the query generation prompt.""" + + @property + def input_keys(self) -> List[str]: + """Return the input keys. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return the output keys. + + :meta private: + """ + _output_keys = [self.output_key] + return _output_keys + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + *, + qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, + cypher_prompt: Optional[BasePromptTemplate] = None, + extra_instructions: Optional[str] = None, + **kwargs: Any, + ) -> NeptuneOpenCypherQAChain: + """Initialize from LLM.""" + qa_chain = LLMChain(llm=llm, prompt=qa_prompt) + + _cypher_prompt = cypher_prompt or PROMPT_SELECTOR.get_prompt(llm) + cypher_generation_chain = LLMChain(llm=llm, prompt=_cypher_prompt) + + return cls( + qa_chain=qa_chain, + cypher_generation_chain=cypher_generation_chain, + extra_instructions=extra_instructions, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + """Generate Cypher statement, use it to look up in db and answer question.""" + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + question = inputs[self.input_key] + + intermediate_steps: List = [] + + generated_cypher = self.cypher_generation_chain.run( + { + "question": question, + "schema": self.graph.get_schema, + "extra_instructions": self.extra_instructions or "", + }, + callbacks=callbacks, + ) + + # Extract Cypher code if it is wrapped in backticks + generated_cypher = extract_cypher(generated_cypher) + generated_cypher = trim_query(generated_cypher) + + _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) + _run_manager.on_text( + generated_cypher, color="green", end="\n", verbose=self.verbose + ) + + intermediate_steps.append({"query": generated_cypher}) + + context = self.graph.query(generated_cypher) + + if self.return_direct: + final_result = context + else: + _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) + _run_manager.on_text( + str(context), color="green", end="\n", verbose=self.verbose + ) + + intermediate_steps.append({"context": context}) + + result = self.qa_chain( + {"question": question, "context": context}, + callbacks=callbacks, + ) + final_result = result[self.qa_chain.output_key] + + chain_result: Dict[str, Any] = {self.output_key: final_result} + if self.return_intermediate_steps: + chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps + + return chain_result diff --git a/libs/community/langchain_community/chains/graph_qa/neptune_sparql.py b/libs/community/langchain_community/chains/graph_qa/neptune_sparql.py new file mode 100644 index 0000000000..3348b783a8 --- /dev/null +++ b/libs/community/langchain_community/chains/graph_qa/neptune_sparql.py @@ -0,0 +1,204 @@ +""" +Question answering over an RDF or OWL graph using SPARQL. +""" +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain_core.callbacks.manager import CallbackManagerForChainRun +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts.base import BasePromptTemplate +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.pydantic_v1 import Field + +from langchain_community.chains.graph_qa.prompts import SPARQL_QA_PROMPT +from langchain_community.graphs import NeptuneRdfGraph + +INTERMEDIATE_STEPS_KEY = "intermediate_steps" + +SPARQL_GENERATION_TEMPLATE = """ +Task: Generate a SPARQL SELECT statement for querying a graph database. +For instance, to find all email addresses of John Doe, the following +query in backticks would be suitable: +``` +PREFIX foaf: +SELECT ?email +WHERE {{ + ?person foaf:name "John Doe" . + ?person foaf:mbox ?email . +}} +``` +Instructions: +Use only the node types and properties provided in the schema. +Do not use any node types and properties that are not explicitly provided. +Include all necessary prefixes. + +Examples: + +Schema: +{schema} +Note: Be as concise as possible. +Do not include any explanations or apologies in your responses. +Do not respond to any questions that ask for anything else than +for you to construct a SPARQL query. +Do not include any text except the SPARQL query generated. + +The question is: +{prompt}""" + +SPARQL_GENERATION_PROMPT = PromptTemplate( + input_variables=["schema", "prompt"], template=SPARQL_GENERATION_TEMPLATE +) + + +def extract_sparql(query: str) -> str: + """Extract SPARQL code from a text. + + Args: + query: Text to extract SPARQL code from. + + Returns: + SPARQL code extracted from the text. + """ + query = query.strip() + querytoks = query.split("```") + if len(querytoks) == 3: + query = querytoks[1] + + if query.startswith("sparql"): + query = query[6:] + elif query.startswith("") and query.endswith(""): + query = query[8:-9] + return query + + +class NeptuneSparqlQAChain(Chain): + """Chain for question-answering against a Neptune graph + by generating SPARQL statements. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + + Example: + .. code-block:: python + + chain = NeptuneSparqlQAChain.from_llm( + llm=llm, + graph=graph + ) + response = chain.invoke(query) + """ + + graph: NeptuneRdfGraph = Field(exclude=True) + sparql_generation_chain: LLMChain + qa_chain: LLMChain + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + top_k: int = 10 + return_intermediate_steps: bool = False + """Whether or not to return the intermediate steps along with the final answer.""" + return_direct: bool = False + """Whether or not to return the result of querying the graph directly.""" + extra_instructions: Optional[str] = None + """Extra instructions by the appended to the query generation prompt.""" + + @property + def input_keys(self) -> List[str]: + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + _output_keys = [self.output_key] + return _output_keys + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + *, + qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT, + sparql_prompt: BasePromptTemplate = SPARQL_GENERATION_PROMPT, + examples: Optional[str] = None, + **kwargs: Any, + ) -> NeptuneSparqlQAChain: + """Initialize from LLM.""" + qa_chain = LLMChain(llm=llm, prompt=qa_prompt) + template_to_use = SPARQL_GENERATION_TEMPLATE + if examples: + template_to_use = template_to_use.replace( + "Examples:", "Examples: " + examples + ) + sparql_prompt = PromptTemplate( + input_variables=["schema", "prompt"], template=template_to_use + ) + sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_prompt) + + return cls( # type: ignore[call-arg] + qa_chain=qa_chain, + sparql_generation_chain=sparql_generation_chain, + examples=examples, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + """ + Generate SPARQL query, use it to retrieve a response from the gdb and answer + the question. + """ + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + prompt = inputs[self.input_key] + + intermediate_steps: List = [] + + generated_sparql = self.sparql_generation_chain.run( + {"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks + ) + + # Extract SPARQL + generated_sparql = extract_sparql(generated_sparql) + + _run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose) + _run_manager.on_text( + generated_sparql, color="green", end="\n", verbose=self.verbose + ) + + intermediate_steps.append({"query": generated_sparql}) + + context = self.graph.query(generated_sparql) + + if self.return_direct: + final_result = context + else: + _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) + _run_manager.on_text( + str(context), color="green", end="\n", verbose=self.verbose + ) + + intermediate_steps.append({"context": context}) + + result = self.qa_chain( + {"prompt": prompt, "context": context}, + callbacks=callbacks, + ) + final_result = result[self.qa_chain.output_key] + + chain_result: Dict[str, Any] = {self.output_key: final_result} + if self.return_intermediate_steps: + chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps + + return chain_result diff --git a/libs/community/langchain_community/chains/graph_qa/ontotext_graphdb.py b/libs/community/langchain_community/chains/graph_qa/ontotext_graphdb.py new file mode 100644 index 0000000000..bf8a8419e7 --- /dev/null +++ b/libs/community/langchain_community/chains/graph_qa/ontotext_graphdb.py @@ -0,0 +1,190 @@ +"""Question answering over a graph.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +if TYPE_CHECKING: + import rdflib + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain_core.callbacks.manager import CallbackManager, CallbackManagerForChainRun +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts.base import BasePromptTemplate +from langchain_core.pydantic_v1 import Field + +from langchain_community.chains.graph_qa.prompts import ( + GRAPHDB_QA_PROMPT, + GRAPHDB_SPARQL_FIX_PROMPT, + GRAPHDB_SPARQL_GENERATION_PROMPT, +) +from langchain_community.graphs import OntotextGraphDBGraph + + +class OntotextGraphDBQAChain(Chain): + """Question-answering against Ontotext GraphDB + https://graphdb.ontotext.com/ by generating SPARQL queries. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + """ + + graph: OntotextGraphDBGraph = Field(exclude=True) + sparql_generation_chain: LLMChain + sparql_fix_chain: LLMChain + max_fix_retries: int + qa_chain: LLMChain + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + + @property + def input_keys(self) -> List[str]: + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + _output_keys = [self.output_key] + return _output_keys + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + *, + sparql_generation_prompt: BasePromptTemplate = GRAPHDB_SPARQL_GENERATION_PROMPT, + sparql_fix_prompt: BasePromptTemplate = GRAPHDB_SPARQL_FIX_PROMPT, + max_fix_retries: int = 5, + qa_prompt: BasePromptTemplate = GRAPHDB_QA_PROMPT, + **kwargs: Any, + ) -> OntotextGraphDBQAChain: + """Initialize from LLM.""" + sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_generation_prompt) + sparql_fix_chain = LLMChain(llm=llm, prompt=sparql_fix_prompt) + max_fix_retries = max_fix_retries + qa_chain = LLMChain(llm=llm, prompt=qa_prompt) + return cls( + qa_chain=qa_chain, + sparql_generation_chain=sparql_generation_chain, + sparql_fix_chain=sparql_fix_chain, + max_fix_retries=max_fix_retries, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + """ + Generate a SPARQL query, use it to retrieve a response from GraphDB and answer + the question. + """ + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + prompt = inputs[self.input_key] + ontology_schema = self.graph.get_schema + + sparql_generation_chain_result = self.sparql_generation_chain.invoke( + {"prompt": prompt, "schema": ontology_schema}, callbacks=callbacks + ) + generated_sparql = sparql_generation_chain_result[ + self.sparql_generation_chain.output_key + ] + + generated_sparql = self._get_prepared_sparql_query( + _run_manager, callbacks, generated_sparql, ontology_schema + ) + query_results = self._execute_query(generated_sparql) + + qa_chain_result = self.qa_chain.invoke( + {"prompt": prompt, "context": query_results}, callbacks=callbacks + ) + result = qa_chain_result[self.qa_chain.output_key] + return {self.output_key: result} + + def _get_prepared_sparql_query( + self, + _run_manager: CallbackManagerForChainRun, + callbacks: CallbackManager, + generated_sparql: str, + ontology_schema: str, + ) -> str: + try: + return self._prepare_sparql_query(_run_manager, generated_sparql) + except Exception as e: + retries = 0 + error_message = str(e) + self._log_invalid_sparql_query( + _run_manager, generated_sparql, error_message + ) + + while retries < self.max_fix_retries: + try: + sparql_fix_chain_result = self.sparql_fix_chain.invoke( + { + "error_message": error_message, + "generated_sparql": generated_sparql, + "schema": ontology_schema, + }, + callbacks=callbacks, + ) + generated_sparql = sparql_fix_chain_result[ + self.sparql_fix_chain.output_key + ] + return self._prepare_sparql_query(_run_manager, generated_sparql) + except Exception as e: + retries += 1 + parse_exception = str(e) + self._log_invalid_sparql_query( + _run_manager, generated_sparql, parse_exception + ) + + raise ValueError("The generated SPARQL query is invalid.") + + def _prepare_sparql_query( + self, _run_manager: CallbackManagerForChainRun, generated_sparql: str + ) -> str: + from rdflib.plugins.sparql import prepareQuery + + prepareQuery(generated_sparql) + self._log_prepared_sparql_query(_run_manager, generated_sparql) + return generated_sparql + + def _log_prepared_sparql_query( + self, _run_manager: CallbackManagerForChainRun, generated_query: str + ) -> None: + _run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose) + _run_manager.on_text( + generated_query, color="green", end="\n", verbose=self.verbose + ) + + def _log_invalid_sparql_query( + self, + _run_manager: CallbackManagerForChainRun, + generated_query: str, + error_message: str, + ) -> None: + _run_manager.on_text("Invalid SPARQL query: ", end="\n", verbose=self.verbose) + _run_manager.on_text( + generated_query, color="red", end="\n", verbose=self.verbose + ) + _run_manager.on_text( + "SPARQL Query Parse Error: ", end="\n", verbose=self.verbose + ) + _run_manager.on_text( + error_message, color="red", end="\n\n", verbose=self.verbose + ) + + def _execute_query(self, query: str) -> List[rdflib.query.ResultRow]: + try: + return self.graph.query(query) + except Exception: + raise ValueError("Failed to execute the generated SPARQL query.") diff --git a/libs/community/langchain_community/chains/graph_qa/prompts.py b/libs/community/langchain_community/chains/graph_qa/prompts.py new file mode 100644 index 0000000000..a4b5db9583 --- /dev/null +++ b/libs/community/langchain_community/chains/graph_qa/prompts.py @@ -0,0 +1,415 @@ +# flake8: noqa +from langchain_core.prompts.prompt import PromptTemplate + +_DEFAULT_ENTITY_EXTRACTION_TEMPLATE = """Extract all entities from the following text. As a guideline, a proper noun is generally capitalized. You should definitely extract all names and places. + +Return the output as a single comma-separated list, or NONE if there is nothing of note to return. + +EXAMPLE +i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff. +Output: Langchain +END OF EXAMPLE + +EXAMPLE +i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff. I'm working with Sam. +Output: Langchain, Sam +END OF EXAMPLE + +Begin! + +{input} +Output:""" +ENTITY_EXTRACTION_PROMPT = PromptTemplate( + input_variables=["input"], template=_DEFAULT_ENTITY_EXTRACTION_TEMPLATE +) + +_DEFAULT_GRAPH_QA_TEMPLATE = """Use the following knowledge triplets to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. + +{context} + +Question: {question} +Helpful Answer:""" +GRAPH_QA_PROMPT = PromptTemplate( + template=_DEFAULT_GRAPH_QA_TEMPLATE, input_variables=["context", "question"] +) + +CYPHER_GENERATION_TEMPLATE = """Task:Generate Cypher statement to query a graph database. +Instructions: +Use only the provided relationship types and properties in the schema. +Do not use any other relationship types or properties that are not provided. +Schema: +{schema} +Note: Do not include any explanations or apologies in your responses. +Do not respond to any questions that might ask anything else than for you to construct a Cypher statement. +Do not include any text except the generated Cypher statement. + +The question is: +{question}""" +CYPHER_GENERATION_PROMPT = PromptTemplate( + input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE +) + +NEBULAGRAPH_EXTRA_INSTRUCTIONS = """ +Instructions: + +First, generate cypher then convert it to NebulaGraph Cypher dialect(rather than standard): +1. it requires explicit label specification only when referring to node properties: v.`Foo`.name +2. note explicit label specification is not needed for edge properties, so it's e.name instead of e.`Bar`.name +3. it uses double equals sign for comparison: `==` rather than `=` +For instance: +```diff +< MATCH (p:person)-[e:directed]->(m:movie) WHERE m.name = 'The Godfather II' +< RETURN p.name, e.year, m.name; +--- +> MATCH (p:`person`)-[e:directed]->(m:`movie`) WHERE m.`movie`.`name` == 'The Godfather II' +> RETURN p.`person`.`name`, e.year, m.`movie`.`name`; +```\n""" + +NGQL_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace( + "Generate Cypher", "Generate NebulaGraph Cypher" +).replace("Instructions:", NEBULAGRAPH_EXTRA_INSTRUCTIONS) + +NGQL_GENERATION_PROMPT = PromptTemplate( + input_variables=["schema", "question"], template=NGQL_GENERATION_TEMPLATE +) + +KUZU_EXTRA_INSTRUCTIONS = """ +Instructions: + +Generate the Kùzu dialect of Cypher with the following rules in mind: + +1. Do not use a `WHERE EXISTS` clause to check the existence of a property. +2. Do not omit the relationship pattern. Always use `()-[]->()` instead of `()->()`. +3. Do not include any notes or comments even if the statement does not produce the expected result. +```\n""" + +KUZU_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace( + "Generate Cypher", "Generate Kùzu Cypher" +).replace("Instructions:", KUZU_EXTRA_INSTRUCTIONS) + +KUZU_GENERATION_PROMPT = PromptTemplate( + input_variables=["schema", "question"], template=KUZU_GENERATION_TEMPLATE +) + +GREMLIN_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace("Cypher", "Gremlin") + +GREMLIN_GENERATION_PROMPT = PromptTemplate( + input_variables=["schema", "question"], template=GREMLIN_GENERATION_TEMPLATE +) + +CYPHER_QA_TEMPLATE = """You are an assistant that helps to form nice and human understandable answers. +The information part contains the provided information that you must use to construct an answer. +The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it. +Make the answer sound as a response to the question. Do not mention that you based the result on the given information. +Here is an example: + +Question: Which managers own Neo4j stocks? +Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC] +Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks. + +Follow this example when generating answers. +If the provided information is empty, say that you don't know the answer. +Information: +{context} + +Question: {question} +Helpful Answer:""" +CYPHER_QA_PROMPT = PromptTemplate( + input_variables=["context", "question"], template=CYPHER_QA_TEMPLATE +) + +SPARQL_INTENT_TEMPLATE = """Task: Identify the intent of a prompt and return the appropriate SPARQL query type. +You are an assistant that distinguishes different types of prompts and returns the corresponding SPARQL query types. +Consider only the following query types: +* SELECT: this query type corresponds to questions +* UPDATE: this query type corresponds to all requests for deleting, inserting, or changing triples +Note: Be as concise as possible. +Do not include any explanations or apologies in your responses. +Do not respond to any questions that ask for anything else than for you to identify a SPARQL query type. +Do not include any unnecessary whitespaces or any text except the query type, i.e., either return 'SELECT' or 'UPDATE'. + +The prompt is: +{prompt} +Helpful Answer:""" +SPARQL_INTENT_PROMPT = PromptTemplate( + input_variables=["prompt"], template=SPARQL_INTENT_TEMPLATE +) + +SPARQL_GENERATION_SELECT_TEMPLATE = """Task: Generate a SPARQL SELECT statement for querying a graph database. +For instance, to find all email addresses of John Doe, the following query in backticks would be suitable: +``` +PREFIX foaf: +SELECT ?email +WHERE {{ + ?person foaf:name "John Doe" . + ?person foaf:mbox ?email . +}} +``` +Instructions: +Use only the node types and properties provided in the schema. +Do not use any node types and properties that are not explicitly provided. +Include all necessary prefixes. +Schema: +{schema} +Note: Be as concise as possible. +Do not include any explanations or apologies in your responses. +Do not respond to any questions that ask for anything else than for you to construct a SPARQL query. +Do not include any text except the SPARQL query generated. + +The question is: +{prompt}""" +SPARQL_GENERATION_SELECT_PROMPT = PromptTemplate( + input_variables=["schema", "prompt"], template=SPARQL_GENERATION_SELECT_TEMPLATE +) + +SPARQL_GENERATION_UPDATE_TEMPLATE = """Task: Generate a SPARQL UPDATE statement for updating a graph database. +For instance, to add 'jane.doe@foo.bar' as a new email address for Jane Doe, the following query in backticks would be suitable: +``` +PREFIX foaf: +INSERT {{ + ?person foaf:mbox . +}} +WHERE {{ + ?person foaf:name "Jane Doe" . +}} +``` +Instructions: +Make the query as short as possible and avoid adding unnecessary triples. +Use only the node types and properties provided in the schema. +Do not use any node types and properties that are not explicitly provided. +Include all necessary prefixes. +Schema: +{schema} +Note: Be as concise as possible. +Do not include any explanations or apologies in your responses. +Do not respond to any questions that ask for anything else than for you to construct a SPARQL query. +Return only the generated SPARQL query, nothing else. + +The information to be inserted is: +{prompt}""" +SPARQL_GENERATION_UPDATE_PROMPT = PromptTemplate( + input_variables=["schema", "prompt"], template=SPARQL_GENERATION_UPDATE_TEMPLATE +) + +SPARQL_QA_TEMPLATE = """Task: Generate a natural language response from the results of a SPARQL query. +You are an assistant that creates well-written and human understandable answers. +The information part contains the information provided, which you can use to construct an answer. +The information provided is authoritative, you must never doubt it or try to use your internal knowledge to correct it. +Make your response sound like the information is coming from an AI assistant, but don't add any information. +Information: +{context} + +Question: {prompt} +Helpful Answer:""" +SPARQL_QA_PROMPT = PromptTemplate( + input_variables=["context", "prompt"], template=SPARQL_QA_TEMPLATE +) + +GRAPHDB_SPARQL_GENERATION_TEMPLATE = """ +Write a SPARQL SELECT query for querying a graph database. +The ontology schema delimited by triple backticks in Turtle format is: +``` +{schema} +``` +Use only the classes and properties provided in the schema to construct the SPARQL query. +Do not use any classes or properties that are not explicitly provided in the SPARQL query. +Include all necessary prefixes. +Do not include any explanations or apologies in your responses. +Do not wrap the query in backticks. +Do not include any text except the SPARQL query generated. +The question delimited by triple backticks is: +``` +{prompt} +``` +""" +GRAPHDB_SPARQL_GENERATION_PROMPT = PromptTemplate( + input_variables=["schema", "prompt"], + template=GRAPHDB_SPARQL_GENERATION_TEMPLATE, +) + +GRAPHDB_SPARQL_FIX_TEMPLATE = """ +This following SPARQL query delimited by triple backticks +``` +{generated_sparql} +``` +is not valid. +The error delimited by triple backticks is +``` +{error_message} +``` +Give me a correct version of the SPARQL query. +Do not change the logic of the query. +Do not include any explanations or apologies in your responses. +Do not wrap the query in backticks. +Do not include any text except the SPARQL query generated. +The ontology schema delimited by triple backticks in Turtle format is: +``` +{schema} +``` +""" + +GRAPHDB_SPARQL_FIX_PROMPT = PromptTemplate( + input_variables=["error_message", "generated_sparql", "schema"], + template=GRAPHDB_SPARQL_FIX_TEMPLATE, +) + +GRAPHDB_QA_TEMPLATE = """Task: Generate a natural language response from the results of a SPARQL query. +You are an assistant that creates well-written and human understandable answers. +The information part contains the information provided, which you can use to construct an answer. +The information provided is authoritative, you must never doubt it or try to use your internal knowledge to correct it. +Make your response sound like the information is coming from an AI assistant, but don't add any information. +Don't use internal knowledge to answer the question, just say you don't know if no information is available. +Information: +{context} + +Question: {prompt} +Helpful Answer:""" +GRAPHDB_QA_PROMPT = PromptTemplate( + input_variables=["context", "prompt"], template=GRAPHDB_QA_TEMPLATE +) + +AQL_GENERATION_TEMPLATE = """Task: Generate an ArangoDB Query Language (AQL) query from a User Input. + +You are an ArangoDB Query Language (AQL) expert responsible for translating a `User Input` into an ArangoDB Query Language (AQL) query. + +You are given an `ArangoDB Schema`. It is a JSON Object containing: +1. `Graph Schema`: Lists all Graphs within the ArangoDB Database Instance, along with their Edge Relationships. +2. `Collection Schema`: Lists all Collections within the ArangoDB Database Instance, along with their document/edge properties and a document/edge example. + +You may also be given a set of `AQL Query Examples` to help you create the `AQL Query`. If provided, the `AQL Query Examples` should be used as a reference, similar to how `ArangoDB Schema` should be used. + +Things you should do: +- Think step by step. +- Rely on `ArangoDB Schema` and `AQL Query Examples` (if provided) to generate the query. +- Begin the `AQL Query` by the `WITH` AQL keyword to specify all of the ArangoDB Collections required. +- Return the `AQL Query` wrapped in 3 backticks (```). +- Use only the provided relationship types and properties in the `ArangoDB Schema` and any `AQL Query Examples` queries. +- Only answer to requests related to generating an AQL Query. +- If a request is unrelated to generating AQL Query, say that you cannot help the user. + +Things you should not do: +- Do not use any properties/relationships that can't be inferred from the `ArangoDB Schema` or the `AQL Query Examples`. +- Do not include any text except the generated AQL Query. +- Do not provide explanations or apologies in your responses. +- Do not generate an AQL Query that removes or deletes any data. + +Under no circumstance should you generate an AQL Query that deletes any data whatsoever. + +ArangoDB Schema: +{adb_schema} + +AQL Query Examples (Optional): +{aql_examples} + +User Input: +{user_input} + +AQL Query: +""" + +AQL_GENERATION_PROMPT = PromptTemplate( + input_variables=["adb_schema", "aql_examples", "user_input"], + template=AQL_GENERATION_TEMPLATE, +) + +AQL_FIX_TEMPLATE = """Task: Address the ArangoDB Query Language (AQL) error message of an ArangoDB Query Language query. + +You are an ArangoDB Query Language (AQL) expert responsible for correcting the provided `AQL Query` based on the provided `AQL Error`. + +The `AQL Error` explains why the `AQL Query` could not be executed in the database. +The `AQL Error` may also contain the position of the error relative to the total number of lines of the `AQL Query`. +For example, 'error X at position 2:5' denotes that the error X occurs on line 2, column 5 of the `AQL Query`. + +You are also given the `ArangoDB Schema`. It is a JSON Object containing: +1. `Graph Schema`: Lists all Graphs within the ArangoDB Database Instance, along with their Edge Relationships. +2. `Collection Schema`: Lists all Collections within the ArangoDB Database Instance, along with their document/edge properties and a document/edge example. + +You will output the `Corrected AQL Query` wrapped in 3 backticks (```). Do not include any text except the Corrected AQL Query. + +Remember to think step by step. + +ArangoDB Schema: +{adb_schema} + +AQL Query: +{aql_query} + +AQL Error: +{aql_error} + +Corrected AQL Query: +""" + +AQL_FIX_PROMPT = PromptTemplate( + input_variables=[ + "adb_schema", + "aql_query", + "aql_error", + ], + template=AQL_FIX_TEMPLATE, +) + +AQL_QA_TEMPLATE = """Task: Generate a natural language `Summary` from the results of an ArangoDB Query Language query. + +You are an ArangoDB Query Language (AQL) expert responsible for creating a well-written `Summary` from the `User Input` and associated `AQL Result`. + +A user has executed an ArangoDB Query Language query, which has returned the AQL Result in JSON format. +You are responsible for creating an `Summary` based on the AQL Result. + +You are given the following information: +- `ArangoDB Schema`: contains a schema representation of the user's ArangoDB Database. +- `User Input`: the original question/request of the user, which has been translated into an AQL Query. +- `AQL Query`: the AQL equivalent of the `User Input`, translated by another AI Model. Should you deem it to be incorrect, suggest a different AQL Query. +- `AQL Result`: the JSON output returned by executing the `AQL Query` within the ArangoDB Database. + +Remember to think step by step. + +Your `Summary` should sound like it is a response to the `User Input`. +Your `Summary` should not include any mention of the `AQL Query` or the `AQL Result`. + +ArangoDB Schema: +{adb_schema} + +User Input: +{user_input} + +AQL Query: +{aql_query} + +AQL Result: +{aql_result} +""" +AQL_QA_PROMPT = PromptTemplate( + input_variables=["adb_schema", "user_input", "aql_query", "aql_result"], + template=AQL_QA_TEMPLATE, +) + + +NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS = """ +Instructions: +Generate the query in openCypher format and follow these rules: +Do not use `NONE`, `ALL` or `ANY` predicate functions, rather use list comprehensions. +Do not use `REDUCE` function. Rather use a combination of list comprehension and the `UNWIND` clause to achieve similar results. +Do not use `FOREACH` clause. Rather use a combination of `WITH` and `UNWIND` clauses to achieve similar results.{extra_instructions} +\n""" + +NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace( + "Instructions:", NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS +) + +NEPTUNE_OPENCYPHER_GENERATION_PROMPT = PromptTemplate( + input_variables=["schema", "question", "extra_instructions"], + template=NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE, +) + +NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE = """ +Write an openCypher query to answer the following question. Do not explain the answer. Only return the query.{extra_instructions} +Question: "{question}". +Here is the property graph schema: +{schema} +\n""" + +NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT = PromptTemplate( + input_variables=["schema", "question", "extra_instructions"], + template=NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE, +) diff --git a/libs/community/langchain_community/chains/graph_qa/sparql.py b/libs/community/langchain_community/chains/graph_qa/sparql.py new file mode 100644 index 0000000000..62198347fb --- /dev/null +++ b/libs/community/langchain_community/chains/graph_qa/sparql.py @@ -0,0 +1,152 @@ +""" +Question answering over an RDF or OWL graph using SPARQL. +""" +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain_core.callbacks import CallbackManagerForChainRun +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts.base import BasePromptTemplate +from langchain_core.pydantic_v1 import Field + +from langchain_community.chains.graph_qa.prompts import ( + SPARQL_GENERATION_SELECT_PROMPT, + SPARQL_GENERATION_UPDATE_PROMPT, + SPARQL_INTENT_PROMPT, + SPARQL_QA_PROMPT, +) +from langchain_community.graphs.rdf_graph import RdfGraph + + +class GraphSparqlQAChain(Chain): + """Question-answering against an RDF or OWL graph by generating SPARQL statements. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + """ + + graph: RdfGraph = Field(exclude=True) + sparql_generation_select_chain: LLMChain + sparql_generation_update_chain: LLMChain + sparql_intent_chain: LLMChain + qa_chain: LLMChain + return_sparql_query: bool = False + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + sparql_query_key: str = "sparql_query" #: :meta private: + + @property + def input_keys(self) -> List[str]: + """Return the input keys. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return the output keys. + + :meta private: + """ + _output_keys = [self.output_key] + return _output_keys + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + *, + qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT, + sparql_select_prompt: BasePromptTemplate = SPARQL_GENERATION_SELECT_PROMPT, + sparql_update_prompt: BasePromptTemplate = SPARQL_GENERATION_UPDATE_PROMPT, + sparql_intent_prompt: BasePromptTemplate = SPARQL_INTENT_PROMPT, + **kwargs: Any, + ) -> GraphSparqlQAChain: + """Initialize from LLM.""" + qa_chain = LLMChain(llm=llm, prompt=qa_prompt) + sparql_generation_select_chain = LLMChain(llm=llm, prompt=sparql_select_prompt) + sparql_generation_update_chain = LLMChain(llm=llm, prompt=sparql_update_prompt) + sparql_intent_chain = LLMChain(llm=llm, prompt=sparql_intent_prompt) + + return cls( + qa_chain=qa_chain, + sparql_generation_select_chain=sparql_generation_select_chain, + sparql_generation_update_chain=sparql_generation_update_chain, + sparql_intent_chain=sparql_intent_chain, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + """ + Generate SPARQL query, use it to retrieve a response from the gdb and answer + the question. + """ + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + prompt = inputs[self.input_key] + + _intent = self.sparql_intent_chain.run({"prompt": prompt}, callbacks=callbacks) + intent = _intent.strip() + + if "SELECT" in intent and "UPDATE" not in intent: + sparql_generation_chain = self.sparql_generation_select_chain + intent = "SELECT" + elif "UPDATE" in intent and "SELECT" not in intent: + sparql_generation_chain = self.sparql_generation_update_chain + intent = "UPDATE" + else: + raise ValueError( + "I am sorry, but this prompt seems to fit none of the currently " + "supported SPARQL query types, i.e., SELECT and UPDATE." + ) + + _run_manager.on_text("Identified intent:", end="\n", verbose=self.verbose) + _run_manager.on_text(intent, color="green", end="\n", verbose=self.verbose) + + generated_sparql = sparql_generation_chain.run( + {"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks + ) + + _run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose) + _run_manager.on_text( + generated_sparql, color="green", end="\n", verbose=self.verbose + ) + + if intent == "SELECT": + context = self.graph.query(generated_sparql) + + _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) + _run_manager.on_text( + str(context), color="green", end="\n", verbose=self.verbose + ) + result = self.qa_chain( + {"prompt": prompt, "context": context}, + callbacks=callbacks, + ) + res = result[self.qa_chain.output_key] + elif intent == "UPDATE": + self.graph.update(generated_sparql) + res = "Successfully inserted triples into the graph." + else: + raise ValueError("Unsupported SPARQL query type.") + + chain_result: Dict[str, Any] = {self.output_key: res} + if self.return_sparql_query: + chain_result[self.sparql_query_key] = generated_sparql + return chain_result diff --git a/libs/community/langchain_community/chains/llm_requests.py b/libs/community/langchain_community/chains/llm_requests.py new file mode 100644 index 0000000000..304a74fc09 --- /dev/null +++ b/libs/community/langchain_community/chains/llm_requests.py @@ -0,0 +1,97 @@ +"""Chain that hits a URL and then uses an LLM to parse results.""" +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from langchain.chains import LLMChain +from langchain.chains.base import Chain +from langchain_core.callbacks import CallbackManagerForChainRun +from langchain_core.pydantic_v1 import Extra, Field, root_validator + +from langchain_community.utilities.requests import TextRequestsWrapper + +DEFAULT_HEADERS = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36" # noqa: E501 +} + + +class LLMRequestsChain(Chain): + """Chain that requests a URL and then uses an LLM to parse results. + + **Security Note**: This chain can make GET requests to arbitrary URLs, + including internal URLs. + + Control access to who can run this chain and what network access + this chain has. + + See https://python.langchain.com/docs/security for more information. + """ + + llm_chain: LLMChain # type: ignore[valid-type] + requests_wrapper: TextRequestsWrapper = Field( + default_factory=lambda: TextRequestsWrapper(headers=DEFAULT_HEADERS), + exclude=True, + ) + text_length: int = 8000 + requests_key: str = "requests_result" #: :meta private: + input_key: str = "url" #: :meta private: + output_key: str = "output" #: :meta private: + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Will be whatever keys the prompt expects. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Will always return text key. + + :meta private: + """ + return [self.output_key] + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + try: + from bs4 import BeautifulSoup # noqa: F401 + + except ImportError: + raise ImportError( + "Could not import bs4 python package. " + "Please install it with `pip install bs4`." + ) + return values + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + from bs4 import BeautifulSoup + + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + # Other keys are assumed to be needed for LLM prediction + other_keys = {k: v for k, v in inputs.items() if k != self.input_key} + url = inputs[self.input_key] + res = self.requests_wrapper.get(url) + # extract the text from the html + soup = BeautifulSoup(res, "html.parser") + other_keys[self.requests_key] = soup.get_text()[: self.text_length] + result = self.llm_chain.predict( # type: ignore[attr-defined] + callbacks=_run_manager.get_child(), **other_keys + ) + return {self.output_key: result} + + @property + def _chain_type(self) -> str: + return "llm_requests_chain" diff --git a/libs/community/langchain_community/chains/openapi/__init__.py b/libs/community/langchain_community/chains/openapi/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/community/langchain_community/chains/openapi/chain.py b/libs/community/langchain_community/chains/openapi/chain.py new file mode 100644 index 0000000000..8112a1aaf5 --- /dev/null +++ b/libs/community/langchain_community/chains/openapi/chain.py @@ -0,0 +1,229 @@ +"""Chain that makes API calls and summarizes the responses to answer a question.""" +from __future__ import annotations + +import json +from typing import Any, Dict, List, NamedTuple, Optional, cast + +from langchain.chains.api.openapi.requests_chain import APIRequesterChain +from langchain.chains.api.openapi.response_chain import APIResponderChain +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain_core.callbacks import CallbackManagerForChainRun, Callbacks +from langchain_core.language_models import BaseLanguageModel +from langchain_core.pydantic_v1 import BaseModel, Field +from requests import Response + +from langchain_community.tools.openapi.utils.api_models import APIOperation +from langchain_community.utilities.requests import Requests + + +class _ParamMapping(NamedTuple): + """Mapping from parameter name to parameter value.""" + + query_params: List[str] + body_params: List[str] + path_params: List[str] + + +class OpenAPIEndpointChain(Chain, BaseModel): + """Chain interacts with an OpenAPI endpoint using natural language.""" + + api_request_chain: LLMChain + api_response_chain: Optional[LLMChain] + api_operation: APIOperation + requests: Requests = Field(exclude=True, default_factory=Requests) + param_mapping: _ParamMapping = Field(alias="param_mapping") + return_intermediate_steps: bool = False + instructions_key: str = "instructions" #: :meta private: + output_key: str = "output" #: :meta private: + max_text_length: Optional[int] = Field(ge=0) #: :meta private: + + @property + def input_keys(self) -> List[str]: + """Expect input key. + + :meta private: + """ + return [self.instructions_key] + + @property + def output_keys(self) -> List[str]: + """Expect output key. + + :meta private: + """ + if not self.return_intermediate_steps: + return [self.output_key] + else: + return [self.output_key, "intermediate_steps"] + + def _construct_path(self, args: Dict[str, str]) -> str: + """Construct the path from the deserialized input.""" + path = self.api_operation.base_url + self.api_operation.path + for param in self.param_mapping.path_params: + path = path.replace(f"{{{param}}}", str(args.pop(param, ""))) + return path + + def _extract_query_params(self, args: Dict[str, str]) -> Dict[str, str]: + """Extract the query params from the deserialized input.""" + query_params = {} + for param in self.param_mapping.query_params: + if param in args: + query_params[param] = args.pop(param) + return query_params + + def _extract_body_params(self, args: Dict[str, str]) -> Optional[Dict[str, str]]: + """Extract the request body params from the deserialized input.""" + body_params = None + if self.param_mapping.body_params: + body_params = {} + for param in self.param_mapping.body_params: + if param in args: + body_params[param] = args.pop(param) + return body_params + + def deserialize_json_input(self, serialized_args: str) -> dict: + """Use the serialized typescript dictionary. + + Resolve the path, query params dict, and optional requestBody dict. + """ + args: dict = json.loads(serialized_args) + path = self._construct_path(args) + body_params = self._extract_body_params(args) + query_params = self._extract_query_params(args) + return { + "url": path, + "data": body_params, + "params": query_params, + } + + def _get_output(self, output: str, intermediate_steps: dict) -> dict: + """Return the output from the API call.""" + if self.return_intermediate_steps: + return { + self.output_key: output, + "intermediate_steps": intermediate_steps, + } + else: + return {self.output_key: output} + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + intermediate_steps = {} + instructions = inputs[self.instructions_key] + instructions = instructions[: self.max_text_length] + _api_arguments = self.api_request_chain.predict_and_parse( + instructions=instructions, callbacks=_run_manager.get_child() + ) + api_arguments = cast(str, _api_arguments) + intermediate_steps["request_args"] = api_arguments + _run_manager.on_text( + api_arguments, color="green", end="\n", verbose=self.verbose + ) + if api_arguments.startswith("ERROR"): + return self._get_output(api_arguments, intermediate_steps) + elif api_arguments.startswith("MESSAGE:"): + return self._get_output( + api_arguments[len("MESSAGE:") :], intermediate_steps + ) + try: + request_args = self.deserialize_json_input(api_arguments) + method = getattr(self.requests, self.api_operation.method.value) + api_response: Response = method(**request_args) + if api_response.status_code != 200: + method_str = str(self.api_operation.method.value) + response_text = ( + f"{api_response.status_code}: {api_response.reason}" + + f"\nFor {method_str.upper()} {request_args['url']}\n" + + f"Called with args: {request_args['params']}" + ) + else: + response_text = api_response.text + except Exception as e: + response_text = f"Error with message {str(e)}" + response_text = response_text[: self.max_text_length] + intermediate_steps["response_text"] = response_text + _run_manager.on_text( + response_text, color="blue", end="\n", verbose=self.verbose + ) + if self.api_response_chain is not None: + _answer = self.api_response_chain.predict_and_parse( + response=response_text, + instructions=instructions, + callbacks=_run_manager.get_child(), + ) + answer = cast(str, _answer) + _run_manager.on_text(answer, color="yellow", end="\n", verbose=self.verbose) + return self._get_output(answer, intermediate_steps) + else: + return self._get_output(response_text, intermediate_steps) + + @classmethod + def from_url_and_method( + cls, + spec_url: str, + path: str, + method: str, + llm: BaseLanguageModel, + requests: Optional[Requests] = None, + return_intermediate_steps: bool = False, + **kwargs: Any, + # TODO: Handle async + ) -> "OpenAPIEndpointChain": + """Create an OpenAPIEndpoint from a spec at the specified url.""" + operation = APIOperation.from_openapi_url(spec_url, path, method) + return cls.from_api_operation( + operation, + requests=requests, + llm=llm, + return_intermediate_steps=return_intermediate_steps, + **kwargs, + ) + + @classmethod + def from_api_operation( + cls, + operation: APIOperation, + llm: BaseLanguageModel, + requests: Optional[Requests] = None, + verbose: bool = False, + return_intermediate_steps: bool = False, + raw_response: bool = False, + callbacks: Callbacks = None, + **kwargs: Any, + # TODO: Handle async + ) -> "OpenAPIEndpointChain": + """Create an OpenAPIEndpointChain from an operation and a spec.""" + param_mapping = _ParamMapping( + query_params=operation.query_params, + body_params=operation.body_params, + path_params=operation.path_params, + ) + requests_chain = APIRequesterChain.from_llm_and_typescript( + llm, + typescript_definition=operation.to_typescript(), + verbose=verbose, + callbacks=callbacks, + ) + if raw_response: + response_chain = None + else: + response_chain = APIResponderChain.from_llm( + llm, verbose=verbose, callbacks=callbacks + ) + _requests = requests or Requests() + return cls( + api_request_chain=requests_chain, + api_response_chain=response_chain, + api_operation=operation, + requests=_requests, + param_mapping=param_mapping, + verbose=verbose, + return_intermediate_steps=return_intermediate_steps, + callbacks=callbacks, + **kwargs, + ) diff --git a/libs/community/langchain_community/chains/openapi/prompts.py b/libs/community/langchain_community/chains/openapi/prompts.py new file mode 100644 index 0000000000..84e5a2baee --- /dev/null +++ b/libs/community/langchain_community/chains/openapi/prompts.py @@ -0,0 +1,57 @@ +# flake8: noqa +REQUEST_TEMPLATE = """You are a helpful AI Assistant. Please provide JSON arguments to agentFunc() based on the user's instructions. + +API_SCHEMA: ```typescript +{schema} +``` + +USER_INSTRUCTIONS: "{instructions}" + +Your arguments must be plain json provided in a markdown block: + +ARGS: ```json +{{valid json conforming to API_SCHEMA}} +``` + +Example +----- + +ARGS: ```json +{{"foo": "bar", "baz": {{"qux": "quux"}}}} +``` + +The block must be no more than 1 line long, and all arguments must be valid JSON. All string arguments must be wrapped in double quotes. +You MUST strictly comply to the types indicated by the provided schema, including all required args. + +If you don't have sufficient information to call the function due to things like requiring specific uuid's, you can reply with the following message: + +Message: ```text +Concise response requesting the additional information that would make calling the function successful. +``` + +Begin +----- +ARGS: +""" +RESPONSE_TEMPLATE = """You are a helpful AI assistant trained to answer user queries from API responses. +You attempted to call an API, which resulted in: +API_RESPONSE: {response} + +USER_COMMENT: "{instructions}" + + +If the API_RESPONSE can answer the USER_COMMENT respond with the following markdown json block: +Response: ```json +{{"response": "Human-understandable synthesis of the API_RESPONSE"}} +``` + +Otherwise respond with the following markdown json block: +Response Error: ```json +{{"response": "What you did and a concise statement of the resulting error. If it can be easily fixed, provide a suggestion."}} +``` + +You MUST respond as a markdown json code block. The person you are responding to CANNOT see the API_RESPONSE, so if there is any relevant information there you must include it in your response. + +Begin: +--- +""" diff --git a/libs/community/langchain_community/chains/openapi/requests_chain.py b/libs/community/langchain_community/chains/openapi/requests_chain.py new file mode 100644 index 0000000000..ab7655ad65 --- /dev/null +++ b/libs/community/langchain_community/chains/openapi/requests_chain.py @@ -0,0 +1,62 @@ +"""request parser.""" + +import json +import re +from typing import Any + +from langchain.chains.api.openapi.prompts import REQUEST_TEMPLATE +from langchain.chains.llm import LLMChain +from langchain_core.language_models import BaseLanguageModel +from langchain_core.output_parsers import BaseOutputParser +from langchain_core.prompts.prompt import PromptTemplate + + +class APIRequesterOutputParser(BaseOutputParser): + """Parse the request and error tags.""" + + def _load_json_block(self, serialized_block: str) -> str: + try: + return json.dumps(json.loads(serialized_block, strict=False)) + except json.JSONDecodeError: + return "ERROR serializing request." + + def parse(self, llm_output: str) -> str: + """Parse the request and error tags.""" + + json_match = re.search(r"```json(.*?)```", llm_output, re.DOTALL) + if json_match: + return self._load_json_block(json_match.group(1).strip()) + message_match = re.search(r"```text(.*?)```", llm_output, re.DOTALL) + if message_match: + return f"MESSAGE: {message_match.group(1).strip()}" + return "ERROR making request" + + @property + def _type(self) -> str: + return "api_requester" + + +class APIRequesterChain(LLMChain): + """Get the request parser.""" + + @classmethod + def is_lc_serializable(cls) -> bool: + return False + + @classmethod + def from_llm_and_typescript( + cls, + llm: BaseLanguageModel, + typescript_definition: str, + verbose: bool = True, + **kwargs: Any, + ) -> LLMChain: + """Get the request parser.""" + output_parser = APIRequesterOutputParser() + prompt = PromptTemplate( + template=REQUEST_TEMPLATE, + output_parser=output_parser, + partial_variables={"schema": typescript_definition}, + input_variables=["instructions"], + ) + return cls(prompt=prompt, llm=llm, verbose=verbose, **kwargs) diff --git a/libs/community/langchain_community/chains/openapi/response_chain.py b/libs/community/langchain_community/chains/openapi/response_chain.py new file mode 100644 index 0000000000..7e13afa6fd --- /dev/null +++ b/libs/community/langchain_community/chains/openapi/response_chain.py @@ -0,0 +1,57 @@ +"""Response parser.""" + +import json +import re +from typing import Any + +from langchain.chains.api.openapi.prompts import RESPONSE_TEMPLATE +from langchain.chains.llm import LLMChain +from langchain_core.language_models import BaseLanguageModel +from langchain_core.output_parsers import BaseOutputParser +from langchain_core.prompts.prompt import PromptTemplate + + +class APIResponderOutputParser(BaseOutputParser): + """Parse the response and error tags.""" + + def _load_json_block(self, serialized_block: str) -> str: + try: + response_content = json.loads(serialized_block, strict=False) + return response_content.get("response", "ERROR parsing response.") + except json.JSONDecodeError: + return "ERROR parsing response." + except: + raise + + def parse(self, llm_output: str) -> str: + """Parse the response and error tags.""" + json_match = re.search(r"```json(.*?)```", llm_output, re.DOTALL) + if json_match: + return self._load_json_block(json_match.group(1).strip()) + else: + raise ValueError(f"No response found in output: {llm_output}.") + + @property + def _type(self) -> str: + return "api_responder" + + +class APIResponderChain(LLMChain): + """Get the response parser.""" + + @classmethod + def is_lc_serializable(cls) -> bool: + return False + + @classmethod + def from_llm( + cls, llm: BaseLanguageModel, verbose: bool = True, **kwargs: Any + ) -> LLMChain: + """Get the response parser.""" + output_parser = APIResponderOutputParser() + prompt = PromptTemplate( + template=RESPONSE_TEMPLATE, + output_parser=output_parser, + input_variables=["response", "instructions"], + ) + return cls(prompt=prompt, llm=llm, verbose=verbose, **kwargs) diff --git a/libs/community/langchain_community/cross_encoders/base.py b/libs/community/langchain_community/cross_encoders/base.py index 98fa056898..8f46e0f0dc 100644 --- a/libs/community/langchain_community/cross_encoders/base.py +++ b/libs/community/langchain_community/cross_encoders/base.py @@ -1,17 +1,3 @@ -from abc import ABC, abstractmethod -from typing import List, Tuple +from langchain.retrievers.document_compressors.cross_encoder import BaseCrossEncoder - -class BaseCrossEncoder(ABC): - """Interface for cross encoder models.""" - - @abstractmethod - def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]: - """Score pairs' similarity. - - Args: - text_pairs: List of pairs of texts. - - Returns: - List of scores. - """ +__all__ = ["BaseCrossEncoder"] diff --git a/libs/community/langchain_community/query_constructors/__init__.py b/libs/community/langchain_community/query_constructors/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/community/langchain_community/query_constructors/astradb.py b/libs/community/langchain_community/query_constructors/astradb.py new file mode 100644 index 0000000000..006972935f --- /dev/null +++ b/libs/community/langchain_community/query_constructors/astradb.py @@ -0,0 +1,70 @@ +"""Logic for converting internal query language to a valid AstraDB query.""" +from typing import Dict, Tuple, Union + +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + +MULTIPLE_ARITY_COMPARATORS = [Comparator.IN, Comparator.NIN] + + +class AstraDBTranslator(Visitor): + """Translate AstraDB internal query language elements to valid filters.""" + + """Subset of allowed logical comparators.""" + allowed_comparators = [ + Comparator.EQ, + Comparator.NE, + Comparator.GT, + Comparator.GTE, + Comparator.LT, + Comparator.LTE, + Comparator.IN, + Comparator.NIN, + ] + + """Subset of allowed logical operators.""" + allowed_operators = [Operator.AND, Operator.OR] + + def _format_func(self, func: Union[Operator, Comparator]) -> str: + self._validate_func(func) + map_dict = { + Operator.AND: "$and", + Operator.OR: "$or", + Comparator.EQ: "$eq", + Comparator.NE: "$ne", + Comparator.GTE: "$gte", + Comparator.LTE: "$lte", + Comparator.LT: "$lt", + Comparator.GT: "$gt", + Comparator.IN: "$in", + Comparator.NIN: "$nin", + } + return map_dict[func] + + def visit_operation(self, operation: Operation) -> Dict: + args = [arg.accept(self) for arg in operation.arguments] + return {self._format_func(operation.operator): args} + + def visit_comparison(self, comparison: Comparison) -> Dict: + if comparison.comparator in MULTIPLE_ARITY_COMPARATORS and not isinstance( + comparison.value, list + ): + comparison.value = [comparison.value] + + comparator = self._format_func(comparison.comparator) + return {comparison.attribute: {comparator: comparison.value}} + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + if structured_query.filter is None: + kwargs = {} + else: + kwargs = {"filter": structured_query.filter.accept(self)} + return structured_query.query, kwargs diff --git a/libs/community/langchain_community/query_constructors/chroma.py b/libs/community/langchain_community/query_constructors/chroma.py new file mode 100644 index 0000000000..6f766e7e13 --- /dev/null +++ b/libs/community/langchain_community/query_constructors/chroma.py @@ -0,0 +1,50 @@ +from typing import Dict, Tuple, Union + +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + + +class ChromaTranslator(Visitor): + """Translate `Chroma` internal query language elements to valid filters.""" + + allowed_operators = [Operator.AND, Operator.OR] + """Subset of allowed logical operators.""" + allowed_comparators = [ + Comparator.EQ, + Comparator.NE, + Comparator.GT, + Comparator.GTE, + Comparator.LT, + Comparator.LTE, + ] + """Subset of allowed logical comparators.""" + + def _format_func(self, func: Union[Operator, Comparator]) -> str: + self._validate_func(func) + return f"${func.value}" + + def visit_operation(self, operation: Operation) -> Dict: + args = [arg.accept(self) for arg in operation.arguments] + return {self._format_func(operation.operator): args} + + def visit_comparison(self, comparison: Comparison) -> Dict: + return { + comparison.attribute: { + self._format_func(comparison.comparator): comparison.value + } + } + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + if structured_query.filter is None: + kwargs = {} + else: + kwargs = {"filter": structured_query.filter.accept(self)} + return structured_query.query, kwargs diff --git a/libs/community/langchain_community/query_constructors/dashvector.py b/libs/community/langchain_community/query_constructors/dashvector.py new file mode 100644 index 0000000000..c1d63d1aae --- /dev/null +++ b/libs/community/langchain_community/query_constructors/dashvector.py @@ -0,0 +1,64 @@ +"""Logic for converting internal query language to a valid DashVector query.""" +from typing import Tuple, Union + +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + + +class DashvectorTranslator(Visitor): + """Logic for converting internal query language elements to valid filters.""" + + allowed_operators = [Operator.AND, Operator.OR] + allowed_comparators = [ + Comparator.EQ, + Comparator.GT, + Comparator.GTE, + Comparator.LT, + Comparator.LTE, + Comparator.LIKE, + ] + + map_dict = { + Operator.AND: " AND ", + Operator.OR: " OR ", + Comparator.EQ: " = ", + Comparator.GT: " > ", + Comparator.GTE: " >= ", + Comparator.LT: " < ", + Comparator.LTE: " <= ", + Comparator.LIKE: " LIKE ", + } + + def _format_func(self, func: Union[Operator, Comparator]) -> str: + self._validate_func(func) + return self.map_dict[func] + + def visit_operation(self, operation: Operation) -> str: + args = [arg.accept(self) for arg in operation.arguments] + return self._format_func(operation.operator).join(args) + + def visit_comparison(self, comparison: Comparison) -> str: + value = comparison.value + if isinstance(value, str): + if comparison.comparator == Comparator.LIKE: + value = f"'%{value}%'" + else: + value = f"'{value}'" + return ( + f"{comparison.attribute}{self._format_func(comparison.comparator)}{value}" + ) + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + if structured_query.filter is None: + kwargs = {} + else: + kwargs = {"filter": structured_query.filter.accept(self)} + return structured_query.query, kwargs diff --git a/libs/community/langchain_community/query_constructors/databricks_vector_search.py b/libs/community/langchain_community/query_constructors/databricks_vector_search.py new file mode 100644 index 0000000000..03e7de8efa --- /dev/null +++ b/libs/community/langchain_community/query_constructors/databricks_vector_search.py @@ -0,0 +1,94 @@ +from collections import ChainMap +from itertools import chain +from typing import Dict, Tuple + +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + +_COMPARATOR_TO_SYMBOL = { + Comparator.EQ: "", + Comparator.GT: " >", + Comparator.GTE: " >=", + Comparator.LT: " <", + Comparator.LTE: " <=", + Comparator.IN: "", + Comparator.LIKE: " LIKE", +} + + +class DatabricksVectorSearchTranslator(Visitor): + """Translate `Databricks vector search` internal query language elements to + valid filters.""" + + """Subset of allowed logical operators.""" + allowed_operators = [Operator.AND, Operator.NOT, Operator.OR] + + """Subset of allowed logical comparators.""" + allowed_comparators = [ + Comparator.EQ, + Comparator.GT, + Comparator.GTE, + Comparator.LT, + Comparator.LTE, + Comparator.IN, + Comparator.LIKE, + ] + + def _visit_and_operation(self, operation: Operation) -> Dict: + return dict(ChainMap(*[arg.accept(self) for arg in operation.arguments])) + + def _visit_or_operation(self, operation: Operation) -> Dict: + filter_args = [arg.accept(self) for arg in operation.arguments] + flattened_args = list( + chain.from_iterable(filter_arg.items() for filter_arg in filter_args) + ) + return { + " OR ".join(key for key, _ in flattened_args): [ + value for _, value in flattened_args + ] + } + + def _visit_not_operation(self, operation: Operation) -> Dict: + if len(operation.arguments) > 1: + raise ValueError( + f'"{operation.operator.value}" can have only one argument ' + f"in Databricks vector search" + ) + filter_arg = operation.arguments[0].accept(self) + return { + f"{colum_with_bool_expression} NOT": value + for colum_with_bool_expression, value in filter_arg.items() + } + + def visit_operation(self, operation: Operation) -> Dict: + self._validate_func(operation.operator) + if operation.operator == Operator.AND: + return self._visit_and_operation(operation) + elif operation.operator == Operator.OR: + return self._visit_or_operation(operation) + elif operation.operator == Operator.NOT: + return self._visit_not_operation(operation) + else: + raise NotImplementedError( + f'Operator "{operation.operator}" is not supported' + ) + + def visit_comparison(self, comparison: Comparison) -> Dict: + self._validate_func(comparison.comparator) + comparator_symbol = _COMPARATOR_TO_SYMBOL[comparison.comparator] + return {f"{comparison.attribute}{comparator_symbol}": comparison.value} + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + if structured_query.filter is None: + kwargs = {} + else: + kwargs = {"filters": structured_query.filter.accept(self)} + return structured_query.query, kwargs diff --git a/libs/community/langchain_community/query_constructors/deeplake.py b/libs/community/langchain_community/query_constructors/deeplake.py new file mode 100644 index 0000000000..d7e2ab87d6 --- /dev/null +++ b/libs/community/langchain_community/query_constructors/deeplake.py @@ -0,0 +1,88 @@ +"""Logic for converting internal query language to a valid Chroma query.""" +from typing import Tuple, Union + +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + +COMPARATOR_TO_TQL = { + Comparator.EQ: "==", + Comparator.GT: ">", + Comparator.GTE: ">=", + Comparator.LT: "<", + Comparator.LTE: "<=", +} + + +OPERATOR_TO_TQL = { + Operator.AND: "and", + Operator.OR: "or", + Operator.NOT: "NOT", +} + + +def can_cast_to_float(string: str) -> bool: + """Check if a string can be cast to a float.""" + try: + float(string) + return True + except ValueError: + return False + + +class DeepLakeTranslator(Visitor): + """Translate `DeepLake` internal query language elements to valid filters.""" + + allowed_operators = [Operator.AND, Operator.OR, Operator.NOT] + """Subset of allowed logical operators.""" + allowed_comparators = [ + Comparator.EQ, + Comparator.GT, + Comparator.GTE, + Comparator.LT, + Comparator.LTE, + ] + """Subset of allowed logical comparators.""" + + def _format_func(self, func: Union[Operator, Comparator]) -> str: + self._validate_func(func) + if isinstance(func, Operator): + value = OPERATOR_TO_TQL[func.value] # type: ignore + elif isinstance(func, Comparator): + value = COMPARATOR_TO_TQL[func.value] # type: ignore + return f"{value}" + + def visit_operation(self, operation: Operation) -> str: + args = [arg.accept(self) for arg in operation.arguments] + operator = self._format_func(operation.operator) + return "(" + (" " + operator + " ").join(args) + ")" + + def visit_comparison(self, comparison: Comparison) -> str: + comparator = self._format_func(comparison.comparator) + values = comparison.value + if isinstance(values, list): + tql = [] + for value in values: + comparison.value = value + tql.append(self.visit_comparison(comparison)) + + return "(" + (" or ").join(tql) + ")" + + if not can_cast_to_float(comparison.value): + values = f"'{values}'" + return f"metadata['{comparison.attribute}'] {comparator} {values}" + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + if structured_query.filter is None: + kwargs = {} + else: + tqL = f"SELECT * WHERE {structured_query.filter.accept(self)}" + kwargs = {"tql": tqL} + return structured_query.query, kwargs diff --git a/libs/community/langchain_community/query_constructors/dingo.py b/libs/community/langchain_community/query_constructors/dingo.py new file mode 100644 index 0000000000..6c2402f65c --- /dev/null +++ b/libs/community/langchain_community/query_constructors/dingo.py @@ -0,0 +1,49 @@ +from typing import Tuple, Union + +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + + +class DingoDBTranslator(Visitor): + """Translate `DingoDB` internal query language elements to valid filters.""" + + allowed_comparators = ( + Comparator.EQ, + Comparator.NE, + Comparator.LT, + Comparator.LTE, + Comparator.GT, + Comparator.GTE, + ) + """Subset of allowed logical comparators.""" + allowed_operators = (Operator.AND, Operator.OR) + """Subset of allowed logical operators.""" + + def _format_func(self, func: Union[Operator, Comparator]) -> str: + self._validate_func(func) + return f"${func.value}" + + def visit_operation(self, operation: Operation) -> Operation: + return operation + + def visit_comparison(self, comparison: Comparison) -> Comparison: + return comparison + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + if structured_query.filter is None: + kwargs = {} + else: + kwargs = { + "search_params": { + "langchain_expr": structured_query.filter.accept(self) + } + } + return structured_query.query, kwargs diff --git a/libs/community/langchain_community/query_constructors/elasticsearch.py b/libs/community/langchain_community/query_constructors/elasticsearch.py new file mode 100644 index 0000000000..d07c284b12 --- /dev/null +++ b/libs/community/langchain_community/query_constructors/elasticsearch.py @@ -0,0 +1,100 @@ +from typing import Dict, Tuple, Union + +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + + +class ElasticsearchTranslator(Visitor): + """Translate `Elasticsearch` internal query language elements to valid filters.""" + + allowed_comparators = [ + Comparator.EQ, + Comparator.GT, + Comparator.GTE, + Comparator.LT, + Comparator.LTE, + Comparator.CONTAIN, + Comparator.LIKE, + ] + """Subset of allowed logical comparators.""" + + allowed_operators = [Operator.AND, Operator.OR, Operator.NOT] + """Subset of allowed logical operators.""" + + def _format_func(self, func: Union[Operator, Comparator]) -> str: + self._validate_func(func) + map_dict = { + Operator.OR: "should", + Operator.NOT: "must_not", + Operator.AND: "must", + Comparator.EQ: "term", + Comparator.GT: "gt", + Comparator.GTE: "gte", + Comparator.LT: "lt", + Comparator.LTE: "lte", + Comparator.CONTAIN: "match", + Comparator.LIKE: "match", + } + return map_dict[func] + + def visit_operation(self, operation: Operation) -> Dict: + args = [arg.accept(self) for arg in operation.arguments] + + return {"bool": {self._format_func(operation.operator): args}} + + def visit_comparison(self, comparison: Comparison) -> Dict: + # ElasticsearchStore filters require to target + # the metadata object field + field = f"metadata.{comparison.attribute}" + + is_range_comparator = comparison.comparator in [ + Comparator.GT, + Comparator.GTE, + Comparator.LT, + Comparator.LTE, + ] + + if is_range_comparator: + value = comparison.value + if isinstance(comparison.value, dict) and "date" in comparison.value: + value = comparison.value["date"] + return {"range": {field: {self._format_func(comparison.comparator): value}}} + + if comparison.comparator == Comparator.CONTAIN: + return { + self._format_func(comparison.comparator): { + field: {"query": comparison.value} + } + } + + if comparison.comparator == Comparator.LIKE: + return { + self._format_func(comparison.comparator): { + field: {"query": comparison.value, "fuzziness": "AUTO"} + } + } + + # we assume that if the value is a string, + # we want to use the keyword field + field = f"{field}.keyword" if isinstance(comparison.value, str) else field + + if isinstance(comparison.value, dict): + if "date" in comparison.value: + comparison.value = comparison.value["date"] + + return {self._format_func(comparison.comparator): {field: comparison.value}} + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + if structured_query.filter is None: + kwargs = {} + else: + kwargs = {"filter": [structured_query.filter.accept(self)]} + return structured_query.query, kwargs diff --git a/libs/community/langchain_community/query_constructors/milvus.py b/libs/community/langchain_community/query_constructors/milvus.py new file mode 100644 index 0000000000..6fb1cc5c4e --- /dev/null +++ b/libs/community/langchain_community/query_constructors/milvus.py @@ -0,0 +1,103 @@ +"""Logic for converting internal query language to a valid Milvus query.""" +from typing import Tuple, Union + +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + +COMPARATOR_TO_BER = { + Comparator.EQ: "==", + Comparator.GT: ">", + Comparator.GTE: ">=", + Comparator.LT: "<", + Comparator.LTE: "<=", + Comparator.IN: "in", + Comparator.LIKE: "like", +} + +UNARY_OPERATORS = [Operator.NOT] + + +def process_value(value: Union[int, float, str], comparator: Comparator) -> str: + """Convert a value to a string and add double quotes if it is a string. + + It required for comparators involving strings. + + Args: + value: The value to convert. + comparator: The comparator. + + Returns: + The converted value as a string. + """ + # + if isinstance(value, str): + if comparator is Comparator.LIKE: + # If the comparator is LIKE, add a percent sign after it for prefix matching + # and add double quotes + return f'"{value}%"' + else: + # If the value is already a string, add double quotes + return f'"{value}"' + else: + # If the value is not a string, convert it to a string without double quotes + return str(value) + + +class MilvusTranslator(Visitor): + """Translate Milvus internal query language elements to valid filters.""" + + """Subset of allowed logical operators.""" + allowed_operators = [Operator.AND, Operator.NOT, Operator.OR] + + """Subset of allowed logical comparators.""" + allowed_comparators = [ + Comparator.EQ, + Comparator.GT, + Comparator.GTE, + Comparator.LT, + Comparator.LTE, + Comparator.IN, + Comparator.LIKE, + ] + + def _format_func(self, func: Union[Operator, Comparator]) -> str: + self._validate_func(func) + value = func.value + if isinstance(func, Comparator): + value = COMPARATOR_TO_BER[func] + return f"{value}" + + def visit_operation(self, operation: Operation) -> str: + if operation.operator in UNARY_OPERATORS and len(operation.arguments) == 1: + operator = self._format_func(operation.operator) + return operator + "(" + operation.arguments[0].accept(self) + ")" + elif operation.operator in UNARY_OPERATORS: + raise ValueError( + f'"{operation.operator.value}" can have only one argument in Milvus' + ) + else: + args = [arg.accept(self) for arg in operation.arguments] + operator = self._format_func(operation.operator) + return "(" + (" " + operator + " ").join(args) + ")" + + def visit_comparison(self, comparison: Comparison) -> str: + comparator = self._format_func(comparison.comparator) + processed_value = process_value(comparison.value, comparison.comparator) + attribute = comparison.attribute + + return "( " + attribute + " " + comparator + " " + processed_value + " )" + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + if structured_query.filter is None: + kwargs = {} + else: + kwargs = {"expr": structured_query.filter.accept(self)} + return structured_query.query, kwargs diff --git a/libs/community/langchain_community/query_constructors/mongodb_atlas.py b/libs/community/langchain_community/query_constructors/mongodb_atlas.py new file mode 100644 index 0000000000..ebef2163be --- /dev/null +++ b/libs/community/langchain_community/query_constructors/mongodb_atlas.py @@ -0,0 +1,74 @@ +"""Logic for converting internal query language to a valid MongoDB Atlas query.""" +from typing import Dict, Tuple, Union + +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + +MULTIPLE_ARITY_COMPARATORS = [Comparator.IN, Comparator.NIN] + + +class MongoDBAtlasTranslator(Visitor): + """Translate Mongo internal query language elements to valid filters.""" + + """Subset of allowed logical comparators.""" + allowed_comparators = [ + Comparator.EQ, + Comparator.NE, + Comparator.GT, + Comparator.GTE, + Comparator.LT, + Comparator.LTE, + Comparator.IN, + Comparator.NIN, + ] + + """Subset of allowed logical operators.""" + allowed_operators = [Operator.AND, Operator.OR] + + ## Convert a operator or a comparator to Mongo Query Format + def _format_func(self, func: Union[Operator, Comparator]) -> str: + self._validate_func(func) + map_dict = { + Operator.AND: "$and", + Operator.OR: "$or", + Comparator.EQ: "$eq", + Comparator.NE: "$ne", + Comparator.GTE: "$gte", + Comparator.LTE: "$lte", + Comparator.LT: "$lt", + Comparator.GT: "$gt", + Comparator.IN: "$in", + Comparator.NIN: "$nin", + } + return map_dict[func] + + def visit_operation(self, operation: Operation) -> Dict: + args = [arg.accept(self) for arg in operation.arguments] + return {self._format_func(operation.operator): args} + + def visit_comparison(self, comparison: Comparison) -> Dict: + if comparison.comparator in MULTIPLE_ARITY_COMPARATORS and not isinstance( + comparison.value, list + ): + comparison.value = [comparison.value] + + comparator = self._format_func(comparison.comparator) + + attribute = comparison.attribute + + return {attribute: {comparator: comparison.value}} + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + if structured_query.filter is None: + kwargs = {} + else: + kwargs = {"pre_filter": structured_query.filter.accept(self)} + return structured_query.query, kwargs diff --git a/libs/community/langchain_community/query_constructors/myscale.py b/libs/community/langchain_community/query_constructors/myscale.py new file mode 100644 index 0000000000..50a74c568b --- /dev/null +++ b/libs/community/langchain_community/query_constructors/myscale.py @@ -0,0 +1,125 @@ +import re +from typing import Any, Callable, Dict, Tuple + +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + + +def _DEFAULT_COMPOSER(op_name: str) -> Callable: + """ + Default composer for logical operators. + + Args: + op_name: Name of the operator. + + Returns: + Callable that takes a list of arguments and returns a string. + """ + + def f(*args: Any) -> str: + args_: map[str] = map(str, args) + return f" {op_name} ".join(args_) + + return f + + +def _FUNCTION_COMPOSER(op_name: str) -> Callable: + """ + Composer for functions. + + Args: + op_name: Name of the function. + + Returns: + Callable that takes a list of arguments and returns a string. + """ + + def f(*args: Any) -> str: + args_: map[str] = map(str, args) + return f"{op_name}({','.join(args_)})" + + return f + + +class MyScaleTranslator(Visitor): + """Translate `MyScale` internal query language elements to valid filters.""" + + allowed_operators = [Operator.AND, Operator.OR, Operator.NOT] + """Subset of allowed logical operators.""" + + allowed_comparators = [ + Comparator.EQ, + Comparator.GT, + Comparator.GTE, + Comparator.LT, + Comparator.LTE, + Comparator.CONTAIN, + Comparator.LIKE, + ] + + map_dict = { + Operator.AND: _DEFAULT_COMPOSER("AND"), + Operator.OR: _DEFAULT_COMPOSER("OR"), + Operator.NOT: _DEFAULT_COMPOSER("NOT"), + Comparator.EQ: _DEFAULT_COMPOSER("="), + Comparator.GT: _DEFAULT_COMPOSER(">"), + Comparator.GTE: _DEFAULT_COMPOSER(">="), + Comparator.LT: _DEFAULT_COMPOSER("<"), + Comparator.LTE: _DEFAULT_COMPOSER("<="), + Comparator.CONTAIN: _FUNCTION_COMPOSER("has"), + Comparator.LIKE: _DEFAULT_COMPOSER("ILIKE"), + } + + def __init__(self, metadata_key: str = "metadata") -> None: + super().__init__() + self.metadata_key = metadata_key + + def visit_operation(self, operation: Operation) -> Dict: + args = [arg.accept(self) for arg in operation.arguments] + func = operation.operator + self._validate_func(func) + return self.map_dict[func](*args) + + def visit_comparison(self, comparison: Comparison) -> Dict: + regex = r"\((.*?)\)" + matched = re.search(r"\(\w+\)", comparison.attribute) + + # If arbitrary function is applied to an attribute + if matched: + attr = re.sub( + regex, + f"({self.metadata_key}.{matched.group(0)[1:-1]})", + comparison.attribute, + ) + else: + attr = f"{self.metadata_key}.{comparison.attribute}" + value = comparison.value + comp = comparison.comparator + + value = f"'{value}'" if isinstance(value, str) else value + + # convert timestamp for datetime objects + if isinstance(value, dict) and value.get("type") == "date": + attr = f"parseDateTime32BestEffort({attr})" + value = f"parseDateTime32BestEffort('{value['date']}')" + + # string pattern match + if comp is Comparator.LIKE: + value = f"'%{value[1:-1]}%'" + return self.map_dict[comp](attr, value) + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + print(structured_query) # noqa: T201 + if structured_query.filter is None: + kwargs = {} + else: + kwargs = {"where_str": structured_query.filter.accept(self)} + return structured_query.query, kwargs diff --git a/libs/community/langchain_community/query_constructors/opensearch.py b/libs/community/langchain_community/query_constructors/opensearch.py new file mode 100644 index 0000000000..e01ec66639 --- /dev/null +++ b/libs/community/langchain_community/query_constructors/opensearch.py @@ -0,0 +1,104 @@ +from typing import Dict, Tuple, Union + +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + + +class OpenSearchTranslator(Visitor): + """Translate `OpenSearch` internal query domain-specific + language elements to valid filters.""" + + allowed_comparators = [ + Comparator.EQ, + Comparator.LT, + Comparator.LTE, + Comparator.GT, + Comparator.GTE, + Comparator.CONTAIN, + Comparator.LIKE, + ] + """Subset of allowed logical comparators.""" + + allowed_operators = [Operator.AND, Operator.OR, Operator.NOT] + """Subset of allowed logical operators.""" + + def _format_func(self, func: Union[Operator, Comparator]) -> str: + self._validate_func(func) + comp_operator_map = { + Comparator.EQ: "term", + Comparator.LT: "lt", + Comparator.LTE: "lte", + Comparator.GT: "gt", + Comparator.GTE: "gte", + Comparator.CONTAIN: "match", + Comparator.LIKE: "fuzzy", + Operator.AND: "must", + Operator.OR: "should", + Operator.NOT: "must_not", + } + return comp_operator_map[func] + + def visit_operation(self, operation: Operation) -> Dict: + args = [arg.accept(self) for arg in operation.arguments] + + return {"bool": {self._format_func(operation.operator): args}} + + def visit_comparison(self, comparison: Comparison) -> Dict: + field = f"metadata.{comparison.attribute}" + + if comparison.comparator in [ + Comparator.LT, + Comparator.LTE, + Comparator.GT, + Comparator.GTE, + ]: + if isinstance(comparison.value, dict): + if "date" in comparison.value: + return { + "range": { + field: { + self._format_func( + comparison.comparator + ): comparison.value["date"] + } + } + } + else: + return { + "range": { + field: { + self._format_func(comparison.comparator): comparison.value + } + } + } + + if comparison.comparator == Comparator.LIKE: + return { + self._format_func(comparison.comparator): { + field: {"value": comparison.value} + } + } + + field = f"{field}.keyword" if isinstance(comparison.value, str) else field + + if isinstance(comparison.value, dict): + if "date" in comparison.value: + comparison.value = comparison.value["date"] + + return {self._format_func(comparison.comparator): {field: comparison.value}} + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + if structured_query.filter is None: + kwargs = {} + else: + kwargs = {"filter": structured_query.filter.accept(self)} + + return structured_query.query, kwargs diff --git a/libs/community/langchain_community/query_constructors/pgvector.py b/libs/community/langchain_community/query_constructors/pgvector.py new file mode 100644 index 0000000000..5fea65b01c --- /dev/null +++ b/libs/community/langchain_community/query_constructors/pgvector.py @@ -0,0 +1,52 @@ +from typing import Dict, Tuple, Union + +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + + +class PGVectorTranslator(Visitor): + """Translate `PGVector` internal query language elements to valid filters.""" + + allowed_operators = [Operator.AND, Operator.OR] + """Subset of allowed logical operators.""" + allowed_comparators = [ + Comparator.EQ, + Comparator.NE, + Comparator.GT, + Comparator.LT, + Comparator.IN, + Comparator.NIN, + Comparator.CONTAIN, + Comparator.LIKE, + ] + """Subset of allowed logical comparators.""" + + def _format_func(self, func: Union[Operator, Comparator]) -> str: + self._validate_func(func) + return f"{func.value}" + + def visit_operation(self, operation: Operation) -> Dict: + args = [arg.accept(self) for arg in operation.arguments] + return {self._format_func(operation.operator): args} + + def visit_comparison(self, comparison: Comparison) -> Dict: + return { + comparison.attribute: { + self._format_func(comparison.comparator): comparison.value + } + } + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + if structured_query.filter is None: + kwargs = {} + else: + kwargs = {"filter": structured_query.filter.accept(self)} + return structured_query.query, kwargs diff --git a/libs/community/langchain_community/query_constructors/pinecone.py b/libs/community/langchain_community/query_constructors/pinecone.py new file mode 100644 index 0000000000..99c42f393b --- /dev/null +++ b/libs/community/langchain_community/query_constructors/pinecone.py @@ -0,0 +1,57 @@ +from typing import Dict, Tuple, Union + +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + + +class PineconeTranslator(Visitor): + """Translate `Pinecone` internal query language elements to valid filters.""" + + allowed_comparators = ( + Comparator.EQ, + Comparator.NE, + Comparator.LT, + Comparator.LTE, + Comparator.GT, + Comparator.GTE, + Comparator.IN, + Comparator.NIN, + ) + """Subset of allowed logical comparators.""" + allowed_operators = (Operator.AND, Operator.OR) + """Subset of allowed logical operators.""" + + def _format_func(self, func: Union[Operator, Comparator]) -> str: + self._validate_func(func) + return f"${func.value}" + + def visit_operation(self, operation: Operation) -> Dict: + args = [arg.accept(self) for arg in operation.arguments] + return {self._format_func(operation.operator): args} + + def visit_comparison(self, comparison: Comparison) -> Dict: + if comparison.comparator in (Comparator.IN, Comparator.NIN) and not isinstance( + comparison.value, list + ): + comparison.value = [comparison.value] + + return { + comparison.attribute: { + self._format_func(comparison.comparator): comparison.value + } + } + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + if structured_query.filter is None: + kwargs = {} + else: + kwargs = {"filter": structured_query.filter.accept(self)} + return structured_query.query, kwargs diff --git a/libs/community/langchain_community/query_constructors/qdrant.py b/libs/community/langchain_community/query_constructors/qdrant.py new file mode 100644 index 0000000000..f4c3298b66 --- /dev/null +++ b/libs/community/langchain_community/query_constructors/qdrant.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Tuple + +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + +if TYPE_CHECKING: + from qdrant_client.http import models as rest + + +class QdrantTranslator(Visitor): + """Translate `Qdrant` internal query language elements to valid filters.""" + + allowed_operators = ( + Operator.AND, + Operator.OR, + Operator.NOT, + ) + """Subset of allowed logical operators.""" + + allowed_comparators = ( + Comparator.EQ, + Comparator.LT, + Comparator.LTE, + Comparator.GT, + Comparator.GTE, + Comparator.LIKE, + ) + """Subset of allowed logical comparators.""" + + def __init__(self, metadata_key: str): + self.metadata_key = metadata_key + + def visit_operation(self, operation: Operation) -> rest.Filter: + try: + from qdrant_client.http import models as rest + except ImportError as e: + raise ImportError( + "Cannot import qdrant_client. Please install with `pip install " + "qdrant-client`." + ) from e + + args = [arg.accept(self) for arg in operation.arguments] + operator = { + Operator.AND: "must", + Operator.OR: "should", + Operator.NOT: "must_not", + }[operation.operator] + return rest.Filter(**{operator: args}) + + def visit_comparison(self, comparison: Comparison) -> rest.FieldCondition: + try: + from qdrant_client.http import models as rest + except ImportError as e: + raise ImportError( + "Cannot import qdrant_client. Please install with `pip install " + "qdrant-client`." + ) from e + + self._validate_func(comparison.comparator) + attribute = self.metadata_key + "." + comparison.attribute + if comparison.comparator == Comparator.EQ: + return rest.FieldCondition( + key=attribute, match=rest.MatchValue(value=comparison.value) + ) + if comparison.comparator == Comparator.LIKE: + return rest.FieldCondition( + key=attribute, match=rest.MatchText(text=comparison.value) + ) + kwargs = {comparison.comparator.value: comparison.value} + return rest.FieldCondition(key=attribute, range=rest.Range(**kwargs)) + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + try: + from qdrant_client.http import models as rest + except ImportError as e: + raise ImportError( + "Cannot import qdrant_client. Please install with `pip install " + "qdrant-client`." + ) from e + + if structured_query.filter is None: + kwargs = {} + else: + filter = structured_query.filter.accept(self) + if isinstance(filter, rest.FieldCondition): + filter = rest.Filter(must=[filter]) + kwargs = {"filter": filter} + return structured_query.query, kwargs diff --git a/libs/community/langchain_community/query_constructors/redis.py b/libs/community/langchain_community/query_constructors/redis.py new file mode 100644 index 0000000000..e74d1eb199 --- /dev/null +++ b/libs/community/langchain_community/query_constructors/redis.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from typing import Any, Tuple + +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + +from langchain_community.vectorstores.redis import Redis +from langchain_community.vectorstores.redis.filters import ( + RedisFilterExpression, + RedisFilterField, + RedisFilterOperator, + RedisNum, + RedisTag, + RedisText, +) +from langchain_community.vectorstores.redis.schema import RedisModel + +_COMPARATOR_TO_BUILTIN_METHOD = { + Comparator.EQ: "__eq__", + Comparator.NE: "__ne__", + Comparator.LT: "__lt__", + Comparator.GT: "__gt__", + Comparator.LTE: "__le__", + Comparator.GTE: "__ge__", + Comparator.CONTAIN: "__eq__", + Comparator.LIKE: "__mod__", +} + + +class RedisTranslator(Visitor): + """Visitor for translating structured queries to Redis filter expressions.""" + + allowed_comparators = ( + Comparator.EQ, + Comparator.NE, + Comparator.LT, + Comparator.LTE, + Comparator.GT, + Comparator.GTE, + Comparator.CONTAIN, + Comparator.LIKE, + ) + """Subset of allowed logical comparators.""" + allowed_operators = (Operator.AND, Operator.OR) + """Subset of allowed logical operators.""" + + def __init__(self, schema: RedisModel) -> None: + self._schema = schema + + def _attribute_to_filter_field(self, attribute: str) -> RedisFilterField: + if attribute in [tf.name for tf in self._schema.text]: + return RedisText(attribute) + elif attribute in [tf.name for tf in self._schema.tag or []]: + return RedisTag(attribute) + elif attribute in [tf.name for tf in self._schema.numeric or []]: + return RedisNum(attribute) + else: + raise ValueError( + f"Invalid attribute {attribute} not in vector store schema. Schema is:" + f"\n{self._schema.as_dict()}" + ) + + def visit_comparison(self, comparison: Comparison) -> RedisFilterExpression: + filter_field = self._attribute_to_filter_field(comparison.attribute) + comparison_method = _COMPARATOR_TO_BUILTIN_METHOD[comparison.comparator] + return getattr(filter_field, comparison_method)(comparison.value) + + def visit_operation(self, operation: Operation) -> Any: + left = operation.arguments[0].accept(self) + if len(operation.arguments) > 2: + right = self.visit_operation( + Operation( + operator=operation.operator, arguments=operation.arguments[1:] + ) + ) + else: + right = operation.arguments[1].accept(self) + redis_operator = ( + RedisFilterOperator.OR + if operation.operator == Operator.OR + else RedisFilterOperator.AND + ) + return RedisFilterExpression(operator=redis_operator, left=left, right=right) + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + if structured_query.filter is None: + kwargs = {} + else: + kwargs = {"filter": structured_query.filter.accept(self)} + return structured_query.query, kwargs + + @classmethod + def from_vectorstore(cls, vectorstore: Redis) -> RedisTranslator: + return cls(vectorstore._schema) diff --git a/libs/community/langchain_community/query_constructors/supabase.py b/libs/community/langchain_community/query_constructors/supabase.py new file mode 100644 index 0000000000..63794cf378 --- /dev/null +++ b/libs/community/langchain_community/query_constructors/supabase.py @@ -0,0 +1,97 @@ +from typing import Any, Dict, Tuple + +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + + +class SupabaseVectorTranslator(Visitor): + """Translate Langchain filters to Supabase PostgREST filters.""" + + allowed_operators = [Operator.AND, Operator.OR] + """Subset of allowed logical operators.""" + + allowed_comparators = [ + Comparator.EQ, + Comparator.NE, + Comparator.GT, + Comparator.GTE, + Comparator.LT, + Comparator.LTE, + Comparator.LIKE, + ] + """Subset of allowed logical comparators.""" + + metadata_column = "metadata" + + def _map_comparator(self, comparator: Comparator) -> str: + """ + Maps Langchain comparator to PostgREST comparator: + + https://postgrest.org/en/stable/references/api/tables_views.html#operators + """ + postgrest_comparator = { + Comparator.EQ: "eq", + Comparator.NE: "neq", + Comparator.GT: "gt", + Comparator.GTE: "gte", + Comparator.LT: "lt", + Comparator.LTE: "lte", + Comparator.LIKE: "like", + }.get(comparator) + + if postgrest_comparator is None: + raise Exception( + f"Comparator '{comparator}' is not currently " + "supported in Supabase Vector" + ) + + return postgrest_comparator + + def _get_json_operator(self, value: Any) -> str: + if isinstance(value, str): + return "->>" + else: + return "->" + + def visit_operation(self, operation: Operation) -> str: + args = [arg.accept(self) for arg in operation.arguments] + return f"{operation.operator.value}({','.join(args)})" + + def visit_comparison(self, comparison: Comparison) -> str: + if isinstance(comparison.value, list): + return self.visit_operation( + Operation( + operator=Operator.AND, + arguments=[ + Comparison( + comparator=comparison.comparator, + attribute=comparison.attribute, + value=value, + ) + for value in comparison.value + ], + ) + ) + + return ".".join( + [ + f"{self.metadata_column}{self._get_json_operator(comparison.value)}{comparison.attribute}", + f"{self._map_comparator(comparison.comparator)}", + f"{comparison.value}", + ] + ) + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, Dict[str, str]]: + if structured_query.filter is None: + kwargs = {} + else: + kwargs = {"postgrest_filter": structured_query.filter.accept(self)} + return structured_query.query, kwargs diff --git a/libs/community/langchain_community/query_constructors/tencentvectordb.py b/libs/community/langchain_community/query_constructors/tencentvectordb.py new file mode 100644 index 0000000000..b1ec31a1a2 --- /dev/null +++ b/libs/community/langchain_community/query_constructors/tencentvectordb.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from typing import Optional, Sequence, Tuple + +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + + +class TencentVectorDBTranslator(Visitor): + """Translate StructuredQuery to Tencent VectorDB query.""" + + COMPARATOR_MAP = { + Comparator.EQ: "=", + Comparator.NE: "!=", + Comparator.GT: ">", + Comparator.GTE: ">=", + Comparator.LT: "<", + Comparator.LTE: "<=", + Comparator.IN: "in", + Comparator.NIN: "not in", + } + + allowed_comparators: Optional[Sequence[Comparator]] = list(COMPARATOR_MAP.keys()) + allowed_operators: Optional[Sequence[Operator]] = [ + Operator.AND, + Operator.OR, + Operator.NOT, + ] + + def __init__(self, meta_keys: Optional[Sequence[str]] = None): + """Initialize the translator. + + Args: + meta_keys: List of meta keys to be used in the query. Default: []. + """ + self.meta_keys = meta_keys or [] + + def visit_operation(self, operation: Operation) -> str: + """Visit an operation node and return the translated query. + + Args: + operation: Operation node to be visited. + + Returns: + Translated query. + """ + if operation.operator in (Operator.AND, Operator.OR): + ret = f" {operation.operator.value} ".join( + [arg.accept(self) for arg in operation.arguments] + ) + if operation.operator == Operator.OR: + ret = f"({ret})" + return ret + else: + return f"not ({operation.arguments[0].accept(self)})" + + def visit_comparison(self, comparison: Comparison) -> str: + """Visit a comparison node and return the translated query. + + Args: + comparison: Comparison node to be visited. + + Returns: + Translated query. + """ + if self.meta_keys and comparison.attribute not in self.meta_keys: + raise ValueError( + f"Expr Filtering found Unsupported attribute: {comparison.attribute}" + ) + + if comparison.comparator in self.COMPARATOR_MAP: + if comparison.comparator in [Comparator.IN, Comparator.NIN]: + value = map( + lambda x: f'"{x}"' if isinstance(x, str) else x, comparison.value + ) + return ( + f"{comparison.attribute}" + f" {self.COMPARATOR_MAP[comparison.comparator]} " + f"({', '.join(value)})" + ) + if isinstance(comparison.value, str): + return ( + f"{comparison.attribute} " + f"{self.COMPARATOR_MAP[comparison.comparator]}" + f' "{comparison.value}"' + ) + return ( + f"{comparison.attribute}" + f" {self.COMPARATOR_MAP[comparison.comparator]} " + f"{comparison.value}" + ) + else: + raise ValueError(f"Unsupported comparator {comparison.comparator}") + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + """Visit a structured query node and return the translated query. + + Args: + structured_query: StructuredQuery node to be visited. + + Returns: + Translated query and query kwargs. + """ + if structured_query.filter is None: + kwargs = {} + else: + kwargs = {"expr": structured_query.filter.accept(self)} + return structured_query.query, kwargs diff --git a/libs/community/langchain_community/query_constructors/timescalevector.py b/libs/community/langchain_community/query_constructors/timescalevector.py new file mode 100644 index 0000000000..bfac120bde --- /dev/null +++ b/libs/community/langchain_community/query_constructors/timescalevector.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Tuple, Union + +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + +if TYPE_CHECKING: + from timescale_vector import client + + +class TimescaleVectorTranslator(Visitor): + """Translate the internal query language elements to valid filters.""" + + allowed_operators = [Operator.AND, Operator.OR, Operator.NOT] + """Subset of allowed logical operators.""" + + allowed_comparators = [ + Comparator.EQ, + Comparator.GT, + Comparator.GTE, + Comparator.LT, + Comparator.LTE, + ] + + COMPARATOR_MAP = { + Comparator.EQ: "==", + Comparator.GT: ">", + Comparator.GTE: ">=", + Comparator.LT: "<", + Comparator.LTE: "<=", + } + + OPERATOR_MAP = {Operator.AND: "AND", Operator.OR: "OR", Operator.NOT: "NOT"} + + def _format_func(self, func: Union[Operator, Comparator]) -> str: + self._validate_func(func) + if isinstance(func, Operator): + value = self.OPERATOR_MAP[func.value] # type: ignore + elif isinstance(func, Comparator): + value = self.COMPARATOR_MAP[func.value] # type: ignore + return f"{value}" + + def visit_operation(self, operation: Operation) -> client.Predicates: + try: + from timescale_vector import client + except ImportError as e: + raise ImportError( + "Cannot import timescale-vector. Please install with `pip install " + "timescale-vector`." + ) from e + args = [arg.accept(self) for arg in operation.arguments] + return client.Predicates(*args, operator=self._format_func(operation.operator)) + + def visit_comparison(self, comparison: Comparison) -> client.Predicates: + try: + from timescale_vector import client + except ImportError as e: + raise ImportError( + "Cannot import timescale-vector. Please install with `pip install " + "timescale-vector`." + ) from e + return client.Predicates( + ( + comparison.attribute, + self._format_func(comparison.comparator), + comparison.value, + ) + ) + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + if structured_query.filter is None: + kwargs = {} + else: + kwargs = {"predicates": structured_query.filter.accept(self)} + return structured_query.query, kwargs diff --git a/libs/community/langchain_community/query_constructors/vectara.py b/libs/community/langchain_community/query_constructors/vectara.py new file mode 100644 index 0000000000..24886a1af9 --- /dev/null +++ b/libs/community/langchain_community/query_constructors/vectara.py @@ -0,0 +1,70 @@ +from typing import Tuple, Union + +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + + +def process_value(value: Union[int, float, str]) -> str: + """Convert a value to a string and add single quotes if it is a string.""" + if isinstance(value, str): + return f"'{value}'" + else: + return str(value) + + +class VectaraTranslator(Visitor): + """Translate `Vectara` internal query language elements to valid filters.""" + + allowed_operators = [Operator.AND, Operator.OR] + """Subset of allowed logical operators.""" + allowed_comparators = [ + Comparator.EQ, + Comparator.NE, + Comparator.GT, + Comparator.GTE, + Comparator.LT, + Comparator.LTE, + ] + """Subset of allowed logical comparators.""" + + def _format_func(self, func: Union[Operator, Comparator]) -> str: + map_dict = { + Operator.AND: " and ", + Operator.OR: " or ", + Comparator.EQ: "=", + Comparator.NE: "!=", + Comparator.GT: ">", + Comparator.GTE: ">=", + Comparator.LT: "<", + Comparator.LTE: "<=", + } + self._validate_func(func) + return map_dict[func] + + def visit_operation(self, operation: Operation) -> str: + args = [arg.accept(self) for arg in operation.arguments] + operator = self._format_func(operation.operator) + return "( " + operator.join(args) + " )" + + def visit_comparison(self, comparison: Comparison) -> str: + comparator = self._format_func(comparison.comparator) + processed_value = process_value(comparison.value) + attribute = comparison.attribute + return ( + "( " + "doc." + attribute + " " + comparator + " " + processed_value + " )" + ) + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + if structured_query.filter is None: + kwargs = {} + else: + kwargs = {"filter": structured_query.filter.accept(self)} + return structured_query.query, kwargs diff --git a/libs/community/langchain_community/query_constructors/weaviate.py b/libs/community/langchain_community/query_constructors/weaviate.py new file mode 100644 index 0000000000..2e5e3e691e --- /dev/null +++ b/libs/community/langchain_community/query_constructors/weaviate.py @@ -0,0 +1,79 @@ +from datetime import datetime +from typing import Dict, Tuple, Union + +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + + +class WeaviateTranslator(Visitor): + """Translate `Weaviate` internal query language elements to valid filters.""" + + allowed_operators = [Operator.AND, Operator.OR] + """Subset of allowed logical operators.""" + + allowed_comparators = [ + Comparator.EQ, + Comparator.NE, + Comparator.GTE, + Comparator.LTE, + Comparator.LT, + Comparator.GT, + ] + + def _format_func(self, func: Union[Operator, Comparator]) -> str: + self._validate_func(func) + # https://weaviate.io/developers/weaviate/api/graphql/filters + map_dict = { + Operator.AND: "And", + Operator.OR: "Or", + Comparator.EQ: "Equal", + Comparator.NE: "NotEqual", + Comparator.GTE: "GreaterThanEqual", + Comparator.LTE: "LessThanEqual", + Comparator.LT: "LessThan", + Comparator.GT: "GreaterThan", + } + return map_dict[func] + + def visit_operation(self, operation: Operation) -> Dict: + args = [arg.accept(self) for arg in operation.arguments] + return {"operator": self._format_func(operation.operator), "operands": args} + + def visit_comparison(self, comparison: Comparison) -> Dict: + value_type = "valueText" + value = comparison.value + if isinstance(comparison.value, bool): + value_type = "valueBoolean" + elif isinstance(comparison.value, float): + value_type = "valueNumber" + elif isinstance(comparison.value, int): + value_type = "valueInt" + elif ( + isinstance(comparison.value, dict) + and comparison.value.get("type") == "date" + ): + value_type = "valueDate" + # ISO 8601 timestamp, formatted as RFC3339 + date = datetime.strptime(comparison.value["date"], "%Y-%m-%d") + value = date.strftime("%Y-%m-%dT%H:%M:%SZ") + filter = { + "path": [comparison.attribute], + "operator": self._format_func(comparison.comparator), + value_type: value, + } + return filter + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + if structured_query.filter is None: + kwargs = {} + else: + kwargs = {"where_filter": structured_query.filter.accept(self)} + return structured_query.query, kwargs diff --git a/libs/community/langchain_community/retrievers/__init__.py b/libs/community/langchain_community/retrievers/__init__.py index 82e90c0741..9fcadebed7 100644 --- a/libs/community/langchain_community/retrievers/__init__.py +++ b/libs/community/langchain_community/retrievers/__init__.py @@ -123,6 +123,7 @@ if TYPE_CHECKING: from langchain_community.retrievers.weaviate_hybrid_search import ( WeaviateHybridSearchRetriever, ) + from langchain_community.retrievers.web_research import WebResearchRetriever from langchain_community.retrievers.wikipedia import ( WikipediaRetriever, ) @@ -174,6 +175,7 @@ _module_lookup = { "TavilySearchAPIRetriever": "langchain_community.retrievers.tavily_search_api", "VespaRetriever": "langchain_community.retrievers.vespa_retriever", "WeaviateHybridSearchRetriever": "langchain_community.retrievers.weaviate_hybrid_search", # noqa: E501 + "WebResearchRetriever": "langchain_community.retrievers.web_research", "WikipediaRetriever": "langchain_community.retrievers.wikipedia", "YouRetriever": "langchain_community.retrievers.you", "ZepRetriever": "langchain_community.retrievers.zep", @@ -194,8 +196,8 @@ __all__ = [ "AmazonKnowledgeBasesRetriever", "ArceeRetriever", "ArxivRetriever", - "AzureCognitiveSearchRetriever", "AzureAISearchRetriever", + "AzureCognitiveSearchRetriever", "BM25Retriever", "BreebsRetriever", "ChaindeskRetriever", @@ -209,8 +211,8 @@ __all__ = [ "GoogleDocumentAIWarehouseRetriever", "GoogleVertexAIMultiTurnSearchRetriever", "GoogleVertexAISearchRetriever", - "KNNRetriever", "KayAiRetriever", + "KNNRetriever", "LlamaIndexGraphRetriever", "LlamaIndexRetriever", "MetalRetriever", @@ -223,10 +225,11 @@ __all__ = [ "RememberizerRetriever", "RemoteLangChainRetriever", "SVMRetriever", - "TFIDFRetriever", "TavilySearchAPIRetriever", + "TFIDFRetriever", "VespaRetriever", "WeaviateHybridSearchRetriever", + "WebResearchRetriever", "WikipediaRetriever", "YouRetriever", "ZepRetriever", diff --git a/libs/community/langchain_community/retrievers/web_research.py b/libs/community/langchain_community/retrievers/web_research.py new file mode 100644 index 0000000000..9003f51740 --- /dev/null +++ b/libs/community/langchain_community/retrievers/web_research.py @@ -0,0 +1,223 @@ +import logging +import re +from typing import List, Optional + +from langchain.chains import LLMChain +from langchain.chains.prompt_selector import ConditionalPromptSelector +from langchain_core.callbacks import ( + AsyncCallbackManagerForRetrieverRun, + CallbackManagerForRetrieverRun, +) +from langchain_core.documents import Document +from langchain_core.language_models import BaseLLM +from langchain_core.output_parsers import BaseOutputParser +from langchain_core.prompts import BasePromptTemplate, PromptTemplate +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.retrievers import BaseRetriever +from langchain_core.vectorstores import VectorStore +from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter + +from langchain_community.document_loaders import AsyncHtmlLoader +from langchain_community.document_transformers import Html2TextTransformer +from langchain_community.llms import LlamaCpp +from langchain_community.utilities import GoogleSearchAPIWrapper + +logger = logging.getLogger(__name__) + + +class SearchQueries(BaseModel): + """Search queries to research for the user's goal.""" + + queries: List[str] = Field( + ..., description="List of search queries to look up on Google" + ) + + +DEFAULT_LLAMA_SEARCH_PROMPT = PromptTemplate( + input_variables=["question"], + template="""<> \n You are an assistant tasked with improving Google search \ +results. \n <> \n\n [INST] Generate THREE Google search queries that \ +are similar to this question. The output should be a numbered list of questions \ +and each should have a question mark at the end: \n\n {question} [/INST]""", +) + +DEFAULT_SEARCH_PROMPT = PromptTemplate( + input_variables=["question"], + template="""You are an assistant tasked with improving Google search \ +results. Generate THREE Google search queries that are similar to \ +this question. The output should be a numbered list of questions and each \ +should have a question mark at the end: {question}""", +) + + +class QuestionListOutputParser(BaseOutputParser[List[str]]): + """Output parser for a list of numbered questions.""" + + def parse(self, text: str) -> List[str]: + lines = re.findall(r"\d+\..*?(?:\n|$)", text) + return lines + + +class WebResearchRetriever(BaseRetriever): + """`Google Search API` retriever.""" + + # Inputs + vectorstore: VectorStore = Field( + ..., description="Vector store for storing web pages" + ) + llm_chain: LLMChain + search: GoogleSearchAPIWrapper = Field(..., description="Google Search API Wrapper") + num_search_results: int = Field(1, description="Number of pages per Google search") + text_splitter: TextSplitter = Field( + RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=50), + description="Text splitter for splitting web pages into chunks", + ) + url_database: List[str] = Field( + default_factory=list, description="List of processed URLs" + ) + + @classmethod + def from_llm( + cls, + vectorstore: VectorStore, + llm: BaseLLM, + search: GoogleSearchAPIWrapper, + prompt: Optional[BasePromptTemplate] = None, + num_search_results: int = 1, + text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter( + chunk_size=1500, chunk_overlap=150 + ), + ) -> "WebResearchRetriever": + """Initialize from llm using default template. + + Args: + vectorstore: Vector store for storing web pages + llm: llm for search question generation + search: GoogleSearchAPIWrapper + prompt: prompt to generating search questions + num_search_results: Number of pages per Google search + text_splitter: Text splitter for splitting web pages into chunks + + Returns: + WebResearchRetriever + """ + + if not prompt: + QUESTION_PROMPT_SELECTOR = ConditionalPromptSelector( + default_prompt=DEFAULT_SEARCH_PROMPT, + conditionals=[ + (lambda llm: isinstance(llm, LlamaCpp), DEFAULT_LLAMA_SEARCH_PROMPT) + ], + ) + prompt = QUESTION_PROMPT_SELECTOR.get_prompt(llm) + + # Use chat model prompt + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + output_parser=QuestionListOutputParser(), + ) + + return cls( + vectorstore=vectorstore, + llm_chain=llm_chain, + search=search, + num_search_results=num_search_results, + text_splitter=text_splitter, + ) + + def clean_search_query(self, query: str) -> str: + # Some search tools (e.g., Google) will + # fail to return results if query has a + # leading digit: 1. "LangCh..." + # Check if the first character is a digit + if query[0].isdigit(): + # Find the position of the first quote + first_quote_pos = query.find('"') + if first_quote_pos != -1: + # Extract the part of the string after the quote + query = query[first_quote_pos + 1 :] + # Remove the trailing quote if present + if query.endswith('"'): + query = query[:-1] + return query.strip() + + def search_tool(self, query: str, num_search_results: int = 1) -> List[dict]: + """Returns num_search_results pages per Google search.""" + query_clean = self.clean_search_query(query) + result = self.search.results(query_clean, num_search_results) + return result + + def _get_relevant_documents( + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun, + ) -> List[Document]: + """Search Google for documents related to the query input. + + Args: + query: user query + + Returns: + Relevant documents from all various urls. + """ + + # Get search questions + logger.info("Generating questions for Google Search ...") + result = self.llm_chain({"question": query}) + logger.info(f"Questions for Google Search (raw): {result}") + questions = result["text"] + logger.info(f"Questions for Google Search: {questions}") + + # Get urls + logger.info("Searching for relevant urls...") + urls_to_look = [] + for query in questions: + # Google search + search_results = self.search_tool(query, self.num_search_results) + logger.info("Searching for relevant urls...") + logger.info(f"Search results: {search_results}") + for res in search_results: + if res.get("link", None): + urls_to_look.append(res["link"]) + + # Relevant urls + urls = set(urls_to_look) + + # Check for any new urls that we have not processed + new_urls = list(urls.difference(self.url_database)) + + logger.info(f"New URLs to load: {new_urls}") + # Load, split, and add new urls to vectorstore + if new_urls: + loader = AsyncHtmlLoader(new_urls, ignore_load_errors=True) + html2text = Html2TextTransformer() + logger.info("Indexing new urls...") + docs = loader.load() + docs = list(html2text.transform_documents(docs)) + docs = self.text_splitter.split_documents(docs) + self.vectorstore.add_documents(docs) + self.url_database.extend(new_urls) + + # Search for relevant splits + # TODO: make this async + logger.info("Grabbing most relevant splits from urls...") + docs = [] + for query in questions: + docs.extend(self.vectorstore.similarity_search(query)) + + # Get unique docs + unique_documents_dict = { + (doc.page_content, tuple(sorted(doc.metadata.items()))): doc for doc in docs + } + unique_documents = list(unique_documents_dict.values()) + return unique_documents + + async def _aget_relevant_documents( + self, + query: str, + *, + run_manager: AsyncCallbackManagerForRetrieverRun, + ) -> List[Document]: + raise NotImplementedError diff --git a/libs/community/poetry.lock b/libs/community/poetry.lock index 148947130e..25ec29beea 100644 --- a/libs/community/poetry.lock +++ b/libs/community/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "aenum" @@ -3454,6 +3454,7 @@ files = [ {file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"}, {file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"}, {file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"}, + {file = "jq-1.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:49e44ed677713f4115bd5bf2dbae23baa4cd503be350e12a1c1f506b0687848f"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"}, @@ -3962,9 +3963,51 @@ files = [ {file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"}, ] +[[package]] +name = "langchain" +version = "0.2.0rc1" +description = "Building applications with LLMs through composability" +optional = false +python-versions = ">=3.8.1,<4.0" +files = [] +develop = true + +[package.dependencies] +aiohttp = "^3.8.3" +async-timeout = {version = "^4.0.0", markers = "python_version < \"3.11\""} +dataclasses-json = ">= 0.5.7, < 0.7" +langchain-core = "^0.1.48" +langchain-text-splitters = ">=0.0.1,<0.1" +langsmith = "^0.1.17" +numpy = "^1" +pydantic = ">=1,<3" +PyYAML = ">=5.3" +requests = "^2" +SQLAlchemy = ">=1.4,<3" +tenacity = "^8.1.0" + +[package.extras] +all = [] +azure = ["azure-ai-formrecognizer (>=3.2.1,<4.0.0)", "azure-ai-textanalytics (>=5.3.0,<6.0.0)", "azure-cognitiveservices-speech (>=1.28.0,<2.0.0)", "azure-core (>=1.26.4,<2.0.0)", "azure-cosmos (>=4.4.0b1,<5.0.0)", "azure-identity (>=1.12.0,<2.0.0)", "azure-search-documents (==11.4.0b8)", "openai (<2)"] +clarifai = ["clarifai (>=9.1.0)"] +cli = ["typer (>=0.9.0,<0.10.0)"] +cohere = ["cohere (>=4,<6)"] +docarray = ["docarray[hnswlib] (>=0.32.0,<0.33.0)"] +embeddings = ["sentence-transformers (>=2,<3)"] +extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cohere (>=4,<6)", "couchbase (>=4.1.9,<5.0.0)", "dashvector (>=1.0.1,<2.0.0)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "langchain-openai (>=0.0.2,<0.1)", "lxml (>=4.9.3,<6.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "openai (<2)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "rdflib (==7.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)"] +javascript = ["esprima (>=4.0.1,<5.0.0)"] +llms = ["clarifai (>=9.1.0)", "cohere (>=4,<6)", "huggingface_hub (>=0,<1)", "manifest-ml (>=0.0.1,<0.0.2)", "nlpcloud (>=1,<2)", "openai (<2)", "openlm (>=0.0.5,<0.0.6)", "torch (>=1,<3)", "transformers (>=4,<5)"] +openai = ["openai (<2)", "tiktoken (>=0.3.2,<0.6.0)"] +qdrant = ["qdrant-client (>=1.3.1,<2.0.0)"] +text-helpers = ["chardet (>=5.1.0,<6.0.0)"] + +[package.source] +type = "directory" +url = "../langchain" + [[package]] name = "langchain-core" -version = "0.1.51" +version = "0.1.52" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -6064,8 +6107,6 @@ files = [ {file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"}, {file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"}, {file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"}, - {file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"}, - {file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"}, {file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"}, {file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"}, {file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"}, @@ -6108,7 +6149,6 @@ files = [ {file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"}, {file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"}, @@ -6117,8 +6157,6 @@ files = [ {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"}, @@ -6651,26 +6689,31 @@ python-versions = ">=3.8" files = [ {file = "PyMuPDF-1.23.26-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:645a05321aecc8c45739f71f0eb574ce33138d19189582ffa5241fea3a8e2549"}, {file = "PyMuPDF-1.23.26-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:2dfc9e010669ae92fade6fb72aaea49ebe3b8dcd7ee4dcbbe50115abcaa4d3fe"}, + {file = "PyMuPDF-1.23.26-cp310-none-manylinux2014_aarch64.whl", hash = "sha256:734ee380b3abd038602be79114194a3cb74ac102b7c943bcb333104575922c50"}, {file = "PyMuPDF-1.23.26-cp310-none-manylinux2014_x86_64.whl", hash = "sha256:b22f8d854f8196ad5b20308c1cebad3d5189ed9f0988acbafa043947ea7e6c55"}, {file = "PyMuPDF-1.23.26-cp310-none-win32.whl", hash = "sha256:cc0f794e3466bc96b5bf79d42fbc1551428751e3fef38ebc10ac70396b676144"}, {file = "PyMuPDF-1.23.26-cp310-none-win_amd64.whl", hash = "sha256:2eb701247d8e685a24e45899d1175f01a3ce5fc792a4431c91fbb68633b29298"}, {file = "PyMuPDF-1.23.26-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:e2804a64bb57da414781e312fb0561f6be67658ad57ed4a73dce008b23fc70a6"}, {file = "PyMuPDF-1.23.26-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:97b40bb22e3056874634617a90e0ed24a5172cf71791b9e25d1d91c6743bc567"}, + {file = "PyMuPDF-1.23.26-cp311-none-manylinux2014_aarch64.whl", hash = "sha256:fab8833559bc47ab26ce736f915b8fc1dd37c108049b90396f7cd5e1004d7593"}, {file = "PyMuPDF-1.23.26-cp311-none-manylinux2014_x86_64.whl", hash = "sha256:f25aafd3e7fb9d7761a22acf2b67d704f04cc36d4dc33a3773f0eb3f4ec3606f"}, {file = "PyMuPDF-1.23.26-cp311-none-win32.whl", hash = "sha256:05e672ed3e82caca7ef02a88ace30130b1dd392a1190f03b2b58ffe7aa331400"}, {file = "PyMuPDF-1.23.26-cp311-none-win_amd64.whl", hash = "sha256:92b3c4dd4d0491d495f333be2d41f4e1c155a409bc9d04b5ff29655dccbf4655"}, {file = "PyMuPDF-1.23.26-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:a217689ede18cc6991b4e6a78afee8a440b3075d53b9dec4ba5ef7487d4547e9"}, {file = "PyMuPDF-1.23.26-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:42ad2b819b90ce1947e11b90ec5085889df0a2e3aa0207bc97ecacfc6157cabc"}, + {file = "PyMuPDF-1.23.26-cp312-none-manylinux2014_aarch64.whl", hash = "sha256:99607649f89a02bba7d8ebe96e2410664316adc95e9337f7dfeff6a154f93049"}, {file = "PyMuPDF-1.23.26-cp312-none-manylinux2014_x86_64.whl", hash = "sha256:bb42d4b8407b4de7cb58c28f01449f16f32a6daed88afb41108f1aeb3552bdd4"}, {file = "PyMuPDF-1.23.26-cp312-none-win32.whl", hash = "sha256:c40d044411615e6f0baa7d3d933b3032cf97e168c7fa77d1be8a46008c109aee"}, {file = "PyMuPDF-1.23.26-cp312-none-win_amd64.whl", hash = "sha256:3f876533aa7f9a94bcd9a0225ce72571b7808260903fec1d95c120bc842fb52d"}, {file = "PyMuPDF-1.23.26-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:52df831d46beb9ff494f5fba3e5d069af6d81f49abf6b6e799ee01f4f8fa6799"}, {file = "PyMuPDF-1.23.26-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:0bbb0cf6593e53524f3fc26fb5e6ead17c02c64791caec7c4afe61b677dedf80"}, + {file = "PyMuPDF-1.23.26-cp38-none-manylinux2014_aarch64.whl", hash = "sha256:5ef4360f20015673c20cf59b7e19afc97168795188c584254ed3778cde43ce77"}, {file = "PyMuPDF-1.23.26-cp38-none-manylinux2014_x86_64.whl", hash = "sha256:d7cd88842b2e7f4c71eef4d87c98c35646b80b60e6375392d7ce40e519261f59"}, {file = "PyMuPDF-1.23.26-cp38-none-win32.whl", hash = "sha256:6577e2f473625e2d0df5f5a3bf1e4519e94ae749733cc9937994d1b256687bfa"}, {file = "PyMuPDF-1.23.26-cp38-none-win_amd64.whl", hash = "sha256:fbe1a3255b2cd0d769b2da2c4efdd0c0f30d4961a1aac02c0f75cf951b337aa4"}, {file = "PyMuPDF-1.23.26-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:73fce034f2afea886a59ead2d0caedf27e2b2a8558b5da16d0286882e0b1eb82"}, {file = "PyMuPDF-1.23.26-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:b3de8618b7cb5b36db611083840b3bcf09b11a893e2d8262f4e042102c7e65de"}, + {file = "PyMuPDF-1.23.26-cp39-none-manylinux2014_aarch64.whl", hash = "sha256:879e7f5ad35709d8760ab6103c3d5dac8ab8043a856ab3653fd324af7358ee87"}, {file = "PyMuPDF-1.23.26-cp39-none-manylinux2014_x86_64.whl", hash = "sha256:deee96c2fd415ded7b5070d8d5b2c60679aee6ed0e28ac0d2cb998060d835c2c"}, {file = "PyMuPDF-1.23.26-cp39-none-win32.whl", hash = "sha256:9f7f4ef99dd8ac97fb0b852efa3dcbee515798078b6c79a6a13c7b1e7c5d41a4"}, {file = "PyMuPDF-1.23.26-cp39-none-win_amd64.whl", hash = "sha256:ba9a54552c7afb9ec85432c765e2fa9a81413acfaa7d70db7c9b528297749e5b"}, @@ -7111,7 +7154,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -10044,4 +10086,4 @@ extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "as [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "ca64e52a60e8ee6f2f4ea303e1779a4508f401e283f63861161cb6a9560e2178" +content-hash = "9ad37aae2905701ec099c1f9cdec59692de43e8d047ceb2ce25898b4c873b190" diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index e7cb149b36..7512ca250c 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-community" -version = "0.0.37" +version = "0.0.38rc1" description = "Community contributed LangChain integrations." authors = [] license = "MIT" @@ -10,6 +10,7 @@ repository = "https://github.com/langchain-ai/langchain" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" langchain-core = "^0.1.51" +langchain = "~0.2.0rc1" SQLAlchemy = ">=1.4,<3" requests = "^2" PyYAML = ">=5.3" @@ -126,6 +127,7 @@ pytest-socket = "^0.6.0" syrupy = "^4.0.2" requests-mock = "^1.11.0" langchain-core = {path = "../core", develop = true} +langchain = {path = "../langchain", develop = true} [tool.poetry.group.codespell] optional = true @@ -160,6 +162,7 @@ cassio = "^0.1.6" tiktoken = ">=0.3.2,<0.6.0" anthropic = "^0.3.11" langchain-core = { path = "../core", develop = true } +langchain = {path = "../langchain", develop = true} fireworks-ai = "^0.9.0" vdms = "^0.0.20" exllamav2 = "^0.0.18" @@ -181,6 +184,7 @@ types-redis = "^4.3.21.6" mypy-protobuf = "^3.0.0" langchain-core = {path = "../core", develop = true} langchain-text-splitters = {path = "../text-splitters", develop = true} +langchain = {path = "../langchain", develop = true} [tool.poetry.group.dev] optional = true diff --git a/libs/langchain/tests/integration_tests/agent/test_ainetwork_agent.py b/libs/community/tests/integration_tests/agent/test_ainetwork_agent.py similarity index 100% rename from libs/langchain/tests/integration_tests/agent/test_ainetwork_agent.py rename to libs/community/tests/integration_tests/agent/test_ainetwork_agent.py index e9fe03c148..bc93af4250 100644 --- a/libs/langchain/tests/integration_tests/agent/test_ainetwork_agent.py +++ b/libs/community/tests/integration_tests/agent/test_ainetwork_agent.py @@ -8,12 +8,12 @@ from typing import Any from urllib.error import HTTPError import pytest +from langchain.agents import AgentType, initialize_agent + from langchain_community.agent_toolkits.ainetwork.toolkit import AINetworkToolkit from langchain_community.chat_models import ChatOpenAI from langchain_community.tools.ainetwork.utils import authenticate -from langchain.agents import AgentType, initialize_agent - class Match(Enum): __test__ = False diff --git a/libs/langchain/tests/integration_tests/agent/test_powerbi_agent.py b/libs/community/tests/integration_tests/agent/test_powerbi_agent.py similarity index 97% rename from libs/langchain/tests/integration_tests/agent/test_powerbi_agent.py rename to libs/community/tests/integration_tests/agent/test_powerbi_agent.py index 26c8873b74..7f16ce8d2e 100644 --- a/libs/langchain/tests/integration_tests/agent/test_powerbi_agent.py +++ b/libs/community/tests/integration_tests/agent/test_powerbi_agent.py @@ -1,10 +1,10 @@ import pytest +from langchain_core.utils import get_from_env + from langchain_community.agent_toolkits import PowerBIToolkit, create_pbi_agent from langchain_community.chat_models import ChatOpenAI from langchain_community.utilities.powerbi import PowerBIDataset -from langchain.utils import get_from_env - def azure_installed() -> bool: try: diff --git a/libs/community/tests/integration_tests/cache/fake_embeddings.py b/libs/community/tests/integration_tests/cache/fake_embeddings.py new file mode 100644 index 0000000000..5de74832de --- /dev/null +++ b/libs/community/tests/integration_tests/cache/fake_embeddings.py @@ -0,0 +1,81 @@ +"""Fake Embedding class for testing purposes.""" +import math +from typing import List + +from langchain_core.embeddings import Embeddings + +fake_texts = ["foo", "bar", "baz"] + + +class FakeEmbeddings(Embeddings): + """Fake embeddings functionality for testing.""" + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return simple embeddings. + Embeddings encode each text as its index.""" + return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))] + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + return self.embed_documents(texts) + + def embed_query(self, text: str) -> List[float]: + """Return constant query embeddings. + Embeddings are identical to embed_documents(texts)[0]. + Distance to each text will be that text's index, + as it was passed to embed_documents.""" + return [float(1.0)] * 9 + [float(0.0)] + + async def aembed_query(self, text: str) -> List[float]: + return self.embed_query(text) + + +class ConsistentFakeEmbeddings(FakeEmbeddings): + """Fake embeddings which remember all the texts seen so far to return consistent + vectors for the same texts.""" + + def __init__(self, dimensionality: int = 10) -> None: + self.known_texts: List[str] = [] + self.dimensionality = dimensionality + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return consistent embeddings for each text seen so far.""" + out_vectors = [] + for text in texts: + if text not in self.known_texts: + self.known_texts.append(text) + vector = [float(1.0)] * (self.dimensionality - 1) + [ + float(self.known_texts.index(text)) + ] + out_vectors.append(vector) + return out_vectors + + def embed_query(self, text: str) -> List[float]: + """Return consistent embeddings for the text, if seen before, or a constant + one if the text is unknown.""" + return self.embed_documents([text])[0] + + +class AngularTwoDimensionalEmbeddings(Embeddings): + """ + From angles (as strings in units of pi) to unit embedding vectors on a circle. + """ + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """ + Make a list of texts into a list of embedding vectors. + """ + return [self.embed_query(text) for text in texts] + + def embed_query(self, text: str) -> List[float]: + """ + Convert input text to a 'vector' (list of floats). + If the text is a number, use it as the angle for the + unit vector in units of pi. + Any other input text becomes the singular result [0, 0] ! + """ + try: + angle = float(text) + return [math.cos(angle * math.pi), math.sin(angle * math.pi)] + except ValueError: + # Assume: just test string, no attention is paid to values. + return [0.0, 0.0] diff --git a/libs/langchain/tests/integration_tests/cache/test_astradb.py b/libs/community/tests/integration_tests/cache/test_astradb.py similarity index 98% rename from libs/langchain/tests/integration_tests/cache/test_astradb.py rename to libs/community/tests/integration_tests/cache/test_astradb.py index 0c973e60fd..f880767797 100644 --- a/libs/langchain/tests/integration_tests/cache/test_astradb.py +++ b/libs/community/tests/integration_tests/cache/test_astradb.py @@ -15,13 +15,13 @@ import os from typing import AsyncIterator, Iterator import pytest -from langchain_community.utilities.astradb import SetupMode +from langchain.globals import get_llm_cache, set_llm_cache from langchain_core.caches import BaseCache from langchain_core.language_models import LLM from langchain_core.outputs import Generation, LLMResult -from langchain.cache import AstraDBCache, AstraDBSemanticCache -from langchain.globals import get_llm_cache, set_llm_cache +from langchain_community.cache import AstraDBCache, AstraDBSemanticCache +from langchain_community.utilities.astradb import SetupMode from tests.integration_tests.cache.fake_embeddings import FakeEmbeddings from tests.unit_tests.llms.fake_llm import FakeLLM diff --git a/libs/langchain/tests/integration_tests/cache/test_azure_cosmosdb_cache.py b/libs/community/tests/integration_tests/cache/test_azure_cosmosdb_cache.py similarity index 100% rename from libs/langchain/tests/integration_tests/cache/test_azure_cosmosdb_cache.py rename to libs/community/tests/integration_tests/cache/test_azure_cosmosdb_cache.py index b068b3d900..d6eb6fc43e 100644 --- a/libs/langchain/tests/integration_tests/cache/test_azure_cosmosdb_cache.py +++ b/libs/community/tests/integration_tests/cache/test_azure_cosmosdb_cache.py @@ -10,14 +10,14 @@ import os import uuid import pytest +from langchain.globals import get_llm_cache, set_llm_cache +from langchain_core.outputs import Generation + from langchain_community.cache import AzureCosmosDBSemanticCache from langchain_community.vectorstores.azure_cosmos_db import ( CosmosDBSimilarityType, CosmosDBVectorSearchType, ) -from langchain_core.outputs import Generation - -from langchain.globals import get_llm_cache, set_llm_cache from tests.integration_tests.cache.fake_embeddings import ( FakeEmbeddings, ) diff --git a/libs/langchain/tests/integration_tests/cache/test_cassandra.py b/libs/community/tests/integration_tests/cache/test_cassandra.py similarity index 98% rename from libs/langchain/tests/integration_tests/cache/test_cassandra.py rename to libs/community/tests/integration_tests/cache/test_cassandra.py index a61eb764eb..44308db1bf 100644 --- a/libs/langchain/tests/integration_tests/cache/test_cassandra.py +++ b/libs/community/tests/integration_tests/cache/test_cassandra.py @@ -5,11 +5,11 @@ import time from typing import Any, Iterator, Tuple import pytest -from langchain_community.utilities.cassandra import SetupMode +from langchain.globals import get_llm_cache, set_llm_cache from langchain_core.outputs import Generation, LLMResult -from langchain.cache import CassandraCache, CassandraSemanticCache -from langchain.globals import get_llm_cache, set_llm_cache +from langchain_community.cache import CassandraCache, CassandraSemanticCache +from langchain_community.utilities.cassandra import SetupMode from tests.integration_tests.cache.fake_embeddings import FakeEmbeddings from tests.unit_tests.llms.fake_llm import FakeLLM diff --git a/libs/langchain/tests/integration_tests/cache/test_gptcache.py b/libs/community/tests/integration_tests/cache/test_gptcache.py similarity index 97% rename from libs/langchain/tests/integration_tests/cache/test_gptcache.py rename to libs/community/tests/integration_tests/cache/test_gptcache.py index 2e2f4c084c..4126dd9e3d 100644 --- a/libs/langchain/tests/integration_tests/cache/test_gptcache.py +++ b/libs/community/tests/integration_tests/cache/test_gptcache.py @@ -2,10 +2,10 @@ import os from typing import Any, Callable, Union import pytest +from langchain.globals import get_llm_cache, set_llm_cache from langchain_core.outputs import Generation -from langchain.cache import GPTCache -from langchain.globals import get_llm_cache, set_llm_cache +from langchain_community.cache import GPTCache from tests.unit_tests.llms.fake_llm import FakeLLM try: diff --git a/libs/langchain/tests/integration_tests/cache/test_momento_cache.py b/libs/community/tests/integration_tests/cache/test_momento_cache.py similarity index 98% rename from libs/langchain/tests/integration_tests/cache/test_momento_cache.py rename to libs/community/tests/integration_tests/cache/test_momento_cache.py index f5ef26ba66..6e4b41e3ce 100644 --- a/libs/langchain/tests/integration_tests/cache/test_momento_cache.py +++ b/libs/community/tests/integration_tests/cache/test_momento_cache.py @@ -11,10 +11,10 @@ from datetime import timedelta from typing import Iterator import pytest +from langchain.globals import set_llm_cache from langchain_core.outputs import Generation, LLMResult -from langchain.cache import MomentoCache -from langchain.globals import set_llm_cache +from langchain_community.cache import MomentoCache from tests.unit_tests.llms.fake_llm import FakeLLM diff --git a/libs/langchain/tests/integration_tests/cache/test_opensearch_cache.py b/libs/community/tests/integration_tests/cache/test_opensearch_cache.py similarity index 100% rename from libs/langchain/tests/integration_tests/cache/test_opensearch_cache.py rename to libs/community/tests/integration_tests/cache/test_opensearch_cache.py index eda3297e86..db203bbb39 100644 --- a/libs/langchain/tests/integration_tests/cache/test_opensearch_cache.py +++ b/libs/community/tests/integration_tests/cache/test_opensearch_cache.py @@ -1,7 +1,7 @@ -from langchain_community.cache import OpenSearchSemanticCache +from langchain.globals import get_llm_cache, set_llm_cache from langchain_core.outputs import Generation -from langchain.globals import get_llm_cache, set_llm_cache +from langchain_community.cache import OpenSearchSemanticCache from tests.integration_tests.cache.fake_embeddings import ( FakeEmbeddings, ) diff --git a/libs/langchain/tests/integration_tests/cache/test_redis_cache.py b/libs/community/tests/integration_tests/cache/test_redis_cache.py similarity index 100% rename from libs/langchain/tests/integration_tests/cache/test_redis_cache.py rename to libs/community/tests/integration_tests/cache/test_redis_cache.py index 20ed08e678..eb9f00da81 100644 --- a/libs/langchain/tests/integration_tests/cache/test_redis_cache.py +++ b/libs/community/tests/integration_tests/cache/test_redis_cache.py @@ -5,13 +5,13 @@ from contextlib import asynccontextmanager, contextmanager from typing import AsyncGenerator, Generator, List, Optional, cast import pytest -from langchain_community.cache import AsyncRedisCache, RedisCache, RedisSemanticCache +from langchain.globals import get_llm_cache, set_llm_cache from langchain_core.embeddings import Embeddings from langchain_core.load.dump import dumps from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.outputs import ChatGeneration, Generation, LLMResult -from langchain.globals import get_llm_cache, set_llm_cache +from langchain_community.cache import AsyncRedisCache, RedisCache, RedisSemanticCache from tests.integration_tests.cache.fake_embeddings import ( ConsistentFakeEmbeddings, FakeEmbeddings, diff --git a/libs/langchain/tests/integration_tests/cache/test_upstash_redis_cache.py b/libs/community/tests/integration_tests/cache/test_upstash_redis_cache.py similarity index 98% rename from libs/langchain/tests/integration_tests/cache/test_upstash_redis_cache.py rename to libs/community/tests/integration_tests/cache/test_upstash_redis_cache.py index 8b1b0d4dcf..dfbf8e9d69 100644 --- a/libs/langchain/tests/integration_tests/cache/test_upstash_redis_cache.py +++ b/libs/community/tests/integration_tests/cache/test_upstash_redis_cache.py @@ -1,11 +1,11 @@ """Test Upstash Redis cache functionality.""" import uuid +import langchain import pytest from langchain_core.outputs import Generation, LLMResult -import langchain -from langchain.cache import UpstashRedisCache +from langchain_community.cache import UpstashRedisCache from tests.unit_tests.llms.fake_chat_model import FakeChatModel from tests.unit_tests.llms.fake_llm import FakeLLM diff --git a/libs/langchain/tests/integration_tests/chains/test_dalle_agent.py b/libs/community/tests/integration_tests/chains/test_dalle_agent.py similarity index 79% rename from libs/langchain/tests/integration_tests/chains/test_dalle_agent.py rename to libs/community/tests/integration_tests/chains/test_dalle_agent.py index 943d8fcbe7..f393522ab6 100644 --- a/libs/langchain/tests/integration_tests/chains/test_dalle_agent.py +++ b/libs/community/tests/integration_tests/chains/test_dalle_agent.py @@ -1,7 +1,8 @@ """Integration test for Dall-E image generator agent.""" -from langchain_community.llms import OpenAI +from langchain.agents import AgentType, initialize_agent -from langchain.agents import AgentType, initialize_agent, load_tools +from langchain_community.agent_toolkits.load_tools import load_tools +from langchain_community.llms import OpenAI def test_call() -> None: diff --git a/libs/langchain/tests/integration_tests/chains/test_graph_database.py b/libs/community/tests/integration_tests/chains/test_graph_database.py similarity index 99% rename from libs/langchain/tests/integration_tests/chains/test_graph_database.py rename to libs/community/tests/integration_tests/chains/test_graph_database.py index 0543e8dc25..91a46a579d 100644 --- a/libs/langchain/tests/integration_tests/chains/test_graph_database.py +++ b/libs/community/tests/integration_tests/chains/test_graph_database.py @@ -1,12 +1,12 @@ """Test Graph Database Chain.""" import os +from langchain.chains.loading import load_chain + +from langchain_community.chains.graph_qa.cypher import GraphCypherQAChain from langchain_community.graphs import Neo4jGraph from langchain_community.llms.openai import OpenAI -from langchain.chains.graph_qa.cypher import GraphCypherQAChain -from langchain.chains.loading import load_chain - def test_connect_neo4j() -> None: """Test that Neo4j database is correctly instantiated and connected.""" diff --git a/libs/langchain/tests/integration_tests/chains/test_graph_database_arangodb.py b/libs/community/tests/integration_tests/chains/test_graph_database_arangodb.py similarity index 97% rename from libs/langchain/tests/integration_tests/chains/test_graph_database_arangodb.py rename to libs/community/tests/integration_tests/chains/test_graph_database_arangodb.py index 9de49ff9a5..35e8614489 100644 --- a/libs/langchain/tests/integration_tests/chains/test_graph_database_arangodb.py +++ b/libs/community/tests/integration_tests/chains/test_graph_database_arangodb.py @@ -1,12 +1,11 @@ """Test Graph Database Chain.""" from typing import Any +from langchain_community.chains.graph_qa.arangodb import ArangoGraphQAChain from langchain_community.graphs import ArangoGraph from langchain_community.graphs.arangodb_graph import get_arangodb_client from langchain_community.llms.openai import OpenAI -from langchain.chains.graph_qa.arangodb import ArangoGraphQAChain - def populate_arangodb_database(db: Any) -> None: if db.has_graph("GameOfThrones"): diff --git a/libs/langchain/tests/integration_tests/chains/test_graph_database_sparql.py b/libs/community/tests/integration_tests/chains/test_graph_database_sparql.py similarity index 99% rename from libs/langchain/tests/integration_tests/chains/test_graph_database_sparql.py rename to libs/community/tests/integration_tests/chains/test_graph_database_sparql.py index ffa1afae04..2b2a3b676f 100644 --- a/libs/langchain/tests/integration_tests/chains/test_graph_database_sparql.py +++ b/libs/community/tests/integration_tests/chains/test_graph_database_sparql.py @@ -3,10 +3,10 @@ import pathlib import re from unittest.mock import MagicMock, Mock -from langchain_community.graphs import RdfGraph - from langchain.chains import LLMChain -from langchain.chains.graph_qa.sparql import GraphSparqlQAChain + +from langchain_community.chains.graph_qa.sparql import GraphSparqlQAChain +from langchain_community.graphs import RdfGraph """ cd libs/langchain/tests/integration_tests/chains/docker-compose-ontotext-graphdb diff --git a/libs/langchain/tests/integration_tests/chains/test_ontotext_graphdb_qa.py b/libs/community/tests/integration_tests/chains/test_ontotext_graphdb_qa.py similarity index 99% rename from libs/langchain/tests/integration_tests/chains/test_ontotext_graphdb_qa.py rename to libs/community/tests/integration_tests/chains/test_ontotext_graphdb_qa.py index 77b3c26a5c..34955b6643 100644 --- a/libs/langchain/tests/integration_tests/chains/test_ontotext_graphdb_qa.py +++ b/libs/community/tests/integration_tests/chains/test_ontotext_graphdb_qa.py @@ -1,9 +1,10 @@ from unittest.mock import MagicMock, Mock import pytest -from langchain_community.graphs import OntotextGraphDBGraph +from langchain.chains import LLMChain -from langchain.chains import LLMChain, OntotextGraphDBQAChain +from langchain_community.chains.graph_qa.ontotext_graphdb import OntotextGraphDBQAChain +from langchain_community.graphs import OntotextGraphDBGraph """ cd libs/langchain/tests/integration_tests/chains/docker-compose-ontotext-graphdb diff --git a/libs/langchain/tests/integration_tests/chains/test_react.py b/libs/community/tests/integration_tests/chains/test_react.py similarity index 92% rename from libs/langchain/tests/integration_tests/chains/test_react.py rename to libs/community/tests/integration_tests/chains/test_react.py index b577b77871..a1fc7ba83e 100644 --- a/libs/langchain/tests/integration_tests/chains/test_react.py +++ b/libs/community/tests/integration_tests/chains/test_react.py @@ -1,9 +1,9 @@ """Integration test for self ask with search.""" -from langchain_community.llms.openai import OpenAI - from langchain.agents.react.base import ReActChain -from langchain.docstore.wikipedia import Wikipedia + +from langchain_community.docstore import Wikipedia +from langchain_community.llms.openai import OpenAI def test_react() -> None: diff --git a/libs/langchain/tests/integration_tests/chains/test_retrieval_qa.py b/libs/community/tests/integration_tests/chains/test_retrieval_qa.py similarity index 100% rename from libs/langchain/tests/integration_tests/chains/test_retrieval_qa.py rename to libs/community/tests/integration_tests/chains/test_retrieval_qa.py index eb7e281136..16091aa4b7 100644 --- a/libs/langchain/tests/integration_tests/chains/test_retrieval_qa.py +++ b/libs/community/tests/integration_tests/chains/test_retrieval_qa.py @@ -1,14 +1,14 @@ """Test RetrievalQA functionality.""" from pathlib import Path +from langchain.chains import RetrievalQA +from langchain.chains.loading import load_chain +from langchain_text_splitters.character import CharacterTextSplitter + from langchain_community.document_loaders import TextLoader from langchain_community.embeddings.openai import OpenAIEmbeddings from langchain_community.llms import OpenAI from langchain_community.vectorstores import FAISS -from langchain_text_splitters.character import CharacterTextSplitter - -from langchain.chains import RetrievalQA -from langchain.chains.loading import load_chain def test_retrieval_qa_saving_loading(tmp_path: Path) -> None: diff --git a/libs/langchain/tests/integration_tests/chains/test_retrieval_qa_with_sources.py b/libs/community/tests/integration_tests/chains/test_retrieval_qa_with_sources.py similarity index 100% rename from libs/langchain/tests/integration_tests/chains/test_retrieval_qa_with_sources.py rename to libs/community/tests/integration_tests/chains/test_retrieval_qa_with_sources.py index c8752cd7e7..f3c1661fce 100644 --- a/libs/langchain/tests/integration_tests/chains/test_retrieval_qa_with_sources.py +++ b/libs/community/tests/integration_tests/chains/test_retrieval_qa_with_sources.py @@ -1,12 +1,12 @@ """Test RetrievalQA functionality.""" +from langchain.chains import RetrievalQAWithSourcesChain +from langchain.chains.loading import load_chain +from langchain_text_splitters.character import CharacterTextSplitter + from langchain_community.document_loaders import DirectoryLoader from langchain_community.embeddings.openai import OpenAIEmbeddings from langchain_community.llms import OpenAI from langchain_community.vectorstores import FAISS -from langchain_text_splitters.character import CharacterTextSplitter - -from langchain.chains import RetrievalQAWithSourcesChain -from langchain.chains.loading import load_chain def test_retrieval_qa_with_sources_chain_saving_loading(tmp_path: str) -> None: diff --git a/libs/langchain/tests/integration_tests/chains/test_self_ask_with_search.py b/libs/community/tests/integration_tests/chains/test_self_ask_with_search.py similarity index 100% rename from libs/langchain/tests/integration_tests/chains/test_self_ask_with_search.py rename to libs/community/tests/integration_tests/chains/test_self_ask_with_search.py index 5cf0f93b8c..543790e14a 100644 --- a/libs/langchain/tests/integration_tests/chains/test_self_ask_with_search.py +++ b/libs/community/tests/integration_tests/chains/test_self_ask_with_search.py @@ -1,9 +1,9 @@ """Integration test for self ask with search.""" +from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain + from langchain_community.llms.openai import OpenAI from langchain_community.utilities.searchapi import SearchApiAPIWrapper -from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain - def test_self_ask_with_search() -> None: """Test functionality on a prompt.""" diff --git a/libs/community/tests/integration_tests/document_transformers/__init__.py b/libs/community/tests/integration_tests/document_transformers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/langchain/tests/integration_tests/test_document_transformers.py b/libs/community/tests/integration_tests/document_transformers/test_embeddings_filter.py similarity index 99% rename from libs/langchain/tests/integration_tests/test_document_transformers.py rename to libs/community/tests/integration_tests/document_transformers/test_embeddings_filter.py index 25188f2ae7..b9fba7681b 100644 --- a/libs/langchain/tests/integration_tests/test_document_transformers.py +++ b/libs/community/tests/integration_tests/document_transformers/test_embeddings_filter.py @@ -1,11 +1,12 @@ """Integration test for embedding-based redundant doc filtering.""" +from langchain_core.documents import Document + from langchain_community.document_transformers.embeddings_redundant_filter import ( EmbeddingsClusteringFilter, EmbeddingsRedundantFilter, _DocumentWithState, ) from langchain_community.embeddings import OpenAIEmbeddings -from langchain_core.documents import Document def test_embeddings_redundant_filter() -> None: diff --git a/libs/langchain/tests/integration_tests/memory/test_cosmos_db.py b/libs/community/tests/integration_tests/memory/test_cosmos_db.py similarity index 100% rename from libs/langchain/tests/integration_tests/memory/test_cosmos_db.py rename to libs/community/tests/integration_tests/memory/test_cosmos_db.py index c2c085144c..94ca42af1c 100644 --- a/libs/langchain/tests/integration_tests/memory/test_cosmos_db.py +++ b/libs/community/tests/integration_tests/memory/test_cosmos_db.py @@ -1,10 +1,10 @@ import json import os -from langchain_community.chat_message_histories import CosmosDBChatMessageHistory +from langchain.memory import ConversationBufferMemory from langchain_core.messages import message_to_dict -from langchain.memory import ConversationBufferMemory +from langchain_community.chat_message_histories import CosmosDBChatMessageHistory # Replace these with your Azure Cosmos DB endpoint and key endpoint = os.environ.get("COSMOS_DB_ENDPOINT", "") diff --git a/libs/langchain/tests/integration_tests/memory/test_elasticsearch.py b/libs/community/tests/integration_tests/memory/test_elasticsearch.py similarity index 100% rename from libs/langchain/tests/integration_tests/memory/test_elasticsearch.py rename to libs/community/tests/integration_tests/memory/test_elasticsearch.py index 38a5d0635d..4e4ab4e135 100644 --- a/libs/langchain/tests/integration_tests/memory/test_elasticsearch.py +++ b/libs/community/tests/integration_tests/memory/test_elasticsearch.py @@ -4,10 +4,10 @@ import uuid from typing import Generator, Union import pytest -from langchain_community.chat_message_histories import ElasticsearchChatMessageHistory +from langchain.memory import ConversationBufferMemory from langchain_core.messages import message_to_dict -from langchain.memory import ConversationBufferMemory +from langchain_community.chat_message_histories import ElasticsearchChatMessageHistory """ cd tests/integration_tests/memory/docker-compose diff --git a/libs/langchain/tests/integration_tests/memory/test_firestore.py b/libs/community/tests/integration_tests/memory/test_firestore.py similarity index 100% rename from libs/langchain/tests/integration_tests/memory/test_firestore.py rename to libs/community/tests/integration_tests/memory/test_firestore.py index 5d9fabf4ae..ae4be495c5 100644 --- a/libs/langchain/tests/integration_tests/memory/test_firestore.py +++ b/libs/community/tests/integration_tests/memory/test_firestore.py @@ -1,9 +1,9 @@ import json -from langchain_community.chat_message_histories import FirestoreChatMessageHistory +from langchain.memory import ConversationBufferMemory from langchain_core.messages import message_to_dict -from langchain.memory import ConversationBufferMemory +from langchain_community.chat_message_histories import FirestoreChatMessageHistory def test_memory_with_message_store() -> None: diff --git a/libs/langchain/tests/integration_tests/memory/test_astradb.py b/libs/community/tests/integration_tests/memory/test_memory_astradb.py similarity index 100% rename from libs/langchain/tests/integration_tests/memory/test_astradb.py rename to libs/community/tests/integration_tests/memory/test_memory_astradb.py index 4caf39985e..0728535d80 100644 --- a/libs/langchain/tests/integration_tests/memory/test_astradb.py +++ b/libs/community/tests/integration_tests/memory/test_memory_astradb.py @@ -2,13 +2,13 @@ import os from typing import AsyncIterable, Iterable import pytest +from langchain.memory import ConversationBufferMemory +from langchain_core.messages import AIMessage, HumanMessage + from langchain_community.chat_message_histories.astradb import ( AstraDBChatMessageHistory, ) from langchain_community.utilities.astradb import SetupMode -from langchain_core.messages import AIMessage, HumanMessage - -from langchain.memory import ConversationBufferMemory def _has_env_vars() -> bool: diff --git a/libs/langchain/tests/integration_tests/memory/test_cassandra.py b/libs/community/tests/integration_tests/memory/test_memory_cassandra.py similarity index 100% rename from libs/langchain/tests/integration_tests/memory/test_cassandra.py rename to libs/community/tests/integration_tests/memory/test_memory_cassandra.py index 668cc87a44..6ff03ba6e7 100644 --- a/libs/langchain/tests/integration_tests/memory/test_cassandra.py +++ b/libs/community/tests/integration_tests/memory/test_memory_cassandra.py @@ -2,12 +2,12 @@ import os import time from typing import Optional +from langchain.memory import ConversationBufferMemory +from langchain_core.messages import AIMessage, HumanMessage + from langchain_community.chat_message_histories.cassandra import ( CassandraChatMessageHistory, ) -from langchain_core.messages import AIMessage, HumanMessage - -from langchain.memory import ConversationBufferMemory def _chat_message_history( diff --git a/libs/langchain/tests/integration_tests/memory/test_momento.py b/libs/community/tests/integration_tests/memory/test_momento.py similarity index 100% rename from libs/langchain/tests/integration_tests/memory/test_momento.py rename to libs/community/tests/integration_tests/memory/test_momento.py index b5d0e7f43c..5f0d468554 100644 --- a/libs/langchain/tests/integration_tests/memory/test_momento.py +++ b/libs/community/tests/integration_tests/memory/test_momento.py @@ -10,10 +10,10 @@ from datetime import timedelta from typing import Iterator import pytest -from langchain_community.chat_message_histories import MomentoChatMessageHistory +from langchain.memory import ConversationBufferMemory from langchain_core.messages import message_to_dict -from langchain.memory import ConversationBufferMemory +from langchain_community.chat_message_histories import MomentoChatMessageHistory def random_string() -> str: diff --git a/libs/langchain/tests/integration_tests/memory/test_mongodb.py b/libs/community/tests/integration_tests/memory/test_mongodb.py similarity index 100% rename from libs/langchain/tests/integration_tests/memory/test_mongodb.py rename to libs/community/tests/integration_tests/memory/test_mongodb.py index f99242ebee..384bbb52e7 100644 --- a/libs/langchain/tests/integration_tests/memory/test_mongodb.py +++ b/libs/community/tests/integration_tests/memory/test_mongodb.py @@ -1,10 +1,10 @@ import json import os -from langchain_community.chat_message_histories import MongoDBChatMessageHistory +from langchain.memory import ConversationBufferMemory from langchain_core.messages import message_to_dict -from langchain.memory import ConversationBufferMemory +from langchain_community.chat_message_histories import MongoDBChatMessageHistory # Replace these with your mongodb connection string connection_string = os.environ.get("MONGODB_CONNECTION_STRING", "") diff --git a/libs/langchain/tests/integration_tests/memory/test_neo4j.py b/libs/community/tests/integration_tests/memory/test_neo4j.py similarity index 100% rename from libs/langchain/tests/integration_tests/memory/test_neo4j.py rename to libs/community/tests/integration_tests/memory/test_neo4j.py index 06a4956944..5f1aa954d1 100644 --- a/libs/langchain/tests/integration_tests/memory/test_neo4j.py +++ b/libs/community/tests/integration_tests/memory/test_neo4j.py @@ -1,9 +1,9 @@ import json -from langchain_community.chat_message_histories import Neo4jChatMessageHistory +from langchain.memory import ConversationBufferMemory from langchain_core.messages import message_to_dict -from langchain.memory import ConversationBufferMemory +from langchain_community.chat_message_histories import Neo4jChatMessageHistory def test_memory_with_message_store() -> None: diff --git a/libs/langchain/tests/integration_tests/memory/test_redis.py b/libs/community/tests/integration_tests/memory/test_redis.py similarity index 100% rename from libs/langchain/tests/integration_tests/memory/test_redis.py rename to libs/community/tests/integration_tests/memory/test_redis.py index cb00d04fa0..22863d1b8f 100644 --- a/libs/langchain/tests/integration_tests/memory/test_redis.py +++ b/libs/community/tests/integration_tests/memory/test_redis.py @@ -1,9 +1,9 @@ import json -from langchain_community.chat_message_histories import RedisChatMessageHistory +from langchain.memory import ConversationBufferMemory from langchain_core.messages import message_to_dict -from langchain.memory import ConversationBufferMemory +from langchain_community.chat_message_histories import RedisChatMessageHistory def test_memory_with_message_store() -> None: diff --git a/libs/langchain/tests/integration_tests/memory/test_rockset.py b/libs/community/tests/integration_tests/memory/test_rockset.py similarity index 100% rename from libs/langchain/tests/integration_tests/memory/test_rockset.py rename to libs/community/tests/integration_tests/memory/test_rockset.py index ce220aea72..21540d462d 100644 --- a/libs/langchain/tests/integration_tests/memory/test_rockset.py +++ b/libs/community/tests/integration_tests/memory/test_rockset.py @@ -8,10 +8,10 @@ and ROCKSET_REGION environment variables set. import json import os -from langchain_community.chat_message_histories import RocksetChatMessageHistory +from langchain.memory import ConversationBufferMemory from langchain_core.messages import message_to_dict -from langchain.memory import ConversationBufferMemory +from langchain_community.chat_message_histories import RocksetChatMessageHistory collection_name = "langchain_demo" session_id = "MySession" diff --git a/libs/langchain/tests/integration_tests/memory/test_singlestoredb.py b/libs/community/tests/integration_tests/memory/test_singlestoredb.py similarity index 91% rename from libs/langchain/tests/integration_tests/memory/test_singlestoredb.py rename to libs/community/tests/integration_tests/memory/test_singlestoredb.py index 94bb0e7197..fbc895455b 100644 --- a/libs/langchain/tests/integration_tests/memory/test_singlestoredb.py +++ b/libs/community/tests/integration_tests/memory/test_singlestoredb.py @@ -1,8 +1,9 @@ import json +from langchain.memory import ConversationBufferMemory from langchain_core.messages import message_to_dict -from langchain.memory import ConversationBufferMemory, SingleStoreDBChatMessageHistory +from langchain_community.chat_message_histories import SingleStoreDBChatMessageHistory # Replace these with your mongodb connection string TEST_SINGLESTOREDB_URL = "root:pass@localhost:3306/db" diff --git a/libs/langchain/tests/integration_tests/memory/test_upstash_redis.py b/libs/community/tests/integration_tests/memory/test_upstash_redis.py similarity index 100% rename from libs/langchain/tests/integration_tests/memory/test_upstash_redis.py rename to libs/community/tests/integration_tests/memory/test_upstash_redis.py index f03aa5fca1..4e6b24362f 100644 --- a/libs/langchain/tests/integration_tests/memory/test_upstash_redis.py +++ b/libs/community/tests/integration_tests/memory/test_upstash_redis.py @@ -1,12 +1,12 @@ import json import pytest +from langchain.memory import ConversationBufferMemory +from langchain_core.messages import message_to_dict + from langchain_community.chat_message_histories.upstash_redis import ( UpstashRedisChatMessageHistory, ) -from langchain_core.messages import message_to_dict - -from langchain.memory import ConversationBufferMemory URL = "" TOKEN = "" diff --git a/libs/langchain/tests/integration_tests/memory/test_xata.py b/libs/community/tests/integration_tests/memory/test_xata.py similarity index 100% rename from libs/langchain/tests/integration_tests/memory/test_xata.py rename to libs/community/tests/integration_tests/memory/test_xata.py index bfcf90e98a..c53c2963fd 100644 --- a/libs/langchain/tests/integration_tests/memory/test_xata.py +++ b/libs/community/tests/integration_tests/memory/test_xata.py @@ -6,10 +6,10 @@ Before running this test, please create a Xata database. import json import os -from langchain_community.chat_message_histories import XataChatMessageHistory +from langchain.memory import ConversationBufferMemory from langchain_core.messages import message_to_dict -from langchain.memory import ConversationBufferMemory +from langchain_community.chat_message_histories import XataChatMessageHistory class TestXata: diff --git a/libs/langchain/tests/integration_tests/prompts/test_ngram_overlap_example_selector.py b/libs/community/tests/integration_tests/prompts/test_ngram_overlap_example_selector.py similarity index 97% rename from libs/langchain/tests/integration_tests/prompts/test_ngram_overlap_example_selector.py rename to libs/community/tests/integration_tests/prompts/test_ngram_overlap_example_selector.py index 401042bec9..c803298797 100644 --- a/libs/langchain/tests/integration_tests/prompts/test_ngram_overlap_example_selector.py +++ b/libs/community/tests/integration_tests/prompts/test_ngram_overlap_example_selector.py @@ -3,7 +3,7 @@ import pytest from langchain_core.prompts import PromptTemplate -from langchain.prompts.example_selector.ngram_overlap import ( +from langchain_community.example_selectors import ( NGramOverlapExampleSelector, ngram_overlap_score, ) diff --git a/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_base.py b/libs/community/tests/integration_tests/retrievers/document_compressors/test_base.py similarity index 100% rename from libs/langchain/tests/integration_tests/retrievers/document_compressors/test_base.py rename to libs/community/tests/integration_tests/retrievers/document_compressors/test_base.py index 9414592511..500e32a392 100644 --- a/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_base.py +++ b/libs/community/tests/integration_tests/retrievers/document_compressors/test_base.py @@ -1,13 +1,13 @@ """Integration test for compression pipelines.""" -from langchain_community.document_transformers import EmbeddingsRedundantFilter -from langchain_community.embeddings import OpenAIEmbeddings -from langchain_core.documents import Document -from langchain_text_splitters.character import CharacterTextSplitter - from langchain.retrievers.document_compressors import ( DocumentCompressorPipeline, EmbeddingsFilter, ) +from langchain_core.documents import Document +from langchain_text_splitters.character import CharacterTextSplitter + +from langchain_community.document_transformers import EmbeddingsRedundantFilter +from langchain_community.embeddings import OpenAIEmbeddings def test_document_compressor_pipeline() -> None: diff --git a/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py b/libs/community/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py similarity index 100% rename from libs/langchain/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py rename to libs/community/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py index eb58416200..a99b115990 100644 --- a/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py +++ b/libs/community/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py @@ -1,8 +1,8 @@ """Integration test for LLMChainExtractor.""" -from langchain_community.chat_models import ChatOpenAI +from langchain.retrievers.document_compressors import LLMChainExtractor from langchain_core.documents import Document -from langchain.retrievers.document_compressors import LLMChainExtractor +from langchain_community.chat_models import ChatOpenAI def test_llm_construction_with_kwargs() -> None: diff --git a/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_chain_filter.py b/libs/community/tests/integration_tests/retrievers/document_compressors/test_chain_filter.py similarity index 100% rename from libs/langchain/tests/integration_tests/retrievers/document_compressors/test_chain_filter.py rename to libs/community/tests/integration_tests/retrievers/document_compressors/test_chain_filter.py index 55249a879a..c6e1d5b41c 100644 --- a/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_chain_filter.py +++ b/libs/community/tests/integration_tests/retrievers/document_compressors/test_chain_filter.py @@ -1,8 +1,8 @@ """Integration test for llm-based relevant doc filtering.""" -from langchain_community.chat_models import ChatOpenAI +from langchain.retrievers.document_compressors import LLMChainFilter from langchain_core.documents import Document -from langchain.retrievers.document_compressors import LLMChainFilter +from langchain_community.chat_models import ChatOpenAI def test_llm_chain_filter() -> None: diff --git a/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py b/libs/community/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py similarity index 100% rename from libs/langchain/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py rename to libs/community/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py index b433d57bf2..d90f09ff31 100644 --- a/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py +++ b/libs/community/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py @@ -1,12 +1,12 @@ """Integration test for embedding-based relevant doc filtering.""" import numpy as np +from langchain.retrievers.document_compressors import EmbeddingsFilter +from langchain_core.documents import Document + from langchain_community.document_transformers.embeddings_redundant_filter import ( _DocumentWithState, ) from langchain_community.embeddings import OpenAIEmbeddings -from langchain_core.documents import Document - -from langchain.retrievers.document_compressors import EmbeddingsFilter def test_embeddings_filter() -> None: diff --git a/libs/langchain/tests/integration_tests/retrievers/test_contextual_compression.py b/libs/community/tests/integration_tests/retrievers/test_contextual_compression.py similarity index 100% rename from libs/langchain/tests/integration_tests/retrievers/test_contextual_compression.py rename to libs/community/tests/integration_tests/retrievers/test_contextual_compression.py index 020cb1133c..203cd222d4 100644 --- a/libs/langchain/tests/integration_tests/retrievers/test_contextual_compression.py +++ b/libs/community/tests/integration_tests/retrievers/test_contextual_compression.py @@ -1,9 +1,9 @@ -from langchain_community.embeddings import OpenAIEmbeddings -from langchain_community.vectorstores import FAISS - from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain.retrievers.document_compressors import EmbeddingsFilter +from langchain_community.embeddings import OpenAIEmbeddings +from langchain_community.vectorstores import FAISS + def test_contextual_compression_retriever_get_relevant_docs() -> None: """Test get_relevant_docs.""" diff --git a/libs/langchain/tests/integration_tests/retrievers/test_merger_retriever.py b/libs/community/tests/integration_tests/retrievers/test_merger_retriever.py similarity index 100% rename from libs/langchain/tests/integration_tests/retrievers/test_merger_retriever.py rename to libs/community/tests/integration_tests/retrievers/test_merger_retriever.py index ec0eeb4cf3..b5b575e7c3 100644 --- a/libs/langchain/tests/integration_tests/retrievers/test_merger_retriever.py +++ b/libs/community/tests/integration_tests/retrievers/test_merger_retriever.py @@ -1,8 +1,8 @@ +from langchain.retrievers.merger_retriever import MergerRetriever + from langchain_community.embeddings import OpenAIEmbeddings from langchain_community.vectorstores import Chroma -from langchain.retrievers.merger_retriever import MergerRetriever - def test_merger_retriever_get_relevant_docs() -> None: """Test get_relevant_docs.""" diff --git a/libs/langchain/tests/integration_tests/smith/evaluation/test_runner_utils.py b/libs/community/tests/integration_tests/smith/evaluation/test_runner_utils.py similarity index 100% rename from libs/langchain/tests/integration_tests/smith/evaluation/test_runner_utils.py rename to libs/community/tests/integration_tests/smith/evaluation/test_runner_utils.py index 285e63b9d9..37ff9de58a 100644 --- a/libs/langchain/tests/integration_tests/smith/evaluation/test_runner_utils.py +++ b/libs/community/tests/integration_tests/smith/evaluation/test_runner_utils.py @@ -2,19 +2,19 @@ from typing import Iterator, List, Optional from uuid import uuid4 import pytest -from langchain_community.chat_models import ChatOpenAI -from langchain_community.llms.openai import OpenAI +from langchain.chains.llm import LLMChain +from langchain.evaluation import EvaluatorType +from langchain.smith import RunEvalConfig, run_on_dataset +from langchain.smith.evaluation import InputFormatError +from langchain.smith.evaluation.runner_utils import arun_on_dataset from langchain_core.messages import BaseMessage, HumanMessage from langchain_core.prompts.chat import ChatPromptTemplate from langsmith import Client as Client from langsmith.evaluation import run_evaluator from langsmith.schemas import DataType, Example, Run -from langchain.chains.llm import LLMChain -from langchain.evaluation import EvaluatorType -from langchain.smith import RunEvalConfig, run_on_dataset -from langchain.smith.evaluation import InputFormatError -from langchain.smith.evaluation.runner_utils import arun_on_dataset +from langchain_community.chat_models import ChatOpenAI +from langchain_community.llms.openai import OpenAI def _check_all_feedback_passed(_project_name: str, client: Client) -> None: diff --git a/libs/langchain/tests/integration_tests/test_dalle.py b/libs/community/tests/integration_tests/test_dalle.py similarity index 100% rename from libs/langchain/tests/integration_tests/test_dalle.py rename to libs/community/tests/integration_tests/test_dalle.py diff --git a/libs/community/tests/integration_tests/test_document_transformers.py b/libs/community/tests/integration_tests/test_document_transformers.py new file mode 100644 index 0000000000..c8a3d1a948 --- /dev/null +++ b/libs/community/tests/integration_tests/test_document_transformers.py @@ -0,0 +1,73 @@ +"""Integration test for embedding-based redundant doc filtering.""" + +from langchain_core.documents import Document + +from langchain_community.document_transformers.embeddings_redundant_filter import ( + EmbeddingsClusteringFilter, + EmbeddingsRedundantFilter, + _DocumentWithState, +) +from langchain_community.embeddings import OpenAIEmbeddings + + +def test_embeddings_redundant_filter() -> None: + texts = [ + "What happened to all of my cookies?", + "Where did all of my cookies go?", + "I wish there were better Italian restaurants in my neighborhood.", + ] + docs = [Document(page_content=t) for t in texts] + embeddings = OpenAIEmbeddings() + redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings) + actual = redundant_filter.transform_documents(docs) + assert len(actual) == 2 + assert set(texts[:2]).intersection([d.page_content for d in actual]) + + +def test_embeddings_redundant_filter_with_state() -> None: + texts = ["What happened to all of my cookies?", "foo bar baz"] + state = {"embedded_doc": [0.5] * 10} + docs = [_DocumentWithState(page_content=t, state=state) for t in texts] + embeddings = OpenAIEmbeddings() + redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings) + actual = redundant_filter.transform_documents(docs) + assert len(actual) == 1 + + +def test_embeddings_clustering_filter() -> None: + texts = [ + "What happened to all of my cookies?", + "A cookie is a small, baked sweet treat and you can find it in the cookie", + "monsters' jar.", + "Cookies are good.", + "I have nightmares about the cookie monster.", + "The most popular pizza styles are: Neapolitan, New York-style and", + "Chicago-style. You can find them on iconic restaurants in major cities.", + "Neapolitan pizza: This is the original pizza style,hailing from Naples,", + "Italy.", + "I wish there were better Italian Pizza restaurants in my neighborhood.", + "New York-style pizza: This is characterized by its large, thin crust, and", + "generous toppings.", + "The first movie to feature a robot was 'A Trip to the Moon' (1902).", + "The first movie to feature a robot that could pass for a human was", + "'Blade Runner' (1982)", + "The first movie to feature a robot that could fall in love with a human", + "was 'Her' (2013)", + "A robot is a machine capable of carrying out complex actions automatically.", + "There are certainly hundreds, if not thousands movies about robots like:", + "'Blade Runner', 'Her' and 'A Trip to the Moon'", + ] + + docs = [Document(page_content=t) for t in texts] + embeddings = OpenAIEmbeddings() + redundant_filter = EmbeddingsClusteringFilter( + embeddings=embeddings, + num_clusters=3, + num_closest=1, + sorted=True, + ) + actual = redundant_filter.transform_documents(docs) + assert len(actual) == 3 + assert texts[1] in [d.page_content for d in actual] + assert texts[4] in [d.page_content for d in actual] + assert texts[11] in [d.page_content for d in actual] diff --git a/libs/langchain/tests/integration_tests/test_long_context_reorder.py b/libs/community/tests/integration_tests/test_long_context_reorder.py similarity index 100% rename from libs/langchain/tests/integration_tests/test_long_context_reorder.py rename to libs/community/tests/integration_tests/test_long_context_reorder.py diff --git a/libs/langchain/tests/integration_tests/test_nuclia_transformer.py b/libs/community/tests/integration_tests/test_nuclia_transformer.py similarity index 99% rename from libs/langchain/tests/integration_tests/test_nuclia_transformer.py rename to libs/community/tests/integration_tests/test_nuclia_transformer.py index 5dc1c70ed6..0d3bc3ee42 100644 --- a/libs/langchain/tests/integration_tests/test_nuclia_transformer.py +++ b/libs/community/tests/integration_tests/test_nuclia_transformer.py @@ -3,11 +3,12 @@ import json from typing import Any from unittest import mock +from langchain_core.documents import Document + from langchain_community.document_transformers.nuclia_text_transform import ( NucliaTextTransformer, ) from langchain_community.tools.nuclia.tool import NucliaUnderstandingAPI -from langchain_core.documents import Document def fakerun(**args: Any) -> Any: diff --git a/libs/langchain/tests/integration_tests/test_pdf_pagesplitter.py b/libs/community/tests/integration_tests/test_pdf_pagesplitter.py similarity index 100% rename from libs/langchain/tests/integration_tests/test_pdf_pagesplitter.py rename to libs/community/tests/integration_tests/test_pdf_pagesplitter.py diff --git a/libs/community/tests/unit_tests/agents/__init__.py b/libs/community/tests/unit_tests/agents/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/langchain/tests/unit_tests/agents/test_react.py b/libs/community/tests/unit_tests/agents/test_react.py similarity index 97% rename from libs/langchain/tests/unit_tests/agents/test_react.py rename to libs/community/tests/unit_tests/agents/test_react.py index 4b7cbe1c02..7b9a0b1eea 100644 --- a/libs/langchain/tests/unit_tests/agents/test_react.py +++ b/libs/community/tests/unit_tests/agents/test_react.py @@ -2,14 +2,14 @@ from typing import Union +from langchain.agents.react.base import ReActChain, ReActDocstoreAgent from langchain_core.agents import AgentAction from langchain_core.documents import Document from langchain_core.language_models import FakeListLLM from langchain_core.prompts.prompt import PromptTemplate from langchain_core.tools import Tool -from langchain.agents.react.base import ReActChain, ReActDocstoreAgent -from langchain.docstore.base import Docstore +from langchain_community.docstore.base import Docstore _PAGE_CONTENT = """This is a page about LangChain. diff --git a/libs/langchain/tests/unit_tests/agents/test_sql.py b/libs/community/tests/unit_tests/agents/test_sql.py similarity index 99% rename from libs/langchain/tests/unit_tests/agents/test_sql.py rename to libs/community/tests/unit_tests/agents/test_sql.py index e694847832..2d7f42f194 100644 --- a/libs/langchain/tests/unit_tests/agents/test_sql.py +++ b/libs/community/tests/unit_tests/agents/test_sql.py @@ -1,6 +1,5 @@ from langchain_community.agent_toolkits import SQLDatabaseToolkit, create_sql_agent from langchain_community.utilities.sql_database import SQLDatabase - from tests.unit_tests.llms.fake_llm import FakeLLM diff --git a/libs/langchain/tests/unit_tests/agents/test_tools.py b/libs/community/tests/unit_tests/agents/test_tools.py similarity index 98% rename from libs/langchain/tests/unit_tests/agents/test_tools.py rename to libs/community/tests/unit_tests/agents/test_tools.py index d32a57e4e4..028cf71b44 100644 --- a/libs/langchain/tests/unit_tests/agents/test_tools.py +++ b/libs/community/tests/unit_tests/agents/test_tools.py @@ -4,9 +4,6 @@ from typing import Any, Type from unittest.mock import MagicMock, Mock import pytest -from langchain_core.tools import Tool, ToolException, tool - -from langchain.agents import load_tools from langchain.agents.agent import Agent from langchain.agents.chat.base import ChatAgent from langchain.agents.conversational.base import ConversationalAgent @@ -14,6 +11,9 @@ from langchain.agents.conversational_chat.base import ConversationalChatAgent from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent +from langchain_core.tools import Tool, ToolException, tool + +from langchain_community.agent_toolkits.load_tools import load_tools from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler diff --git a/libs/community/tests/unit_tests/chains/__init__.py b/libs/community/tests/unit_tests/chains/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/langchain/tests/unit_tests/chains/test_api.py b/libs/community/tests/unit_tests/chains/test_api.py similarity index 100% rename from libs/langchain/tests/unit_tests/chains/test_api.py rename to libs/community/tests/unit_tests/chains/test_api.py index 449f713d68..c5fa8cf086 100644 --- a/libs/langchain/tests/unit_tests/chains/test_api.py +++ b/libs/community/tests/unit_tests/chains/test_api.py @@ -4,11 +4,11 @@ import json from typing import Any import pytest -from langchain_community.utilities.requests import TextRequestsWrapper - from langchain.chains.api.base import APIChain from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.llm import LLMChain + +from langchain_community.utilities.requests import TextRequestsWrapper from tests.unit_tests.llms.fake_llm import FakeLLM diff --git a/libs/langchain/tests/unit_tests/chains/test_graph_qa.py b/libs/community/tests/unit_tests/chains/test_graph_qa.py similarity index 97% rename from libs/langchain/tests/unit_tests/chains/test_graph_qa.py rename to libs/community/tests/unit_tests/chains/test_graph_qa.py index 32dfbfe18b..51c41bffec 100644 --- a/libs/langchain/tests/unit_tests/chains/test_graph_qa.py +++ b/libs/community/tests/unit_tests/chains/test_graph_qa.py @@ -1,18 +1,22 @@ +import pathlib from typing import Any, Dict, List import pandas as pd -from langchain_community.graphs.graph_document import GraphDocument -from langchain_community.graphs.graph_store import GraphStore +from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT +from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory from langchain_core.prompts import PromptTemplate -from langchain.chains.graph_qa.cypher import ( +from langchain_community.chains.graph_qa.cypher import ( GraphCypherQAChain, construct_schema, extract_cypher, ) -from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema -from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT -from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory +from langchain_community.chains.graph_qa.cypher_utils import ( + CypherQueryCorrector, + Schema, +) +from langchain_community.graphs.graph_document import GraphDocument +from langchain_community.graphs.graph_store import GraphStore from tests.unit_tests.llms.fake_llm import FakeLLM @@ -298,8 +302,13 @@ def test_include_types3() -> None: assert output == expected_schema +HERE = pathlib.Path(__file__).parent + +UNIT_TESTS_ROOT = HERE.parent + + def test_validating_cypher_statements() -> None: - cypher_file = "tests/unit_tests/data/cypher_corrector.csv" + cypher_file = str(UNIT_TESTS_ROOT / "data/cypher_corrector.csv") examples = pd.read_csv(cypher_file) examples.fillna("", inplace=True) for _, row in examples.iterrows(): diff --git a/libs/langchain/tests/unit_tests/chains/test_llm.py b/libs/community/tests/unit_tests/chains/test_llm.py similarity index 100% rename from libs/langchain/tests/unit_tests/chains/test_llm.py rename to libs/community/tests/unit_tests/chains/test_llm.py index 0179cd135f..f19c9b776b 100644 --- a/libs/langchain/tests/unit_tests/chains/test_llm.py +++ b/libs/community/tests/unit_tests/chains/test_llm.py @@ -4,10 +4,10 @@ from typing import Dict, List, Union from unittest.mock import patch import pytest +from langchain.chains.llm import LLMChain from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts import PromptTemplate -from langchain.chains.llm import LLMChain from tests.unit_tests.llms.fake_llm import FakeLLM diff --git a/libs/langchain/tests/unit_tests/data/cypher_corrector.csv b/libs/community/tests/unit_tests/data/cypher_corrector.csv similarity index 100% rename from libs/langchain/tests/unit_tests/data/cypher_corrector.csv rename to libs/community/tests/unit_tests/data/cypher_corrector.csv diff --git a/libs/community/tests/unit_tests/query_constructors/__init__.py b/libs/community/tests/unit_tests/query_constructors/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_astradb.py b/libs/community/tests/unit_tests/query_constructors/test_astradb.py similarity index 98% rename from libs/langchain/tests/unit_tests/retrievers/self_query/test_astradb.py rename to libs/community/tests/unit_tests/query_constructors/test_astradb.py index 8112871607..212b311f4e 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_astradb.py +++ b/libs/community/tests/unit_tests/query_constructors/test_astradb.py @@ -8,7 +8,7 @@ from langchain_core.structured_query import ( StructuredQuery, ) -from langchain.retrievers.self_query.astradb import AstraDBTranslator +from langchain_community.query_constructors.astradb import AstraDBTranslator DEFAULT_TRANSLATOR = AstraDBTranslator() diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_chroma.py b/libs/community/tests/unit_tests/query_constructors/test_chroma.py similarity index 97% rename from libs/langchain/tests/unit_tests/retrievers/self_query/test_chroma.py rename to libs/community/tests/unit_tests/query_constructors/test_chroma.py index 45fed14f56..87e8448386 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_chroma.py +++ b/libs/community/tests/unit_tests/query_constructors/test_chroma.py @@ -8,7 +8,7 @@ from langchain_core.structured_query import ( StructuredQuery, ) -from langchain.retrievers.self_query.chroma import ChromaTranslator +from langchain_community.query_constructors.chroma import ChromaTranslator DEFAULT_TRANSLATOR = ChromaTranslator() diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_dashvector.py b/libs/community/tests/unit_tests/query_constructors/test_dashvector.py similarity index 94% rename from libs/langchain/tests/unit_tests/retrievers/self_query/test_dashvector.py rename to libs/community/tests/unit_tests/query_constructors/test_dashvector.py index 29c50c36f0..da4b5e8ad1 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_dashvector.py +++ b/libs/community/tests/unit_tests/query_constructors/test_dashvector.py @@ -8,7 +8,7 @@ from langchain_core.structured_query import ( Operator, ) -from langchain.retrievers.self_query.dashvector import DashvectorTranslator +from langchain_community.query_constructors.dashvector import DashvectorTranslator DEFAULT_TRANSLATOR = DashvectorTranslator() diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_databricks_vector_search.py b/libs/community/tests/unit_tests/query_constructors/test_databricks_vector_search.py similarity index 98% rename from libs/langchain/tests/unit_tests/retrievers/self_query/test_databricks_vector_search.py rename to libs/community/tests/unit_tests/query_constructors/test_databricks_vector_search.py index 4b1728ce41..1f38f3b1b6 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_databricks_vector_search.py +++ b/libs/community/tests/unit_tests/query_constructors/test_databricks_vector_search.py @@ -9,7 +9,7 @@ from langchain_core.structured_query import ( StructuredQuery, ) -from langchain.retrievers.self_query.databricks_vector_search import ( +from langchain_community.query_constructors.databricks_vector_search import ( DatabricksVectorSearchTranslator, ) diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_deeplake.py b/libs/community/tests/unit_tests/query_constructors/test_deeplake.py similarity index 96% rename from libs/langchain/tests/unit_tests/retrievers/self_query/test_deeplake.py rename to libs/community/tests/unit_tests/query_constructors/test_deeplake.py index bbc61f7bde..881c13062f 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_deeplake.py +++ b/libs/community/tests/unit_tests/query_constructors/test_deeplake.py @@ -8,7 +8,7 @@ from langchain_core.structured_query import ( StructuredQuery, ) -from langchain.retrievers.self_query.deeplake import DeepLakeTranslator +from langchain_community.query_constructors.deeplake import DeepLakeTranslator DEFAULT_TRANSLATOR = DeepLakeTranslator() diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_dingo.py b/libs/community/tests/unit_tests/query_constructors/test_dingo.py similarity index 97% rename from libs/langchain/tests/unit_tests/retrievers/self_query/test_dingo.py rename to libs/community/tests/unit_tests/query_constructors/test_dingo.py index b853ff081f..c4c1520c75 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_dingo.py +++ b/libs/community/tests/unit_tests/query_constructors/test_dingo.py @@ -8,7 +8,7 @@ from langchain_core.structured_query import ( StructuredQuery, ) -from langchain.retrievers.self_query.dingo import DingoDBTranslator +from langchain_community.query_constructors.dingo import DingoDBTranslator DEFAULT_TRANSLATOR = DingoDBTranslator() diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_elasticsearch.py b/libs/community/tests/unit_tests/query_constructors/test_elasticsearch.py similarity index 99% rename from libs/langchain/tests/unit_tests/retrievers/self_query/test_elasticsearch.py rename to libs/community/tests/unit_tests/query_constructors/test_elasticsearch.py index 56a0328377..fb41d38fcc 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_elasticsearch.py +++ b/libs/community/tests/unit_tests/query_constructors/test_elasticsearch.py @@ -8,7 +8,7 @@ from langchain_core.structured_query import ( StructuredQuery, ) -from langchain.retrievers.self_query.elasticsearch import ElasticsearchTranslator +from langchain_community.query_constructors.elasticsearch import ElasticsearchTranslator DEFAULT_TRANSLATOR = ElasticsearchTranslator() diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_milvus.py b/libs/community/tests/unit_tests/query_constructors/test_milvus.py similarity index 97% rename from libs/langchain/tests/unit_tests/retrievers/self_query/test_milvus.py rename to libs/community/tests/unit_tests/query_constructors/test_milvus.py index 9e5e994908..70035628ad 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_milvus.py +++ b/libs/community/tests/unit_tests/query_constructors/test_milvus.py @@ -9,7 +9,7 @@ from langchain_core.structured_query import ( StructuredQuery, ) -from langchain.retrievers.self_query.milvus import MilvusTranslator +from langchain_community.query_constructors.milvus import MilvusTranslator DEFAULT_TRANSLATOR = MilvusTranslator() diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_mongodb_atlas.py b/libs/community/tests/unit_tests/query_constructors/test_mongodb_atlas.py similarity index 97% rename from libs/langchain/tests/unit_tests/retrievers/self_query/test_mongodb_atlas.py rename to libs/community/tests/unit_tests/query_constructors/test_mongodb_atlas.py index 6827c56683..9ba0e27238 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_mongodb_atlas.py +++ b/libs/community/tests/unit_tests/query_constructors/test_mongodb_atlas.py @@ -8,7 +8,7 @@ from langchain_core.structured_query import ( StructuredQuery, ) -from langchain.retrievers.self_query.mongodb_atlas import MongoDBAtlasTranslator +from langchain_community.query_constructors.mongodb_atlas import MongoDBAtlasTranslator DEFAULT_TRANSLATOR = MongoDBAtlasTranslator() diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_myscale.py b/libs/community/tests/unit_tests/query_constructors/test_myscale.py similarity index 97% rename from libs/langchain/tests/unit_tests/retrievers/self_query/test_myscale.py rename to libs/community/tests/unit_tests/query_constructors/test_myscale.py index ca8b878ea8..bf17d4a37e 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_myscale.py +++ b/libs/community/tests/unit_tests/query_constructors/test_myscale.py @@ -9,7 +9,7 @@ from langchain_core.structured_query import ( StructuredQuery, ) -from langchain.retrievers.self_query.myscale import MyScaleTranslator +from langchain_community.query_constructors.myscale import MyScaleTranslator DEFAULT_TRANSLATOR = MyScaleTranslator() diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_opensearch.py b/libs/community/tests/unit_tests/query_constructors/test_opensearch.py similarity index 98% rename from libs/langchain/tests/unit_tests/retrievers/self_query/test_opensearch.py rename to libs/community/tests/unit_tests/query_constructors/test_opensearch.py index 0b18f1180e..d7e9b525e8 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_opensearch.py +++ b/libs/community/tests/unit_tests/query_constructors/test_opensearch.py @@ -6,7 +6,7 @@ from langchain_core.structured_query import ( StructuredQuery, ) -from langchain.retrievers.self_query.opensearch import OpenSearchTranslator +from langchain_community.query_constructors.opensearch import OpenSearchTranslator DEFAULT_TRANSLATOR = OpenSearchTranslator() diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_pgvector.py b/libs/community/tests/unit_tests/query_constructors/test_pgvector.py similarity index 96% rename from libs/langchain/tests/unit_tests/retrievers/self_query/test_pgvector.py rename to libs/community/tests/unit_tests/query_constructors/test_pgvector.py index 43bdae8949..e6ca92e655 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_pgvector.py +++ b/libs/community/tests/unit_tests/query_constructors/test_pgvector.py @@ -9,7 +9,7 @@ from langchain_core.structured_query import ( StructuredQuery, ) -from langchain.retrievers.self_query.pgvector import PGVectorTranslator +from langchain_community.query_constructors.pgvector import PGVectorTranslator DEFAULT_TRANSLATOR = PGVectorTranslator() diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_pinecone.py b/libs/community/tests/unit_tests/query_constructors/test_pinecone.py similarity index 96% rename from libs/langchain/tests/unit_tests/retrievers/self_query/test_pinecone.py rename to libs/community/tests/unit_tests/query_constructors/test_pinecone.py index cf175667b3..30906dc807 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_pinecone.py +++ b/libs/community/tests/unit_tests/query_constructors/test_pinecone.py @@ -8,7 +8,7 @@ from langchain_core.structured_query import ( StructuredQuery, ) -from langchain.retrievers.self_query.pinecone import PineconeTranslator +from langchain_community.query_constructors.pinecone import PineconeTranslator DEFAULT_TRANSLATOR = PineconeTranslator() diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_redis.py b/libs/community/tests/unit_tests/query_constructors/test_redis.py similarity index 98% rename from libs/langchain/tests/unit_tests/retrievers/self_query/test_redis.py rename to libs/community/tests/unit_tests/query_constructors/test_redis.py index 44801f6d8e..f2a90e9e54 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_redis.py +++ b/libs/community/tests/unit_tests/query_constructors/test_redis.py @@ -1,6 +1,15 @@ from typing import Dict, Tuple import pytest +from langchain_core.structured_query import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, +) + +from langchain_community.query_constructors.redis import RedisTranslator from langchain_community.vectorstores.redis.filters import ( RedisFilterExpression, RedisNum, @@ -13,15 +22,6 @@ from langchain_community.vectorstores.redis.schema import ( TagFieldSchema, TextFieldSchema, ) -from langchain_core.structured_query import ( - Comparator, - Comparison, - Operation, - Operator, - StructuredQuery, -) - -from langchain.retrievers.self_query.redis import RedisTranslator @pytest.fixture diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_supabase.py b/libs/community/tests/unit_tests/query_constructors/test_supabase.py similarity index 96% rename from libs/langchain/tests/unit_tests/retrievers/self_query/test_supabase.py rename to libs/community/tests/unit_tests/query_constructors/test_supabase.py index 8fc45dd2ee..a0f0944838 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_supabase.py +++ b/libs/community/tests/unit_tests/query_constructors/test_supabase.py @@ -8,7 +8,7 @@ from langchain_core.structured_query import ( StructuredQuery, ) -from langchain.retrievers.self_query.supabase import SupabaseVectorTranslator +from langchain_community.query_constructors.supabase import SupabaseVectorTranslator DEFAULT_TRANSLATOR = SupabaseVectorTranslator() diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_tencentvectordb.py b/libs/community/tests/unit_tests/query_constructors/test_tencentvectordb.py similarity index 96% rename from libs/langchain/tests/unit_tests/retrievers/self_query/test_tencentvectordb.py rename to libs/community/tests/unit_tests/query_constructors/test_tencentvectordb.py index d7ee4a8c30..f3ce130664 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_tencentvectordb.py +++ b/libs/community/tests/unit_tests/query_constructors/test_tencentvectordb.py @@ -6,7 +6,9 @@ from langchain_core.structured_query import ( StructuredQuery, ) -from langchain.retrievers.self_query.tencentvectordb import TencentVectorDBTranslator +from langchain_community.query_constructors.tencentvectordb import ( + TencentVectorDBTranslator, +) def test_translate_with_operator() -> None: diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_timescalevector.py b/libs/community/tests/unit_tests/query_constructors/test_timescalevector.py similarity index 96% rename from libs/langchain/tests/unit_tests/retrievers/self_query/test_timescalevector.py rename to libs/community/tests/unit_tests/query_constructors/test_timescalevector.py index e813d63e2e..79ba14a419 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_timescalevector.py +++ b/libs/community/tests/unit_tests/query_constructors/test_timescalevector.py @@ -9,7 +9,9 @@ from langchain_core.structured_query import ( StructuredQuery, ) -from langchain.retrievers.self_query.timescalevector import TimescaleVectorTranslator +from langchain_community.query_constructors.timescalevector import ( + TimescaleVectorTranslator, +) DEFAULT_TRANSLATOR = TimescaleVectorTranslator() diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_vectara.py b/libs/community/tests/unit_tests/query_constructors/test_vectara.py similarity index 96% rename from libs/langchain/tests/unit_tests/retrievers/self_query/test_vectara.py rename to libs/community/tests/unit_tests/query_constructors/test_vectara.py index fe41cdb214..aa42d3107d 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_vectara.py +++ b/libs/community/tests/unit_tests/query_constructors/test_vectara.py @@ -8,7 +8,7 @@ from langchain_core.structured_query import ( StructuredQuery, ) -from langchain.retrievers.self_query.vectara import VectaraTranslator +from langchain_community.query_constructors.vectara import VectaraTranslator DEFAULT_TRANSLATOR = VectaraTranslator() diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_weaviate.py b/libs/community/tests/unit_tests/query_constructors/test_weaviate.py similarity index 98% rename from libs/langchain/tests/unit_tests/retrievers/self_query/test_weaviate.py rename to libs/community/tests/unit_tests/query_constructors/test_weaviate.py index 999c7c6fb8..59d352bebc 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_weaviate.py +++ b/libs/community/tests/unit_tests/query_constructors/test_weaviate.py @@ -8,7 +8,7 @@ from langchain_core.structured_query import ( StructuredQuery, ) -from langchain.retrievers.self_query.weaviate import WeaviateTranslator +from langchain_community.query_constructors.weaviate import WeaviateTranslator DEFAULT_TRANSLATOR = WeaviateTranslator() diff --git a/libs/community/tests/unit_tests/retrievers/test_imports.py b/libs/community/tests/unit_tests/retrievers/test_imports.py index 6e438bf5cd..ede083680c 100644 --- a/libs/community/tests/unit_tests/retrievers/test_imports.py +++ b/libs/community/tests/unit_tests/retrievers/test_imports.py @@ -39,6 +39,7 @@ EXPECTED_ALL = [ "VespaRetriever", "WeaviateHybridSearchRetriever", "WikipediaRetriever", + "WebResearchRetriever", "YouRetriever", "ZepRetriever", "ZillizRetriever", diff --git a/libs/langchain/tests/unit_tests/retrievers/test_web_research.py b/libs/community/tests/unit_tests/retrievers/test_web_research.py similarity index 90% rename from libs/langchain/tests/unit_tests/retrievers/test_web_research.py rename to libs/community/tests/unit_tests/retrievers/test_web_research.py index 29878dd3b4..ca51382639 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_web_research.py +++ b/libs/community/tests/unit_tests/retrievers/test_web_research.py @@ -2,7 +2,7 @@ from typing import List import pytest -from langchain.retrievers.web_research import QuestionListOutputParser +from langchain_community.retrievers.web_research import QuestionListOutputParser @pytest.mark.parametrize( diff --git a/libs/langchain/tests/unit_tests/test_cache.py b/libs/community/tests/unit_tests/test_cache.py similarity index 98% rename from libs/langchain/tests/unit_tests/test_cache.py rename to libs/community/tests/unit_tests/test_cache.py index c2391adda4..b275a61ca8 100644 --- a/libs/langchain/tests/unit_tests/test_cache.py +++ b/libs/community/tests/unit_tests/test_cache.py @@ -4,6 +4,7 @@ from typing import Dict, Generator, List, Union import pytest from _pytest.fixtures import FixtureRequest +from langchain.globals import get_llm_cache, set_llm_cache from langchain_core.caches import InMemoryCache from langchain_core.language_models import FakeListChatModel, FakeListLLM from langchain_core.language_models.chat_models import BaseChatModel @@ -14,8 +15,9 @@ from langchain_core.outputs import ChatGeneration, Generation from sqlalchemy import create_engine from sqlalchemy.orm import Session -from langchain.cache import SQLAlchemyCache -from langchain.globals import get_llm_cache, set_llm_cache +pytest.importorskip("langchain_community") + +from langchain_community.cache import SQLAlchemyCache # noqa: E402 def get_sqlite_cache() -> SQLAlchemyCache: diff --git a/libs/community/tests/unit_tests/test_dependencies.py b/libs/community/tests/unit_tests/test_dependencies.py index 5f9c8bbd38..01209f1615 100644 --- a/libs/community/tests/unit_tests/test_dependencies.py +++ b/libs/community/tests/unit_tests/test_dependencies.py @@ -47,6 +47,7 @@ def test_required_dependencies(poetry_conf: Mapping[str, Any]) -> None: "python", "requests", "tenacity", + "langchain", ] ) @@ -73,6 +74,7 @@ def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None: "duckdb-engine", "freezegun", "langchain-core", + "langchain", "lark", "pandas", "pytest", diff --git a/libs/langchain/tests/unit_tests/test_document_transformers.py b/libs/community/tests/unit_tests/test_document_transformers.py similarity index 76% rename from libs/langchain/tests/unit_tests/test_document_transformers.py rename to libs/community/tests/unit_tests/test_document_transformers.py index b995de0a17..88f1d6dc9d 100644 --- a/libs/langchain/tests/unit_tests/test_document_transformers.py +++ b/libs/community/tests/unit_tests/test_document_transformers.py @@ -1,9 +1,12 @@ """Unit tests for document transformers.""" -from langchain_community.document_transformers.embeddings_redundant_filter import ( +import pytest + +pytest.importorskip("langchain_community") + +from langchain_community.document_transformers.embeddings_redundant_filter import ( # noqa: E402,E501 _filter_similar_embeddings, ) - -from langchain.utils.math import cosine_similarity +from langchain_community.utils.math import cosine_similarity # noqa: E402 def test__filter_similar_embeddings() -> None: diff --git a/libs/experimental/poetry.lock b/libs/experimental/poetry.lock index 5a1636abc8..0a7992edbe 100644 --- a/libs/experimental/poetry.lock +++ b/libs/experimental/poetry.lock @@ -1736,7 +1736,7 @@ files = [ [[package]] name = "langchain" -version = "0.1.18" +version = "0.2.0rc1" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -1747,7 +1747,6 @@ develop = true aiohttp = "^3.8.3" async-timeout = {version = "^4.0.0", markers = "python_version < \"3.11\""} dataclasses-json = ">= 0.5.7, < 0.7" -langchain-community = ">=0.0.37,<0.1" langchain-core = "^0.1.48" langchain-text-splitters = ">=0.0.1,<0.1" langsmith = "^0.1.17" @@ -1779,7 +1778,7 @@ url = "../langchain" [[package]] name = "langchain-community" -version = "0.0.37" +version = "0.0.38rc1" description = "Community contributed LangChain integrations." optional = false python-versions = ">=3.8.1,<4.0" @@ -1789,6 +1788,7 @@ develop = true [package.dependencies] aiohttp = "^3.8.3" dataclasses-json = ">= 0.5.7, < 0.7" +langchain = "~0.2.0rc1" langchain-core = "^0.1.51" langsmith = "^0.1.0" numpy = "^1" @@ -1884,13 +1884,13 @@ test = ["pytest", "pytest-cov"] [[package]] name = "langsmith" -version = "0.1.55" +version = "0.1.56" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.55-py3-none-any.whl", hash = "sha256:c198b4019d0e0948fa2c94efcafa0312bd5e7ce36aae8d62a38af2d6b16584fc"}, - {file = "langsmith-0.1.55.tar.gz", hash = "sha256:08b75046471e3c32cb6b526e48ca4570bfe3911d6b0a3f8575ee062da940324c"}, + {file = "langsmith-0.1.56-py3-none-any.whl", hash = "sha256:2f930e054ea8eccd8ff99f0f129ae7d2513973b2e706d5483f44ea9951a1dca0"}, + {file = "langsmith-0.1.56.tar.gz", hash = "sha256:ff645b5bf16e2566740218ed6c048a1f8edbbedb4480a0d305a837ec71303fbf"}, ] [package.dependencies] @@ -2700,13 +2700,13 @@ files = [ [[package]] name = "openai" -version = "1.26.0" +version = "1.27.0" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.26.0-py3-none-any.whl", hash = "sha256:884ced523fb0225780f8b0e0ed6f7e014049c32d049a41ad0ac962869f1055d1"}, - {file = "openai-1.26.0.tar.gz", hash = "sha256:642e857b60855702ee6ff665e8fa80946164f77b92e58fd24e01b545685b8405"}, + {file = "openai-1.27.0-py3-none-any.whl", hash = "sha256:1183346fae6e63cb3a9134e397c0067690dc9d94ceb36eb0eb2c1bb9a1542aca"}, + {file = "openai-1.27.0.tar.gz", hash = "sha256:498adc80ba81a95324afdfd11a71fa43a37e1d94a5ca5f4542e52fe9568d995b"}, ] [package.dependencies] @@ -2835,8 +2835,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -5556,4 +5556,4 @@ extended-testing = ["faker", "jinja2", "pandas", "presidio-analyzer", "presidio- [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "4321a2d1943c27b7c85292796a45a4eb89e19e1c736201c80a67fc51a8ce1eb0" +content-hash = "8094eaac737aa33963b2c9c73e61bc0b366def7556f0ab081d8055fa96b342a5" diff --git a/libs/experimental/pyproject.toml b/libs/experimental/pyproject.toml index 6b758ffcbd..64205989f8 100644 --- a/libs/experimental/pyproject.toml +++ b/libs/experimental/pyproject.toml @@ -11,7 +11,7 @@ repository = "https://github.com/langchain-ai/langchain" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" langchain-core = "^0.1.52" -langchain = "^0.1.17" +langchain-community = "^0.0.38rc1" presidio-anonymizer = {version = "^2.2.352", optional = true} presidio-analyzer = {version = "^2.2.352", optional = true} faker = {version = "^19.3.1", optional = true} diff --git a/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py index 65e4e37458..2faa3fc7f2 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py @@ -1,18 +1,11 @@ """Toolkit for interacting with a vector store.""" from typing import List -from langchain_community.agent_toolkits.base import BaseToolkit -from langchain_community.llms.openai import OpenAI -from langchain_community.tools.vectorstore.tool import ( - VectorStoreQATool, - VectorStoreQAWithSourcesTool, -) from langchain_core.language_models import BaseLanguageModel from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.tools import BaseTool, BaseToolkit from langchain_core.vectorstores import VectorStore -from langchain.tools import BaseTool - class VectorStoreInfo(BaseModel): """Information about a VectorStore.""" @@ -31,7 +24,7 @@ class VectorStoreToolkit(BaseToolkit): """Toolkit for interacting with a Vector Store.""" vectorstore_info: VectorStoreInfo = Field(exclude=True) - llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0)) + llm: BaseLanguageModel class Config: """Configuration for this pydantic object.""" @@ -40,6 +33,15 @@ class VectorStoreToolkit(BaseToolkit): def get_tools(self) -> List[BaseTool]: """Get the tools in the toolkit.""" + try: + from langchain_community.tools.vectorstore.tool import ( + VectorStoreQATool, + VectorStoreQAWithSourcesTool, + ) + except ImportError: + raise ImportError( + "You need to install langchain-community to use this toolkit." + ) description = VectorStoreQATool.get_description( self.vectorstore_info.name, self.vectorstore_info.description ) @@ -65,7 +67,7 @@ class VectorStoreRouterToolkit(BaseToolkit): """Toolkit for routing between Vector Stores.""" vectorstores: List[VectorStoreInfo] = Field(exclude=True) - llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0)) + llm: BaseLanguageModel class Config: """Configuration for this pydantic object.""" @@ -75,6 +77,14 @@ class VectorStoreRouterToolkit(BaseToolkit): def get_tools(self) -> List[BaseTool]: """Get the tools in the toolkit.""" tools: List[BaseTool] = [] + try: + from langchain_community.tools.vectorstore.tool import ( + VectorStoreQATool, + ) + except ImportError: + raise ImportError( + "You need to install langchain-community to use this toolkit." + ) for vectorstore_info in self.vectorstores: description = VectorStoreQATool.get_description( vectorstore_info.name, vectorstore_info.description diff --git a/libs/langchain/langchain/callbacks/streamlit/__init__.py b/libs/langchain/langchain/callbacks/streamlit/__init__.py index 4b17f4d608..446ab9a27e 100644 --- a/libs/langchain/langchain/callbacks/streamlit/__init__.py +++ b/libs/langchain/langchain/callbacks/streamlit/__init__.py @@ -4,14 +4,8 @@ from typing import TYPE_CHECKING, Optional from langchain_core.callbacks.base import BaseCallbackHandler -from langchain.callbacks.streamlit.streamlit_callback_handler import ( - LLMThoughtLabeler as LLMThoughtLabeler, -) -from langchain.callbacks.streamlit.streamlit_callback_handler import ( - StreamlitCallbackHandler as _InternalStreamlitCallbackHandler, -) - if TYPE_CHECKING: + from langchain_community.callbacks import LLMThoughtLabeler from streamlit.delta_generator import DeltaGenerator @@ -61,11 +55,10 @@ def StreamlitCallbackHandler( # delegate to it instead of using our built-in handler. The official handler is # guaranteed to support the same set of kwargs. try: - from streamlit.external.langchain import ( - StreamlitCallbackHandler as OfficialStreamlitCallbackHandler, # type: ignore # noqa: 501 - ) + from streamlit.external.langchain import StreamlitCallbackHandler - return OfficialStreamlitCallbackHandler( + # This is the official handler, so we can just return it. + return StreamlitCallbackHandler( parent_container, max_thought_containers=max_thought_containers, expand_new_thoughts=expand_new_thoughts, @@ -73,6 +66,16 @@ def StreamlitCallbackHandler( thought_labeler=thought_labeler, ) except ImportError: + try: + from langchain_community.callbacks.streamlit.streamlit_callback_handler import ( # noqa: E501 + StreamlitCallbackHandler as _InternalStreamlitCallbackHandler, + ) + except ImportError: + raise ImportError( + "To use the StreamlitCallbackHandler, please install " + "langchain-community with `pip install langchain-community`." + ) + return _InternalStreamlitCallbackHandler( parent_container, max_thought_containers=max_thought_containers, diff --git a/libs/langchain/langchain/callbacks/tracers/__init__.py b/libs/langchain/langchain/callbacks/tracers/__init__.py index 20b8816f07..744c6b863d 100644 --- a/libs/langchain/langchain/callbacks/tracers/__init__.py +++ b/libs/langchain/langchain/callbacks/tracers/__init__.py @@ -11,7 +11,6 @@ from langchain_core.tracers.stdout import ( from langchain._api import create_importer from langchain.callbacks.tracers.logging import LoggingCallbackHandler -from langchain.callbacks.tracers.wandb import WandbTracer if TYPE_CHECKING: from langchain_community.callbacks.tracers.wandb import WandbTracer diff --git a/libs/langchain/langchain/chains/api/openapi/chain.py b/libs/langchain/langchain/chains/api/openapi/chain.py index 7fad47fedd..eb55218890 100644 --- a/libs/langchain/langchain/chains/api/openapi/chain.py +++ b/libs/langchain/langchain/chains/api/openapi/chain.py @@ -1,229 +1,23 @@ -"""Chain that makes API calls and summarizes the responses to answer a question.""" -from __future__ import annotations +from typing import TYPE_CHECKING, Any -import json -from typing import Any, Dict, List, NamedTuple, Optional, cast +from langchain._api import create_importer -from langchain_community.tools.openapi.utils.api_models import APIOperation -from langchain_community.utilities.requests import Requests -from langchain_core.callbacks import CallbackManagerForChainRun, Callbacks -from langchain_core.language_models import BaseLanguageModel -from langchain_core.pydantic_v1 import BaseModel, Field -from requests import Response +if TYPE_CHECKING: + from langchain_community.chains.openapi.chain import OpenAPIEndpointChain -from langchain.chains.api.openapi.requests_chain import APIRequesterChain -from langchain.chains.api.openapi.response_chain import APIResponderChain -from langchain.chains.base import Chain -from langchain.chains.llm import LLMChain +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "OpenAPIEndpointChain": "langchain_community.chains.openapi.chain", +} +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) -class _ParamMapping(NamedTuple): - """Mapping from parameter name to parameter value.""" - query_params: List[str] - body_params: List[str] - path_params: List[str] +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) -class OpenAPIEndpointChain(Chain, BaseModel): - """Chain interacts with an OpenAPI endpoint using natural language.""" - - api_request_chain: LLMChain - api_response_chain: Optional[LLMChain] - api_operation: APIOperation - requests: Requests = Field(exclude=True, default_factory=Requests) - param_mapping: _ParamMapping = Field(alias="param_mapping") - return_intermediate_steps: bool = False - instructions_key: str = "instructions" #: :meta private: - output_key: str = "output" #: :meta private: - max_text_length: Optional[int] = Field(ge=0) #: :meta private: - - @property - def input_keys(self) -> List[str]: - """Expect input key. - - :meta private: - """ - return [self.instructions_key] - - @property - def output_keys(self) -> List[str]: - """Expect output key. - - :meta private: - """ - if not self.return_intermediate_steps: - return [self.output_key] - else: - return [self.output_key, "intermediate_steps"] - - def _construct_path(self, args: Dict[str, str]) -> str: - """Construct the path from the deserialized input.""" - path = self.api_operation.base_url + self.api_operation.path - for param in self.param_mapping.path_params: - path = path.replace(f"{{{param}}}", str(args.pop(param, ""))) - return path - - def _extract_query_params(self, args: Dict[str, str]) -> Dict[str, str]: - """Extract the query params from the deserialized input.""" - query_params = {} - for param in self.param_mapping.query_params: - if param in args: - query_params[param] = args.pop(param) - return query_params - - def _extract_body_params(self, args: Dict[str, str]) -> Optional[Dict[str, str]]: - """Extract the request body params from the deserialized input.""" - body_params = None - if self.param_mapping.body_params: - body_params = {} - for param in self.param_mapping.body_params: - if param in args: - body_params[param] = args.pop(param) - return body_params - - def deserialize_json_input(self, serialized_args: str) -> dict: - """Use the serialized typescript dictionary. - - Resolve the path, query params dict, and optional requestBody dict. - """ - args: dict = json.loads(serialized_args) - path = self._construct_path(args) - body_params = self._extract_body_params(args) - query_params = self._extract_query_params(args) - return { - "url": path, - "data": body_params, - "params": query_params, - } - - def _get_output(self, output: str, intermediate_steps: dict) -> dict: - """Return the output from the API call.""" - if self.return_intermediate_steps: - return { - self.output_key: output, - "intermediate_steps": intermediate_steps, - } - else: - return {self.output_key: output} - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - intermediate_steps = {} - instructions = inputs[self.instructions_key] - instructions = instructions[: self.max_text_length] - _api_arguments = self.api_request_chain.predict_and_parse( - instructions=instructions, callbacks=_run_manager.get_child() - ) - api_arguments = cast(str, _api_arguments) - intermediate_steps["request_args"] = api_arguments - _run_manager.on_text( - api_arguments, color="green", end="\n", verbose=self.verbose - ) - if api_arguments.startswith("ERROR"): - return self._get_output(api_arguments, intermediate_steps) - elif api_arguments.startswith("MESSAGE:"): - return self._get_output( - api_arguments[len("MESSAGE:") :], intermediate_steps - ) - try: - request_args = self.deserialize_json_input(api_arguments) - method = getattr(self.requests, self.api_operation.method.value) - api_response: Response = method(**request_args) - if api_response.status_code != 200: - method_str = str(self.api_operation.method.value) - response_text = ( - f"{api_response.status_code}: {api_response.reason}" - + f"\nFor {method_str.upper()} {request_args['url']}\n" - + f"Called with args: {request_args['params']}" - ) - else: - response_text = api_response.text - except Exception as e: - response_text = f"Error with message {str(e)}" - response_text = response_text[: self.max_text_length] - intermediate_steps["response_text"] = response_text - _run_manager.on_text( - response_text, color="blue", end="\n", verbose=self.verbose - ) - if self.api_response_chain is not None: - _answer = self.api_response_chain.predict_and_parse( - response=response_text, - instructions=instructions, - callbacks=_run_manager.get_child(), - ) - answer = cast(str, _answer) - _run_manager.on_text(answer, color="yellow", end="\n", verbose=self.verbose) - return self._get_output(answer, intermediate_steps) - else: - return self._get_output(response_text, intermediate_steps) - - @classmethod - def from_url_and_method( - cls, - spec_url: str, - path: str, - method: str, - llm: BaseLanguageModel, - requests: Optional[Requests] = None, - return_intermediate_steps: bool = False, - **kwargs: Any, - # TODO: Handle async - ) -> "OpenAPIEndpointChain": - """Create an OpenAPIEndpoint from a spec at the specified url.""" - operation = APIOperation.from_openapi_url(spec_url, path, method) - return cls.from_api_operation( - operation, - requests=requests, - llm=llm, - return_intermediate_steps=return_intermediate_steps, - **kwargs, - ) - - @classmethod - def from_api_operation( - cls, - operation: APIOperation, - llm: BaseLanguageModel, - requests: Optional[Requests] = None, - verbose: bool = False, - return_intermediate_steps: bool = False, - raw_response: bool = False, - callbacks: Callbacks = None, - **kwargs: Any, - # TODO: Handle async - ) -> "OpenAPIEndpointChain": - """Create an OpenAPIEndpointChain from an operation and a spec.""" - param_mapping = _ParamMapping( - query_params=operation.query_params, - body_params=operation.body_params, - path_params=operation.path_params, - ) - requests_chain = APIRequesterChain.from_llm_and_typescript( - llm, - typescript_definition=operation.to_typescript(), - verbose=verbose, - callbacks=callbacks, - ) - if raw_response: - response_chain = None - else: - response_chain = APIResponderChain.from_llm( - llm, verbose=verbose, callbacks=callbacks - ) - _requests = requests or Requests() - return cls( - api_request_chain=requests_chain, - api_response_chain=response_chain, - api_operation=operation, - requests=_requests, - param_mapping=param_mapping, - verbose=verbose, - return_intermediate_steps=return_intermediate_steps, - callbacks=callbacks, - **kwargs, - ) +__all__ = ["OpenAPIEndpointChain"] diff --git a/libs/langchain/langchain/chains/api/openapi/prompts.py b/libs/langchain/langchain/chains/api/openapi/prompts.py index 84e5a2baee..cfc7cf2b0f 100644 --- a/libs/langchain/langchain/chains/api/openapi/prompts.py +++ b/libs/langchain/langchain/chains/api/openapi/prompts.py @@ -1,57 +1,27 @@ -# flake8: noqa -REQUEST_TEMPLATE = """You are a helpful AI Assistant. Please provide JSON arguments to agentFunc() based on the user's instructions. +from typing import TYPE_CHECKING, Any -API_SCHEMA: ```typescript -{schema} -``` +from langchain._api import create_importer -USER_INSTRUCTIONS: "{instructions}" +if TYPE_CHECKING: + from langchain_community.chains.openapi.prompts import ( + REQUEST_TEMPLATE, + RESPONSE_TEMPLATE, + ) -Your arguments must be plain json provided in a markdown block: +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "REQUEST_TEMPLATE": "langchain_community.chains.openapi.prompts", + "RESPONSE_TEMPLATE": "langchain_community.chains.openapi.prompts", +} -ARGS: ```json -{{valid json conforming to API_SCHEMA}} -``` +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) -Example ------ -ARGS: ```json -{{"foo": "bar", "baz": {{"qux": "quux"}}}} -``` +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) -The block must be no more than 1 line long, and all arguments must be valid JSON. All string arguments must be wrapped in double quotes. -You MUST strictly comply to the types indicated by the provided schema, including all required args. -If you don't have sufficient information to call the function due to things like requiring specific uuid's, you can reply with the following message: - -Message: ```text -Concise response requesting the additional information that would make calling the function successful. -``` - -Begin ------ -ARGS: -""" -RESPONSE_TEMPLATE = """You are a helpful AI assistant trained to answer user queries from API responses. -You attempted to call an API, which resulted in: -API_RESPONSE: {response} - -USER_COMMENT: "{instructions}" - - -If the API_RESPONSE can answer the USER_COMMENT respond with the following markdown json block: -Response: ```json -{{"response": "Human-understandable synthesis of the API_RESPONSE"}} -``` - -Otherwise respond with the following markdown json block: -Response Error: ```json -{{"response": "What you did and a concise statement of the resulting error. If it can be easily fixed, provide a suggestion."}} -``` - -You MUST respond as a markdown json code block. The person you are responding to CANNOT see the API_RESPONSE, so if there is any relevant information there you must include it in your response. - -Begin: ---- -""" +__all__ = ["REQUEST_TEMPLATE", "RESPONSE_TEMPLATE"] diff --git a/libs/langchain/langchain/chains/api/openapi/requests_chain.py b/libs/langchain/langchain/chains/api/openapi/requests_chain.py index cca19165ee..0d221ae24c 100644 --- a/libs/langchain/langchain/chains/api/openapi/requests_chain.py +++ b/libs/langchain/langchain/chains/api/openapi/requests_chain.py @@ -1,63 +1,29 @@ -"""request parser.""" - -import json -import re -from typing import Any - -from langchain_core.language_models import BaseLanguageModel -from langchain_core.output_parsers import BaseOutputParser -from langchain_core.prompts.prompt import PromptTemplate - -from langchain.chains.api.openapi.prompts import REQUEST_TEMPLATE -from langchain.chains.llm import LLMChain - - -class APIRequesterOutputParser(BaseOutputParser): - """Parse the request and error tags.""" - - def _load_json_block(self, serialized_block: str) -> str: - try: - return json.dumps(json.loads(serialized_block, strict=False)) - except json.JSONDecodeError: - return "ERROR serializing request." - - def parse(self, llm_output: str) -> str: - """Parse the request and error tags.""" - - json_match = re.search(r"```json(.*?)```", llm_output, re.DOTALL) - if json_match: - return self._load_json_block(json_match.group(1).strip()) - message_match = re.search(r"```text(.*?)```", llm_output, re.DOTALL) - if message_match: - return f"MESSAGE: {message_match.group(1).strip()}" - return "ERROR making request" - - @property - def _type(self) -> str: - return "api_requester" - - -class APIRequesterChain(LLMChain): - """Get the request parser.""" - - @classmethod - def is_lc_serializable(cls) -> bool: - return False - - @classmethod - def from_llm_and_typescript( - cls, - llm: BaseLanguageModel, - typescript_definition: str, - verbose: bool = True, - **kwargs: Any, - ) -> LLMChain: - """Get the request parser.""" - output_parser = APIRequesterOutputParser() - prompt = PromptTemplate( - template=REQUEST_TEMPLATE, - output_parser=output_parser, - partial_variables={"schema": typescript_definition}, - input_variables=["instructions"], - ) - return cls(prompt=prompt, llm=llm, verbose=verbose, **kwargs) +from typing import TYPE_CHECKING, Any + +from langchain._api import create_importer + +if TYPE_CHECKING: + from langchain_community.chains.openapi.requests_chain import ( + REQUEST_TEMPLATE, + APIRequesterChain, + APIRequesterOutputParser, + ) + +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "APIRequesterChain": "langchain_community.chains.openapi.requests_chain", + "APIRequesterOutputParser": "langchain_community.chains.openapi.requests_chain", + "REQUEST_TEMPLATE": "langchain_community.chains.openapi.requests_chain", +} + +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) + + +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) + + +__all__ = ["APIRequesterChain", "APIRequesterOutputParser", "REQUEST_TEMPLATE"] diff --git a/libs/langchain/langchain/chains/api/openapi/response_chain.py b/libs/langchain/langchain/chains/api/openapi/response_chain.py index 06e7686b4a..8664476226 100644 --- a/libs/langchain/langchain/chains/api/openapi/response_chain.py +++ b/libs/langchain/langchain/chains/api/openapi/response_chain.py @@ -1,58 +1,29 @@ -"""Response parser.""" - -import json -import re -from typing import Any - -from langchain_core.language_models import BaseLanguageModel -from langchain_core.output_parsers import BaseOutputParser -from langchain_core.prompts.prompt import PromptTemplate - -from langchain.chains.api.openapi.prompts import RESPONSE_TEMPLATE -from langchain.chains.llm import LLMChain - - -class APIResponderOutputParser(BaseOutputParser): - """Parse the response and error tags.""" - - def _load_json_block(self, serialized_block: str) -> str: - try: - response_content = json.loads(serialized_block, strict=False) - return response_content.get("response", "ERROR parsing response.") - except json.JSONDecodeError: - return "ERROR parsing response." - except: - raise - - def parse(self, llm_output: str) -> str: - """Parse the response and error tags.""" - json_match = re.search(r"```json(.*?)```", llm_output, re.DOTALL) - if json_match: - return self._load_json_block(json_match.group(1).strip()) - else: - raise ValueError(f"No response found in output: {llm_output}.") - - @property - def _type(self) -> str: - return "api_responder" - - -class APIResponderChain(LLMChain): - """Get the response parser.""" - - @classmethod - def is_lc_serializable(cls) -> bool: - return False - - @classmethod - def from_llm( - cls, llm: BaseLanguageModel, verbose: bool = True, **kwargs: Any - ) -> LLMChain: - """Get the response parser.""" - output_parser = APIResponderOutputParser() - prompt = PromptTemplate( - template=RESPONSE_TEMPLATE, - output_parser=output_parser, - input_variables=["response", "instructions"], - ) - return cls(prompt=prompt, llm=llm, verbose=verbose, **kwargs) +from typing import TYPE_CHECKING, Any + +from langchain._api import create_importer + +if TYPE_CHECKING: + from langchain_community.chains.openapi.response_chain import ( + RESPONSE_TEMPLATE, + APIResponderChain, + APIResponderOutputParser, + ) + +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "APIResponderChain": "langchain_community.chains.openapi.response_chain", + "APIResponderOutputParser": "langchain_community.chains.openapi.response_chain", + "RESPONSE_TEMPLATE": "langchain_community.chains.openapi.response_chain", +} + +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) + + +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) + + +__all__ = ["APIResponderChain", "APIResponderOutputParser", "RESPONSE_TEMPLATE"] diff --git a/libs/langchain/langchain/chains/conversation/memory.py b/libs/langchain/langchain/chains/conversation/memory.py index 7aad58f8cc..03230c4d62 100644 --- a/libs/langchain/langchain/chains/conversation/memory.py +++ b/libs/langchain/langchain/chains/conversation/memory.py @@ -1,5 +1,8 @@ """Memory modules for conversation prompts.""" +from typing import TYPE_CHECKING, Any + +from langchain._api import create_importer from langchain.memory.buffer import ( ConversationBufferMemory, ConversationStringBufferMemory, @@ -7,10 +10,27 @@ from langchain.memory.buffer import ( from langchain.memory.buffer_window import ConversationBufferWindowMemory from langchain.memory.combined import CombinedMemory from langchain.memory.entity import ConversationEntityMemory -from langchain.memory.kg import ConversationKGMemory from langchain.memory.summary import ConversationSummaryMemory from langchain.memory.summary_buffer import ConversationSummaryBufferMemory +if TYPE_CHECKING: + from langchain_community.memory.kg import ConversationKGMemory + +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "ConversationKGMemory": "langchain_community.memory.kg", +} + +_importer = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) + + +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _importer(name) + + # This is only for backwards compatibility. __all__ = [ diff --git a/libs/langchain/langchain/chains/ernie_functions/__init__.py b/libs/langchain/langchain/chains/ernie_functions/__init__.py index 3efc22d419..fd2e506818 100644 --- a/libs/langchain/langchain/chains/ernie_functions/__init__.py +++ b/libs/langchain/langchain/chains/ernie_functions/__init__.py @@ -1,11 +1,38 @@ -from langchain.chains.ernie_functions.base import ( - convert_to_ernie_function, - create_ernie_fn_chain, - create_ernie_fn_runnable, - create_structured_output_chain, - create_structured_output_runnable, - get_ernie_output_parser, -) +from typing import TYPE_CHECKING, Any + +from langchain._api import create_importer + +if TYPE_CHECKING: + from langchain_community.chains.ernie_functions.base import ( + convert_to_ernie_function, + create_ernie_fn_chain, + create_ernie_fn_runnable, + create_structured_output_chain, + create_structured_output_runnable, + get_ernie_output_parser, + ) + +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "convert_to_ernie_function": "langchain_community.chains.ernie_functions.base", + "create_ernie_fn_chain": "langchain_community.chains.ernie_functions.base", + "create_ernie_fn_runnable": "langchain_community.chains.ernie_functions.base", + "create_structured_output_chain": "langchain_community.chains.ernie_functions.base", + "create_structured_output_runnable": ( + "langchain_community.chains.ernie_functions.base" + ), + "get_ernie_output_parser": "langchain_community.chains.ernie_functions.base", +} + +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) + + +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) + __all__ = [ "convert_to_ernie_function", diff --git a/libs/langchain/langchain/chains/ernie_functions/base.py b/libs/langchain/langchain/chains/ernie_functions/base.py index 5434136ba7..e81ac42a58 100644 --- a/libs/langchain/langchain/chains/ernie_functions/base.py +++ b/libs/langchain/langchain/chains/ernie_functions/base.py @@ -1,551 +1,49 @@ -"""Methods for creating chains that use Ernie function-calling APIs.""" -import inspect -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Sequence, - Tuple, - Type, - Union, - cast, -) - -from langchain_core.language_models import BaseLanguageModel -from langchain_core.output_parsers import ( - BaseGenerationOutputParser, - BaseLLMOutputParser, - BaseOutputParser, -) -from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import BaseModel -from langchain_core.runnables import Runnable - -from langchain.chains import LLMChain -from langchain.output_parsers.ernie_functions import ( - JsonOutputFunctionsParser, - PydanticAttrOutputFunctionsParser, - PydanticOutputFunctionsParser, -) -from langchain.utils.ernie_functions import convert_pydantic_to_ernie_function - -PYTHON_TO_JSON_TYPES = { - "str": "string", - "int": "number", - "float": "number", - "bool": "boolean", -} - - -def _get_python_function_name(function: Callable) -> str: - """Get the name of a Python function.""" - return function.__name__ - - -def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]: - """Parse the function and argument descriptions from the docstring of a function. - - Assumes the function docstring follows Google Python style guide. - """ - docstring = inspect.getdoc(function) - if docstring: - docstring_blocks = docstring.split("\n\n") - descriptors = [] - args_block = None - past_descriptors = False - for block in docstring_blocks: - if block.startswith("Args:"): - args_block = block - break - elif block.startswith("Returns:") or block.startswith("Example:"): - # Don't break in case Args come after - past_descriptors = True - elif not past_descriptors: - descriptors.append(block) - else: - continue - description = " ".join(descriptors) - else: - description = "" - args_block = None - arg_descriptions = {} - if args_block: - arg = None - for line in args_block.split("\n")[1:]: - if ":" in line: - arg, desc = line.split(":") - arg_descriptions[arg.strip()] = desc.strip() - elif arg: - arg_descriptions[arg.strip()] += " " + line.strip() - return description, arg_descriptions - - -def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -> dict: - """Get JsonSchema describing a Python functions arguments. - - Assumes all function arguments are of primitive types (int, float, str, bool) or - are subclasses of pydantic.BaseModel. - """ - properties = {} - annotations = inspect.getfullargspec(function).annotations - for arg, arg_type in annotations.items(): - if arg == "return": - continue - if isinstance(arg_type, type) and issubclass(arg_type, BaseModel): - # Mypy error: - # "type" has no attribute "schema" - properties[arg] = arg_type.schema() # type: ignore[attr-defined] - elif arg_type.__name__ in PYTHON_TO_JSON_TYPES: - properties[arg] = {"type": PYTHON_TO_JSON_TYPES[arg_type.__name__]} - if arg in arg_descriptions: - if arg not in properties: - properties[arg] = {} - properties[arg]["description"] = arg_descriptions[arg] - return properties - - -def _get_python_function_required_args(function: Callable) -> List[str]: - """Get the required arguments for a Python function.""" - spec = inspect.getfullargspec(function) - required = spec.args[: -len(spec.defaults)] if spec.defaults else spec.args - required += [k for k in spec.kwonlyargs if k not in (spec.kwonlydefaults or {})] - - is_class = type(function) is type - if is_class and required[0] == "self": - required = required[1:] - return required - - -def convert_python_function_to_ernie_function( - function: Callable, -) -> Dict[str, Any]: - """Convert a Python function to an Ernie function-calling API compatible dict. - - Assumes the Python function has type hints and a docstring with a description. If - the docstring has Google Python style argument descriptions, these will be - included as well. - """ - description, arg_descriptions = _parse_python_function_docstring(function) - return { - "name": _get_python_function_name(function), - "description": description, - "parameters": { - "type": "object", - "properties": _get_python_function_arguments(function, arg_descriptions), - "required": _get_python_function_required_args(function), - }, - } - - -def convert_to_ernie_function( - function: Union[Dict[str, Any], Type[BaseModel], Callable], -) -> Dict[str, Any]: - """Convert a raw function/class to an Ernie function. - - Args: - function: Either a dictionary, a pydantic.BaseModel class, or a Python function. - If a dictionary is passed in, it is assumed to already be a valid Ernie - function. - - Returns: - A dict version of the passed in function which is compatible with the - Ernie function-calling API. - """ - if isinstance(function, dict): - return function - elif isinstance(function, type) and issubclass(function, BaseModel): - return cast(Dict, convert_pydantic_to_ernie_function(function)) - elif callable(function): - return convert_python_function_to_ernie_function(function) - - else: - raise ValueError( - f"Unsupported function type {type(function)}. Functions must be passed in" - f" as Dict, pydantic.BaseModel, or Callable." - ) - - -def get_ernie_output_parser( - functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], -) -> Union[BaseOutputParser, BaseGenerationOutputParser]: - """Get the appropriate function output parser given the user functions. - - Args: - functions: Sequence where element is a dictionary, a pydantic.BaseModel class, - or a Python function. If a dictionary is passed in, it is assumed to - already be a valid Ernie function. - - Returns: - A PydanticOutputFunctionsParser if functions are Pydantic classes, otherwise - a JsonOutputFunctionsParser. If there's only one function and it is - not a Pydantic class, then the output parser will automatically extract - only the function arguments and not the function name. - """ - function_names = [convert_to_ernie_function(f)["name"] for f in functions] - if isinstance(functions[0], type) and issubclass(functions[0], BaseModel): - if len(functions) > 1: - pydantic_schema: Union[Dict, Type[BaseModel]] = { - name: fn for name, fn in zip(function_names, functions) - } - else: - pydantic_schema = functions[0] - output_parser: Union[ - BaseOutputParser, BaseGenerationOutputParser - ] = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema) - else: - output_parser = JsonOutputFunctionsParser(args_only=len(functions) <= 1) - return output_parser - - -def create_ernie_fn_runnable( - functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], - llm: Runnable, - prompt: BasePromptTemplate, - *, - output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None, - **kwargs: Any, -) -> Runnable: - """Create a runnable sequence that uses Ernie functions. - - Args: - functions: A sequence of either dictionaries, pydantic.BaseModels classes, or - Python functions. If dictionaries are passed in, they are assumed to - already be a valid Ernie functions. If only a single - function is passed in, then it will be enforced that the model use that - function. pydantic.BaseModels and Python functions should have docstrings - describing what the function does. For best results, pydantic.BaseModels - should have descriptions of the parameters and Python functions should have - Google Python style args descriptions in the docstring. Additionally, - Python functions should only use primitive types (str, int, float, bool) or - pydantic.BaseModels for arguments. - llm: Language model to use, assumed to support the Ernie function-calling API. - prompt: BasePromptTemplate to pass to the model. - output_parser: BaseLLMOutputParser to use for parsing model outputs. By default - will be inferred from the function types. If pydantic.BaseModels are passed - in, then the OutputParser will try to parse outputs using those. Otherwise - model outputs will simply be parsed as JSON. If multiple functions are - passed in and they are not pydantic.BaseModels, the chain output will - include both the name of the function that was returned and the arguments - to pass to the function. - - Returns: - A runnable sequence that will pass in the given functions to the model when run. - - Example: - .. code-block:: python - - from typing import Optional - - from langchain.chains.ernie_functions import create_ernie_fn_chain - from langchain_community.chat_models import ErnieBotChat - from langchain_core.prompts import ChatPromptTemplate - from langchain.pydantic_v1 import BaseModel, Field - - - class RecordPerson(BaseModel): - \"\"\"Record some identifying information about a person.\"\"\" - - name: str = Field(..., description="The person's name") - age: int = Field(..., description="The person's age") - fav_food: Optional[str] = Field(None, description="The person's favorite food") - - - class RecordDog(BaseModel): - \"\"\"Record some identifying information about a dog.\"\"\" - - name: str = Field(..., description="The dog's name") - color: str = Field(..., description="The dog's color") - fav_food: Optional[str] = Field(None, description="The dog's favorite food") - - - llm = ErnieBotChat(model_name="ERNIE-Bot-4") - prompt = ChatPromptTemplate.from_messages( - [ - ("user", "Make calls to the relevant function to record the entities in the following input: {input}"), - ("assistant", "OK!"), - ("user", "Tip: Make sure to answer in the correct format"), - ] - ) - chain = create_ernie_fn_runnable([RecordPerson, RecordDog], llm, prompt) - chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"}) - # -> RecordDog(name="Harry", color="brown", fav_food="chicken") - """ # noqa: E501 - if not functions: - raise ValueError("Need to pass in at least one function. Received zero.") - ernie_functions = [convert_to_ernie_function(f) for f in functions] - llm_kwargs: Dict[str, Any] = {"functions": ernie_functions, **kwargs} - if len(ernie_functions) == 1: - llm_kwargs["function_call"] = {"name": ernie_functions[0]["name"]} - output_parser = output_parser or get_ernie_output_parser(functions) - return prompt | llm.bind(**llm_kwargs) | output_parser - - -def create_structured_output_runnable( - output_schema: Union[Dict[str, Any], Type[BaseModel]], - llm: Runnable, - prompt: BasePromptTemplate, - *, - output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None, - **kwargs: Any, -) -> Runnable: - """Create a runnable that uses an Ernie function to get a structured output. - - Args: - output_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary - is passed in, it's assumed to already be a valid JsonSchema. - For best results, pydantic.BaseModels should have docstrings describing what - the schema represents and descriptions for the parameters. - llm: Language model to use, assumed to support the Ernie function-calling API. - prompt: BasePromptTemplate to pass to the model. - output_parser: BaseLLMOutputParser to use for parsing model outputs. By default - will be inferred from the function types. If pydantic.BaseModels are passed - in, then the OutputParser will try to parse outputs using those. Otherwise - model outputs will simply be parsed as JSON. - - Returns: - A runnable sequence that will pass the given function to the model when run. - - Example: - .. code-block:: python - - from typing import Optional - - from langchain.chains.ernie_functions import create_structured_output_chain - from langchain_community.chat_models import ErnieBotChat - from langchain_core.prompts import ChatPromptTemplate - from langchain.pydantic_v1 import BaseModel, Field - - class Dog(BaseModel): - \"\"\"Identifying information about a dog.\"\"\" - - name: str = Field(..., description="The dog's name") - color: str = Field(..., description="The dog's color") - fav_food: Optional[str] = Field(None, description="The dog's favorite food") - - llm = ErnieBotChat(model_name="ERNIE-Bot-4") - prompt = ChatPromptTemplate.from_messages( - [ - ("user", "Use the given format to extract information from the following input: {input}"), - ("assistant", "OK!"), - ("user", "Tip: Make sure to answer in the correct format"), - ] - ) - chain = create_structured_output_chain(Dog, llm, prompt) - chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"}) - # -> Dog(name="Harry", color="brown", fav_food="chicken") - """ # noqa: E501 - if isinstance(output_schema, dict): - function: Any = { - "name": "output_formatter", - "description": ( - "Output formatter. Should always be used to format your response to the" - " user." - ), - "parameters": output_schema, - } - else: - - class _OutputFormatter(BaseModel): - """Output formatter. Should always be used to format your response to the user.""" # noqa: E501 - - output: output_schema # type: ignore - - function = _OutputFormatter - output_parser = output_parser or PydanticAttrOutputFunctionsParser( - pydantic_schema=_OutputFormatter, attr_name="output" - ) - return create_ernie_fn_runnable( - [function], - llm, - prompt, - output_parser=output_parser, - **kwargs, - ) - - -""" --- Legacy --- """ - - -def create_ernie_fn_chain( - functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], - llm: BaseLanguageModel, - prompt: BasePromptTemplate, - *, - output_key: str = "function", - output_parser: Optional[BaseLLMOutputParser] = None, - **kwargs: Any, -) -> LLMChain: - """[Legacy] Create an LLM chain that uses Ernie functions. - - Args: - functions: A sequence of either dictionaries, pydantic.BaseModels classes, or - Python functions. If dictionaries are passed in, they are assumed to - already be a valid Ernie functions. If only a single - function is passed in, then it will be enforced that the model use that - function. pydantic.BaseModels and Python functions should have docstrings - describing what the function does. For best results, pydantic.BaseModels - should have descriptions of the parameters and Python functions should have - Google Python style args descriptions in the docstring. Additionally, - Python functions should only use primitive types (str, int, float, bool) or - pydantic.BaseModels for arguments. - llm: Language model to use, assumed to support the Ernie function-calling API. - prompt: BasePromptTemplate to pass to the model. - output_key: The key to use when returning the output in LLMChain.__call__. - output_parser: BaseLLMOutputParser to use for parsing model outputs. By default - will be inferred from the function types. If pydantic.BaseModels are passed - in, then the OutputParser will try to parse outputs using those. Otherwise - model outputs will simply be parsed as JSON. If multiple functions are - passed in and they are not pydantic.BaseModels, the chain output will - include both the name of the function that was returned and the arguments - to pass to the function. - - Returns: - An LLMChain that will pass in the given functions to the model when run. - - Example: - .. code-block:: python - - from typing import Optional - - from langchain.chains.ernie_functions import create_ernie_fn_chain - from langchain_community.chat_models import ErnieBotChat - from langchain_core.prompts import ChatPromptTemplate - - from langchain.pydantic_v1 import BaseModel, Field - - - class RecordPerson(BaseModel): - \"\"\"Record some identifying information about a person.\"\"\" - - name: str = Field(..., description="The person's name") - age: int = Field(..., description="The person's age") - fav_food: Optional[str] = Field(None, description="The person's favorite food") - - - class RecordDog(BaseModel): - \"\"\"Record some identifying information about a dog.\"\"\" - - name: str = Field(..., description="The dog's name") - color: str = Field(..., description="The dog's color") - fav_food: Optional[str] = Field(None, description="The dog's favorite food") - - - llm = ErnieBotChat(model_name="ERNIE-Bot-4") - prompt = ChatPromptTemplate.from_messages( - [ - ("user", "Make calls to the relevant function to record the entities in the following input: {input}"), - ("assistant", "OK!"), - ("user", "Tip: Make sure to answer in the correct format"), - ] - ) - chain = create_ernie_fn_chain([RecordPerson, RecordDog], llm, prompt) - chain.run("Harry was a chubby brown beagle who loved chicken") - # -> RecordDog(name="Harry", color="brown", fav_food="chicken") - """ # noqa: E501 - if not functions: - raise ValueError("Need to pass in at least one function. Received zero.") - ernie_functions = [convert_to_ernie_function(f) for f in functions] - output_parser = output_parser or get_ernie_output_parser(functions) - llm_kwargs: Dict[str, Any] = { - "functions": ernie_functions, - } - if len(ernie_functions) == 1: - llm_kwargs["function_call"] = {"name": ernie_functions[0]["name"]} - llm_chain = LLMChain( - llm=llm, - prompt=prompt, - output_parser=output_parser, - llm_kwargs=llm_kwargs, - output_key=output_key, - **kwargs, +from typing import TYPE_CHECKING, Any + +from langchain._api import create_importer + +if TYPE_CHECKING: + from langchain_community.chains.ernie_functions.base import ( + convert_python_function_to_ernie_function, + convert_to_ernie_function, + create_ernie_fn_chain, + create_ernie_fn_runnable, + create_structured_output_chain, + create_structured_output_runnable, + get_ernie_output_parser, ) - return llm_chain +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "convert_python_function_to_ernie_function": ( + "langchain_community.chains.ernie_functions.base" + ), + "convert_to_ernie_function": "langchain_community.chains.ernie_functions.base", + "create_ernie_fn_chain": "langchain_community.chains.ernie_functions.base", + "create_ernie_fn_runnable": "langchain_community.chains.ernie_functions.base", + "create_structured_output_chain": "langchain_community.chains.ernie_functions.base", + "create_structured_output_runnable": ( + "langchain_community.chains.ernie_functions.base" + ), + "get_ernie_output_parser": "langchain_community.chains.ernie_functions.base", +} -def create_structured_output_chain( - output_schema: Union[Dict[str, Any], Type[BaseModel]], - llm: BaseLanguageModel, - prompt: BasePromptTemplate, - *, - output_key: str = "function", - output_parser: Optional[BaseLLMOutputParser] = None, - **kwargs: Any, -) -> LLMChain: - """[Legacy] Create an LLMChain that uses an Ernie function to get a structured output. - - Args: - output_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary - is passed in, it's assumed to already be a valid JsonSchema. - For best results, pydantic.BaseModels should have docstrings describing what - the schema represents and descriptions for the parameters. - llm: Language model to use, assumed to support the Ernie function-calling API. - prompt: BasePromptTemplate to pass to the model. - output_key: The key to use when returning the output in LLMChain.__call__. - output_parser: BaseLLMOutputParser to use for parsing model outputs. By default - will be inferred from the function types. If pydantic.BaseModels are passed - in, then the OutputParser will try to parse outputs using those. Otherwise - model outputs will simply be parsed as JSON. - - Returns: - An LLMChain that will pass the given function to the model. - - Example: - .. code-block:: python - - from typing import Optional - - from langchain.chains.ernie_functions import create_structured_output_chain - from langchain_community.chat_models import ErnieBotChat - from langchain_core.prompts import ChatPromptTemplate - - from langchain.pydantic_v1 import BaseModel, Field - - class Dog(BaseModel): - \"\"\"Identifying information about a dog.\"\"\" - - name: str = Field(..., description="The dog's name") - color: str = Field(..., description="The dog's color") - fav_food: Optional[str] = Field(None, description="The dog's favorite food") +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) - llm = ErnieBotChat(model_name="ERNIE-Bot-4") - prompt = ChatPromptTemplate.from_messages( - [ - ("user", "Use the given format to extract information from the following input: {input}"), - ("assistant", "OK!"), - ("user", "Tip: Make sure to answer in the correct format"), - ] - ) - chain = create_structured_output_chain(Dog, llm, prompt) - chain.run("Harry was a chubby brown beagle who loved chicken") - # -> Dog(name="Harry", color="brown", fav_food="chicken") - """ # noqa: E501 - if isinstance(output_schema, dict): - function: Any = { - "name": "output_formatter", - "description": ( - "Output formatter. Should always be used to format your response to the" - " user." - ), - "parameters": output_schema, - } - else: - class _OutputFormatter(BaseModel): - """Output formatter. Should always be used to format your response to the user.""" # noqa: E501 +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - output: output_schema # type: ignore - function = _OutputFormatter - output_parser = output_parser or PydanticAttrOutputFunctionsParser( - pydantic_schema=_OutputFormatter, attr_name="output" - ) - return create_ernie_fn_chain( - [function], - llm, - prompt, - output_key=output_key, - output_parser=output_parser, - **kwargs, - ) +__all__ = [ + "convert_python_function_to_ernie_function", + "convert_to_ernie_function", + "create_ernie_fn_chain", + "create_ernie_fn_runnable", + "create_structured_output_chain", + "create_structured_output_runnable", + "get_ernie_output_parser", +] diff --git a/libs/langchain/langchain/chains/flare/base.py b/libs/langchain/langchain/chains/flare/base.py index 8beb5b82e2..8070fc1237 100644 --- a/libs/langchain/langchain/chains/flare/base.py +++ b/libs/langchain/langchain/chains/flare/base.py @@ -5,7 +5,6 @@ from abc import abstractmethod from typing import Any, Dict, List, Optional, Sequence, Tuple import numpy as np -from langchain_community.llms.openai import OpenAI from langchain_core.callbacks import ( CallbackManagerForChainRun, ) @@ -56,11 +55,7 @@ class _ResponseChain(LLMChain): class _OpenAIResponseChain(_ResponseChain): """Chain that generates responses from user input and context.""" - llm: OpenAI = Field( - default_factory=lambda: OpenAI( - max_tokens=32, model_kwargs={"logprobs": 1}, temperature=0 - ) - ) + llm: BaseLanguageModel def _extract_tokens_and_log_probs( self, generations: List[Generation] @@ -118,7 +113,7 @@ class FlareChain(Chain): question_generator_chain: QuestionGeneratorChain """Chain that generates questions from uncertain spans.""" - response_chain: _ResponseChain = Field(default_factory=_OpenAIResponseChain) + response_chain: _ResponseChain """Chain that generates responses from user input and context.""" output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser) """Parser that determines whether the chain is finished.""" @@ -255,6 +250,14 @@ class FlareChain(Chain): Returns: FlareChain class with the given language model. """ + try: + from langchain_openai import OpenAI + except ImportError: + raise ImportError( + "OpenAI is required for FlareChain. " + "Please install langchain-openai." + "pip install langchain-openai" + ) question_gen_chain = QuestionGeneratorChain(llm=llm) response_llm = OpenAI( max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0 diff --git a/libs/langchain/langchain/chains/graph_qa/__init__.py b/libs/langchain/langchain/chains/graph_qa/__init__.py index f3bc55efbc..e69de29bb2 100644 --- a/libs/langchain/langchain/chains/graph_qa/__init__.py +++ b/libs/langchain/langchain/chains/graph_qa/__init__.py @@ -1 +0,0 @@ -"""Question answering over a knowledge graph.""" diff --git a/libs/langchain/langchain/chains/graph_qa/arangodb.py b/libs/langchain/langchain/chains/graph_qa/arangodb.py index 4c723e2046..bf536fabe1 100644 --- a/libs/langchain/langchain/chains/graph_qa/arangodb.py +++ b/libs/langchain/langchain/chains/graph_qa/arangodb.py @@ -1,241 +1,23 @@ -"""Question answering over a graph.""" -from __future__ import annotations +from typing import TYPE_CHECKING, Any -import re -from typing import Any, Dict, List, Optional +from langchain._api import create_importer -from langchain_community.graphs.arangodb_graph import ArangoGraph -from langchain_core.callbacks import CallbackManagerForChainRun -from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Field +if TYPE_CHECKING: + from langchain_community.chains.graph_qa.arangodb import ArangoGraphQAChain -from langchain.chains.base import Chain -from langchain.chains.graph_qa.prompts import ( - AQL_FIX_PROMPT, - AQL_GENERATION_PROMPT, - AQL_QA_PROMPT, -) -from langchain.chains.llm import LLMChain +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "ArangoGraphQAChain": "langchain_community.chains.graph_qa.arangodb", +} +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) -class ArangoGraphQAChain(Chain): - """Chain for question-answering against a graph by generating AQL statements. - *Security note*: Make sure that the database connection uses credentials - that are narrowly-scoped to only include necessary permissions. - Failure to do so may result in data corruption or loss, since the calling - code may attempt commands that would result in deletion, mutation - of data if appropriately prompted or reading sensitive data if such - data is present in the database. - The best way to guard against such negative outcomes is to (as appropriate) - limit the permissions granted to the credentials used with this tool. +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - See https://python.langchain.com/docs/security for more information. - """ - graph: ArangoGraph = Field(exclude=True) - aql_generation_chain: LLMChain - aql_fix_chain: LLMChain - qa_chain: LLMChain - input_key: str = "query" #: :meta private: - output_key: str = "result" #: :meta private: - - # Specifies the maximum number of AQL Query Results to return - top_k: int = 10 - - # Specifies the set of AQL Query Examples that promote few-shot-learning - aql_examples: str = "" - - # Specify whether to return the AQL Query in the output dictionary - return_aql_query: bool = False - - # Specify whether to return the AQL JSON Result in the output dictionary - return_aql_result: bool = False - - # Specify the maximum amount of AQL Generation attempts that should be made - max_aql_generation_attempts: int = 3 - - @property - def input_keys(self) -> List[str]: - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - return [self.output_key] - - @property - def _chain_type(self) -> str: - return "graph_aql_chain" - - @classmethod - def from_llm( - cls, - llm: BaseLanguageModel, - *, - qa_prompt: BasePromptTemplate = AQL_QA_PROMPT, - aql_generation_prompt: BasePromptTemplate = AQL_GENERATION_PROMPT, - aql_fix_prompt: BasePromptTemplate = AQL_FIX_PROMPT, - **kwargs: Any, - ) -> ArangoGraphQAChain: - """Initialize from LLM.""" - qa_chain = LLMChain(llm=llm, prompt=qa_prompt) - aql_generation_chain = LLMChain(llm=llm, prompt=aql_generation_prompt) - aql_fix_chain = LLMChain(llm=llm, prompt=aql_fix_prompt) - - return cls( - qa_chain=qa_chain, - aql_generation_chain=aql_generation_chain, - aql_fix_chain=aql_fix_chain, - **kwargs, - ) - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, Any]: - """ - Generate an AQL statement from user input, use it retrieve a response - from an ArangoDB Database instance, and respond to the user input - in natural language. - - Users can modify the following ArangoGraphQAChain Class Variables: - - :var top_k: The maximum number of AQL Query Results to return - :type top_k: int - - :var aql_examples: A set of AQL Query Examples that are passed to - the AQL Generation Prompt Template to promote few-shot-learning. - Defaults to an empty string. - :type aql_examples: str - - :var return_aql_query: Whether to return the AQL Query in the - output dictionary. Defaults to False. - :type return_aql_query: bool - - :var return_aql_result: Whether to return the AQL Query in the - output dictionary. Defaults to False - :type return_aql_result: bool - - :var max_aql_generation_attempts: The maximum amount of AQL - Generation attempts to be made prior to raising the last - AQL Query Execution Error. Defaults to 3. - :type max_aql_generation_attempts: int - """ - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - callbacks = _run_manager.get_child() - user_input = inputs[self.input_key] - - ######################### - # Generate AQL Query # - aql_generation_output = self.aql_generation_chain.run( - { - "adb_schema": self.graph.schema, - "aql_examples": self.aql_examples, - "user_input": user_input, - }, - callbacks=callbacks, - ) - ######################### - - aql_query = "" - aql_error = "" - aql_result = None - aql_generation_attempt = 1 - - while ( - aql_result is None - and aql_generation_attempt < self.max_aql_generation_attempts + 1 - ): - ##################### - # Extract AQL Query # - pattern = r"```(?i:aql)?(.*?)```" - matches = re.findall(pattern, aql_generation_output, re.DOTALL) - if not matches: - _run_manager.on_text( - "Invalid Response: ", end="\n", verbose=self.verbose - ) - _run_manager.on_text( - aql_generation_output, color="red", end="\n", verbose=self.verbose - ) - raise ValueError(f"Response is Invalid: {aql_generation_output}") - - aql_query = matches[0] - ##################### - - _run_manager.on_text( - f"AQL Query ({aql_generation_attempt}):", verbose=self.verbose - ) - _run_manager.on_text( - aql_query, color="green", end="\n", verbose=self.verbose - ) - - ##################### - # Execute AQL Query # - from arango import AQLQueryExecuteError - - try: - aql_result = self.graph.query(aql_query, self.top_k) - except AQLQueryExecuteError as e: - aql_error = e.error_message - - _run_manager.on_text( - "AQL Query Execution Error: ", end="\n", verbose=self.verbose - ) - _run_manager.on_text( - aql_error, color="yellow", end="\n\n", verbose=self.verbose - ) - - ######################## - # Retry AQL Generation # - aql_generation_output = self.aql_fix_chain.run( - { - "adb_schema": self.graph.schema, - "aql_query": aql_query, - "aql_error": aql_error, - }, - callbacks=callbacks, - ) - ######################## - - ##################### - - aql_generation_attempt += 1 - - if aql_result is None: - m = f""" - Maximum amount of AQL Query Generation attempts reached. - Unable to execute the AQL Query due to the following error: - {aql_error} - """ - raise ValueError(m) - - _run_manager.on_text("AQL Result:", end="\n", verbose=self.verbose) - _run_manager.on_text( - str(aql_result), color="green", end="\n", verbose=self.verbose - ) - - ######################## - # Interpret AQL Result # - result = self.qa_chain( - { - "adb_schema": self.graph.schema, - "user_input": user_input, - "aql_query": aql_query, - "aql_result": aql_result, - }, - callbacks=callbacks, - ) - ######################## - - # Return results # - result = {self.output_key: result[self.qa_chain.output_key]} - - if self.return_aql_query: - result["aql_query"] = aql_query - - if self.return_aql_result: - result["aql_result"] = aql_result - - return result +__all__ = ["ArangoGraphQAChain"] diff --git a/libs/langchain/langchain/chains/graph_qa/base.py b/libs/langchain/langchain/chains/graph_qa/base.py index 5ca9d22f2c..0b2b5a324f 100644 --- a/libs/langchain/langchain/chains/graph_qa/base.py +++ b/libs/langchain/langchain/chains/graph_qa/base.py @@ -1,100 +1,23 @@ -"""Question answering over a graph.""" -from __future__ import annotations +from typing import TYPE_CHECKING, Any -from typing import Any, Dict, List, Optional +from langchain._api import create_importer -from langchain_community.graphs.networkx_graph import NetworkxEntityGraph, get_entities -from langchain_core.callbacks.manager import CallbackManagerForChainRun -from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Field +if TYPE_CHECKING: + from langchain_community.chains.graph_qa.base import GraphQAChain -from langchain.chains.base import Chain -from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, GRAPH_QA_PROMPT -from langchain.chains.llm import LLMChain +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "GraphQAChain": "langchain_community.chains.graph_qa.base", +} +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) -class GraphQAChain(Chain): - """Chain for question-answering against a graph. - *Security note*: Make sure that the database connection uses credentials - that are narrowly-scoped to only include necessary permissions. - Failure to do so may result in data corruption or loss, since the calling - code may attempt commands that would result in deletion, mutation - of data if appropriately prompted or reading sensitive data if such - data is present in the database. - The best way to guard against such negative outcomes is to (as appropriate) - limit the permissions granted to the credentials used with this tool. +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - See https://python.langchain.com/docs/security for more information. - """ - graph: NetworkxEntityGraph = Field(exclude=True) - entity_extraction_chain: LLMChain - qa_chain: LLMChain - input_key: str = "query" #: :meta private: - output_key: str = "result" #: :meta private: - - @property - def input_keys(self) -> List[str]: - """Input keys. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Output keys. - - :meta private: - """ - _output_keys = [self.output_key] - return _output_keys - - @classmethod - def from_llm( - cls, - llm: BaseLanguageModel, - qa_prompt: BasePromptTemplate = GRAPH_QA_PROMPT, - entity_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT, - **kwargs: Any, - ) -> GraphQAChain: - """Initialize from LLM.""" - qa_chain = LLMChain(llm=llm, prompt=qa_prompt) - entity_chain = LLMChain(llm=llm, prompt=entity_prompt) - - return cls( - qa_chain=qa_chain, - entity_extraction_chain=entity_chain, - **kwargs, - ) - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - """Extract entities, look up info and answer question.""" - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - question = inputs[self.input_key] - - entity_string = self.entity_extraction_chain.run(question) - - _run_manager.on_text("Entities Extracted:", end="\n", verbose=self.verbose) - _run_manager.on_text( - entity_string, color="green", end="\n", verbose=self.verbose - ) - entities = get_entities(entity_string) - context = "" - all_triplets = [] - for entity in entities: - all_triplets.extend(self.graph.get_entity_knowledge(entity)) - context = "\n".join(all_triplets) - _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) - _run_manager.on_text(context, color="green", end="\n", verbose=self.verbose) - result = self.qa_chain( - {"question": question, "context": context}, - callbacks=_run_manager.get_child(), - ) - return {self.output_key: result[self.qa_chain.output_key]} +__all__ = ["GraphQAChain"] diff --git a/libs/langchain/langchain/chains/graph_qa/cypher.py b/libs/langchain/langchain/chains/graph_qa/cypher.py index 7682de853a..fbddc52eb5 100644 --- a/libs/langchain/langchain/chains/graph_qa/cypher.py +++ b/libs/langchain/langchain/chains/graph_qa/cypher.py @@ -1,292 +1,39 @@ -"""Question answering over a graph.""" -from __future__ import annotations +from typing import TYPE_CHECKING, Any -import re -from typing import Any, Dict, List, Optional +from langchain._api import create_importer -from langchain_community.graphs.graph_store import GraphStore -from langchain_core.callbacks import CallbackManagerForChainRun -from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Field - -from langchain.chains.base import Chain -from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema -from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT -from langchain.chains.llm import LLMChain - -INTERMEDIATE_STEPS_KEY = "intermediate_steps" - - -def extract_cypher(text: str) -> str: - """Extract Cypher code from a text. - - Args: - text: Text to extract Cypher code from. - - Returns: - Cypher code extracted from the text. - """ - # The pattern to find Cypher code enclosed in triple backticks - pattern = r"```(.*?)```" - - # Find all matches in the input text - matches = re.findall(pattern, text, re.DOTALL) - - return matches[0] if matches else text - - -def construct_schema( - structured_schema: Dict[str, Any], - include_types: List[str], - exclude_types: List[str], -) -> str: - """Filter the schema based on included or excluded types""" - - def filter_func(x: str) -> bool: - return x in include_types if include_types else x not in exclude_types - - filtered_schema: Dict[str, Any] = { - "node_props": { - k: v - for k, v in structured_schema.get("node_props", {}).items() - if filter_func(k) - }, - "rel_props": { - k: v - for k, v in structured_schema.get("rel_props", {}).items() - if filter_func(k) - }, - "relationships": [ - r - for r in structured_schema.get("relationships", []) - if all(filter_func(r[t]) for t in ["start", "end", "type"]) - ], - } - - # Format node properties - formatted_node_props = [] - for label, properties in filtered_schema["node_props"].items(): - props_str = ", ".join( - [f"{prop['property']}: {prop['type']}" for prop in properties] - ) - formatted_node_props.append(f"{label} {{{props_str}}}") - - # Format relationship properties - formatted_rel_props = [] - for rel_type, properties in filtered_schema["rel_props"].items(): - props_str = ", ".join( - [f"{prop['property']}: {prop['type']}" for prop in properties] - ) - formatted_rel_props.append(f"{rel_type} {{{props_str}}}") - - # Format relationships - formatted_rels = [ - f"(:{el['start']})-[:{el['type']}]->(:{el['end']})" - for el in filtered_schema["relationships"] - ] - - return "\n".join( - [ - "Node properties are the following:", - ",".join(formatted_node_props), - "Relationship properties are the following:", - ",".join(formatted_rel_props), - "The relationships are the following:", - ",".join(formatted_rels), - ] +if TYPE_CHECKING: + from langchain_community.chains.graph_qa.cypher import ( + CYPHER_GENERATION_PROMPT, + INTERMEDIATE_STEPS_KEY, + GraphCypherQAChain, + construct_schema, + extract_cypher, ) +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "GraphCypherQAChain": "langchain_community.chains.graph_qa.cypher", + "INTERMEDIATE_STEPS_KEY": "langchain_community.chains.graph_qa.cypher", + "construct_schema": "langchain_community.chains.graph_qa.cypher", + "extract_cypher": "langchain_community.chains.graph_qa.cypher", + "CYPHER_GENERATION_PROMPT": "langchain_community.chains.graph_qa.cypher", +} -class GraphCypherQAChain(Chain): - """Chain for question-answering against a graph by generating Cypher statements. - - *Security note*: Make sure that the database connection uses credentials - that are narrowly-scoped to only include necessary permissions. - Failure to do so may result in data corruption or loss, since the calling - code may attempt commands that would result in deletion, mutation - of data if appropriately prompted or reading sensitive data if such - data is present in the database. - The best way to guard against such negative outcomes is to (as appropriate) - limit the permissions granted to the credentials used with this tool. - - See https://python.langchain.com/docs/security for more information. - """ - - graph: GraphStore = Field(exclude=True) - cypher_generation_chain: LLMChain - qa_chain: LLMChain - graph_schema: str - input_key: str = "query" #: :meta private: - output_key: str = "result" #: :meta private: - top_k: int = 10 - """Number of results to return from the query""" - return_intermediate_steps: bool = False - """Whether or not to return the intermediate steps along with the final answer.""" - return_direct: bool = False - """Whether or not to return the result of querying the graph directly.""" - cypher_query_corrector: Optional[CypherQueryCorrector] = None - """Optional cypher validation tool""" - - @property - def input_keys(self) -> List[str]: - """Return the input keys. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Return the output keys. - - :meta private: - """ - _output_keys = [self.output_key] - return _output_keys - - @property - def _chain_type(self) -> str: - return "graph_cypher_chain" - - @classmethod - def from_llm( - cls, - llm: Optional[BaseLanguageModel] = None, - *, - qa_prompt: Optional[BasePromptTemplate] = None, - cypher_prompt: Optional[BasePromptTemplate] = None, - cypher_llm: Optional[BaseLanguageModel] = None, - qa_llm: Optional[BaseLanguageModel] = None, - exclude_types: List[str] = [], - include_types: List[str] = [], - validate_cypher: bool = False, - qa_llm_kwargs: Optional[Dict[str, Any]] = None, - cypher_llm_kwargs: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> GraphCypherQAChain: - """Initialize from LLM.""" - - if not cypher_llm and not llm: - raise ValueError("Either `llm` or `cypher_llm` parameters must be provided") - if not qa_llm and not llm: - raise ValueError("Either `llm` or `qa_llm` parameters must be provided") - if cypher_llm and qa_llm and llm: - raise ValueError( - "You can specify up to two of 'cypher_llm', 'qa_llm'" - ", and 'llm', but not all three simultaneously." - ) - if cypher_prompt and cypher_llm_kwargs: - raise ValueError( - "Specifying cypher_prompt and cypher_llm_kwargs together is" - " not allowed. Please pass prompt via cypher_llm_kwargs." - ) - if qa_prompt and qa_llm_kwargs: - raise ValueError( - "Specifying qa_prompt and qa_llm_kwargs together is" - " not allowed. Please pass prompt via qa_llm_kwargs." - ) - use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {} - use_cypher_llm_kwargs = ( - cypher_llm_kwargs if cypher_llm_kwargs is not None else {} - ) - if "prompt" not in use_qa_llm_kwargs: - use_qa_llm_kwargs["prompt"] = ( - qa_prompt if qa_prompt is not None else CYPHER_QA_PROMPT - ) - if "prompt" not in use_cypher_llm_kwargs: - use_cypher_llm_kwargs["prompt"] = ( - cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT - ) - - qa_chain = LLMChain(llm=qa_llm or llm, **use_qa_llm_kwargs) # type: ignore[arg-type] - - cypher_generation_chain = LLMChain( - llm=cypher_llm or llm, # type: ignore[arg-type] - **use_cypher_llm_kwargs, # type: ignore[arg-type] - ) - - if exclude_types and include_types: - raise ValueError( - "Either `exclude_types` or `include_types` " - "can be provided, but not both" - ) - - graph_schema = construct_schema( - kwargs["graph"].get_structured_schema, include_types, exclude_types - ) - - cypher_query_corrector = None - if validate_cypher: - corrector_schema = [ - Schema(el["start"], el["type"], el["end"]) - for el in kwargs["graph"].structured_schema.get("relationships") - ] - cypher_query_corrector = CypherQueryCorrector(corrector_schema) - - return cls( - graph_schema=graph_schema, - qa_chain=qa_chain, - cypher_generation_chain=cypher_generation_chain, - cypher_query_corrector=cypher_query_corrector, - **kwargs, - ) - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, Any]: - """Generate Cypher statement, use it to look up in db and answer question.""" - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - callbacks = _run_manager.get_child() - question = inputs[self.input_key] - - intermediate_steps: List = [] - - generated_cypher = self.cypher_generation_chain.run( - {"question": question, "schema": self.graph_schema}, callbacks=callbacks - ) - - # Extract Cypher code if it is wrapped in backticks - generated_cypher = extract_cypher(generated_cypher) - - # Correct Cypher query if enabled - if self.cypher_query_corrector: - generated_cypher = self.cypher_query_corrector(generated_cypher) - - _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) - _run_manager.on_text( - generated_cypher, color="green", end="\n", verbose=self.verbose - ) - - intermediate_steps.append({"query": generated_cypher}) - - # Retrieve and limit the number of results - # Generated Cypher be null if query corrector identifies invalid schema - if generated_cypher: - context = self.graph.query(generated_cypher)[: self.top_k] - else: - context = [] - - if self.return_direct: - final_result = context - else: - _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) - _run_manager.on_text( - str(context), color="green", end="\n", verbose=self.verbose - ) +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) - intermediate_steps.append({"context": context}) - result = self.qa_chain( - {"question": question, "context": context}, - callbacks=callbacks, - ) - final_result = result[self.qa_chain.output_key] +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - chain_result: Dict[str, Any] = {self.output_key: final_result} - if self.return_intermediate_steps: - chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps - return chain_result +__all__ = [ + "GraphCypherQAChain", + "INTERMEDIATE_STEPS_KEY", + "construct_schema", + "extract_cypher", + "CYPHER_GENERATION_PROMPT", +] diff --git a/libs/langchain/langchain/chains/graph_qa/cypher_utils.py b/libs/langchain/langchain/chains/graph_qa/cypher_utils.py index c123cac9b5..deeb1051a2 100644 --- a/libs/langchain/langchain/chains/graph_qa/cypher_utils.py +++ b/libs/langchain/langchain/chains/graph_qa/cypher_utils.py @@ -1,260 +1,27 @@ -import re -from collections import namedtuple -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any -Schema = namedtuple("Schema", ["left_node", "relation", "right_node"]) +from langchain._api import create_importer - -class CypherQueryCorrector: - """ - Used to correct relationship direction in generated Cypher statements. - This code is copied from the winner's submission to the Cypher competition: - https://github.com/sakusaku-rich/cypher-direction-competition - """ - - property_pattern = re.compile(r"\{.+?\}") - node_pattern = re.compile(r"\(.+?\)") - path_pattern = re.compile( - r"(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))(?)(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))" - ) - node_relation_node_pattern = re.compile( - r"(\()+(?P[^()]*?)\)(?P.*?)\((?P[^()]*?)(\))+" +if TYPE_CHECKING: + from langchain_community.chains.graph_qa.cypher_utils import ( + CypherQueryCorrector, + Schema, ) - relation_type_pattern = re.compile(r":(?P.+?)?(\{.+\})?]") - - def __init__(self, schemas: List[Schema]): - """ - Args: - schemas: list of schemas - """ - self.schemas = schemas - - def clean_node(self, node: str) -> str: - """ - Args: - node: node in string format - - """ - node = re.sub(self.property_pattern, "", node) - node = node.replace("(", "") - node = node.replace(")", "") - node = node.strip() - return node - - def detect_node_variables(self, query: str) -> Dict[str, List[str]]: - """ - Args: - query: cypher query - """ - nodes = re.findall(self.node_pattern, query) - nodes = [self.clean_node(node) for node in nodes] - res: Dict[str, Any] = {} - for node in nodes: - parts = node.split(":") - if parts == "": - continue - variable = parts[0] - if variable not in res: - res[variable] = [] - res[variable] += parts[1:] - return res - - def extract_paths(self, query: str) -> "List[str]": - """ - Args: - query: cypher query - """ - paths = [] - idx = 0 - while matched := self.path_pattern.findall(query[idx:]): - matched = matched[0] - matched = [ - m for i, m in enumerate(matched) if i not in [1, len(matched) - 1] - ] - path = "".join(matched) - idx = query.find(path) + len(path) - len(matched[-1]) - paths.append(path) - return paths - - def judge_direction(self, relation: str) -> str: - """ - Args: - relation: relation in string format - """ - direction = "BIDIRECTIONAL" - if relation[0] == "<": - direction = "INCOMING" - if relation[-1] == ">": - direction = "OUTGOING" - return direction - - def extract_node_variable(self, part: str) -> Optional[str]: - """ - Args: - part: node in string format - """ - part = part.lstrip("(").rstrip(")") - idx = part.find(":") - if idx != -1: - part = part[:idx] - return None if part == "" else part - - def detect_labels( - self, str_node: str, node_variable_dict: Dict[str, Any] - ) -> List[str]: - """ - Args: - str_node: node in string format - node_variable_dict: dictionary of node variables - """ - splitted_node = str_node.split(":") - variable = splitted_node[0] - labels = [] - if variable in node_variable_dict: - labels = node_variable_dict[variable] - elif variable == "" and len(splitted_node) > 1: - labels = splitted_node[1:] - return labels - - def verify_schema( - self, - from_node_labels: List[str], - relation_types: List[str], - to_node_labels: List[str], - ) -> bool: - """ - Args: - from_node_labels: labels of the from node - relation_type: type of the relation - to_node_labels: labels of the to node - """ - valid_schemas = self.schemas - if from_node_labels != []: - from_node_labels = [label.strip("`") for label in from_node_labels] - valid_schemas = [ - schema for schema in valid_schemas if schema[0] in from_node_labels - ] - if to_node_labels != []: - to_node_labels = [label.strip("`") for label in to_node_labels] - valid_schemas = [ - schema for schema in valid_schemas if schema[2] in to_node_labels - ] - if relation_types != []: - relation_types = [type.strip("`") for type in relation_types] - valid_schemas = [ - schema for schema in valid_schemas if schema[1] in relation_types - ] - return valid_schemas != [] - def detect_relation_types(self, str_relation: str) -> Tuple[str, List[str]]: - """ - Args: - str_relation: relation in string format - """ - relation_direction = self.judge_direction(str_relation) - relation_type = self.relation_type_pattern.search(str_relation) - if relation_type is None or relation_type.group("relation_type") is None: - return relation_direction, [] - relation_types = [ - t.strip().strip("!") - for t in relation_type.group("relation_type").split("|") - ] - return relation_direction, relation_types +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "CypherQueryCorrector": "langchain_community.chains.graph_qa.cypher_utils", + "Schema": "langchain_community.chains.graph_qa.cypher_utils", +} - def correct_query(self, query: str) -> str: - """ - Args: - query: cypher query - """ - node_variable_dict = self.detect_node_variables(query) - paths = self.extract_paths(query) - for path in paths: - original_path = path - start_idx = 0 - while start_idx < len(path): - match_res = re.match(self.node_relation_node_pattern, path[start_idx:]) - if match_res is None: - break - start_idx += match_res.start() - match_dict = match_res.groupdict() - left_node_labels = self.detect_labels( - match_dict["left_node"], node_variable_dict - ) - right_node_labels = self.detect_labels( - match_dict["right_node"], node_variable_dict - ) - end_idx = ( - start_idx - + 4 - + len(match_dict["left_node"]) - + len(match_dict["relation"]) - + len(match_dict["right_node"]) - ) - original_partial_path = original_path[start_idx : end_idx + 1] - relation_direction, relation_types = self.detect_relation_types( - match_dict["relation"] - ) +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) - if relation_types != [] and "".join(relation_types).find("*") != -1: - start_idx += ( - len(match_dict["left_node"]) + len(match_dict["relation"]) + 2 - ) - continue - if relation_direction == "OUTGOING": - is_legal = self.verify_schema( - left_node_labels, relation_types, right_node_labels - ) - if not is_legal: - is_legal = self.verify_schema( - right_node_labels, relation_types, left_node_labels - ) - if is_legal: - corrected_relation = "<" + match_dict["relation"][:-1] - corrected_partial_path = original_partial_path.replace( - match_dict["relation"], corrected_relation - ) - query = query.replace( - original_partial_path, corrected_partial_path - ) - else: - return "" - elif relation_direction == "INCOMING": - is_legal = self.verify_schema( - right_node_labels, relation_types, left_node_labels - ) - if not is_legal: - is_legal = self.verify_schema( - left_node_labels, relation_types, right_node_labels - ) - if is_legal: - corrected_relation = match_dict["relation"][1:] + ">" - corrected_partial_path = original_partial_path.replace( - match_dict["relation"], corrected_relation - ) - query = query.replace( - original_partial_path, corrected_partial_path - ) - else: - return "" - else: - is_legal = self.verify_schema( - left_node_labels, relation_types, right_node_labels - ) - is_legal |= self.verify_schema( - right_node_labels, relation_types, left_node_labels - ) - if not is_legal: - return "" +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - start_idx += ( - len(match_dict["left_node"]) + len(match_dict["relation"]) + 2 - ) - return query - def __call__(self, query: str) -> str: - """Correct the query to make it valid. If - Args: - query: cypher query - """ - return self.correct_query(query) +__all__ = ["CypherQueryCorrector", "Schema"] diff --git a/libs/langchain/langchain/chains/graph_qa/falkordb.py b/libs/langchain/langchain/chains/graph_qa/falkordb.py index b7ead3d871..1aba6adff4 100644 --- a/libs/langchain/langchain/chains/graph_qa/falkordb.py +++ b/libs/langchain/langchain/chains/graph_qa/falkordb.py @@ -1,154 +1,29 @@ -"""Question answering over a graph.""" -from __future__ import annotations +from typing import TYPE_CHECKING, Any -import re -from typing import Any, Dict, List, Optional +from langchain._api import create_importer -from langchain_community.graphs import FalkorDBGraph -from langchain_core.callbacks import CallbackManagerForChainRun -from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Field +if TYPE_CHECKING: + from langchain_community.chains.graph_qa.falkordb import ( + INTERMEDIATE_STEPS_KEY, + FalkorDBQAChain, + extract_cypher, + ) -from langchain.chains.base import Chain -from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT -from langchain.chains.llm import LLMChain +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "FalkorDBQAChain": "langchain_community.chains.graph_qa.falkordb", + "INTERMEDIATE_STEPS_KEY": "langchain_community.chains.graph_qa.falkordb", + "extract_cypher": "langchain_community.chains.graph_qa.falkordb", +} -INTERMEDIATE_STEPS_KEY = "intermediate_steps" +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) -def extract_cypher(text: str) -> str: - """ - Extract Cypher code from a text. - Args: - text: Text to extract Cypher code from. +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - Returns: - Cypher code extracted from the text. - """ - # The pattern to find Cypher code enclosed in triple backticks - pattern = r"```(.*?)```" - # Find all matches in the input text - matches = re.findall(pattern, text, re.DOTALL) - - return matches[0] if matches else text - - -class FalkorDBQAChain(Chain): - """Chain for question-answering against a graph by generating Cypher statements. - - *Security note*: Make sure that the database connection uses credentials - that are narrowly-scoped to only include necessary permissions. - Failure to do so may result in data corruption or loss, since the calling - code may attempt commands that would result in deletion, mutation - of data if appropriately prompted or reading sensitive data if such - data is present in the database. - The best way to guard against such negative outcomes is to (as appropriate) - limit the permissions granted to the credentials used with this tool. - - See https://python.langchain.com/docs/security for more information. - """ - - graph: FalkorDBGraph = Field(exclude=True) - cypher_generation_chain: LLMChain - qa_chain: LLMChain - input_key: str = "query" #: :meta private: - output_key: str = "result" #: :meta private: - top_k: int = 10 - """Number of results to return from the query""" - return_intermediate_steps: bool = False - """Whether or not to return the intermediate steps along with the final answer.""" - return_direct: bool = False - """Whether or not to return the result of querying the graph directly.""" - - @property - def input_keys(self) -> List[str]: - """Return the input keys. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Return the output keys. - - :meta private: - """ - _output_keys = [self.output_key] - return _output_keys - - @property - def _chain_type(self) -> str: - return "graph_cypher_chain" - - @classmethod - def from_llm( - cls, - llm: BaseLanguageModel, - *, - qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, - cypher_prompt: BasePromptTemplate = CYPHER_GENERATION_PROMPT, - **kwargs: Any, - ) -> FalkorDBQAChain: - """Initialize from LLM.""" - qa_chain = LLMChain(llm=llm, prompt=qa_prompt) - cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt) - - return cls( - qa_chain=qa_chain, - cypher_generation_chain=cypher_generation_chain, - **kwargs, - ) - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, Any]: - """Generate Cypher statement, use it to look up in db and answer question.""" - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - callbacks = _run_manager.get_child() - question = inputs[self.input_key] - - intermediate_steps: List = [] - - generated_cypher = self.cypher_generation_chain.run( - {"question": question, "schema": self.graph.schema}, callbacks=callbacks - ) - - # Extract Cypher code if it is wrapped in backticks - generated_cypher = extract_cypher(generated_cypher) - - _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) - _run_manager.on_text( - generated_cypher, color="green", end="\n", verbose=self.verbose - ) - - intermediate_steps.append({"query": generated_cypher}) - - # Retrieve and limit the number of results - context = self.graph.query(generated_cypher)[: self.top_k] - - if self.return_direct: - final_result = context - else: - _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) - _run_manager.on_text( - str(context), color="green", end="\n", verbose=self.verbose - ) - - intermediate_steps.append({"context": context}) - - result = self.qa_chain( - {"question": question, "context": context}, - callbacks=callbacks, - ) - final_result = result[self.qa_chain.output_key] - - chain_result: Dict[str, Any] = {self.output_key: final_result} - if self.return_intermediate_steps: - chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps - - return chain_result +__all__ = ["FalkorDBQAChain", "INTERMEDIATE_STEPS_KEY", "extract_cypher"] diff --git a/libs/langchain/langchain/chains/graph_qa/gremlin.py b/libs/langchain/langchain/chains/graph_qa/gremlin.py index 732e06bde8..20a7be6ab8 100644 --- a/libs/langchain/langchain/chains/graph_qa/gremlin.py +++ b/libs/langchain/langchain/chains/graph_qa/gremlin.py @@ -1,221 +1,36 @@ -"""Question answering over a graph.""" -from __future__ import annotations - -from typing import Any, Dict, List, Optional - -from langchain_community.graphs import GremlinGraph -from langchain_core.callbacks.manager import CallbackManager, CallbackManagerForChainRun -from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts import BasePromptTemplate -from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.pydantic_v1 import Field - -from langchain.chains.base import Chain -from langchain.chains.graph_qa.prompts import ( - CYPHER_QA_PROMPT, - GRAPHDB_SPARQL_FIX_TEMPLATE, - GREMLIN_GENERATION_PROMPT, -) -from langchain.chains.llm import LLMChain - -INTERMEDIATE_STEPS_KEY = "intermediate_steps" - - -def extract_gremlin(text: str) -> str: - """Extract Gremlin code from a text. - - Args: - text: Text to extract Gremlin code from. - - Returns: - Gremlin code extracted from the text. - """ - text = text.replace("`", "") - if text.startswith("gremlin"): - text = text[len("gremlin") :] - return text.replace("\n", "") - - -class GremlinQAChain(Chain): - """Chain for question-answering against a graph by generating gremlin statements. - - *Security note*: Make sure that the database connection uses credentials - that are narrowly-scoped to only include necessary permissions. - Failure to do so may result in data corruption or loss, since the calling - code may attempt commands that would result in deletion, mutation - of data if appropriately prompted or reading sensitive data if such - data is present in the database. - The best way to guard against such negative outcomes is to (as appropriate) - limit the permissions granted to the credentials used with this tool. - - See https://python.langchain.com/docs/security for more information. - """ - - graph: GremlinGraph = Field(exclude=True) - gremlin_generation_chain: LLMChain - qa_chain: LLMChain - gremlin_fix_chain: LLMChain - max_fix_retries: int = 3 - input_key: str = "query" #: :meta private: - output_key: str = "result" #: :meta private: - top_k: int = 100 - return_direct: bool = False - return_intermediate_steps: bool = False - - @property - def input_keys(self) -> List[str]: - """Input keys. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Output keys. - - :meta private: - """ - _output_keys = [self.output_key] - return _output_keys - - @classmethod - def from_llm( - cls, - llm: BaseLanguageModel, - *, - gremlin_fix_prompt: BasePromptTemplate = PromptTemplate( - input_variables=["error_message", "generated_sparql", "schema"], - template=GRAPHDB_SPARQL_FIX_TEMPLATE.replace("SPARQL", "Gremlin").replace( - "in Turtle format", "" - ), - ), - qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, - gremlin_prompt: BasePromptTemplate = GREMLIN_GENERATION_PROMPT, - **kwargs: Any, - ) -> GremlinQAChain: - """Initialize from LLM.""" - qa_chain = LLMChain(llm=llm, prompt=qa_prompt) - gremlin_generation_chain = LLMChain(llm=llm, prompt=gremlin_prompt) - gremlinl_fix_chain = LLMChain(llm=llm, prompt=gremlin_fix_prompt) - return cls( - qa_chain=qa_chain, - gremlin_generation_chain=gremlin_generation_chain, - gremlin_fix_chain=gremlinl_fix_chain, - **kwargs, - ) - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - """Generate gremlin statement, use it to look up in db and answer question.""" - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - callbacks = _run_manager.get_child() - question = inputs[self.input_key] - - intermediate_steps: List = [] - - chain_response = self.gremlin_generation_chain.invoke( - {"question": question, "schema": self.graph.get_schema}, callbacks=callbacks - ) - - generated_gremlin = extract_gremlin( - chain_response[self.gremlin_generation_chain.output_key] - ) - - _run_manager.on_text("Generated gremlin:", end="\n", verbose=self.verbose) - _run_manager.on_text( - generated_gremlin, color="green", end="\n", verbose=self.verbose - ) - - intermediate_steps.append({"query": generated_gremlin}) - - if generated_gremlin: - context = self.execute_with_retry( - _run_manager, callbacks, generated_gremlin - )[: self.top_k] - else: - context = [] - - if self.return_direct: - final_result = context - else: - _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) - _run_manager.on_text( - str(context), color="green", end="\n", verbose=self.verbose - ) - - intermediate_steps.append({"context": context}) - - result = self.qa_chain.invoke( - {"question": question, "context": context}, - callbacks=callbacks, - ) - final_result = result[self.qa_chain.output_key] - - chain_result: Dict[str, Any] = {self.output_key: final_result} - if self.return_intermediate_steps: - chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps - - return chain_result - - def execute_query(self, query: str) -> List[Any]: - try: - return self.graph.query(query) - except Exception as e: - if hasattr(e, "status_message"): - raise ValueError(e.status_message) - else: - raise ValueError(str(e)) - - def execute_with_retry( - self, - _run_manager: CallbackManagerForChainRun, - callbacks: CallbackManager, - generated_gremlin: str, - ) -> List[Any]: - try: - return self.execute_query(generated_gremlin) - except Exception as e: - retries = 0 - error_message = str(e) - self.log_invalid_query(_run_manager, generated_gremlin, error_message) - - while retries < self.max_fix_retries: - try: - fix_chain_result = self.gremlin_fix_chain.invoke( - { - "error_message": error_message, - # we are borrowing template from sparql - "generated_sparql": generated_gremlin, - "schema": self.schema, - }, - callbacks=callbacks, - ) - fixed_gremlin = fix_chain_result[self.gremlin_fix_chain.output_key] - return self.execute_query(fixed_gremlin) - except Exception as e: - retries += 1 - parse_exception = str(e) - self.log_invalid_query(_run_manager, fixed_gremlin, parse_exception) - - raise ValueError("The generated Gremlin query is invalid.") - - def log_invalid_query( - self, - _run_manager: CallbackManagerForChainRun, - generated_query: str, - error_message: str, - ) -> None: - _run_manager.on_text("Invalid Gremlin query: ", end="\n", verbose=self.verbose) - _run_manager.on_text( - generated_query, color="red", end="\n", verbose=self.verbose - ) - _run_manager.on_text( - "Gremlin Query Parse Error: ", end="\n", verbose=self.verbose - ) - _run_manager.on_text( - error_message, color="red", end="\n\n", verbose=self.verbose - ) +from typing import TYPE_CHECKING, Any + +from langchain._api import create_importer + +if TYPE_CHECKING: + from langchain_community.chains.graph_qa.gremlin import ( + GRAPHDB_SPARQL_FIX_TEMPLATE, + INTERMEDIATE_STEPS_KEY, + GremlinQAChain, + extract_gremlin, + ) + +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "GRAPHDB_SPARQL_FIX_TEMPLATE": "langchain_community.chains.graph_qa.gremlin", + "GremlinQAChain": "langchain_community.chains.graph_qa.gremlin", + "INTERMEDIATE_STEPS_KEY": "langchain_community.chains.graph_qa.gremlin", + "extract_gremlin": "langchain_community.chains.graph_qa.gremlin", +} + +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) + + +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) + + +__all__ = [ + "GRAPHDB_SPARQL_FIX_TEMPLATE", + "GremlinQAChain", + "INTERMEDIATE_STEPS_KEY", + "extract_gremlin", +] diff --git a/libs/langchain/langchain/chains/graph_qa/hugegraph.py b/libs/langchain/langchain/chains/graph_qa/hugegraph.py index 9e1f024937..a7d7d9019c 100644 --- a/libs/langchain/langchain/chains/graph_qa/hugegraph.py +++ b/libs/langchain/langchain/chains/graph_qa/hugegraph.py @@ -1,106 +1,23 @@ -"""Question answering over a graph.""" -from __future__ import annotations +from typing import TYPE_CHECKING, Any -from typing import Any, Dict, List, Optional +from langchain._api import create_importer -from langchain_community.graphs.hugegraph import HugeGraph -from langchain_core.callbacks import CallbackManagerForChainRun -from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Field +if TYPE_CHECKING: + from langchain_community.chains.graph_qa.hugegraph import HugeGraphQAChain -from langchain.chains.base import Chain -from langchain.chains.graph_qa.prompts import ( - CYPHER_QA_PROMPT, - GREMLIN_GENERATION_PROMPT, -) -from langchain.chains.llm import LLMChain +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "HugeGraphQAChain": "langchain_community.chains.graph_qa.hugegraph", +} +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) -class HugeGraphQAChain(Chain): - """Chain for question-answering against a graph by generating gremlin statements. - *Security note*: Make sure that the database connection uses credentials - that are narrowly-scoped to only include necessary permissions. - Failure to do so may result in data corruption or loss, since the calling - code may attempt commands that would result in deletion, mutation - of data if appropriately prompted or reading sensitive data if such - data is present in the database. - The best way to guard against such negative outcomes is to (as appropriate) - limit the permissions granted to the credentials used with this tool. +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - See https://python.langchain.com/docs/security for more information. - """ - graph: HugeGraph = Field(exclude=True) - gremlin_generation_chain: LLMChain - qa_chain: LLMChain - input_key: str = "query" #: :meta private: - output_key: str = "result" #: :meta private: - - @property - def input_keys(self) -> List[str]: - """Input keys. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Output keys. - - :meta private: - """ - _output_keys = [self.output_key] - return _output_keys - - @classmethod - def from_llm( - cls, - llm: BaseLanguageModel, - *, - qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, - gremlin_prompt: BasePromptTemplate = GREMLIN_GENERATION_PROMPT, - **kwargs: Any, - ) -> HugeGraphQAChain: - """Initialize from LLM.""" - qa_chain = LLMChain(llm=llm, prompt=qa_prompt) - gremlin_generation_chain = LLMChain(llm=llm, prompt=gremlin_prompt) - - return cls( - qa_chain=qa_chain, - gremlin_generation_chain=gremlin_generation_chain, - **kwargs, - ) - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - """Generate gremlin statement, use it to look up in db and answer question.""" - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - callbacks = _run_manager.get_child() - question = inputs[self.input_key] - - generated_gremlin = self.gremlin_generation_chain.run( - {"question": question, "schema": self.graph.get_schema}, callbacks=callbacks - ) - - _run_manager.on_text("Generated gremlin:", end="\n", verbose=self.verbose) - _run_manager.on_text( - generated_gremlin, color="green", end="\n", verbose=self.verbose - ) - context = self.graph.query(generated_gremlin) - - _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) - _run_manager.on_text( - str(context), color="green", end="\n", verbose=self.verbose - ) - - result = self.qa_chain( - {"question": question, "context": context}, - callbacks=callbacks, - ) - return {self.output_key: result[self.qa_chain.output_key]} +__all__ = ["HugeGraphQAChain"] diff --git a/libs/langchain/langchain/chains/graph_qa/kuzu.py b/libs/langchain/langchain/chains/graph_qa/kuzu.py index aa1b51deb1..aa436d20b2 100644 --- a/libs/langchain/langchain/chains/graph_qa/kuzu.py +++ b/libs/langchain/langchain/chains/graph_qa/kuzu.py @@ -1,140 +1,29 @@ -"""Question answering over a graph.""" -from __future__ import annotations +from typing import TYPE_CHECKING, Any -import re -from typing import Any, Dict, List, Optional +from langchain._api import create_importer -from langchain_community.graphs.kuzu_graph import KuzuGraph -from langchain_core.callbacks import CallbackManagerForChainRun -from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Field +if TYPE_CHECKING: + from langchain_community.chains.graph_qa.kuzu import ( + KuzuQAChain, + extract_cypher, + remove_prefix, + ) -from langchain.chains.base import Chain -from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, KUZU_GENERATION_PROMPT -from langchain.chains.llm import LLMChain +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "KuzuQAChain": "langchain_community.chains.graph_qa.kuzu", + "extract_cypher": "langchain_community.chains.graph_qa.kuzu", + "remove_prefix": "langchain_community.chains.graph_qa.kuzu", +} +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) -def remove_prefix(text: str, prefix: str) -> str: - """Remove a prefix from a text. - Args: - text: Text to remove the prefix from. - prefix: Prefix to remove from the text. +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - Returns: - Text with the prefix removed. - """ - if text.startswith(prefix): - return text[len(prefix) :] - return text - -def extract_cypher(text: str) -> str: - """Extract Cypher code from a text. - - Args: - text: Text to extract Cypher code from. - - Returns: - Cypher code extracted from the text. - """ - # The pattern to find Cypher code enclosed in triple backticks - pattern = r"```(.*?)```" - - # Find all matches in the input text - matches = re.findall(pattern, text, re.DOTALL) - - return matches[0] if matches else text - - -class KuzuQAChain(Chain): - """Question-answering against a graph by generating Cypher statements for Kùzu. - - *Security note*: Make sure that the database connection uses credentials - that are narrowly-scoped to only include necessary permissions. - Failure to do so may result in data corruption or loss, since the calling - code may attempt commands that would result in deletion, mutation - of data if appropriately prompted or reading sensitive data if such - data is present in the database. - The best way to guard against such negative outcomes is to (as appropriate) - limit the permissions granted to the credentials used with this tool. - - See https://python.langchain.com/docs/security for more information. - """ - - graph: KuzuGraph = Field(exclude=True) - cypher_generation_chain: LLMChain - qa_chain: LLMChain - input_key: str = "query" #: :meta private: - output_key: str = "result" #: :meta private: - - @property - def input_keys(self) -> List[str]: - """Return the input keys. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Return the output keys. - - :meta private: - """ - _output_keys = [self.output_key] - return _output_keys - - @classmethod - def from_llm( - cls, - llm: BaseLanguageModel, - *, - qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, - cypher_prompt: BasePromptTemplate = KUZU_GENERATION_PROMPT, - **kwargs: Any, - ) -> KuzuQAChain: - """Initialize from LLM.""" - qa_chain = LLMChain(llm=llm, prompt=qa_prompt) - cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt) - - return cls( - qa_chain=qa_chain, - cypher_generation_chain=cypher_generation_chain, - **kwargs, - ) - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - """Generate Cypher statement, use it to look up in db and answer question.""" - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - callbacks = _run_manager.get_child() - question = inputs[self.input_key] - - generated_cypher = self.cypher_generation_chain.run( - {"question": question, "schema": self.graph.get_schema}, callbacks=callbacks - ) - # Extract Cypher code if it is wrapped in triple backticks - # with the language marker "cypher" - generated_cypher = remove_prefix(extract_cypher(generated_cypher), "cypher") - - _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) - _run_manager.on_text( - generated_cypher, color="green", end="\n", verbose=self.verbose - ) - context = self.graph.query(generated_cypher) - - _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) - _run_manager.on_text( - str(context), color="green", end="\n", verbose=self.verbose - ) - - result = self.qa_chain( - {"question": question, "context": context}, - callbacks=callbacks, - ) - return {self.output_key: result[self.qa_chain.output_key]} +__all__ = ["KuzuQAChain", "extract_cypher", "remove_prefix"] diff --git a/libs/langchain/langchain/chains/graph_qa/nebulagraph.py b/libs/langchain/langchain/chains/graph_qa/nebulagraph.py index e53eaefb35..8ea379f34a 100644 --- a/libs/langchain/langchain/chains/graph_qa/nebulagraph.py +++ b/libs/langchain/langchain/chains/graph_qa/nebulagraph.py @@ -1,103 +1,23 @@ -"""Question answering over a graph.""" -from __future__ import annotations +from typing import TYPE_CHECKING, Any -from typing import Any, Dict, List, Optional +from langchain._api import create_importer -from langchain_community.graphs.nebula_graph import NebulaGraph -from langchain_core.callbacks import CallbackManagerForChainRun -from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts import BasePromptTemplate -from langchain_core.pydantic_v1 import Field +if TYPE_CHECKING: + from langchain_community.chains.graph_qa.nebulagraph import NebulaGraphQAChain -from langchain.chains.base import Chain -from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, NGQL_GENERATION_PROMPT -from langchain.chains.llm import LLMChain +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "NebulaGraphQAChain": "langchain_community.chains.graph_qa.nebulagraph", +} +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) -class NebulaGraphQAChain(Chain): - """Chain for question-answering against a graph by generating nGQL statements. - *Security note*: Make sure that the database connection uses credentials - that are narrowly-scoped to only include necessary permissions. - Failure to do so may result in data corruption or loss, since the calling - code may attempt commands that would result in deletion, mutation - of data if appropriately prompted or reading sensitive data if such - data is present in the database. - The best way to guard against such negative outcomes is to (as appropriate) - limit the permissions granted to the credentials used with this tool. +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - See https://python.langchain.com/docs/security for more information. - """ - graph: NebulaGraph = Field(exclude=True) - ngql_generation_chain: LLMChain - qa_chain: LLMChain - input_key: str = "query" #: :meta private: - output_key: str = "result" #: :meta private: - - @property - def input_keys(self) -> List[str]: - """Return the input keys. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Return the output keys. - - :meta private: - """ - _output_keys = [self.output_key] - return _output_keys - - @classmethod - def from_llm( - cls, - llm: BaseLanguageModel, - *, - qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, - ngql_prompt: BasePromptTemplate = NGQL_GENERATION_PROMPT, - **kwargs: Any, - ) -> NebulaGraphQAChain: - """Initialize from LLM.""" - qa_chain = LLMChain(llm=llm, prompt=qa_prompt) - ngql_generation_chain = LLMChain(llm=llm, prompt=ngql_prompt) - - return cls( - qa_chain=qa_chain, - ngql_generation_chain=ngql_generation_chain, - **kwargs, - ) - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - """Generate nGQL statement, use it to look up in db and answer question.""" - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - callbacks = _run_manager.get_child() - question = inputs[self.input_key] - - generated_ngql = self.ngql_generation_chain.run( - {"question": question, "schema": self.graph.get_schema}, callbacks=callbacks - ) - - _run_manager.on_text("Generated nGQL:", end="\n", verbose=self.verbose) - _run_manager.on_text( - generated_ngql, color="green", end="\n", verbose=self.verbose - ) - context = self.graph.query(generated_ngql) - - _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) - _run_manager.on_text( - str(context), color="green", end="\n", verbose=self.verbose - ) - - result = self.qa_chain( - {"question": question, "context": context}, - callbacks=callbacks, - ) - return {self.output_key: result[self.qa_chain.output_key]} +__all__ = ["NebulaGraphQAChain"] diff --git a/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py b/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py index 2b9447e70c..96fcf01f01 100644 --- a/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py +++ b/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py @@ -1,217 +1,39 @@ -from __future__ import annotations +from typing import TYPE_CHECKING, Any -import re -from typing import Any, Dict, List, Optional +from langchain._api import create_importer -from langchain_community.graphs import BaseNeptuneGraph -from langchain_core.callbacks import CallbackManagerForChainRun -from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts.base import BasePromptTemplate -from langchain_core.pydantic_v1 import Field - -from langchain.chains.base import Chain -from langchain.chains.graph_qa.prompts import ( - CYPHER_QA_PROMPT, - NEPTUNE_OPENCYPHER_GENERATION_PROMPT, - NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT, -) -from langchain.chains.llm import LLMChain -from langchain.chains.prompt_selector import ConditionalPromptSelector - -INTERMEDIATE_STEPS_KEY = "intermediate_steps" - - -def trim_query(query: str) -> str: - """Trim the query to only include Cypher keywords.""" - keywords = ( - "CALL", - "CREATE", - "DELETE", - "DETACH", - "LIMIT", - "MATCH", - "MERGE", - "OPTIONAL", - "ORDER", - "REMOVE", - "RETURN", - "SET", - "SKIP", - "UNWIND", - "WITH", - "WHERE", - "//", +if TYPE_CHECKING: + from langchain_community.chains.graph_qa.neptune_cypher import ( + INTERMEDIATE_STEPS_KEY, + NeptuneOpenCypherQAChain, + extract_cypher, + trim_query, + use_simple_prompt, ) - lines = query.split("\n") - new_query = "" - - for line in lines: - if line.strip().upper().startswith(keywords): - new_query += line + "\n" - - return new_query - - -def extract_cypher(text: str) -> str: - """Extract Cypher code from text using Regex.""" - # The pattern to find Cypher code enclosed in triple backticks - pattern = r"```(.*?)```" - - # Find all matches in the input text - matches = re.findall(pattern, text, re.DOTALL) - - return matches[0] if matches else text - - -def use_simple_prompt(llm: BaseLanguageModel) -> bool: - """Decides whether to use the simple prompt""" - if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore - return True - - # Bedrock anthropic - if hasattr(llm, "model_id") and "anthropic" in llm.model_id: # type: ignore - return True - - return False - - -PROMPT_SELECTOR = ConditionalPromptSelector( - default_prompt=NEPTUNE_OPENCYPHER_GENERATION_PROMPT, - conditionals=[(use_simple_prompt, NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT)], -) - - -class NeptuneOpenCypherQAChain(Chain): - """Chain for question-answering against a Neptune graph - by generating openCypher statements. - - *Security note*: Make sure that the database connection uses credentials - that are narrowly-scoped to only include necessary permissions. - Failure to do so may result in data corruption or loss, since the calling - code may attempt commands that would result in deletion, mutation - of data if appropriately prompted or reading sensitive data if such - data is present in the database. - The best way to guard against such negative outcomes is to (as appropriate) - limit the permissions granted to the credentials used with this tool. - - See https://python.langchain.com/docs/security for more information. - - Example: - .. code-block:: python - - chain = NeptuneOpenCypherQAChain.from_llm( - llm=llm, - graph=graph - ) - response = chain.run(query) - """ - - graph: BaseNeptuneGraph = Field(exclude=True) - cypher_generation_chain: LLMChain - qa_chain: LLMChain - input_key: str = "query" #: :meta private: - output_key: str = "result" #: :meta private: - top_k: int = 10 - return_intermediate_steps: bool = False - """Whether or not to return the intermediate steps along with the final answer.""" - return_direct: bool = False - """Whether or not to return the result of querying the graph directly.""" - extra_instructions: Optional[str] = None - """Extra instructions by the appended to the query generation prompt.""" - - @property - def input_keys(self) -> List[str]: - """Return the input keys. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Return the output keys. - - :meta private: - """ - _output_keys = [self.output_key] - return _output_keys - - @classmethod - def from_llm( - cls, - llm: BaseLanguageModel, - *, - qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, - cypher_prompt: Optional[BasePromptTemplate] = None, - extra_instructions: Optional[str] = None, - **kwargs: Any, - ) -> NeptuneOpenCypherQAChain: - """Initialize from LLM.""" - qa_chain = LLMChain(llm=llm, prompt=qa_prompt) - - _cypher_prompt = cypher_prompt or PROMPT_SELECTOR.get_prompt(llm) - cypher_generation_chain = LLMChain(llm=llm, prompt=_cypher_prompt) - - return cls( - qa_chain=qa_chain, - cypher_generation_chain=cypher_generation_chain, - extra_instructions=extra_instructions, - **kwargs, - ) - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, Any]: - """Generate Cypher statement, use it to look up in db and answer question.""" - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - callbacks = _run_manager.get_child() - question = inputs[self.input_key] - - intermediate_steps: List = [] - - generated_cypher = self.cypher_generation_chain.run( - { - "question": question, - "schema": self.graph.get_schema, - "extra_instructions": self.extra_instructions or "", - }, - callbacks=callbacks, - ) - - # Extract Cypher code if it is wrapped in backticks - generated_cypher = extract_cypher(generated_cypher) - generated_cypher = trim_query(generated_cypher) - - _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) - _run_manager.on_text( - generated_cypher, color="green", end="\n", verbose=self.verbose - ) - - intermediate_steps.append({"query": generated_cypher}) - - context = self.graph.query(generated_cypher) +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "INTERMEDIATE_STEPS_KEY": "langchain_community.chains.graph_qa.neptune_cypher", + "NeptuneOpenCypherQAChain": "langchain_community.chains.graph_qa.neptune_cypher", + "extract_cypher": "langchain_community.chains.graph_qa.neptune_cypher", + "trim_query": "langchain_community.chains.graph_qa.neptune_cypher", + "use_simple_prompt": "langchain_community.chains.graph_qa.neptune_cypher", +} - if self.return_direct: - final_result = context - else: - _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) - _run_manager.on_text( - str(context), color="green", end="\n", verbose=self.verbose - ) +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) - intermediate_steps.append({"context": context}) - result = self.qa_chain( - {"question": question, "context": context}, - callbacks=callbacks, - ) - final_result = result[self.qa_chain.output_key] +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - chain_result: Dict[str, Any] = {self.output_key: final_result} - if self.return_intermediate_steps: - chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps - return chain_result +__all__ = [ + "INTERMEDIATE_STEPS_KEY", + "NeptuneOpenCypherQAChain", + "extract_cypher", + "trim_query", + "use_simple_prompt", +] diff --git a/libs/langchain/langchain/chains/graph_qa/neptune_sparql.py b/libs/langchain/langchain/chains/graph_qa/neptune_sparql.py index d1e9a2f2e6..f4445b65d3 100644 --- a/libs/langchain/langchain/chains/graph_qa/neptune_sparql.py +++ b/libs/langchain/langchain/chains/graph_qa/neptune_sparql.py @@ -1,204 +1,36 @@ -""" -Question answering over an RDF or OWL graph using SPARQL. -""" -from __future__ import annotations - -from typing import Any, Dict, List, Optional - -from langchain_community.graphs import NeptuneRdfGraph -from langchain_core.callbacks.manager import CallbackManagerForChainRun -from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts.base import BasePromptTemplate -from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.pydantic_v1 import Field - -from langchain.chains.base import Chain -from langchain.chains.graph_qa.prompts import SPARQL_QA_PROMPT -from langchain.chains.llm import LLMChain - -INTERMEDIATE_STEPS_KEY = "intermediate_steps" - -SPARQL_GENERATION_TEMPLATE = """ -Task: Generate a SPARQL SELECT statement for querying a graph database. -For instance, to find all email addresses of John Doe, the following -query in backticks would be suitable: -``` -PREFIX foaf: -SELECT ?email -WHERE {{ - ?person foaf:name "John Doe" . - ?person foaf:mbox ?email . -}} -``` -Instructions: -Use only the node types and properties provided in the schema. -Do not use any node types and properties that are not explicitly provided. -Include all necessary prefixes. - -Examples: - -Schema: -{schema} -Note: Be as concise as possible. -Do not include any explanations or apologies in your responses. -Do not respond to any questions that ask for anything else than -for you to construct a SPARQL query. -Do not include any text except the SPARQL query generated. - -The question is: -{prompt}""" - -SPARQL_GENERATION_PROMPT = PromptTemplate( - input_variables=["schema", "prompt"], template=SPARQL_GENERATION_TEMPLATE -) - - -def extract_sparql(query: str) -> str: - """Extract SPARQL code from a text. - - Args: - query: Text to extract SPARQL code from. - - Returns: - SPARQL code extracted from the text. - """ - query = query.strip() - querytoks = query.split("```") - if len(querytoks) == 3: - query = querytoks[1] - - if query.startswith("sparql"): - query = query[6:] - elif query.startswith("") and query.endswith(""): - query = query[8:-9] - return query - - -class NeptuneSparqlQAChain(Chain): - """Chain for question-answering against a Neptune graph - by generating SPARQL statements. - - *Security note*: Make sure that the database connection uses credentials - that are narrowly-scoped to only include necessary permissions. - Failure to do so may result in data corruption or loss, since the calling - code may attempt commands that would result in deletion, mutation - of data if appropriately prompted or reading sensitive data if such - data is present in the database. - The best way to guard against such negative outcomes is to (as appropriate) - limit the permissions granted to the credentials used with this tool. - - See https://python.langchain.com/docs/security for more information. - - Example: - .. code-block:: python - - chain = NeptuneSparqlQAChain.from_llm( - llm=llm, - graph=graph - ) - response = chain.invoke(query) - """ - - graph: NeptuneRdfGraph = Field(exclude=True) - sparql_generation_chain: LLMChain - qa_chain: LLMChain - input_key: str = "query" #: :meta private: - output_key: str = "result" #: :meta private: - top_k: int = 10 - return_intermediate_steps: bool = False - """Whether or not to return the intermediate steps along with the final answer.""" - return_direct: bool = False - """Whether or not to return the result of querying the graph directly.""" - extra_instructions: Optional[str] = None - """Extra instructions by the appended to the query generation prompt.""" - - @property - def input_keys(self) -> List[str]: - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - _output_keys = [self.output_key] - return _output_keys - - @classmethod - def from_llm( - cls, - llm: BaseLanguageModel, - *, - qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT, - sparql_prompt: BasePromptTemplate = SPARQL_GENERATION_PROMPT, - examples: Optional[str] = None, - **kwargs: Any, - ) -> NeptuneSparqlQAChain: - """Initialize from LLM.""" - qa_chain = LLMChain(llm=llm, prompt=qa_prompt) - template_to_use = SPARQL_GENERATION_TEMPLATE - if examples: - template_to_use = template_to_use.replace( - "Examples:", "Examples: " + examples - ) - sparql_prompt = PromptTemplate( - input_variables=["schema", "prompt"], template=template_to_use - ) - sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_prompt) - - return cls( # type: ignore[call-arg] - qa_chain=qa_chain, - sparql_generation_chain=sparql_generation_chain, - examples=examples, - **kwargs, - ) - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - """ - Generate SPARQL query, use it to retrieve a response from the gdb and answer - the question. - """ - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - callbacks = _run_manager.get_child() - prompt = inputs[self.input_key] - - intermediate_steps: List = [] - - generated_sparql = self.sparql_generation_chain.run( - {"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks - ) - - # Extract SPARQL - generated_sparql = extract_sparql(generated_sparql) - - _run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose) - _run_manager.on_text( - generated_sparql, color="green", end="\n", verbose=self.verbose - ) - - intermediate_steps.append({"query": generated_sparql}) - - context = self.graph.query(generated_sparql) - - if self.return_direct: - final_result = context - else: - _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) - _run_manager.on_text( - str(context), color="green", end="\n", verbose=self.verbose - ) - - intermediate_steps.append({"context": context}) - - result = self.qa_chain( - {"prompt": prompt, "context": context}, - callbacks=callbacks, - ) - final_result = result[self.qa_chain.output_key] - - chain_result: Dict[str, Any] = {self.output_key: final_result} - if self.return_intermediate_steps: - chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps - - return chain_result +from typing import TYPE_CHECKING, Any + +from langchain._api import create_importer + +if TYPE_CHECKING: + from langchain_community.chains.graph_qa.neptune_sparql import ( + INTERMEDIATE_STEPS_KEY, + SPARQL_GENERATION_TEMPLATE, + NeptuneSparqlQAChain, + extract_sparql, + ) + +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "INTERMEDIATE_STEPS_KEY": "langchain_community.chains.graph_qa.neptune_sparql", + "NeptuneSparqlQAChain": "langchain_community.chains.graph_qa.neptune_sparql", + "SPARQL_GENERATION_TEMPLATE": "langchain_community.chains.graph_qa.neptune_sparql", + "extract_sparql": "langchain_community.chains.graph_qa.neptune_sparql", +} + +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) + + +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) + + +__all__ = [ + "INTERMEDIATE_STEPS_KEY", + "NeptuneSparqlQAChain", + "SPARQL_GENERATION_TEMPLATE", + "extract_sparql", +] diff --git a/libs/langchain/langchain/chains/graph_qa/ontotext_graphdb.py b/libs/langchain/langchain/chains/graph_qa/ontotext_graphdb.py index d3e8d365c6..d1e8a11bbf 100644 --- a/libs/langchain/langchain/chains/graph_qa/ontotext_graphdb.py +++ b/libs/langchain/langchain/chains/graph_qa/ontotext_graphdb.py @@ -1,190 +1,25 @@ -"""Question answering over a graph.""" -from __future__ import annotations +from typing import TYPE_CHECKING, Any -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from langchain._api import create_importer if TYPE_CHECKING: - import rdflib + from langchain_community.chains.graph_qa.ontotext_graphdb import ( + OntotextGraphDBQAChain, + ) -from langchain_community.graphs import OntotextGraphDBGraph -from langchain_core.callbacks.manager import CallbackManager, CallbackManagerForChainRun -from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts.base import BasePromptTemplate -from langchain_core.pydantic_v1 import Field +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "OntotextGraphDBQAChain": "langchain_community.chains.graph_qa.ontotext_graphdb", +} -from langchain.chains.base import Chain -from langchain.chains.graph_qa.prompts import ( - GRAPHDB_QA_PROMPT, - GRAPHDB_SPARQL_FIX_PROMPT, - GRAPHDB_SPARQL_GENERATION_PROMPT, -) -from langchain.chains.llm import LLMChain +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) -class OntotextGraphDBQAChain(Chain): - """Question-answering against Ontotext GraphDB - https://graphdb.ontotext.com/ by generating SPARQL queries. +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - *Security note*: Make sure that the database connection uses credentials - that are narrowly-scoped to only include necessary permissions. - Failure to do so may result in data corruption or loss, since the calling - code may attempt commands that would result in deletion, mutation - of data if appropriately prompted or reading sensitive data if such - data is present in the database. - The best way to guard against such negative outcomes is to (as appropriate) - limit the permissions granted to the credentials used with this tool. - See https://python.langchain.com/docs/security for more information. - """ - - graph: OntotextGraphDBGraph = Field(exclude=True) - sparql_generation_chain: LLMChain - sparql_fix_chain: LLMChain - max_fix_retries: int - qa_chain: LLMChain - input_key: str = "query" #: :meta private: - output_key: str = "result" #: :meta private: - - @property - def input_keys(self) -> List[str]: - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - _output_keys = [self.output_key] - return _output_keys - - @classmethod - def from_llm( - cls, - llm: BaseLanguageModel, - *, - sparql_generation_prompt: BasePromptTemplate = GRAPHDB_SPARQL_GENERATION_PROMPT, - sparql_fix_prompt: BasePromptTemplate = GRAPHDB_SPARQL_FIX_PROMPT, - max_fix_retries: int = 5, - qa_prompt: BasePromptTemplate = GRAPHDB_QA_PROMPT, - **kwargs: Any, - ) -> OntotextGraphDBQAChain: - """Initialize from LLM.""" - sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_generation_prompt) - sparql_fix_chain = LLMChain(llm=llm, prompt=sparql_fix_prompt) - max_fix_retries = max_fix_retries - qa_chain = LLMChain(llm=llm, prompt=qa_prompt) - return cls( - qa_chain=qa_chain, - sparql_generation_chain=sparql_generation_chain, - sparql_fix_chain=sparql_fix_chain, - max_fix_retries=max_fix_retries, - **kwargs, - ) - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - """ - Generate a SPARQL query, use it to retrieve a response from GraphDB and answer - the question. - """ - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - callbacks = _run_manager.get_child() - prompt = inputs[self.input_key] - ontology_schema = self.graph.get_schema - - sparql_generation_chain_result = self.sparql_generation_chain.invoke( - {"prompt": prompt, "schema": ontology_schema}, callbacks=callbacks - ) - generated_sparql = sparql_generation_chain_result[ - self.sparql_generation_chain.output_key - ] - - generated_sparql = self._get_prepared_sparql_query( - _run_manager, callbacks, generated_sparql, ontology_schema - ) - query_results = self._execute_query(generated_sparql) - - qa_chain_result = self.qa_chain.invoke( - {"prompt": prompt, "context": query_results}, callbacks=callbacks - ) - result = qa_chain_result[self.qa_chain.output_key] - return {self.output_key: result} - - def _get_prepared_sparql_query( - self, - _run_manager: CallbackManagerForChainRun, - callbacks: CallbackManager, - generated_sparql: str, - ontology_schema: str, - ) -> str: - try: - return self._prepare_sparql_query(_run_manager, generated_sparql) - except Exception as e: - retries = 0 - error_message = str(e) - self._log_invalid_sparql_query( - _run_manager, generated_sparql, error_message - ) - - while retries < self.max_fix_retries: - try: - sparql_fix_chain_result = self.sparql_fix_chain.invoke( - { - "error_message": error_message, - "generated_sparql": generated_sparql, - "schema": ontology_schema, - }, - callbacks=callbacks, - ) - generated_sparql = sparql_fix_chain_result[ - self.sparql_fix_chain.output_key - ] - return self._prepare_sparql_query(_run_manager, generated_sparql) - except Exception as e: - retries += 1 - parse_exception = str(e) - self._log_invalid_sparql_query( - _run_manager, generated_sparql, parse_exception - ) - - raise ValueError("The generated SPARQL query is invalid.") - - def _prepare_sparql_query( - self, _run_manager: CallbackManagerForChainRun, generated_sparql: str - ) -> str: - from rdflib.plugins.sparql import prepareQuery - - prepareQuery(generated_sparql) - self._log_prepared_sparql_query(_run_manager, generated_sparql) - return generated_sparql - - def _log_prepared_sparql_query( - self, _run_manager: CallbackManagerForChainRun, generated_query: str - ) -> None: - _run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose) - _run_manager.on_text( - generated_query, color="green", end="\n", verbose=self.verbose - ) - - def _log_invalid_sparql_query( - self, - _run_manager: CallbackManagerForChainRun, - generated_query: str, - error_message: str, - ) -> None: - _run_manager.on_text("Invalid SPARQL query: ", end="\n", verbose=self.verbose) - _run_manager.on_text( - generated_query, color="red", end="\n", verbose=self.verbose - ) - _run_manager.on_text( - "SPARQL Query Parse Error: ", end="\n", verbose=self.verbose - ) - _run_manager.on_text( - error_message, color="red", end="\n\n", verbose=self.verbose - ) - - def _execute_query(self, query: str) -> List[rdflib.query.ResultRow]: - try: - return self.graph.query(query) - except Exception: - raise ValueError("Failed to execute the generated SPARQL query.") +__all__ = ["OntotextGraphDBQAChain"] diff --git a/libs/langchain/langchain/chains/graph_qa/prompts.py b/libs/langchain/langchain/chains/graph_qa/prompts.py index a4b5db9583..1b7ac18131 100644 --- a/libs/langchain/langchain/chains/graph_qa/prompts.py +++ b/libs/langchain/langchain/chains/graph_qa/prompts.py @@ -1,415 +1,96 @@ -# flake8: noqa -from langchain_core.prompts.prompt import PromptTemplate - -_DEFAULT_ENTITY_EXTRACTION_TEMPLATE = """Extract all entities from the following text. As a guideline, a proper noun is generally capitalized. You should definitely extract all names and places. - -Return the output as a single comma-separated list, or NONE if there is nothing of note to return. - -EXAMPLE -i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff. -Output: Langchain -END OF EXAMPLE - -EXAMPLE -i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff. I'm working with Sam. -Output: Langchain, Sam -END OF EXAMPLE - -Begin! - -{input} -Output:""" -ENTITY_EXTRACTION_PROMPT = PromptTemplate( - input_variables=["input"], template=_DEFAULT_ENTITY_EXTRACTION_TEMPLATE -) - -_DEFAULT_GRAPH_QA_TEMPLATE = """Use the following knowledge triplets to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. - -{context} - -Question: {question} -Helpful Answer:""" -GRAPH_QA_PROMPT = PromptTemplate( - template=_DEFAULT_GRAPH_QA_TEMPLATE, input_variables=["context", "question"] -) - -CYPHER_GENERATION_TEMPLATE = """Task:Generate Cypher statement to query a graph database. -Instructions: -Use only the provided relationship types and properties in the schema. -Do not use any other relationship types or properties that are not provided. -Schema: -{schema} -Note: Do not include any explanations or apologies in your responses. -Do not respond to any questions that might ask anything else than for you to construct a Cypher statement. -Do not include any text except the generated Cypher statement. - -The question is: -{question}""" -CYPHER_GENERATION_PROMPT = PromptTemplate( - input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE -) - -NEBULAGRAPH_EXTRA_INSTRUCTIONS = """ -Instructions: - -First, generate cypher then convert it to NebulaGraph Cypher dialect(rather than standard): -1. it requires explicit label specification only when referring to node properties: v.`Foo`.name -2. note explicit label specification is not needed for edge properties, so it's e.name instead of e.`Bar`.name -3. it uses double equals sign for comparison: `==` rather than `=` -For instance: -```diff -< MATCH (p:person)-[e:directed]->(m:movie) WHERE m.name = 'The Godfather II' -< RETURN p.name, e.year, m.name; ---- -> MATCH (p:`person`)-[e:directed]->(m:`movie`) WHERE m.`movie`.`name` == 'The Godfather II' -> RETURN p.`person`.`name`, e.year, m.`movie`.`name`; -```\n""" - -NGQL_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace( - "Generate Cypher", "Generate NebulaGraph Cypher" -).replace("Instructions:", NEBULAGRAPH_EXTRA_INSTRUCTIONS) - -NGQL_GENERATION_PROMPT = PromptTemplate( - input_variables=["schema", "question"], template=NGQL_GENERATION_TEMPLATE -) - -KUZU_EXTRA_INSTRUCTIONS = """ -Instructions: - -Generate the Kùzu dialect of Cypher with the following rules in mind: - -1. Do not use a `WHERE EXISTS` clause to check the existence of a property. -2. Do not omit the relationship pattern. Always use `()-[]->()` instead of `()->()`. -3. Do not include any notes or comments even if the statement does not produce the expected result. -```\n""" - -KUZU_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace( - "Generate Cypher", "Generate Kùzu Cypher" -).replace("Instructions:", KUZU_EXTRA_INSTRUCTIONS) - -KUZU_GENERATION_PROMPT = PromptTemplate( - input_variables=["schema", "question"], template=KUZU_GENERATION_TEMPLATE -) - -GREMLIN_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace("Cypher", "Gremlin") - -GREMLIN_GENERATION_PROMPT = PromptTemplate( - input_variables=["schema", "question"], template=GREMLIN_GENERATION_TEMPLATE -) - -CYPHER_QA_TEMPLATE = """You are an assistant that helps to form nice and human understandable answers. -The information part contains the provided information that you must use to construct an answer. -The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it. -Make the answer sound as a response to the question. Do not mention that you based the result on the given information. -Here is an example: - -Question: Which managers own Neo4j stocks? -Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC] -Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks. - -Follow this example when generating answers. -If the provided information is empty, say that you don't know the answer. -Information: -{context} - -Question: {question} -Helpful Answer:""" -CYPHER_QA_PROMPT = PromptTemplate( - input_variables=["context", "question"], template=CYPHER_QA_TEMPLATE -) - -SPARQL_INTENT_TEMPLATE = """Task: Identify the intent of a prompt and return the appropriate SPARQL query type. -You are an assistant that distinguishes different types of prompts and returns the corresponding SPARQL query types. -Consider only the following query types: -* SELECT: this query type corresponds to questions -* UPDATE: this query type corresponds to all requests for deleting, inserting, or changing triples -Note: Be as concise as possible. -Do not include any explanations or apologies in your responses. -Do not respond to any questions that ask for anything else than for you to identify a SPARQL query type. -Do not include any unnecessary whitespaces or any text except the query type, i.e., either return 'SELECT' or 'UPDATE'. - -The prompt is: -{prompt} -Helpful Answer:""" -SPARQL_INTENT_PROMPT = PromptTemplate( - input_variables=["prompt"], template=SPARQL_INTENT_TEMPLATE -) - -SPARQL_GENERATION_SELECT_TEMPLATE = """Task: Generate a SPARQL SELECT statement for querying a graph database. -For instance, to find all email addresses of John Doe, the following query in backticks would be suitable: -``` -PREFIX foaf: -SELECT ?email -WHERE {{ - ?person foaf:name "John Doe" . - ?person foaf:mbox ?email . -}} -``` -Instructions: -Use only the node types and properties provided in the schema. -Do not use any node types and properties that are not explicitly provided. -Include all necessary prefixes. -Schema: -{schema} -Note: Be as concise as possible. -Do not include any explanations or apologies in your responses. -Do not respond to any questions that ask for anything else than for you to construct a SPARQL query. -Do not include any text except the SPARQL query generated. - -The question is: -{prompt}""" -SPARQL_GENERATION_SELECT_PROMPT = PromptTemplate( - input_variables=["schema", "prompt"], template=SPARQL_GENERATION_SELECT_TEMPLATE -) - -SPARQL_GENERATION_UPDATE_TEMPLATE = """Task: Generate a SPARQL UPDATE statement for updating a graph database. -For instance, to add 'jane.doe@foo.bar' as a new email address for Jane Doe, the following query in backticks would be suitable: -``` -PREFIX foaf: -INSERT {{ - ?person foaf:mbox . -}} -WHERE {{ - ?person foaf:name "Jane Doe" . -}} -``` -Instructions: -Make the query as short as possible and avoid adding unnecessary triples. -Use only the node types and properties provided in the schema. -Do not use any node types and properties that are not explicitly provided. -Include all necessary prefixes. -Schema: -{schema} -Note: Be as concise as possible. -Do not include any explanations or apologies in your responses. -Do not respond to any questions that ask for anything else than for you to construct a SPARQL query. -Return only the generated SPARQL query, nothing else. - -The information to be inserted is: -{prompt}""" -SPARQL_GENERATION_UPDATE_PROMPT = PromptTemplate( - input_variables=["schema", "prompt"], template=SPARQL_GENERATION_UPDATE_TEMPLATE -) - -SPARQL_QA_TEMPLATE = """Task: Generate a natural language response from the results of a SPARQL query. -You are an assistant that creates well-written and human understandable answers. -The information part contains the information provided, which you can use to construct an answer. -The information provided is authoritative, you must never doubt it or try to use your internal knowledge to correct it. -Make your response sound like the information is coming from an AI assistant, but don't add any information. -Information: -{context} - -Question: {prompt} -Helpful Answer:""" -SPARQL_QA_PROMPT = PromptTemplate( - input_variables=["context", "prompt"], template=SPARQL_QA_TEMPLATE -) - -GRAPHDB_SPARQL_GENERATION_TEMPLATE = """ -Write a SPARQL SELECT query for querying a graph database. -The ontology schema delimited by triple backticks in Turtle format is: -``` -{schema} -``` -Use only the classes and properties provided in the schema to construct the SPARQL query. -Do not use any classes or properties that are not explicitly provided in the SPARQL query. -Include all necessary prefixes. -Do not include any explanations or apologies in your responses. -Do not wrap the query in backticks. -Do not include any text except the SPARQL query generated. -The question delimited by triple backticks is: -``` -{prompt} -``` -""" -GRAPHDB_SPARQL_GENERATION_PROMPT = PromptTemplate( - input_variables=["schema", "prompt"], - template=GRAPHDB_SPARQL_GENERATION_TEMPLATE, -) - -GRAPHDB_SPARQL_FIX_TEMPLATE = """ -This following SPARQL query delimited by triple backticks -``` -{generated_sparql} -``` -is not valid. -The error delimited by triple backticks is -``` -{error_message} -``` -Give me a correct version of the SPARQL query. -Do not change the logic of the query. -Do not include any explanations or apologies in your responses. -Do not wrap the query in backticks. -Do not include any text except the SPARQL query generated. -The ontology schema delimited by triple backticks in Turtle format is: -``` -{schema} -``` -""" - -GRAPHDB_SPARQL_FIX_PROMPT = PromptTemplate( - input_variables=["error_message", "generated_sparql", "schema"], - template=GRAPHDB_SPARQL_FIX_TEMPLATE, -) - -GRAPHDB_QA_TEMPLATE = """Task: Generate a natural language response from the results of a SPARQL query. -You are an assistant that creates well-written and human understandable answers. -The information part contains the information provided, which you can use to construct an answer. -The information provided is authoritative, you must never doubt it or try to use your internal knowledge to correct it. -Make your response sound like the information is coming from an AI assistant, but don't add any information. -Don't use internal knowledge to answer the question, just say you don't know if no information is available. -Information: -{context} - -Question: {prompt} -Helpful Answer:""" -GRAPHDB_QA_PROMPT = PromptTemplate( - input_variables=["context", "prompt"], template=GRAPHDB_QA_TEMPLATE -) - -AQL_GENERATION_TEMPLATE = """Task: Generate an ArangoDB Query Language (AQL) query from a User Input. - -You are an ArangoDB Query Language (AQL) expert responsible for translating a `User Input` into an ArangoDB Query Language (AQL) query. - -You are given an `ArangoDB Schema`. It is a JSON Object containing: -1. `Graph Schema`: Lists all Graphs within the ArangoDB Database Instance, along with their Edge Relationships. -2. `Collection Schema`: Lists all Collections within the ArangoDB Database Instance, along with their document/edge properties and a document/edge example. - -You may also be given a set of `AQL Query Examples` to help you create the `AQL Query`. If provided, the `AQL Query Examples` should be used as a reference, similar to how `ArangoDB Schema` should be used. - -Things you should do: -- Think step by step. -- Rely on `ArangoDB Schema` and `AQL Query Examples` (if provided) to generate the query. -- Begin the `AQL Query` by the `WITH` AQL keyword to specify all of the ArangoDB Collections required. -- Return the `AQL Query` wrapped in 3 backticks (```). -- Use only the provided relationship types and properties in the `ArangoDB Schema` and any `AQL Query Examples` queries. -- Only answer to requests related to generating an AQL Query. -- If a request is unrelated to generating AQL Query, say that you cannot help the user. - -Things you should not do: -- Do not use any properties/relationships that can't be inferred from the `ArangoDB Schema` or the `AQL Query Examples`. -- Do not include any text except the generated AQL Query. -- Do not provide explanations or apologies in your responses. -- Do not generate an AQL Query that removes or deletes any data. - -Under no circumstance should you generate an AQL Query that deletes any data whatsoever. - -ArangoDB Schema: -{adb_schema} - -AQL Query Examples (Optional): -{aql_examples} - -User Input: -{user_input} - -AQL Query: -""" - -AQL_GENERATION_PROMPT = PromptTemplate( - input_variables=["adb_schema", "aql_examples", "user_input"], - template=AQL_GENERATION_TEMPLATE, -) - -AQL_FIX_TEMPLATE = """Task: Address the ArangoDB Query Language (AQL) error message of an ArangoDB Query Language query. - -You are an ArangoDB Query Language (AQL) expert responsible for correcting the provided `AQL Query` based on the provided `AQL Error`. - -The `AQL Error` explains why the `AQL Query` could not be executed in the database. -The `AQL Error` may also contain the position of the error relative to the total number of lines of the `AQL Query`. -For example, 'error X at position 2:5' denotes that the error X occurs on line 2, column 5 of the `AQL Query`. - -You are also given the `ArangoDB Schema`. It is a JSON Object containing: -1. `Graph Schema`: Lists all Graphs within the ArangoDB Database Instance, along with their Edge Relationships. -2. `Collection Schema`: Lists all Collections within the ArangoDB Database Instance, along with their document/edge properties and a document/edge example. - -You will output the `Corrected AQL Query` wrapped in 3 backticks (```). Do not include any text except the Corrected AQL Query. - -Remember to think step by step. - -ArangoDB Schema: -{adb_schema} - -AQL Query: -{aql_query} - -AQL Error: -{aql_error} - -Corrected AQL Query: -""" - -AQL_FIX_PROMPT = PromptTemplate( - input_variables=[ - "adb_schema", - "aql_query", - "aql_error", - ], - template=AQL_FIX_TEMPLATE, -) - -AQL_QA_TEMPLATE = """Task: Generate a natural language `Summary` from the results of an ArangoDB Query Language query. - -You are an ArangoDB Query Language (AQL) expert responsible for creating a well-written `Summary` from the `User Input` and associated `AQL Result`. - -A user has executed an ArangoDB Query Language query, which has returned the AQL Result in JSON format. -You are responsible for creating an `Summary` based on the AQL Result. - -You are given the following information: -- `ArangoDB Schema`: contains a schema representation of the user's ArangoDB Database. -- `User Input`: the original question/request of the user, which has been translated into an AQL Query. -- `AQL Query`: the AQL equivalent of the `User Input`, translated by another AI Model. Should you deem it to be incorrect, suggest a different AQL Query. -- `AQL Result`: the JSON output returned by executing the `AQL Query` within the ArangoDB Database. - -Remember to think step by step. - -Your `Summary` should sound like it is a response to the `User Input`. -Your `Summary` should not include any mention of the `AQL Query` or the `AQL Result`. - -ArangoDB Schema: -{adb_schema} - -User Input: -{user_input} - -AQL Query: -{aql_query} - -AQL Result: -{aql_result} -""" -AQL_QA_PROMPT = PromptTemplate( - input_variables=["adb_schema", "user_input", "aql_query", "aql_result"], - template=AQL_QA_TEMPLATE, -) - - -NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS = """ -Instructions: -Generate the query in openCypher format and follow these rules: -Do not use `NONE`, `ALL` or `ANY` predicate functions, rather use list comprehensions. -Do not use `REDUCE` function. Rather use a combination of list comprehension and the `UNWIND` clause to achieve similar results. -Do not use `FOREACH` clause. Rather use a combination of `WITH` and `UNWIND` clauses to achieve similar results.{extra_instructions} -\n""" - -NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace( - "Instructions:", NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS -) - -NEPTUNE_OPENCYPHER_GENERATION_PROMPT = PromptTemplate( - input_variables=["schema", "question", "extra_instructions"], - template=NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE, -) - -NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE = """ -Write an openCypher query to answer the following question. Do not explain the answer. Only return the query.{extra_instructions} -Question: "{question}". -Here is the property graph schema: -{schema} -\n""" - -NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT = PromptTemplate( - input_variables=["schema", "question", "extra_instructions"], - template=NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE, -) +from typing import TYPE_CHECKING, Any + +from langchain._api import create_importer + +if TYPE_CHECKING: + from langchain_community.chains.graph_qa.prompts import ( + AQL_FIX_TEMPLATE, + AQL_GENERATION_TEMPLATE, + AQL_QA_TEMPLATE, + CYPHER_GENERATION_PROMPT, + CYPHER_GENERATION_TEMPLATE, + CYPHER_QA_PROMPT, + CYPHER_QA_TEMPLATE, + GRAPHDB_QA_TEMPLATE, + GRAPHDB_SPARQL_FIX_TEMPLATE, + GRAPHDB_SPARQL_GENERATION_TEMPLATE, + GREMLIN_GENERATION_TEMPLATE, + KUZU_EXTRA_INSTRUCTIONS, + KUZU_GENERATION_TEMPLATE, + NEBULAGRAPH_EXTRA_INSTRUCTIONS, + NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS, + NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE, + NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE, + NGQL_GENERATION_TEMPLATE, + SPARQL_GENERATION_SELECT_TEMPLATE, + SPARQL_GENERATION_UPDATE_TEMPLATE, + SPARQL_INTENT_TEMPLATE, + SPARQL_QA_TEMPLATE, + ) + +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "AQL_FIX_TEMPLATE": "langchain_community.chains.graph_qa.prompts", + "AQL_GENERATION_TEMPLATE": "langchain_community.chains.graph_qa.prompts", + "AQL_QA_TEMPLATE": "langchain_community.chains.graph_qa.prompts", + "CYPHER_GENERATION_TEMPLATE": "langchain_community.chains.graph_qa.prompts", + "CYPHER_QA_TEMPLATE": "langchain_community.chains.graph_qa.prompts", + "CYPHER_QA_PROMPT": "langchain_community.chains.graph_qa.prompts", + "CYPHER_GENERATION_PROMPT": "langchain_community.chains.graph_qa.prompts", + "GRAPHDB_QA_TEMPLATE": "langchain_community.chains.graph_qa.prompts", + "GRAPHDB_SPARQL_FIX_TEMPLATE": "langchain_community.chains.graph_qa.prompts", + "GRAPHDB_SPARQL_GENERATION_TEMPLATE": "langchain_community.chains.graph_qa.prompts", + "GREMLIN_GENERATION_TEMPLATE": "langchain_community.chains.graph_qa.prompts", + "KUZU_EXTRA_INSTRUCTIONS": "langchain_community.chains.graph_qa.prompts", + "KUZU_GENERATION_TEMPLATE": "langchain_community.chains.graph_qa.prompts", + "NEBULAGRAPH_EXTRA_INSTRUCTIONS": "langchain_community.chains.graph_qa.prompts", + "NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS": ( + "langchain_community.chains.graph_qa.prompts" + ), + "NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE": ( + "langchain_community.chains.graph_qa.prompts" + ), + "NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE": ( + "langchain_community.chains.graph_qa.prompts" + ), + "NGQL_GENERATION_TEMPLATE": "langchain_community.chains.graph_qa.prompts", + "SPARQL_GENERATION_SELECT_TEMPLATE": "langchain_community.chains.graph_qa.prompts", + "SPARQL_GENERATION_UPDATE_TEMPLATE": "langchain_community.chains.graph_qa.prompts", + "SPARQL_INTENT_TEMPLATE": "langchain_community.chains.graph_qa.prompts", + "SPARQL_QA_TEMPLATE": "langchain_community.chains.graph_qa.prompts", +} + +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) + + +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) + + +__all__ = [ + "AQL_FIX_TEMPLATE", + "AQL_GENERATION_TEMPLATE", + "AQL_QA_TEMPLATE", + "CYPHER_GENERATION_TEMPLATE", + "CYPHER_QA_TEMPLATE", + "GRAPHDB_QA_TEMPLATE", + "GRAPHDB_SPARQL_FIX_TEMPLATE", + "GRAPHDB_SPARQL_GENERATION_TEMPLATE", + "GREMLIN_GENERATION_TEMPLATE", + "KUZU_EXTRA_INSTRUCTIONS", + "KUZU_GENERATION_TEMPLATE", + "NEBULAGRAPH_EXTRA_INSTRUCTIONS", + "NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS", + "NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE", + "NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE", + "NGQL_GENERATION_TEMPLATE", + "SPARQL_GENERATION_SELECT_TEMPLATE", + "SPARQL_GENERATION_UPDATE_TEMPLATE", + "SPARQL_INTENT_TEMPLATE", + "SPARQL_QA_TEMPLATE", + "CYPHER_QA_PROMPT", + "CYPHER_GENERATION_PROMPT", +] diff --git a/libs/langchain/langchain/chains/graph_qa/sparql.py b/libs/langchain/langchain/chains/graph_qa/sparql.py index f1c5d2fc81..363f9942c7 100644 --- a/libs/langchain/langchain/chains/graph_qa/sparql.py +++ b/libs/langchain/langchain/chains/graph_qa/sparql.py @@ -1,152 +1,23 @@ -""" -Question answering over an RDF or OWL graph using SPARQL. -""" -from __future__ import annotations +from typing import TYPE_CHECKING, Any -from typing import Any, Dict, List, Optional +from langchain._api import create_importer -from langchain_community.graphs.rdf_graph import RdfGraph -from langchain_core.callbacks import CallbackManagerForChainRun -from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts.base import BasePromptTemplate -from langchain_core.pydantic_v1 import Field +if TYPE_CHECKING: + from langchain_community.chains.graph_qa.sparql import GraphSparqlQAChain -from langchain.chains.base import Chain -from langchain.chains.graph_qa.prompts import ( - SPARQL_GENERATION_SELECT_PROMPT, - SPARQL_GENERATION_UPDATE_PROMPT, - SPARQL_INTENT_PROMPT, - SPARQL_QA_PROMPT, -) -from langchain.chains.llm import LLMChain +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "GraphSparqlQAChain": "langchain_community.chains.graph_qa.sparql", +} +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) -class GraphSparqlQAChain(Chain): - """Question-answering against an RDF or OWL graph by generating SPARQL statements. - *Security note*: Make sure that the database connection uses credentials - that are narrowly-scoped to only include necessary permissions. - Failure to do so may result in data corruption or loss, since the calling - code may attempt commands that would result in deletion, mutation - of data if appropriately prompted or reading sensitive data if such - data is present in the database. - The best way to guard against such negative outcomes is to (as appropriate) - limit the permissions granted to the credentials used with this tool. +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - See https://python.langchain.com/docs/security for more information. - """ - graph: RdfGraph = Field(exclude=True) - sparql_generation_select_chain: LLMChain - sparql_generation_update_chain: LLMChain - sparql_intent_chain: LLMChain - qa_chain: LLMChain - return_sparql_query: bool = False - input_key: str = "query" #: :meta private: - output_key: str = "result" #: :meta private: - sparql_query_key: str = "sparql_query" #: :meta private: - - @property - def input_keys(self) -> List[str]: - """Return the input keys. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Return the output keys. - - :meta private: - """ - _output_keys = [self.output_key] - return _output_keys - - @classmethod - def from_llm( - cls, - llm: BaseLanguageModel, - *, - qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT, - sparql_select_prompt: BasePromptTemplate = SPARQL_GENERATION_SELECT_PROMPT, - sparql_update_prompt: BasePromptTemplate = SPARQL_GENERATION_UPDATE_PROMPT, - sparql_intent_prompt: BasePromptTemplate = SPARQL_INTENT_PROMPT, - **kwargs: Any, - ) -> GraphSparqlQAChain: - """Initialize from LLM.""" - qa_chain = LLMChain(llm=llm, prompt=qa_prompt) - sparql_generation_select_chain = LLMChain(llm=llm, prompt=sparql_select_prompt) - sparql_generation_update_chain = LLMChain(llm=llm, prompt=sparql_update_prompt) - sparql_intent_chain = LLMChain(llm=llm, prompt=sparql_intent_prompt) - - return cls( - qa_chain=qa_chain, - sparql_generation_select_chain=sparql_generation_select_chain, - sparql_generation_update_chain=sparql_generation_update_chain, - sparql_intent_chain=sparql_intent_chain, - **kwargs, - ) - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - """ - Generate SPARQL query, use it to retrieve a response from the gdb and answer - the question. - """ - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - callbacks = _run_manager.get_child() - prompt = inputs[self.input_key] - - _intent = self.sparql_intent_chain.run({"prompt": prompt}, callbacks=callbacks) - intent = _intent.strip() - - if "SELECT" in intent and "UPDATE" not in intent: - sparql_generation_chain = self.sparql_generation_select_chain - intent = "SELECT" - elif "UPDATE" in intent and "SELECT" not in intent: - sparql_generation_chain = self.sparql_generation_update_chain - intent = "UPDATE" - else: - raise ValueError( - "I am sorry, but this prompt seems to fit none of the currently " - "supported SPARQL query types, i.e., SELECT and UPDATE." - ) - - _run_manager.on_text("Identified intent:", end="\n", verbose=self.verbose) - _run_manager.on_text(intent, color="green", end="\n", verbose=self.verbose) - - generated_sparql = sparql_generation_chain.run( - {"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks - ) - - _run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose) - _run_manager.on_text( - generated_sparql, color="green", end="\n", verbose=self.verbose - ) - - if intent == "SELECT": - context = self.graph.query(generated_sparql) - - _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) - _run_manager.on_text( - str(context), color="green", end="\n", verbose=self.verbose - ) - result = self.qa_chain( - {"prompt": prompt, "context": context}, - callbacks=callbacks, - ) - res = result[self.qa_chain.output_key] - elif intent == "UPDATE": - self.graph.update(generated_sparql) - res = "Successfully inserted triples into the graph." - else: - raise ValueError("Unsupported SPARQL query type.") - - chain_result: Dict[str, Any] = {self.output_key: res} - if self.return_sparql_query: - chain_result[self.sparql_query_key] = generated_sparql - return chain_result +__all__ = ["GraphSparqlQAChain"] diff --git a/libs/langchain/langchain/chains/llm_requests.py b/libs/langchain/langchain/chains/llm_requests.py index 5727746874..dca6613203 100644 --- a/libs/langchain/langchain/chains/llm_requests.py +++ b/libs/langchain/langchain/chains/llm_requests.py @@ -1,97 +1,23 @@ -"""Chain that hits a URL and then uses an LLM to parse results.""" -from __future__ import annotations +from typing import TYPE_CHECKING, Any -from typing import Any, Dict, List, Optional +from langchain._api import create_importer -from langchain_community.utilities.requests import TextRequestsWrapper -from langchain_core.callbacks import CallbackManagerForChainRun -from langchain_core.pydantic_v1 import Extra, Field, root_validator +if TYPE_CHECKING: + from langchain_community.chains.llm_requests import LLMRequestsChain -from langchain.chains import LLMChain -from langchain.chains.base import Chain - -DEFAULT_HEADERS = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36" # noqa: E501 +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "LLMRequestsChain": "langchain_community.chains.llm_requests", } +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) -class LLMRequestsChain(Chain): - """Chain that requests a URL and then uses an LLM to parse results. - - **Security Note**: This chain can make GET requests to arbitrary URLs, - including internal URLs. - - Control access to who can run this chain and what network access - this chain has. - - See https://python.langchain.com/docs/security for more information. - """ - - llm_chain: LLMChain # type: ignore[valid-type] - requests_wrapper: TextRequestsWrapper = Field( - default_factory=lambda: TextRequestsWrapper(headers=DEFAULT_HEADERS), - exclude=True, - ) - text_length: int = 8000 - requests_key: str = "requests_result" #: :meta private: - input_key: str = "url" #: :meta private: - output_key: str = "output" #: :meta private: - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - @property - def input_keys(self) -> List[str]: - """Will be whatever keys the prompt expects. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Will always return text key. - - :meta private: - """ - return [self.output_key] - - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and python package exists in environment.""" - try: - from bs4 import BeautifulSoup # noqa: F401 - - except ImportError: - raise ImportError( - "Could not import bs4 python package. " - "Please install it with `pip install bs4`." - ) - return values - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, Any]: - from bs4 import BeautifulSoup +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - # Other keys are assumed to be needed for LLM prediction - other_keys = {k: v for k, v in inputs.items() if k != self.input_key} - url = inputs[self.input_key] - res = self.requests_wrapper.get(url) - # extract the text from the html - soup = BeautifulSoup(res, "html.parser") - other_keys[self.requests_key] = soup.get_text()[: self.text_length] - result = self.llm_chain.predict( # type: ignore[attr-defined] - callbacks=_run_manager.get_child(), **other_keys - ) - return {self.output_key: result} - @property - def _chain_type(self) -> str: - return "llm_requests_chain" +__all__ = ["LLMRequestsChain"] diff --git a/libs/langchain/langchain/chains/loading.py b/libs/langchain/langchain/chains/loading.py index 97798e7f15..41176477ea 100644 --- a/libs/langchain/langchain/chains/loading.py +++ b/libs/langchain/langchain/chains/loading.py @@ -1,8 +1,9 @@ """Functionality for loading chains.""" +from __future__ import annotations import json from pathlib import Path -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Union import yaml from langchain_core.prompts.loading import ( @@ -18,17 +19,20 @@ from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChai from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain from langchain.chains.combine_documents.refine import RefineDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain -from langchain.chains.graph_qa.cypher import GraphCypherQAChain from langchain.chains.hyde.base import HypotheticalDocumentEmbedder from langchain.chains.llm import LLMChain from langchain.chains.llm_checker.base import LLMCheckerChain from langchain.chains.llm_math.base import LLMMathChain -from langchain.chains.llm_requests import LLMRequestsChain from langchain.chains.qa_with_sources.base import QAWithSourcesChain from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain from langchain.chains.retrieval_qa.base import RetrievalQA, VectorDBQA +if TYPE_CHECKING: + from langchain_community.chains.graph_qa.cypher import GraphCypherQAChain + + from langchain.chains.llm_requests import LLMRequestsChain + try: from langchain_community.llms.loading import load_llm, load_llm_from_config except ImportError: @@ -547,6 +551,14 @@ def _load_graph_cypher_chain(config: dict, **kwargs: Any) -> GraphCypherQAChain: else: raise ValueError("`qa_chain` must be present.") + try: + from langchain_community.chains.graph_qa.cypher import GraphCypherQAChain + except ImportError: + raise ImportError( + "To use this GraphCypherQAChain functionality you must install the " + "langchain_community package. " + "You can install it with `pip install langchain_community`" + ) return GraphCypherQAChain( graph=graph, cypher_generation_chain=cypher_generation_chain, # type: ignore[arg-type] @@ -587,6 +599,15 @@ def _load_api_chain(config: dict, **kwargs: Any) -> APIChain: def _load_llm_requests_chain(config: dict, **kwargs: Any) -> LLMRequestsChain: + try: + from langchain.chains.llm_requests import LLMRequestsChain + except ImportError: + raise ImportError( + "To use this LLMRequestsChain functionality you must install the " + "langchain package. " + "You can install it with `pip install langchain`" + ) + if "llm_chain" in config: llm_chain_config = config.pop("llm_chain") llm_chain = load_chain_from_config(llm_chain_config, **kwargs) diff --git a/libs/langchain/langchain/chains/natbot/base.py b/libs/langchain/langchain/chains/natbot/base.py index e74c3477b1..3c6a9aa2e8 100644 --- a/libs/langchain/langchain/chains/natbot/base.py +++ b/libs/langchain/langchain/chains/natbot/base.py @@ -4,7 +4,6 @@ from __future__ import annotations import warnings from typing import Any, Dict, List, Optional -from langchain_community.llms.openai import OpenAI from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.language_models import BaseLanguageModel from langchain_core.pydantic_v1 import Extra, root_validator @@ -68,8 +67,11 @@ class NatBotChain(Chain): @classmethod def from_default(cls, objective: str, **kwargs: Any) -> NatBotChain: """Load with default LLMChain.""" - llm = OpenAI(temperature=0.5, best_of=10, n=3, max_tokens=50) - return cls.from_llm(llm, objective, **kwargs) + raise NotImplementedError( + "This method is no longer implemented. Please use from_llm." + "llm = OpenAI(temperature=0.5, best_of=10, n=3, max_tokens=50)" + "For example, NatBotChain.from_llm(llm, objective)" + ) @classmethod def from_llm( diff --git a/libs/langchain/langchain/chains/openai_functions/openapi.py b/libs/langchain/langchain/chains/openai_functions/openapi.py index 681618815e..86819dc9ee 100644 --- a/libs/langchain/langchain/chains/openai_functions/openapi.py +++ b/libs/langchain/langchain/chains/openai_functions/openapi.py @@ -6,8 +6,6 @@ from collections import defaultdict from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import requests -from langchain_community.chat_models import ChatOpenAI -from langchain_community.utilities.openapi import OpenAPISpec from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser @@ -18,9 +16,9 @@ from requests import Response from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.sequential import SequentialChain -from langchain.tools import APIOperation if TYPE_CHECKING: + from langchain_community.utilities.openapi import OpenAPISpec from openapi_pydantic import Parameter @@ -101,6 +99,14 @@ def openapi_spec_to_openai_fn( Tuple of the OpenAI functions JSON schema and a default function for executing a request based on the OpenAI function schema. """ + try: + from langchain_community.tools import APIOperation + except ImportError: + raise ImportError( + "Could not import langchain_community.tools. " + "Please install it with `pip install langchain-community`." + ) + if not spec.paths: return [], lambda: None functions = [] @@ -256,6 +262,13 @@ def get_openapi_chain( prompt: Main prompt template to use. request_chain: Chain for taking the functions output and executing the request. """ + try: + from langchain_community.utilities.openapi import OpenAPISpec + except ImportError as e: + raise ImportError( + "Could not import langchain_community.utilities.openapi. " + "Please install it with `pip install langchain-community`." + ) from e if isinstance(spec, str): for conversion in ( OpenAPISpec.from_url, @@ -272,9 +285,12 @@ def get_openapi_chain( if isinstance(spec, str): raise ValueError(f"Unable to parse spec from source {spec}") openai_fns, call_api_fn = openapi_spec_to_openai_fn(spec) - llm = llm or ChatOpenAI( - model="gpt-3.5-turbo-0613", - ) + if not llm: + raise ValueError( + "Must provide an LLM for this chain.For example,\n" + "from langchain_openai import ChatOpenAI\n" + "llm = ChatOpenAI()\n" + ) prompt = prompt or ChatPromptTemplate.from_template( "Use the provided API's to respond to this user query:\n\n{query}" ) diff --git a/libs/langchain/langchain/chains/router/multi_retrieval_qa.py b/libs/langchain/langchain/chains/router/multi_retrieval_qa.py index d9b0b924ed..90d84a0113 100644 --- a/libs/langchain/langchain/chains/router/multi_retrieval_qa.py +++ b/libs/langchain/langchain/chains/router/multi_retrieval_qa.py @@ -3,7 +3,6 @@ from __future__ import annotations from typing import Any, Dict, List, Mapping, Optional -from langchain_community.chat_models import ChatOpenAI from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import PromptTemplate from langchain_core.retrievers import BaseRetriever @@ -42,6 +41,8 @@ class MultiRetrievalQAChain(MultiRouteChain): default_retriever: Optional[BaseRetriever] = None, default_prompt: Optional[PromptTemplate] = None, default_chain: Optional[Chain] = None, + *, + default_chain_llm: Optional[BaseLanguageModel] = None, **kwargs: Any, ) -> MultiRetrievalQAChain: if default_prompt and not default_retriever: @@ -78,8 +79,20 @@ class MultiRetrievalQAChain(MultiRouteChain): prompt = PromptTemplate( template=prompt_template, input_variables=["history", "query"] ) + if default_chain_llm is None: + raise NotImplementedError( + "conversation_llm must be provided if default_chain is not " + "specified. This API has been changed to avoid instantiating " + "default LLMs on behalf of users." + "You can provide a conversation LLM like so:\n" + "from langchain_openai import ChatOpenAI\n" + "llm = ChatOpenAI()" + ) _default_chain = ConversationChain( - llm=ChatOpenAI(), prompt=prompt, input_key="query", output_key="result" + llm=default_chain_llm, + prompt=prompt, + input_key="query", + output_key="result", ) return cls( router_chain=router_chain, diff --git a/libs/langchain/langchain/document_loaders/blob_loaders/schema.py b/libs/langchain/langchain/document_loaders/blob_loaders/schema.py index 2eb83ec5b3..677b9cfd98 100644 --- a/libs/langchain/langchain/document_loaders/blob_loaders/schema.py +++ b/libs/langchain/langchain/document_loaders/blob_loaders/schema.py @@ -5,7 +5,7 @@ from langchain_core.document_loaders import Blob, BlobLoader from langchain._api import create_importer if TYPE_CHECKING: - from langchain_community.document_loaders import Blob, BlobLoader + pass # Create a way to dynamically look up deprecated imports. # Used to consolidate logic for raising deprecation warnings and diff --git a/libs/langchain/langchain/evaluation/embedding_distance/base.py b/libs/langchain/langchain/evaluation/embedding_distance/base.py index 9db1df4f1e..5cadafa0f0 100644 --- a/libs/langchain/langchain/evaluation/embedding_distance/base.py +++ b/libs/langchain/langchain/evaluation/embedding_distance/base.py @@ -14,7 +14,6 @@ from langchain_core.pydantic_v1 import Field, root_validator from langchain.chains.base import Chain from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator from langchain.schema import RUN_KEY -from langchain.utils.math import cosine_similarity def _embedding_factory() -> Embeddings: @@ -164,6 +163,14 @@ class _EmbeddingDistanceChainMixin(Chain): Returns: np.ndarray: The cosine distance. """ + try: + from langchain_community.utils.math import cosine_similarity + except ImportError: + raise ImportError( + "The cosine_similarity function is required to compute cosine distance." + " Please install the langchain-community package using" + " `pip install langchain-community`." + ) return 1.0 - cosine_similarity(a, b) @staticmethod diff --git a/libs/langchain/langchain/indexes/vectorstore.py b/libs/langchain/langchain/indexes/vectorstore.py index b70cf33f0f..55e773ebdc 100644 --- a/libs/langchain/langchain/indexes/vectorstore.py +++ b/libs/langchain/langchain/indexes/vectorstore.py @@ -1,9 +1,6 @@ from typing import Any, Dict, List, Optional, Type -from langchain_community.document_loaders.base import BaseLoader -from langchain_community.embeddings.openai import OpenAIEmbeddings -from langchain_community.llms.openai import OpenAI -from langchain_community.vectorstores.inmemory import InMemoryVectorStore +from langchain_core.document_loaders import BaseLoader from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.language_models import BaseLanguageModel @@ -38,7 +35,14 @@ class VectorStoreIndexWrapper(BaseModel): **kwargs: Any, ) -> str: """Query the vectorstore.""" - llm = llm or OpenAI(temperature=0) + if llm is None: + raise NotImplementedError( + "This API has been changed to require an LLM. " + "Please provide an llm to use for querying the vectorstore.\n" + "For example,\n" + "from langchain_openai import OpenAI\n" + "llm = OpenAI(temperature=0)" + ) retriever_kwargs = retriever_kwargs or {} chain = RetrievalQA.from_chain_type( llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs @@ -53,7 +57,14 @@ class VectorStoreIndexWrapper(BaseModel): **kwargs: Any, ) -> str: """Query the vectorstore.""" - llm = llm or OpenAI(temperature=0) + if llm is None: + raise NotImplementedError( + "This API has been changed to require an LLM. " + "Please provide an llm to use for querying the vectorstore.\n" + "For example,\n" + "from langchain_openai import OpenAI\n" + "llm = OpenAI(temperature=0)" + ) retriever_kwargs = retriever_kwargs or {} chain = RetrievalQA.from_chain_type( llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs @@ -68,7 +79,14 @@ class VectorStoreIndexWrapper(BaseModel): **kwargs: Any, ) -> dict: """Query the vectorstore and get back sources.""" - llm = llm or OpenAI(temperature=0) + if llm is None: + raise NotImplementedError( + "This API has been changed to require an LLM. " + "Please provide an llm to use for querying the vectorstore.\n" + "For example,\n" + "from langchain_openai import OpenAI\n" + "llm = OpenAI(temperature=0)" + ) retriever_kwargs = retriever_kwargs or {} chain = RetrievalQAWithSourcesChain.from_chain_type( llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs @@ -83,7 +101,14 @@ class VectorStoreIndexWrapper(BaseModel): **kwargs: Any, ) -> dict: """Query the vectorstore and get back sources.""" - llm = llm or OpenAI(temperature=0) + if llm is None: + raise NotImplementedError( + "This API has been changed to require an LLM. " + "Please provide an llm to use for querying the vectorstore.\n" + "For example,\n" + "from langchain_openai import OpenAI\n" + "llm = OpenAI(temperature=0)" + ) retriever_kwargs = retriever_kwargs or {} chain = RetrievalQAWithSourcesChain.from_chain_type( llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs @@ -91,11 +116,31 @@ class VectorStoreIndexWrapper(BaseModel): return await chain.ainvoke({chain.question_key: question}) +def _get_in_memory_vectorstore() -> Type[VectorStore]: + """Get the InMemoryVectorStore.""" + import warnings + + try: + from langchain_community.vectorstores.inmemory import InMemoryVectorStore + except ImportError: + raise ImportError( + "Please install langchain-community to use the InMemoryVectorStore." + ) + warnings.warn( + "Using InMemoryVectorStore as the default vectorstore." + "This memory store won't persist data. You should explicitly" + "specify a vectorstore when using VectorstoreIndexCreator" + ) + return InMemoryVectorStore + + class VectorstoreIndexCreator(BaseModel): """Logic for creating indexes.""" - vectorstore_cls: Type[VectorStore] = InMemoryVectorStore - embedding: Embeddings = Field(default_factory=OpenAIEmbeddings) + vectorstore_cls: Type[VectorStore] = Field( + default_factory=_get_in_memory_vectorstore + ) + embedding: Embeddings text_splitter: TextSplitter = Field(default_factory=_get_default_text_splitter) vectorstore_kwargs: dict = Field(default_factory=dict) diff --git a/libs/langchain/langchain/retrievers/__init__.py b/libs/langchain/langchain/retrievers/__init__.py index 514b98c092..dc1d38070f 100644 --- a/libs/langchain/langchain/retrievers/__init__.py +++ b/libs/langchain/langchain/retrievers/__init__.py @@ -25,14 +25,12 @@ from langchain.retrievers.ensemble import EnsembleRetriever from langchain.retrievers.merger_retriever import MergerRetriever from langchain.retrievers.multi_query import MultiQueryRetriever from langchain.retrievers.multi_vector import MultiVectorRetriever -from langchain.retrievers.outline import OutlineRetriever from langchain.retrievers.parent_document_retriever import ParentDocumentRetriever from langchain.retrievers.re_phraser import RePhraseQueryRetriever from langchain.retrievers.self_query.base import SelfQueryRetriever from langchain.retrievers.time_weighted_retriever import ( TimeWeightedVectorStoreRetriever, ) -from langchain.retrievers.web_research import WebResearchRetriever if TYPE_CHECKING: from langchain_community.retrievers import ( @@ -70,6 +68,7 @@ if TYPE_CHECKING: TFIDFRetriever, VespaRetriever, WeaviateHybridSearchRetriever, + WebResearchRetriever, WikipediaRetriever, ZepRetriever, ZillizRetriever, @@ -106,12 +105,13 @@ DEPRECATED_LOOKUP = { "RemoteLangChainRetriever": "langchain_community.retrievers", "SVMRetriever": "langchain_community.retrievers", "TavilySearchAPIRetriever": "langchain_community.retrievers", - "TFIDFRetriever": "langchain_community.retrievers", "BM25Retriever": "langchain_community.retrievers", - "VespaRetriever": "langchain_community.retrievers", - "NeuralDBRetriever": "langchain_community.retrievers", "DriaRetriever": "langchain_community.retrievers", + "NeuralDBRetriever": "langchain_community.retrievers", + "TFIDFRetriever": "langchain_community.retrievers", + "VespaRetriever": "langchain_community.retrievers", "WeaviateHybridSearchRetriever": "langchain_community.retrievers", + "WebResearchRetriever": "langchain_community.retrievers", "WikipediaRetriever": "langchain_community.retrievers", "ZepRetriever": "langchain_community.retrievers", "ZillizRetriever": "langchain_community.retrievers", diff --git a/libs/langchain/langchain/retrievers/document_compressors/cross_encoder.py b/libs/langchain/langchain/retrievers/document_compressors/cross_encoder.py new file mode 100644 index 0000000000..98fa056898 --- /dev/null +++ b/libs/langchain/langchain/retrievers/document_compressors/cross_encoder.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod +from typing import List, Tuple + + +class BaseCrossEncoder(ABC): + """Interface for cross encoder models.""" + + @abstractmethod + def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]: + """Score pairs' similarity. + + Args: + text_pairs: List of pairs of texts. + + Returns: + List of scores. + """ diff --git a/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py index e4047fc072..245722b364 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py @@ -3,11 +3,12 @@ from __future__ import annotations import operator from typing import Optional, Sequence -from langchain_community.cross_encoders import BaseCrossEncoder from langchain_core.callbacks import Callbacks from langchain_core.documents import BaseDocumentCompressor, Document from langchain_core.pydantic_v1 import Extra +from langchain.retrievers.document_compressors.cross_encoder import BaseCrossEncoder + class CrossEncoderReranker(BaseDocumentCompressor): """Document compressor that uses CrossEncoder for reranking.""" diff --git a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py index 2f2e0ec914..270d315def 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py +++ b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py @@ -1,19 +1,25 @@ from typing import Callable, Dict, Optional, Sequence import numpy as np -from langchain_community.document_transformers.embeddings_redundant_filter import ( - _get_embeddings_from_stateful_docs, - get_stateful_documents, -) from langchain_core.callbacks.manager import Callbacks from langchain_core.documents import Document from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import root_validator +from langchain_core.pydantic_v1 import Field, root_validator from langchain.retrievers.document_compressors.base import ( BaseDocumentCompressor, ) -from langchain.utils.math import cosine_similarity + + +def _get_similarity_function() -> Callable: + try: + from langchain_community.utils.math import cosine_similarity + except ImportError: + raise ImportError( + "To use please install langchain-community " + "with `pip install langchain-community`." + ) + return cosine_similarity class EmbeddingsFilter(BaseDocumentCompressor): @@ -22,7 +28,7 @@ class EmbeddingsFilter(BaseDocumentCompressor): embeddings: Embeddings """Embeddings to use for embedding document contents and queries.""" - similarity_fn: Callable = cosine_similarity + similarity_fn: Callable = Field(default_factory=_get_similarity_function) """Similarity function for comparing documents. Function expected to take as input two matrices (List[List[float]]) and return a matrix of scores where higher values indicate greater similarity.""" @@ -53,6 +59,16 @@ class EmbeddingsFilter(BaseDocumentCompressor): callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: """Filter documents based on similarity of their embeddings to the query.""" + try: + from langchain_community.document_transformers.embeddings_redundant_filter import ( # noqa: E501 + _get_embeddings_from_stateful_docs, + get_stateful_documents, + ) + except ImportError: + raise ImportError( + "To use please install langchain-community " + "with `pip install langchain-community`." + ) stateful_documents = get_stateful_documents(documents) embedded_documents = _get_embeddings_from_stateful_docs( self.embeddings, stateful_documents diff --git a/libs/langchain/langchain/retrievers/pupmed.py b/libs/langchain/langchain/retrievers/pupmed.py index 0d1029f516..ae03d4b1e2 100644 --- a/libs/langchain/langchain/retrievers/pupmed.py +++ b/libs/langchain/langchain/retrievers/pupmed.py @@ -1,7 +1,6 @@ from typing import TYPE_CHECKING, Any from langchain._api import create_importer -from langchain.retrievers.pubmed import PubMedRetriever if TYPE_CHECKING: from langchain_community.retrievers import PubMedRetriever diff --git a/libs/langchain/langchain/retrievers/self_query/astradb.py b/libs/langchain/langchain/retrievers/self_query/astradb.py index 006972935f..a0d7b4b83f 100644 --- a/libs/langchain/langchain/retrievers/self_query/astradb.py +++ b/libs/langchain/langchain/retrievers/self_query/astradb.py @@ -1,70 +1,23 @@ -"""Logic for converting internal query language to a valid AstraDB query.""" -from typing import Dict, Tuple, Union +from typing import TYPE_CHECKING, Any -from langchain_core.structured_query import ( - Comparator, - Comparison, - Operation, - Operator, - StructuredQuery, - Visitor, -) +from langchain._api import create_importer -MULTIPLE_ARITY_COMPARATORS = [Comparator.IN, Comparator.NIN] +if TYPE_CHECKING: + from langchain_community.query_constructors.astradb import AstraDBTranslator +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "AstraDBTranslator": "langchain_community.query_constructors.astradb", +} -class AstraDBTranslator(Visitor): - """Translate AstraDB internal query language elements to valid filters.""" +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) - """Subset of allowed logical comparators.""" - allowed_comparators = [ - Comparator.EQ, - Comparator.NE, - Comparator.GT, - Comparator.GTE, - Comparator.LT, - Comparator.LTE, - Comparator.IN, - Comparator.NIN, - ] - """Subset of allowed logical operators.""" - allowed_operators = [Operator.AND, Operator.OR] +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - def _format_func(self, func: Union[Operator, Comparator]) -> str: - self._validate_func(func) - map_dict = { - Operator.AND: "$and", - Operator.OR: "$or", - Comparator.EQ: "$eq", - Comparator.NE: "$ne", - Comparator.GTE: "$gte", - Comparator.LTE: "$lte", - Comparator.LT: "$lt", - Comparator.GT: "$gt", - Comparator.IN: "$in", - Comparator.NIN: "$nin", - } - return map_dict[func] - def visit_operation(self, operation: Operation) -> Dict: - args = [arg.accept(self) for arg in operation.arguments] - return {self._format_func(operation.operator): args} - - def visit_comparison(self, comparison: Comparison) -> Dict: - if comparison.comparator in MULTIPLE_ARITY_COMPARATORS and not isinstance( - comparison.value, list - ): - comparison.value = [comparison.value] - - comparator = self._format_func(comparison.comparator) - return {comparison.attribute: {comparator: comparison.value}} - - def visit_structured_query( - self, structured_query: StructuredQuery - ) -> Tuple[str, dict]: - if structured_query.filter is None: - kwargs = {} - else: - kwargs = {"filter": structured_query.filter.accept(self)} - return structured_query.query, kwargs +__all__ = ["AstraDBTranslator"] diff --git a/libs/langchain/langchain/retrievers/self_query/base.py b/libs/langchain/langchain/retrievers/self_query/base.py index 9d1e79eb61..ce6dc6b68d 100644 --- a/libs/langchain/langchain/retrievers/self_query/base.py +++ b/libs/langchain/langchain/retrievers/self_query/base.py @@ -3,32 +3,6 @@ import logging from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union -from langchain_community.vectorstores import ( - AstraDB, - Chroma, - DashVector, - DatabricksVectorSearch, - DeepLake, - Dingo, - Milvus, - MongoDBAtlasVectorSearch, - MyScale, - OpenSearchVectorSearch, - PGVector, - Qdrant, - Redis, - SupabaseVectorStore, - TencentVectorDB, - TimescaleVector, - Vectara, - Weaviate, -) -from langchain_community.vectorstores import ( - ElasticsearchStore as ElasticsearchStoreCommunity, -) -from langchain_community.vectorstores import ( - Pinecone as CommunityPinecone, -) from langchain_core.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, @@ -43,28 +17,6 @@ from langchain_core.vectorstores import VectorStore from langchain.chains.query_constructor.base import load_query_constructor_runnable from langchain.chains.query_constructor.schema import AttributeInfo -from langchain.retrievers.self_query.astradb import AstraDBTranslator -from langchain.retrievers.self_query.chroma import ChromaTranslator -from langchain.retrievers.self_query.dashvector import DashvectorTranslator -from langchain.retrievers.self_query.databricks_vector_search import ( - DatabricksVectorSearchTranslator, -) -from langchain.retrievers.self_query.deeplake import DeepLakeTranslator -from langchain.retrievers.self_query.dingo import DingoDBTranslator -from langchain.retrievers.self_query.elasticsearch import ElasticsearchTranslator -from langchain.retrievers.self_query.milvus import MilvusTranslator -from langchain.retrievers.self_query.mongodb_atlas import MongoDBAtlasTranslator -from langchain.retrievers.self_query.myscale import MyScaleTranslator -from langchain.retrievers.self_query.opensearch import OpenSearchTranslator -from langchain.retrievers.self_query.pgvector import PGVectorTranslator -from langchain.retrievers.self_query.pinecone import PineconeTranslator -from langchain.retrievers.self_query.qdrant import QdrantTranslator -from langchain.retrievers.self_query.redis import RedisTranslator -from langchain.retrievers.self_query.supabase import SupabaseVectorTranslator -from langchain.retrievers.self_query.tencentvectordb import TencentVectorDBTranslator -from langchain.retrievers.self_query.timescalevector import TimescaleVectorTranslator -from langchain.retrievers.self_query.vectara import VectaraTranslator -from langchain.retrievers.self_query.weaviate import WeaviateTranslator logger = logging.getLogger(__name__) QUERY_CONSTRUCTOR_RUN_NAME = "query_constructor" @@ -72,6 +24,71 @@ QUERY_CONSTRUCTOR_RUN_NAME = "query_constructor" def _get_builtin_translator(vectorstore: VectorStore) -> Visitor: """Get the translator class corresponding to the vector store class.""" + try: + import langchain_community # noqa: F401 + except ImportError: + raise ImportError( + "The langchain-community package must be installed to use this feature." + " Please install it using `pip install langchain-community`." + ) + + from langchain_community.query_constructors.astradb import AstraDBTranslator + from langchain_community.query_constructors.chroma import ChromaTranslator + from langchain_community.query_constructors.dashvector import DashvectorTranslator + from langchain_community.query_constructors.databricks_vector_search import ( + DatabricksVectorSearchTranslator, + ) + from langchain_community.query_constructors.deeplake import DeepLakeTranslator + from langchain_community.query_constructors.dingo import DingoDBTranslator + from langchain_community.query_constructors.elasticsearch import ( + ElasticsearchTranslator, + ) + from langchain_community.query_constructors.milvus import MilvusTranslator + from langchain_community.query_constructors.mongodb_atlas import ( + MongoDBAtlasTranslator, + ) + from langchain_community.query_constructors.myscale import MyScaleTranslator + from langchain_community.query_constructors.opensearch import OpenSearchTranslator + from langchain_community.query_constructors.pgvector import PGVectorTranslator + from langchain_community.query_constructors.pinecone import PineconeTranslator + from langchain_community.query_constructors.qdrant import QdrantTranslator + from langchain_community.query_constructors.redis import RedisTranslator + from langchain_community.query_constructors.supabase import SupabaseVectorTranslator + from langchain_community.query_constructors.tencentvectordb import ( + TencentVectorDBTranslator, + ) + from langchain_community.query_constructors.timescalevector import ( + TimescaleVectorTranslator, + ) + from langchain_community.query_constructors.vectara import VectaraTranslator + from langchain_community.query_constructors.weaviate import WeaviateTranslator + from langchain_community.vectorstores import ( + AstraDB, + Chroma, + DashVector, + DatabricksVectorSearch, + DeepLake, + Dingo, + Milvus, + MongoDBAtlasVectorSearch, + MyScale, + OpenSearchVectorSearch, + PGVector, + Qdrant, + Redis, + SupabaseVectorStore, + TencentVectorDB, + TimescaleVector, + Vectara, + Weaviate, + ) + from langchain_community.vectorstores import ( + ElasticsearchStore as ElasticsearchStoreCommunity, + ) + from langchain_community.vectorstores import ( + Pinecone as CommunityPinecone, + ) + BUILTIN_TRANSLATORS: Dict[Type[VectorStore], Type[Visitor]] = { AstraDB: AstraDBTranslator, PGVector: PGVectorTranslator, diff --git a/libs/langchain/langchain/retrievers/self_query/chroma.py b/libs/langchain/langchain/retrievers/self_query/chroma.py index 6f766e7e13..4e1d3217bf 100644 --- a/libs/langchain/langchain/retrievers/self_query/chroma.py +++ b/libs/langchain/langchain/retrievers/self_query/chroma.py @@ -1,50 +1,23 @@ -from typing import Dict, Tuple, Union - -from langchain_core.structured_query import ( - Comparator, - Comparison, - Operation, - Operator, - StructuredQuery, - Visitor, -) - - -class ChromaTranslator(Visitor): - """Translate `Chroma` internal query language elements to valid filters.""" - - allowed_operators = [Operator.AND, Operator.OR] - """Subset of allowed logical operators.""" - allowed_comparators = [ - Comparator.EQ, - Comparator.NE, - Comparator.GT, - Comparator.GTE, - Comparator.LT, - Comparator.LTE, - ] - """Subset of allowed logical comparators.""" - - def _format_func(self, func: Union[Operator, Comparator]) -> str: - self._validate_func(func) - return f"${func.value}" - - def visit_operation(self, operation: Operation) -> Dict: - args = [arg.accept(self) for arg in operation.arguments] - return {self._format_func(operation.operator): args} - - def visit_comparison(self, comparison: Comparison) -> Dict: - return { - comparison.attribute: { - self._format_func(comparison.comparator): comparison.value - } - } - - def visit_structured_query( - self, structured_query: StructuredQuery - ) -> Tuple[str, dict]: - if structured_query.filter is None: - kwargs = {} - else: - kwargs = {"filter": structured_query.filter.accept(self)} - return structured_query.query, kwargs +from typing import TYPE_CHECKING, Any + +from langchain._api import create_importer + +if TYPE_CHECKING: + from langchain_community.query_constructors.chroma import ChromaTranslator + +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "ChromaTranslator": "langchain_community.query_constructors.chroma", +} + +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) + + +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) + + +__all__ = ["ChromaTranslator"] diff --git a/libs/langchain/langchain/retrievers/self_query/dashvector.py b/libs/langchain/langchain/retrievers/self_query/dashvector.py index c1d63d1aae..f4067baae2 100644 --- a/libs/langchain/langchain/retrievers/self_query/dashvector.py +++ b/libs/langchain/langchain/retrievers/self_query/dashvector.py @@ -1,64 +1,23 @@ -"""Logic for converting internal query language to a valid DashVector query.""" -from typing import Tuple, Union +from typing import TYPE_CHECKING, Any -from langchain_core.structured_query import ( - Comparator, - Comparison, - Operation, - Operator, - StructuredQuery, - Visitor, -) +from langchain._api import create_importer +if TYPE_CHECKING: + from langchain_community.query_constructors.dashvector import DashvectorTranslator -class DashvectorTranslator(Visitor): - """Logic for converting internal query language elements to valid filters.""" +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "DashvectorTranslator": "langchain_community.query_constructors.dashvector", +} - allowed_operators = [Operator.AND, Operator.OR] - allowed_comparators = [ - Comparator.EQ, - Comparator.GT, - Comparator.GTE, - Comparator.LT, - Comparator.LTE, - Comparator.LIKE, - ] +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) - map_dict = { - Operator.AND: " AND ", - Operator.OR: " OR ", - Comparator.EQ: " = ", - Comparator.GT: " > ", - Comparator.GTE: " >= ", - Comparator.LT: " < ", - Comparator.LTE: " <= ", - Comparator.LIKE: " LIKE ", - } - def _format_func(self, func: Union[Operator, Comparator]) -> str: - self._validate_func(func) - return self.map_dict[func] +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - def visit_operation(self, operation: Operation) -> str: - args = [arg.accept(self) for arg in operation.arguments] - return self._format_func(operation.operator).join(args) - def visit_comparison(self, comparison: Comparison) -> str: - value = comparison.value - if isinstance(value, str): - if comparison.comparator == Comparator.LIKE: - value = f"'%{value}%'" - else: - value = f"'{value}'" - return ( - f"{comparison.attribute}{self._format_func(comparison.comparator)}{value}" - ) - - def visit_structured_query( - self, structured_query: StructuredQuery - ) -> Tuple[str, dict]: - if structured_query.filter is None: - kwargs = {} - else: - kwargs = {"filter": structured_query.filter.accept(self)} - return structured_query.query, kwargs +__all__ = ["DashvectorTranslator"] diff --git a/libs/langchain/langchain/retrievers/self_query/databricks_vector_search.py b/libs/langchain/langchain/retrievers/self_query/databricks_vector_search.py index 78dc17c7fc..ece66269b4 100644 --- a/libs/langchain/langchain/retrievers/self_query/databricks_vector_search.py +++ b/libs/langchain/langchain/retrievers/self_query/databricks_vector_search.py @@ -1,90 +1,27 @@ -from collections import ChainMap -from itertools import chain -from typing import Dict, Tuple - -from langchain_core.structured_query import ( - Comparator, - Comparison, - Operation, - Operator, - StructuredQuery, - Visitor, -) - -_COMPARATOR_TO_SYMBOL = { - Comparator.EQ: "", - Comparator.GT: " >", - Comparator.GTE: " >=", - Comparator.LT: " <", - Comparator.LTE: " <=", - Comparator.IN: "", - Comparator.LIKE: " LIKE", +from typing import TYPE_CHECKING, Any + +from langchain._api import create_importer + +if TYPE_CHECKING: + from langchain_community.query_constructors.databricks_vector_search import ( + DatabricksVectorSearchTranslator, + ) + +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "DatabricksVectorSearchTranslator": ( + "langchain_community.query_constructors.databricks_vector_search" + ), } +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) -class DatabricksVectorSearchTranslator(Visitor): - """Translate `Databricks vector search` internal query language elements to - valid filters.""" - - """Subset of allowed logical operators.""" - allowed_operators = [Operator.AND, Operator.NOT, Operator.OR] - - """Subset of allowed logical comparators.""" - allowed_comparators = [ - Comparator.EQ, - Comparator.GT, - Comparator.GTE, - Comparator.LT, - Comparator.LTE, - Comparator.IN, - Comparator.LIKE, - ] - - def _visit_and_operation(self, operation: Operation) -> Dict: - return dict(ChainMap(*[arg.accept(self) for arg in operation.arguments])) - - def _visit_or_operation(self, operation: Operation) -> Dict: - filter_args = [arg.accept(self) for arg in operation.arguments] - flattened_args = list( - chain.from_iterable(filter_arg.items() for filter_arg in filter_args) - ) - return { - " OR ".join(key for key, _ in flattened_args): [ - value for _, value in flattened_args - ] - } - - def _visit_not_operation(self, operation: Operation) -> Dict: - if len(operation.arguments) > 1: - raise ValueError( - f'"{operation.operator.value}" can have only one argument ' - f"in Databricks vector search" - ) - filter_arg = operation.arguments[0].accept(self) - return { - f"{colum_with_bool_expression} NOT": value - for colum_with_bool_expression, value in filter_arg.items() - } - def visit_operation(self, operation: Operation) -> Dict: - self._validate_func(operation.operator) - if operation.operator == Operator.AND: - return self._visit_and_operation(operation) - elif operation.operator == Operator.OR: - return self._visit_or_operation(operation) - elif operation.operator == Operator.NOT: - return self._visit_not_operation(operation) +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - def visit_comparison(self, comparison: Comparison) -> Dict: - self._validate_func(comparison.comparator) - comparator_symbol = _COMPARATOR_TO_SYMBOL[comparison.comparator] - return {f"{comparison.attribute}{comparator_symbol}": comparison.value} - def visit_structured_query( - self, structured_query: StructuredQuery - ) -> Tuple[str, dict]: - if structured_query.filter is None: - kwargs = {} - else: - kwargs = {"filters": structured_query.filter.accept(self)} - return structured_query.query, kwargs +__all__ = ["DatabricksVectorSearchTranslator"] diff --git a/libs/langchain/langchain/retrievers/self_query/deeplake.py b/libs/langchain/langchain/retrievers/self_query/deeplake.py index d7e2ab87d6..71f31716b5 100644 --- a/libs/langchain/langchain/retrievers/self_query/deeplake.py +++ b/libs/langchain/langchain/retrievers/self_query/deeplake.py @@ -1,88 +1,27 @@ -"""Logic for converting internal query language to a valid Chroma query.""" -from typing import Tuple, Union - -from langchain_core.structured_query import ( - Comparator, - Comparison, - Operation, - Operator, - StructuredQuery, - Visitor, -) - -COMPARATOR_TO_TQL = { - Comparator.EQ: "==", - Comparator.GT: ">", - Comparator.GTE: ">=", - Comparator.LT: "<", - Comparator.LTE: "<=", -} - - -OPERATOR_TO_TQL = { - Operator.AND: "and", - Operator.OR: "or", - Operator.NOT: "NOT", +from typing import TYPE_CHECKING, Any + +from langchain._api import create_importer + +if TYPE_CHECKING: + from langchain_community.query_constructors.deeplake import ( + DeepLakeTranslator, + can_cast_to_float, + ) + +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "DeepLakeTranslator": "langchain_community.query_constructors.deeplake", + "can_cast_to_float": "langchain_community.query_constructors.deeplake", } +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) -def can_cast_to_float(string: str) -> bool: - """Check if a string can be cast to a float.""" - try: - float(string) - return True - except ValueError: - return False - - -class DeepLakeTranslator(Visitor): - """Translate `DeepLake` internal query language elements to valid filters.""" - - allowed_operators = [Operator.AND, Operator.OR, Operator.NOT] - """Subset of allowed logical operators.""" - allowed_comparators = [ - Comparator.EQ, - Comparator.GT, - Comparator.GTE, - Comparator.LT, - Comparator.LTE, - ] - """Subset of allowed logical comparators.""" - - def _format_func(self, func: Union[Operator, Comparator]) -> str: - self._validate_func(func) - if isinstance(func, Operator): - value = OPERATOR_TO_TQL[func.value] # type: ignore - elif isinstance(func, Comparator): - value = COMPARATOR_TO_TQL[func.value] # type: ignore - return f"{value}" - - def visit_operation(self, operation: Operation) -> str: - args = [arg.accept(self) for arg in operation.arguments] - operator = self._format_func(operation.operator) - return "(" + (" " + operator + " ").join(args) + ")" - - def visit_comparison(self, comparison: Comparison) -> str: - comparator = self._format_func(comparison.comparator) - values = comparison.value - if isinstance(values, list): - tql = [] - for value in values: - comparison.value = value - tql.append(self.visit_comparison(comparison)) - return "(" + (" or ").join(tql) + ")" +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - if not can_cast_to_float(comparison.value): - values = f"'{values}'" - return f"metadata['{comparison.attribute}'] {comparator} {values}" - def visit_structured_query( - self, structured_query: StructuredQuery - ) -> Tuple[str, dict]: - if structured_query.filter is None: - kwargs = {} - else: - tqL = f"SELECT * WHERE {structured_query.filter.accept(self)}" - kwargs = {"tql": tqL} - return structured_query.query, kwargs +__all__ = ["DeepLakeTranslator", "can_cast_to_float"] diff --git a/libs/langchain/langchain/retrievers/self_query/dingo.py b/libs/langchain/langchain/retrievers/self_query/dingo.py index 6c2402f65c..2acfa95acf 100644 --- a/libs/langchain/langchain/retrievers/self_query/dingo.py +++ b/libs/langchain/langchain/retrievers/self_query/dingo.py @@ -1,49 +1,23 @@ -from typing import Tuple, Union - -from langchain_core.structured_query import ( - Comparator, - Comparison, - Operation, - Operator, - StructuredQuery, - Visitor, -) - - -class DingoDBTranslator(Visitor): - """Translate `DingoDB` internal query language elements to valid filters.""" - - allowed_comparators = ( - Comparator.EQ, - Comparator.NE, - Comparator.LT, - Comparator.LTE, - Comparator.GT, - Comparator.GTE, - ) - """Subset of allowed logical comparators.""" - allowed_operators = (Operator.AND, Operator.OR) - """Subset of allowed logical operators.""" - - def _format_func(self, func: Union[Operator, Comparator]) -> str: - self._validate_func(func) - return f"${func.value}" - - def visit_operation(self, operation: Operation) -> Operation: - return operation - - def visit_comparison(self, comparison: Comparison) -> Comparison: - return comparison - - def visit_structured_query( - self, structured_query: StructuredQuery - ) -> Tuple[str, dict]: - if structured_query.filter is None: - kwargs = {} - else: - kwargs = { - "search_params": { - "langchain_expr": structured_query.filter.accept(self) - } - } - return structured_query.query, kwargs +from typing import TYPE_CHECKING, Any + +from langchain._api import create_importer + +if TYPE_CHECKING: + from langchain_community.query_constructors.dingo import DingoDBTranslator + +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "DingoDBTranslator": "langchain_community.query_constructors.dingo", +} + +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) + + +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) + + +__all__ = ["DingoDBTranslator"] diff --git a/libs/langchain/langchain/retrievers/self_query/elasticsearch.py b/libs/langchain/langchain/retrievers/self_query/elasticsearch.py index d07c284b12..868524cb11 100644 --- a/libs/langchain/langchain/retrievers/self_query/elasticsearch.py +++ b/libs/langchain/langchain/retrievers/self_query/elasticsearch.py @@ -1,100 +1,25 @@ -from typing import Dict, Tuple, Union +from typing import TYPE_CHECKING, Any -from langchain_core.structured_query import ( - Comparator, - Comparison, - Operation, - Operator, - StructuredQuery, - Visitor, -) +from langchain._api import create_importer +if TYPE_CHECKING: + from langchain_community.query_constructors.elasticsearch import ( + ElasticsearchTranslator, + ) -class ElasticsearchTranslator(Visitor): - """Translate `Elasticsearch` internal query language elements to valid filters.""" +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "ElasticsearchTranslator": "langchain_community.query_constructors.elasticsearch", +} - allowed_comparators = [ - Comparator.EQ, - Comparator.GT, - Comparator.GTE, - Comparator.LT, - Comparator.LTE, - Comparator.CONTAIN, - Comparator.LIKE, - ] - """Subset of allowed logical comparators.""" +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) - allowed_operators = [Operator.AND, Operator.OR, Operator.NOT] - """Subset of allowed logical operators.""" - def _format_func(self, func: Union[Operator, Comparator]) -> str: - self._validate_func(func) - map_dict = { - Operator.OR: "should", - Operator.NOT: "must_not", - Operator.AND: "must", - Comparator.EQ: "term", - Comparator.GT: "gt", - Comparator.GTE: "gte", - Comparator.LT: "lt", - Comparator.LTE: "lte", - Comparator.CONTAIN: "match", - Comparator.LIKE: "match", - } - return map_dict[func] +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - def visit_operation(self, operation: Operation) -> Dict: - args = [arg.accept(self) for arg in operation.arguments] - return {"bool": {self._format_func(operation.operator): args}} - - def visit_comparison(self, comparison: Comparison) -> Dict: - # ElasticsearchStore filters require to target - # the metadata object field - field = f"metadata.{comparison.attribute}" - - is_range_comparator = comparison.comparator in [ - Comparator.GT, - Comparator.GTE, - Comparator.LT, - Comparator.LTE, - ] - - if is_range_comparator: - value = comparison.value - if isinstance(comparison.value, dict) and "date" in comparison.value: - value = comparison.value["date"] - return {"range": {field: {self._format_func(comparison.comparator): value}}} - - if comparison.comparator == Comparator.CONTAIN: - return { - self._format_func(comparison.comparator): { - field: {"query": comparison.value} - } - } - - if comparison.comparator == Comparator.LIKE: - return { - self._format_func(comparison.comparator): { - field: {"query": comparison.value, "fuzziness": "AUTO"} - } - } - - # we assume that if the value is a string, - # we want to use the keyword field - field = f"{field}.keyword" if isinstance(comparison.value, str) else field - - if isinstance(comparison.value, dict): - if "date" in comparison.value: - comparison.value = comparison.value["date"] - - return {self._format_func(comparison.comparator): {field: comparison.value}} - - def visit_structured_query( - self, structured_query: StructuredQuery - ) -> Tuple[str, dict]: - if structured_query.filter is None: - kwargs = {} - else: - kwargs = {"filter": [structured_query.filter.accept(self)]} - return structured_query.query, kwargs +__all__ = ["ElasticsearchTranslator"] diff --git a/libs/langchain/langchain/retrievers/self_query/milvus.py b/libs/langchain/langchain/retrievers/self_query/milvus.py index 6fb1cc5c4e..17b8934caf 100644 --- a/libs/langchain/langchain/retrievers/self_query/milvus.py +++ b/libs/langchain/langchain/retrievers/self_query/milvus.py @@ -1,103 +1,27 @@ -"""Logic for converting internal query language to a valid Milvus query.""" -from typing import Tuple, Union - -from langchain_core.structured_query import ( - Comparator, - Comparison, - Operation, - Operator, - StructuredQuery, - Visitor, -) - -COMPARATOR_TO_BER = { - Comparator.EQ: "==", - Comparator.GT: ">", - Comparator.GTE: ">=", - Comparator.LT: "<", - Comparator.LTE: "<=", - Comparator.IN: "in", - Comparator.LIKE: "like", +from typing import TYPE_CHECKING, Any + +from langchain._api import create_importer + +if TYPE_CHECKING: + from langchain_community.query_constructors.milvus import ( + MilvusTranslator, + process_value, + ) + +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "MilvusTranslator": "langchain_community.query_constructors.milvus", + "process_value": "langchain_community.query_constructors.milvus", } -UNARY_OPERATORS = [Operator.NOT] - - -def process_value(value: Union[int, float, str], comparator: Comparator) -> str: - """Convert a value to a string and add double quotes if it is a string. - - It required for comparators involving strings. - - Args: - value: The value to convert. - comparator: The comparator. - - Returns: - The converted value as a string. - """ - # - if isinstance(value, str): - if comparator is Comparator.LIKE: - # If the comparator is LIKE, add a percent sign after it for prefix matching - # and add double quotes - return f'"{value}%"' - else: - # If the value is already a string, add double quotes - return f'"{value}"' - else: - # If the value is not a string, convert it to a string without double quotes - return str(value) - - -class MilvusTranslator(Visitor): - """Translate Milvus internal query language elements to valid filters.""" - - """Subset of allowed logical operators.""" - allowed_operators = [Operator.AND, Operator.NOT, Operator.OR] - - """Subset of allowed logical comparators.""" - allowed_comparators = [ - Comparator.EQ, - Comparator.GT, - Comparator.GTE, - Comparator.LT, - Comparator.LTE, - Comparator.IN, - Comparator.LIKE, - ] - - def _format_func(self, func: Union[Operator, Comparator]) -> str: - self._validate_func(func) - value = func.value - if isinstance(func, Comparator): - value = COMPARATOR_TO_BER[func] - return f"{value}" +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) - def visit_operation(self, operation: Operation) -> str: - if operation.operator in UNARY_OPERATORS and len(operation.arguments) == 1: - operator = self._format_func(operation.operator) - return operator + "(" + operation.arguments[0].accept(self) + ")" - elif operation.operator in UNARY_OPERATORS: - raise ValueError( - f'"{operation.operator.value}" can have only one argument in Milvus' - ) - else: - args = [arg.accept(self) for arg in operation.arguments] - operator = self._format_func(operation.operator) - return "(" + (" " + operator + " ").join(args) + ")" - def visit_comparison(self, comparison: Comparison) -> str: - comparator = self._format_func(comparison.comparator) - processed_value = process_value(comparison.value, comparison.comparator) - attribute = comparison.attribute +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - return "( " + attribute + " " + comparator + " " + processed_value + " )" - def visit_structured_query( - self, structured_query: StructuredQuery - ) -> Tuple[str, dict]: - if structured_query.filter is None: - kwargs = {} - else: - kwargs = {"expr": structured_query.filter.accept(self)} - return structured_query.query, kwargs +__all__ = ["MilvusTranslator", "process_value"] diff --git a/libs/langchain/langchain/retrievers/self_query/mongodb_atlas.py b/libs/langchain/langchain/retrievers/self_query/mongodb_atlas.py index ebef2163be..81196772a2 100644 --- a/libs/langchain/langchain/retrievers/self_query/mongodb_atlas.py +++ b/libs/langchain/langchain/retrievers/self_query/mongodb_atlas.py @@ -1,74 +1,25 @@ -"""Logic for converting internal query language to a valid MongoDB Atlas query.""" -from typing import Dict, Tuple, Union +from typing import TYPE_CHECKING, Any -from langchain_core.structured_query import ( - Comparator, - Comparison, - Operation, - Operator, - StructuredQuery, - Visitor, -) +from langchain._api import create_importer -MULTIPLE_ARITY_COMPARATORS = [Comparator.IN, Comparator.NIN] +if TYPE_CHECKING: + from langchain_community.query_constructors.mongodb_atlas import ( + MongoDBAtlasTranslator, + ) +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "MongoDBAtlasTranslator": "langchain_community.query_constructors.mongodb_atlas", +} -class MongoDBAtlasTranslator(Visitor): - """Translate Mongo internal query language elements to valid filters.""" +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) - """Subset of allowed logical comparators.""" - allowed_comparators = [ - Comparator.EQ, - Comparator.NE, - Comparator.GT, - Comparator.GTE, - Comparator.LT, - Comparator.LTE, - Comparator.IN, - Comparator.NIN, - ] - """Subset of allowed logical operators.""" - allowed_operators = [Operator.AND, Operator.OR] +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - ## Convert a operator or a comparator to Mongo Query Format - def _format_func(self, func: Union[Operator, Comparator]) -> str: - self._validate_func(func) - map_dict = { - Operator.AND: "$and", - Operator.OR: "$or", - Comparator.EQ: "$eq", - Comparator.NE: "$ne", - Comparator.GTE: "$gte", - Comparator.LTE: "$lte", - Comparator.LT: "$lt", - Comparator.GT: "$gt", - Comparator.IN: "$in", - Comparator.NIN: "$nin", - } - return map_dict[func] - def visit_operation(self, operation: Operation) -> Dict: - args = [arg.accept(self) for arg in operation.arguments] - return {self._format_func(operation.operator): args} - - def visit_comparison(self, comparison: Comparison) -> Dict: - if comparison.comparator in MULTIPLE_ARITY_COMPARATORS and not isinstance( - comparison.value, list - ): - comparison.value = [comparison.value] - - comparator = self._format_func(comparison.comparator) - - attribute = comparison.attribute - - return {attribute: {comparator: comparison.value}} - - def visit_structured_query( - self, structured_query: StructuredQuery - ) -> Tuple[str, dict]: - if structured_query.filter is None: - kwargs = {} - else: - kwargs = {"pre_filter": structured_query.filter.accept(self)} - return structured_query.query, kwargs +__all__ = ["MongoDBAtlasTranslator"] diff --git a/libs/langchain/langchain/retrievers/self_query/myscale.py b/libs/langchain/langchain/retrievers/self_query/myscale.py index 50a74c568b..a5bfcadc5c 100644 --- a/libs/langchain/langchain/retrievers/self_query/myscale.py +++ b/libs/langchain/langchain/retrievers/self_query/myscale.py @@ -1,125 +1,23 @@ -import re -from typing import Any, Callable, Dict, Tuple +from typing import TYPE_CHECKING, Any -from langchain_core.structured_query import ( - Comparator, - Comparison, - Operation, - Operator, - StructuredQuery, - Visitor, -) +from langchain._api import create_importer +if TYPE_CHECKING: + from langchain_community.query_constructors.myscale import MyScaleTranslator -def _DEFAULT_COMPOSER(op_name: str) -> Callable: - """ - Default composer for logical operators. +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "MyScaleTranslator": "langchain_community.query_constructors.myscale", +} - Args: - op_name: Name of the operator. +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) - Returns: - Callable that takes a list of arguments and returns a string. - """ - def f(*args: Any) -> str: - args_: map[str] = map(str, args) - return f" {op_name} ".join(args_) +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - return f - -def _FUNCTION_COMPOSER(op_name: str) -> Callable: - """ - Composer for functions. - - Args: - op_name: Name of the function. - - Returns: - Callable that takes a list of arguments and returns a string. - """ - - def f(*args: Any) -> str: - args_: map[str] = map(str, args) - return f"{op_name}({','.join(args_)})" - - return f - - -class MyScaleTranslator(Visitor): - """Translate `MyScale` internal query language elements to valid filters.""" - - allowed_operators = [Operator.AND, Operator.OR, Operator.NOT] - """Subset of allowed logical operators.""" - - allowed_comparators = [ - Comparator.EQ, - Comparator.GT, - Comparator.GTE, - Comparator.LT, - Comparator.LTE, - Comparator.CONTAIN, - Comparator.LIKE, - ] - - map_dict = { - Operator.AND: _DEFAULT_COMPOSER("AND"), - Operator.OR: _DEFAULT_COMPOSER("OR"), - Operator.NOT: _DEFAULT_COMPOSER("NOT"), - Comparator.EQ: _DEFAULT_COMPOSER("="), - Comparator.GT: _DEFAULT_COMPOSER(">"), - Comparator.GTE: _DEFAULT_COMPOSER(">="), - Comparator.LT: _DEFAULT_COMPOSER("<"), - Comparator.LTE: _DEFAULT_COMPOSER("<="), - Comparator.CONTAIN: _FUNCTION_COMPOSER("has"), - Comparator.LIKE: _DEFAULT_COMPOSER("ILIKE"), - } - - def __init__(self, metadata_key: str = "metadata") -> None: - super().__init__() - self.metadata_key = metadata_key - - def visit_operation(self, operation: Operation) -> Dict: - args = [arg.accept(self) for arg in operation.arguments] - func = operation.operator - self._validate_func(func) - return self.map_dict[func](*args) - - def visit_comparison(self, comparison: Comparison) -> Dict: - regex = r"\((.*?)\)" - matched = re.search(r"\(\w+\)", comparison.attribute) - - # If arbitrary function is applied to an attribute - if matched: - attr = re.sub( - regex, - f"({self.metadata_key}.{matched.group(0)[1:-1]})", - comparison.attribute, - ) - else: - attr = f"{self.metadata_key}.{comparison.attribute}" - value = comparison.value - comp = comparison.comparator - - value = f"'{value}'" if isinstance(value, str) else value - - # convert timestamp for datetime objects - if isinstance(value, dict) and value.get("type") == "date": - attr = f"parseDateTime32BestEffort({attr})" - value = f"parseDateTime32BestEffort('{value['date']}')" - - # string pattern match - if comp is Comparator.LIKE: - value = f"'%{value[1:-1]}%'" - return self.map_dict[comp](attr, value) - - def visit_structured_query( - self, structured_query: StructuredQuery - ) -> Tuple[str, dict]: - print(structured_query) # noqa: T201 - if structured_query.filter is None: - kwargs = {} - else: - kwargs = {"where_str": structured_query.filter.accept(self)} - return structured_query.query, kwargs +__all__ = ["MyScaleTranslator"] diff --git a/libs/langchain/langchain/retrievers/self_query/opensearch.py b/libs/langchain/langchain/retrievers/self_query/opensearch.py index e01ec66639..519cf8518e 100644 --- a/libs/langchain/langchain/retrievers/self_query/opensearch.py +++ b/libs/langchain/langchain/retrievers/self_query/opensearch.py @@ -1,104 +1,23 @@ -from typing import Dict, Tuple, Union +from typing import TYPE_CHECKING, Any -from langchain_core.structured_query import ( - Comparator, - Comparison, - Operation, - Operator, - StructuredQuery, - Visitor, -) +from langchain._api import create_importer +if TYPE_CHECKING: + from langchain_community.query_constructors.opensearch import OpenSearchTranslator -class OpenSearchTranslator(Visitor): - """Translate `OpenSearch` internal query domain-specific - language elements to valid filters.""" +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "OpenSearchTranslator": "langchain_community.query_constructors.opensearch", +} - allowed_comparators = [ - Comparator.EQ, - Comparator.LT, - Comparator.LTE, - Comparator.GT, - Comparator.GTE, - Comparator.CONTAIN, - Comparator.LIKE, - ] - """Subset of allowed logical comparators.""" +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) - allowed_operators = [Operator.AND, Operator.OR, Operator.NOT] - """Subset of allowed logical operators.""" - def _format_func(self, func: Union[Operator, Comparator]) -> str: - self._validate_func(func) - comp_operator_map = { - Comparator.EQ: "term", - Comparator.LT: "lt", - Comparator.LTE: "lte", - Comparator.GT: "gt", - Comparator.GTE: "gte", - Comparator.CONTAIN: "match", - Comparator.LIKE: "fuzzy", - Operator.AND: "must", - Operator.OR: "should", - Operator.NOT: "must_not", - } - return comp_operator_map[func] +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - def visit_operation(self, operation: Operation) -> Dict: - args = [arg.accept(self) for arg in operation.arguments] - return {"bool": {self._format_func(operation.operator): args}} - - def visit_comparison(self, comparison: Comparison) -> Dict: - field = f"metadata.{comparison.attribute}" - - if comparison.comparator in [ - Comparator.LT, - Comparator.LTE, - Comparator.GT, - Comparator.GTE, - ]: - if isinstance(comparison.value, dict): - if "date" in comparison.value: - return { - "range": { - field: { - self._format_func( - comparison.comparator - ): comparison.value["date"] - } - } - } - else: - return { - "range": { - field: { - self._format_func(comparison.comparator): comparison.value - } - } - } - - if comparison.comparator == Comparator.LIKE: - return { - self._format_func(comparison.comparator): { - field: {"value": comparison.value} - } - } - - field = f"{field}.keyword" if isinstance(comparison.value, str) else field - - if isinstance(comparison.value, dict): - if "date" in comparison.value: - comparison.value = comparison.value["date"] - - return {self._format_func(comparison.comparator): {field: comparison.value}} - - def visit_structured_query( - self, structured_query: StructuredQuery - ) -> Tuple[str, dict]: - if structured_query.filter is None: - kwargs = {} - else: - kwargs = {"filter": structured_query.filter.accept(self)} - - return structured_query.query, kwargs +__all__ = ["OpenSearchTranslator"] diff --git a/libs/langchain/langchain/retrievers/self_query/pgvector.py b/libs/langchain/langchain/retrievers/self_query/pgvector.py index 5fea65b01c..1355b6b830 100644 --- a/libs/langchain/langchain/retrievers/self_query/pgvector.py +++ b/libs/langchain/langchain/retrievers/self_query/pgvector.py @@ -1,52 +1,23 @@ -from typing import Dict, Tuple, Union - -from langchain_core.structured_query import ( - Comparator, - Comparison, - Operation, - Operator, - StructuredQuery, - Visitor, -) - - -class PGVectorTranslator(Visitor): - """Translate `PGVector` internal query language elements to valid filters.""" - - allowed_operators = [Operator.AND, Operator.OR] - """Subset of allowed logical operators.""" - allowed_comparators = [ - Comparator.EQ, - Comparator.NE, - Comparator.GT, - Comparator.LT, - Comparator.IN, - Comparator.NIN, - Comparator.CONTAIN, - Comparator.LIKE, - ] - """Subset of allowed logical comparators.""" - - def _format_func(self, func: Union[Operator, Comparator]) -> str: - self._validate_func(func) - return f"{func.value}" - - def visit_operation(self, operation: Operation) -> Dict: - args = [arg.accept(self) for arg in operation.arguments] - return {self._format_func(operation.operator): args} - - def visit_comparison(self, comparison: Comparison) -> Dict: - return { - comparison.attribute: { - self._format_func(comparison.comparator): comparison.value - } - } - - def visit_structured_query( - self, structured_query: StructuredQuery - ) -> Tuple[str, dict]: - if structured_query.filter is None: - kwargs = {} - else: - kwargs = {"filter": structured_query.filter.accept(self)} - return structured_query.query, kwargs +from typing import TYPE_CHECKING, Any + +from langchain._api import create_importer + +if TYPE_CHECKING: + from langchain_community.query_constructors.pgvector import PGVectorTranslator + +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "PGVectorTranslator": "langchain_community.query_constructors.pgvector", +} + +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) + + +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) + + +__all__ = ["PGVectorTranslator"] diff --git a/libs/langchain/langchain/retrievers/self_query/pinecone.py b/libs/langchain/langchain/retrievers/self_query/pinecone.py index 99c42f393b..43d299e3b5 100644 --- a/libs/langchain/langchain/retrievers/self_query/pinecone.py +++ b/libs/langchain/langchain/retrievers/self_query/pinecone.py @@ -1,57 +1,23 @@ -from typing import Dict, Tuple, Union - -from langchain_core.structured_query import ( - Comparator, - Comparison, - Operation, - Operator, - StructuredQuery, - Visitor, -) - - -class PineconeTranslator(Visitor): - """Translate `Pinecone` internal query language elements to valid filters.""" - - allowed_comparators = ( - Comparator.EQ, - Comparator.NE, - Comparator.LT, - Comparator.LTE, - Comparator.GT, - Comparator.GTE, - Comparator.IN, - Comparator.NIN, - ) - """Subset of allowed logical comparators.""" - allowed_operators = (Operator.AND, Operator.OR) - """Subset of allowed logical operators.""" - - def _format_func(self, func: Union[Operator, Comparator]) -> str: - self._validate_func(func) - return f"${func.value}" - - def visit_operation(self, operation: Operation) -> Dict: - args = [arg.accept(self) for arg in operation.arguments] - return {self._format_func(operation.operator): args} - - def visit_comparison(self, comparison: Comparison) -> Dict: - if comparison.comparator in (Comparator.IN, Comparator.NIN) and not isinstance( - comparison.value, list - ): - comparison.value = [comparison.value] - - return { - comparison.attribute: { - self._format_func(comparison.comparator): comparison.value - } - } - - def visit_structured_query( - self, structured_query: StructuredQuery - ) -> Tuple[str, dict]: - if structured_query.filter is None: - kwargs = {} - else: - kwargs = {"filter": structured_query.filter.accept(self)} - return structured_query.query, kwargs +from typing import TYPE_CHECKING, Any + +from langchain._api import create_importer + +if TYPE_CHECKING: + from langchain_community.query_constructors.pinecone import PineconeTranslator + +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "PineconeTranslator": "langchain_community.query_constructors.pinecone", +} + +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) + + +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) + + +__all__ = ["PineconeTranslator"] diff --git a/libs/langchain/langchain/retrievers/self_query/qdrant.py b/libs/langchain/langchain/retrievers/self_query/qdrant.py index f4c3298b66..aa0f81971d 100644 --- a/libs/langchain/langchain/retrievers/self_query/qdrant.py +++ b/libs/langchain/langchain/retrievers/self_query/qdrant.py @@ -1,98 +1,23 @@ -from __future__ import annotations +from typing import TYPE_CHECKING, Any -from typing import TYPE_CHECKING, Tuple - -from langchain_core.structured_query import ( - Comparator, - Comparison, - Operation, - Operator, - StructuredQuery, - Visitor, -) +from langchain._api import create_importer if TYPE_CHECKING: - from qdrant_client.http import models as rest - - -class QdrantTranslator(Visitor): - """Translate `Qdrant` internal query language elements to valid filters.""" - - allowed_operators = ( - Operator.AND, - Operator.OR, - Operator.NOT, - ) - """Subset of allowed logical operators.""" - - allowed_comparators = ( - Comparator.EQ, - Comparator.LT, - Comparator.LTE, - Comparator.GT, - Comparator.GTE, - Comparator.LIKE, - ) - """Subset of allowed logical comparators.""" - - def __init__(self, metadata_key: str): - self.metadata_key = metadata_key + from langchain_community.query_constructors.qdrant import QdrantTranslator - def visit_operation(self, operation: Operation) -> rest.Filter: - try: - from qdrant_client.http import models as rest - except ImportError as e: - raise ImportError( - "Cannot import qdrant_client. Please install with `pip install " - "qdrant-client`." - ) from e +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "QdrantTranslator": "langchain_community.query_constructors.qdrant", +} - args = [arg.accept(self) for arg in operation.arguments] - operator = { - Operator.AND: "must", - Operator.OR: "should", - Operator.NOT: "must_not", - }[operation.operator] - return rest.Filter(**{operator: args}) +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) - def visit_comparison(self, comparison: Comparison) -> rest.FieldCondition: - try: - from qdrant_client.http import models as rest - except ImportError as e: - raise ImportError( - "Cannot import qdrant_client. Please install with `pip install " - "qdrant-client`." - ) from e - self._validate_func(comparison.comparator) - attribute = self.metadata_key + "." + comparison.attribute - if comparison.comparator == Comparator.EQ: - return rest.FieldCondition( - key=attribute, match=rest.MatchValue(value=comparison.value) - ) - if comparison.comparator == Comparator.LIKE: - return rest.FieldCondition( - key=attribute, match=rest.MatchText(text=comparison.value) - ) - kwargs = {comparison.comparator.value: comparison.value} - return rest.FieldCondition(key=attribute, range=rest.Range(**kwargs)) +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - def visit_structured_query( - self, structured_query: StructuredQuery - ) -> Tuple[str, dict]: - try: - from qdrant_client.http import models as rest - except ImportError as e: - raise ImportError( - "Cannot import qdrant_client. Please install with `pip install " - "qdrant-client`." - ) from e - if structured_query.filter is None: - kwargs = {} - else: - filter = structured_query.filter.accept(self) - if isinstance(filter, rest.FieldCondition): - filter = rest.Filter(must=[filter]) - kwargs = {"filter": filter} - return structured_query.query, kwargs +__all__ = ["QdrantTranslator"] diff --git a/libs/langchain/langchain/retrievers/self_query/redis.py b/libs/langchain/langchain/retrievers/self_query/redis.py index 56d3a17b55..69e7d5c90b 100644 --- a/libs/langchain/langchain/retrievers/self_query/redis.py +++ b/libs/langchain/langchain/retrievers/self_query/redis.py @@ -1,102 +1,23 @@ -from __future__ import annotations +from typing import TYPE_CHECKING, Any -from typing import Any, Tuple +from langchain._api import create_importer -from langchain_community.vectorstores.redis import Redis -from langchain_community.vectorstores.redis.filters import ( - RedisFilterExpression, - RedisFilterField, - RedisFilterOperator, - RedisNum, - RedisTag, - RedisText, -) -from langchain_community.vectorstores.redis.schema import RedisModel -from langchain_core.structured_query import ( - Comparator, - Comparison, - Operation, - Operator, - StructuredQuery, - Visitor, -) +if TYPE_CHECKING: + from langchain_community.query_constructors.redis import RedisTranslator -_COMPARATOR_TO_BUILTIN_METHOD = { - Comparator.EQ: "__eq__", - Comparator.NE: "__ne__", - Comparator.LT: "__lt__", - Comparator.GT: "__gt__", - Comparator.LTE: "__le__", - Comparator.GTE: "__ge__", - Comparator.CONTAIN: "__eq__", - Comparator.LIKE: "__mod__", +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "RedisTranslator": "langchain_community.query_constructors.redis", } +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) -class RedisTranslator(Visitor): - """Visitor for translating structured queries to Redis filter expressions.""" - allowed_comparators = ( - Comparator.EQ, - Comparator.NE, - Comparator.LT, - Comparator.LTE, - Comparator.GT, - Comparator.GTE, - Comparator.CONTAIN, - Comparator.LIKE, - ) - """Subset of allowed logical comparators.""" - allowed_operators = (Operator.AND, Operator.OR) - """Subset of allowed logical operators.""" +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - def __init__(self, schema: RedisModel) -> None: - self._schema = schema - def _attribute_to_filter_field(self, attribute: str) -> RedisFilterField: - if attribute in [tf.name for tf in self._schema.text]: - return RedisText(attribute) - elif attribute in [tf.name for tf in self._schema.tag or []]: - return RedisTag(attribute) - elif attribute in [tf.name for tf in self._schema.numeric or []]: - return RedisNum(attribute) - else: - raise ValueError( - f"Invalid attribute {attribute} not in vector store schema. Schema is:" - f"\n{self._schema.as_dict()}" - ) - - def visit_comparison(self, comparison: Comparison) -> RedisFilterExpression: - filter_field = self._attribute_to_filter_field(comparison.attribute) - comparison_method = _COMPARATOR_TO_BUILTIN_METHOD[comparison.comparator] - return getattr(filter_field, comparison_method)(comparison.value) - - def visit_operation(self, operation: Operation) -> Any: - left = operation.arguments[0].accept(self) - if len(operation.arguments) > 2: - right = self.visit_operation( - Operation( - operator=operation.operator, arguments=operation.arguments[1:] - ) - ) - else: - right = operation.arguments[1].accept(self) - redis_operator = ( - RedisFilterOperator.OR - if operation.operator == Operator.OR - else RedisFilterOperator.AND - ) - return RedisFilterExpression(operator=redis_operator, left=left, right=right) - - def visit_structured_query( - self, structured_query: StructuredQuery - ) -> Tuple[str, dict]: - if structured_query.filter is None: - kwargs = {} - else: - kwargs = {"filter": structured_query.filter.accept(self)} - return structured_query.query, kwargs - - @classmethod - def from_vectorstore(cls, vectorstore: Redis) -> RedisTranslator: - return cls(vectorstore._schema) +__all__ = ["RedisTranslator"] diff --git a/libs/langchain/langchain/retrievers/self_query/supabase.py b/libs/langchain/langchain/retrievers/self_query/supabase.py index 63794cf378..4941fec03e 100644 --- a/libs/langchain/langchain/retrievers/self_query/supabase.py +++ b/libs/langchain/langchain/retrievers/self_query/supabase.py @@ -1,97 +1,23 @@ -from typing import Any, Dict, Tuple +from typing import TYPE_CHECKING, Any -from langchain_core.structured_query import ( - Comparator, - Comparison, - Operation, - Operator, - StructuredQuery, - Visitor, -) +from langchain._api import create_importer +if TYPE_CHECKING: + from langchain_community.query_constructors.supabase import SupabaseVectorTranslator -class SupabaseVectorTranslator(Visitor): - """Translate Langchain filters to Supabase PostgREST filters.""" +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "SupabaseVectorTranslator": "langchain_community.query_constructors.supabase", +} - allowed_operators = [Operator.AND, Operator.OR] - """Subset of allowed logical operators.""" +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) - allowed_comparators = [ - Comparator.EQ, - Comparator.NE, - Comparator.GT, - Comparator.GTE, - Comparator.LT, - Comparator.LTE, - Comparator.LIKE, - ] - """Subset of allowed logical comparators.""" - metadata_column = "metadata" +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - def _map_comparator(self, comparator: Comparator) -> str: - """ - Maps Langchain comparator to PostgREST comparator: - https://postgrest.org/en/stable/references/api/tables_views.html#operators - """ - postgrest_comparator = { - Comparator.EQ: "eq", - Comparator.NE: "neq", - Comparator.GT: "gt", - Comparator.GTE: "gte", - Comparator.LT: "lt", - Comparator.LTE: "lte", - Comparator.LIKE: "like", - }.get(comparator) - - if postgrest_comparator is None: - raise Exception( - f"Comparator '{comparator}' is not currently " - "supported in Supabase Vector" - ) - - return postgrest_comparator - - def _get_json_operator(self, value: Any) -> str: - if isinstance(value, str): - return "->>" - else: - return "->" - - def visit_operation(self, operation: Operation) -> str: - args = [arg.accept(self) for arg in operation.arguments] - return f"{operation.operator.value}({','.join(args)})" - - def visit_comparison(self, comparison: Comparison) -> str: - if isinstance(comparison.value, list): - return self.visit_operation( - Operation( - operator=Operator.AND, - arguments=[ - Comparison( - comparator=comparison.comparator, - attribute=comparison.attribute, - value=value, - ) - for value in comparison.value - ], - ) - ) - - return ".".join( - [ - f"{self.metadata_column}{self._get_json_operator(comparison.value)}{comparison.attribute}", - f"{self._map_comparator(comparison.comparator)}", - f"{comparison.value}", - ] - ) - - def visit_structured_query( - self, structured_query: StructuredQuery - ) -> Tuple[str, Dict[str, str]]: - if structured_query.filter is None: - kwargs = {} - else: - kwargs = {"postgrest_filter": structured_query.filter.accept(self)} - return structured_query.query, kwargs +__all__ = ["SupabaseVectorTranslator"] diff --git a/libs/langchain/langchain/retrievers/self_query/tencentvectordb.py b/libs/langchain/langchain/retrievers/self_query/tencentvectordb.py index b1ec31a1a2..c971a1c1eb 100644 --- a/libs/langchain/langchain/retrievers/self_query/tencentvectordb.py +++ b/libs/langchain/langchain/retrievers/self_query/tencentvectordb.py @@ -1,116 +1,27 @@ -from __future__ import annotations +from typing import TYPE_CHECKING, Any -from typing import Optional, Sequence, Tuple +from langchain._api import create_importer -from langchain_core.structured_query import ( - Comparator, - Comparison, - Operation, - Operator, - StructuredQuery, - Visitor, -) +if TYPE_CHECKING: + from langchain_community.query_constructors.tencentvectordb import ( + TencentVectorDBTranslator, + ) +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "TencentVectorDBTranslator": ( + "langchain_community.query_constructors.tencentvectordb" + ), +} -class TencentVectorDBTranslator(Visitor): - """Translate StructuredQuery to Tencent VectorDB query.""" +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) - COMPARATOR_MAP = { - Comparator.EQ: "=", - Comparator.NE: "!=", - Comparator.GT: ">", - Comparator.GTE: ">=", - Comparator.LT: "<", - Comparator.LTE: "<=", - Comparator.IN: "in", - Comparator.NIN: "not in", - } - allowed_comparators: Optional[Sequence[Comparator]] = list(COMPARATOR_MAP.keys()) - allowed_operators: Optional[Sequence[Operator]] = [ - Operator.AND, - Operator.OR, - Operator.NOT, - ] +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - def __init__(self, meta_keys: Optional[Sequence[str]] = None): - """Initialize the translator. - Args: - meta_keys: List of meta keys to be used in the query. Default: []. - """ - self.meta_keys = meta_keys or [] - - def visit_operation(self, operation: Operation) -> str: - """Visit an operation node and return the translated query. - - Args: - operation: Operation node to be visited. - - Returns: - Translated query. - """ - if operation.operator in (Operator.AND, Operator.OR): - ret = f" {operation.operator.value} ".join( - [arg.accept(self) for arg in operation.arguments] - ) - if operation.operator == Operator.OR: - ret = f"({ret})" - return ret - else: - return f"not ({operation.arguments[0].accept(self)})" - - def visit_comparison(self, comparison: Comparison) -> str: - """Visit a comparison node and return the translated query. - - Args: - comparison: Comparison node to be visited. - - Returns: - Translated query. - """ - if self.meta_keys and comparison.attribute not in self.meta_keys: - raise ValueError( - f"Expr Filtering found Unsupported attribute: {comparison.attribute}" - ) - - if comparison.comparator in self.COMPARATOR_MAP: - if comparison.comparator in [Comparator.IN, Comparator.NIN]: - value = map( - lambda x: f'"{x}"' if isinstance(x, str) else x, comparison.value - ) - return ( - f"{comparison.attribute}" - f" {self.COMPARATOR_MAP[comparison.comparator]} " - f"({', '.join(value)})" - ) - if isinstance(comparison.value, str): - return ( - f"{comparison.attribute} " - f"{self.COMPARATOR_MAP[comparison.comparator]}" - f' "{comparison.value}"' - ) - return ( - f"{comparison.attribute}" - f" {self.COMPARATOR_MAP[comparison.comparator]} " - f"{comparison.value}" - ) - else: - raise ValueError(f"Unsupported comparator {comparison.comparator}") - - def visit_structured_query( - self, structured_query: StructuredQuery - ) -> Tuple[str, dict]: - """Visit a structured query node and return the translated query. - - Args: - structured_query: StructuredQuery node to be visited. - - Returns: - Translated query and query kwargs. - """ - if structured_query.filter is None: - kwargs = {} - else: - kwargs = {"expr": structured_query.filter.accept(self)} - return structured_query.query, kwargs +__all__ = ["TencentVectorDBTranslator"] diff --git a/libs/langchain/langchain/retrievers/self_query/timescalevector.py b/libs/langchain/langchain/retrievers/self_query/timescalevector.py index bfac120bde..623ca390b4 100644 --- a/libs/langchain/langchain/retrievers/self_query/timescalevector.py +++ b/libs/langchain/langchain/retrievers/self_query/timescalevector.py @@ -1,84 +1,27 @@ -from __future__ import annotations +from typing import TYPE_CHECKING, Any -from typing import TYPE_CHECKING, Tuple, Union - -from langchain_core.structured_query import ( - Comparator, - Comparison, - Operation, - Operator, - StructuredQuery, - Visitor, -) +from langchain._api import create_importer if TYPE_CHECKING: - from timescale_vector import client - - -class TimescaleVectorTranslator(Visitor): - """Translate the internal query language elements to valid filters.""" - - allowed_operators = [Operator.AND, Operator.OR, Operator.NOT] - """Subset of allowed logical operators.""" - - allowed_comparators = [ - Comparator.EQ, - Comparator.GT, - Comparator.GTE, - Comparator.LT, - Comparator.LTE, - ] + from langchain_community.query_constructors.timescalevector import ( + TimescaleVectorTranslator, + ) - COMPARATOR_MAP = { - Comparator.EQ: "==", - Comparator.GT: ">", - Comparator.GTE: ">=", - Comparator.LT: "<", - Comparator.LTE: "<=", - } +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "TimescaleVectorTranslator": ( + "langchain_community.query_constructors.timescalevector" + ), +} - OPERATOR_MAP = {Operator.AND: "AND", Operator.OR: "OR", Operator.NOT: "NOT"} +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) - def _format_func(self, func: Union[Operator, Comparator]) -> str: - self._validate_func(func) - if isinstance(func, Operator): - value = self.OPERATOR_MAP[func.value] # type: ignore - elif isinstance(func, Comparator): - value = self.COMPARATOR_MAP[func.value] # type: ignore - return f"{value}" - def visit_operation(self, operation: Operation) -> client.Predicates: - try: - from timescale_vector import client - except ImportError as e: - raise ImportError( - "Cannot import timescale-vector. Please install with `pip install " - "timescale-vector`." - ) from e - args = [arg.accept(self) for arg in operation.arguments] - return client.Predicates(*args, operator=self._format_func(operation.operator)) +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - def visit_comparison(self, comparison: Comparison) -> client.Predicates: - try: - from timescale_vector import client - except ImportError as e: - raise ImportError( - "Cannot import timescale-vector. Please install with `pip install " - "timescale-vector`." - ) from e - return client.Predicates( - ( - comparison.attribute, - self._format_func(comparison.comparator), - comparison.value, - ) - ) - def visit_structured_query( - self, structured_query: StructuredQuery - ) -> Tuple[str, dict]: - if structured_query.filter is None: - kwargs = {} - else: - kwargs = {"predicates": structured_query.filter.accept(self)} - return structured_query.query, kwargs +__all__ = ["TimescaleVectorTranslator"] diff --git a/libs/langchain/langchain/retrievers/self_query/vectara.py b/libs/langchain/langchain/retrievers/self_query/vectara.py index 24886a1af9..0fa15959b7 100644 --- a/libs/langchain/langchain/retrievers/self_query/vectara.py +++ b/libs/langchain/langchain/retrievers/self_query/vectara.py @@ -1,70 +1,27 @@ -from typing import Tuple, Union +from typing import TYPE_CHECKING, Any -from langchain_core.structured_query import ( - Comparator, - Comparison, - Operation, - Operator, - StructuredQuery, - Visitor, -) +from langchain._api import create_importer +if TYPE_CHECKING: + from langchain_community.query_constructors.vectara import ( + VectaraTranslator, + process_value, + ) -def process_value(value: Union[int, float, str]) -> str: - """Convert a value to a string and add single quotes if it is a string.""" - if isinstance(value, str): - return f"'{value}'" - else: - return str(value) +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "VectaraTranslator": "langchain_community.query_constructors.vectara", + "process_value": "langchain_community.query_constructors.vectara", +} +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) -class VectaraTranslator(Visitor): - """Translate `Vectara` internal query language elements to valid filters.""" - allowed_operators = [Operator.AND, Operator.OR] - """Subset of allowed logical operators.""" - allowed_comparators = [ - Comparator.EQ, - Comparator.NE, - Comparator.GT, - Comparator.GTE, - Comparator.LT, - Comparator.LTE, - ] - """Subset of allowed logical comparators.""" +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - def _format_func(self, func: Union[Operator, Comparator]) -> str: - map_dict = { - Operator.AND: " and ", - Operator.OR: " or ", - Comparator.EQ: "=", - Comparator.NE: "!=", - Comparator.GT: ">", - Comparator.GTE: ">=", - Comparator.LT: "<", - Comparator.LTE: "<=", - } - self._validate_func(func) - return map_dict[func] - def visit_operation(self, operation: Operation) -> str: - args = [arg.accept(self) for arg in operation.arguments] - operator = self._format_func(operation.operator) - return "( " + operator.join(args) + " )" - - def visit_comparison(self, comparison: Comparison) -> str: - comparator = self._format_func(comparison.comparator) - processed_value = process_value(comparison.value) - attribute = comparison.attribute - return ( - "( " + "doc." + attribute + " " + comparator + " " + processed_value + " )" - ) - - def visit_structured_query( - self, structured_query: StructuredQuery - ) -> Tuple[str, dict]: - if structured_query.filter is None: - kwargs = {} - else: - kwargs = {"filter": structured_query.filter.accept(self)} - return structured_query.query, kwargs +__all__ = ["VectaraTranslator", "process_value"] diff --git a/libs/langchain/langchain/retrievers/self_query/weaviate.py b/libs/langchain/langchain/retrievers/self_query/weaviate.py index 2e5e3e691e..5385258d53 100644 --- a/libs/langchain/langchain/retrievers/self_query/weaviate.py +++ b/libs/langchain/langchain/retrievers/self_query/weaviate.py @@ -1,79 +1,23 @@ -from datetime import datetime -from typing import Dict, Tuple, Union +from typing import TYPE_CHECKING, Any -from langchain_core.structured_query import ( - Comparator, - Comparison, - Operation, - Operator, - StructuredQuery, - Visitor, -) +from langchain._api import create_importer +if TYPE_CHECKING: + from langchain_community.query_constructors.weaviate import WeaviateTranslator -class WeaviateTranslator(Visitor): - """Translate `Weaviate` internal query language elements to valid filters.""" +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "WeaviateTranslator": "langchain_community.query_constructors.weaviate", +} - allowed_operators = [Operator.AND, Operator.OR] - """Subset of allowed logical operators.""" +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) - allowed_comparators = [ - Comparator.EQ, - Comparator.NE, - Comparator.GTE, - Comparator.LTE, - Comparator.LT, - Comparator.GT, - ] - def _format_func(self, func: Union[Operator, Comparator]) -> str: - self._validate_func(func) - # https://weaviate.io/developers/weaviate/api/graphql/filters - map_dict = { - Operator.AND: "And", - Operator.OR: "Or", - Comparator.EQ: "Equal", - Comparator.NE: "NotEqual", - Comparator.GTE: "GreaterThanEqual", - Comparator.LTE: "LessThanEqual", - Comparator.LT: "LessThan", - Comparator.GT: "GreaterThan", - } - return map_dict[func] +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - def visit_operation(self, operation: Operation) -> Dict: - args = [arg.accept(self) for arg in operation.arguments] - return {"operator": self._format_func(operation.operator), "operands": args} - def visit_comparison(self, comparison: Comparison) -> Dict: - value_type = "valueText" - value = comparison.value - if isinstance(comparison.value, bool): - value_type = "valueBoolean" - elif isinstance(comparison.value, float): - value_type = "valueNumber" - elif isinstance(comparison.value, int): - value_type = "valueInt" - elif ( - isinstance(comparison.value, dict) - and comparison.value.get("type") == "date" - ): - value_type = "valueDate" - # ISO 8601 timestamp, formatted as RFC3339 - date = datetime.strptime(comparison.value["date"], "%Y-%m-%d") - value = date.strftime("%Y-%m-%dT%H:%M:%SZ") - filter = { - "path": [comparison.attribute], - "operator": self._format_func(comparison.comparator), - value_type: value, - } - return filter - - def visit_structured_query( - self, structured_query: StructuredQuery - ) -> Tuple[str, dict]: - if structured_query.filter is None: - kwargs = {} - else: - kwargs = {"where_filter": structured_query.filter.accept(self)} - return structured_query.query, kwargs +__all__ = ["WeaviateTranslator"] diff --git a/libs/langchain/langchain/retrievers/web_research.py b/libs/langchain/langchain/retrievers/web_research.py index 378c77a267..5d853752b7 100644 --- a/libs/langchain/langchain/retrievers/web_research.py +++ b/libs/langchain/langchain/retrievers/web_research.py @@ -1,223 +1,29 @@ -import logging -import re -from typing import List, Optional +from typing import TYPE_CHECKING, Any -from langchain_community.document_loaders import AsyncHtmlLoader -from langchain_community.document_transformers import Html2TextTransformer -from langchain_community.llms import LlamaCpp -from langchain_community.utilities import GoogleSearchAPIWrapper -from langchain_core.callbacks import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) -from langchain_core.documents import Document -from langchain_core.language_models import BaseLLM -from langchain_core.output_parsers import BaseOutputParser -from langchain_core.prompts import BasePromptTemplate, PromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field -from langchain_core.retrievers import BaseRetriever -from langchain_core.vectorstores import VectorStore -from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter +from langchain._api import create_importer -from langchain.chains import LLMChain -from langchain.chains.prompt_selector import ConditionalPromptSelector - -logger = logging.getLogger(__name__) - - -class SearchQueries(BaseModel): - """Search queries to research for the user's goal.""" - - queries: List[str] = Field( - ..., description="List of search queries to look up on Google" - ) - - -DEFAULT_LLAMA_SEARCH_PROMPT = PromptTemplate( - input_variables=["question"], - template="""<> \n You are an assistant tasked with improving Google search \ -results. \n <> \n\n [INST] Generate THREE Google search queries that \ -are similar to this question. The output should be a numbered list of questions \ -and each should have a question mark at the end: \n\n {question} [/INST]""", -) - -DEFAULT_SEARCH_PROMPT = PromptTemplate( - input_variables=["question"], - template="""You are an assistant tasked with improving Google search \ -results. Generate THREE Google search queries that are similar to \ -this question. The output should be a numbered list of questions and each \ -should have a question mark at the end: {question}""", -) - - -class QuestionListOutputParser(BaseOutputParser[List[str]]): - """Output parser for a list of numbered questions.""" - - def parse(self, text: str) -> List[str]: - lines = re.findall(r"\d+\..*?(?:\n|$)", text) - return lines - - -class WebResearchRetriever(BaseRetriever): - """`Google Search API` retriever.""" - - # Inputs - vectorstore: VectorStore = Field( - ..., description="Vector store for storing web pages" - ) - llm_chain: LLMChain - search: GoogleSearchAPIWrapper = Field(..., description="Google Search API Wrapper") - num_search_results: int = Field(1, description="Number of pages per Google search") - text_splitter: TextSplitter = Field( - RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=50), - description="Text splitter for splitting web pages into chunks", +if TYPE_CHECKING: + from langchain_community.retrievers.web_research import ( + QuestionListOutputParser, + SearchQueries, + WebResearchRetriever, ) - url_database: List[str] = Field( - default_factory=list, description="List of processed URLs" - ) - - @classmethod - def from_llm( - cls, - vectorstore: VectorStore, - llm: BaseLLM, - search: GoogleSearchAPIWrapper, - prompt: Optional[BasePromptTemplate] = None, - num_search_results: int = 1, - text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter( - chunk_size=1500, chunk_overlap=150 - ), - ) -> "WebResearchRetriever": - """Initialize from llm using default template. - - Args: - vectorstore: Vector store for storing web pages - llm: llm for search question generation - search: GoogleSearchAPIWrapper - prompt: prompt to generating search questions - num_search_results: Number of pages per Google search - text_splitter: Text splitter for splitting web pages into chunks - - Returns: - WebResearchRetriever - """ - - if not prompt: - QUESTION_PROMPT_SELECTOR = ConditionalPromptSelector( - default_prompt=DEFAULT_SEARCH_PROMPT, - conditionals=[ - (lambda llm: isinstance(llm, LlamaCpp), DEFAULT_LLAMA_SEARCH_PROMPT) - ], - ) - prompt = QUESTION_PROMPT_SELECTOR.get_prompt(llm) - - # Use chat model prompt - llm_chain = LLMChain( - llm=llm, - prompt=prompt, - output_parser=QuestionListOutputParser(), - ) - - return cls( - vectorstore=vectorstore, - llm_chain=llm_chain, - search=search, - num_search_results=num_search_results, - text_splitter=text_splitter, - ) - - def clean_search_query(self, query: str) -> str: - # Some search tools (e.g., Google) will - # fail to return results if query has a - # leading digit: 1. "LangCh..." - # Check if the first character is a digit - if query[0].isdigit(): - # Find the position of the first quote - first_quote_pos = query.find('"') - if first_quote_pos != -1: - # Extract the part of the string after the quote - query = query[first_quote_pos + 1 :] - # Remove the trailing quote if present - if query.endswith('"'): - query = query[:-1] - return query.strip() - - def search_tool(self, query: str, num_search_results: int = 1) -> List[dict]: - """Returns num_search_results pages per Google search.""" - query_clean = self.clean_search_query(query) - result = self.search.results(query_clean, num_search_results) - return result - - def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - ) -> List[Document]: - """Search Google for documents related to the query input. - - Args: - query: user query - - Returns: - Relevant documents from all various urls. - """ - - # Get search questions - logger.info("Generating questions for Google Search ...") - result = self.llm_chain({"question": query}) - logger.info(f"Questions for Google Search (raw): {result}") - questions = result["text"] - logger.info(f"Questions for Google Search: {questions}") - - # Get urls - logger.info("Searching for relevant urls...") - urls_to_look = [] - for query in questions: - # Google search - search_results = self.search_tool(query, self.num_search_results) - logger.info("Searching for relevant urls...") - logger.info(f"Search results: {search_results}") - for res in search_results: - if res.get("link", None): - urls_to_look.append(res["link"]) - # Relevant urls - urls = set(urls_to_look) +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "QuestionListOutputParser": "langchain_community.retrievers.web_research", + "SearchQueries": "langchain_community.retrievers.web_research", + "WebResearchRetriever": "langchain_community.retrievers.web_research", +} - # Check for any new urls that we have not processed - new_urls = list(urls.difference(self.url_database)) +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) - logger.info(f"New URLs to load: {new_urls}") - # Load, split, and add new urls to vectorstore - if new_urls: - loader = AsyncHtmlLoader(new_urls, ignore_load_errors=True) - html2text = Html2TextTransformer() - logger.info("Indexing new urls...") - docs = loader.load() - docs = list(html2text.transform_documents(docs)) - docs = self.text_splitter.split_documents(docs) - self.vectorstore.add_documents(docs) - self.url_database.extend(new_urls) - # Search for relevant splits - # TODO: make this async - logger.info("Grabbing most relevant splits from urls...") - docs = [] - for query in questions: - docs.extend(self.vectorstore.similarity_search(query)) +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - # Get unique docs - unique_documents_dict = { - (doc.page_content, tuple(sorted(doc.metadata.items()))): doc for doc in docs - } - unique_documents = list(unique_documents_dict.values()) - return unique_documents - async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - ) -> List[Document]: - raise NotImplementedError +__all__ = ["QuestionListOutputParser", "SearchQueries", "WebResearchRetriever"] diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index 8ab41d74eb..91f93ad58b 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "aiodns" @@ -3072,6 +3072,7 @@ files = [ {file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"}, {file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"}, {file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"}, + {file = "jq-1.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:49e44ed677713f4115bd5bf2dbae23baa4cd503be350e12a1c1f506b0687848f"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"}, @@ -3467,37 +3468,9 @@ files = [ {file = "jupyterlab_widgets-3.0.10.tar.gz", hash = "sha256:04f2ac04976727e4f9d0fa91cdc2f1ab860f965e504c29dbd6a65c882c9d04c0"}, ] -[[package]] -name = "langchain-community" -version = "0.0.37" -description = "Community contributed LangChain integrations." -optional = false -python-versions = ">=3.8.1,<4.0" -files = [] -develop = true - -[package.dependencies] -aiohttp = "^3.8.3" -dataclasses-json = ">= 0.5.7, < 0.7" -langchain-core = "^0.1.51" -langsmith = "^0.1.0" -numpy = "^1" -PyYAML = ">=5.3" -requests = "^2" -SQLAlchemy = ">=1.4,<3" -tenacity = "^8.1.0" - -[package.extras] -cli = ["typer (>=0.9.0,<0.10.0)"] -extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "azure-ai-documentintelligence (>=1.0.0b1,<2.0.0)", "azure-identity (>=1.15.0,<2.0.0)", "azure-search-documents (==11.4.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.6,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cloudpickle (>=2.0.0)", "cloudpickle (>=2.0.0)", "cohere (>=4,<5)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "elasticsearch (>=8.12.0,<9.0.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "friendli-client (>=1.2.4,<2.0.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "gradientai (>=1.4.0,<2.0.0)", "hdbcli (>=2.19.21,<3.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "httpx (>=0.24.1,<0.25.0)", "httpx-sse (>=0.4.0,<0.5.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.3,<6.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "nvidia-riva-client (>=2.14.0,<3.0.0)", "oci (>=2.119.1,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "oracle-ads (>=2.9.1,<3.0.0)", "oracledb (>=2.2.0,<3.0.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "premai (>=0.3.25,<0.4.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pyjwt (>=2.8.0,<3.0.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "rdflib (==7.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "tidb-vector (>=0.0.3,<1.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "tree-sitter (>=0.20.2,<0.21.0)", "tree-sitter-languages (>=1.8.0,<2.0.0)", "upstash-redis (>=0.15.0,<0.16.0)", "vdms (>=0.0.20,<0.0.21)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)"] - -[package.source] -type = "directory" -url = "../community" - [[package]] name = "langchain-core" -version = "0.1.51" +version = "0.1.52" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -4727,6 +4700,7 @@ description = "Nvidia JIT LTO Library" optional = true python-versions = ">=3" files = [ + {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_aarch64.whl", hash = "sha256:75d6498c96d9adb9435f2bbdbddb479805ddfb97b5c1b32395c694185c20ca57"}, {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c6428836d20fe7e327191c175791d38570e10762edc588fb46749217cd444c74"}, {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-win_amd64.whl", hash = "sha256:991905ffa2144cb603d8ca7962d75c35334ae82bf92820b6ba78157277da1ad2"}, ] @@ -5471,8 +5445,6 @@ files = [ {file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"}, {file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"}, {file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"}, - {file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"}, - {file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"}, {file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"}, {file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"}, {file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"}, @@ -5515,7 +5487,6 @@ files = [ {file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"}, {file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"}, @@ -5524,8 +5495,6 @@ files = [ {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"}, @@ -6074,26 +6043,31 @@ python-versions = ">=3.8" files = [ {file = "PyMuPDF-1.23.26-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:645a05321aecc8c45739f71f0eb574ce33138d19189582ffa5241fea3a8e2549"}, {file = "PyMuPDF-1.23.26-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:2dfc9e010669ae92fade6fb72aaea49ebe3b8dcd7ee4dcbbe50115abcaa4d3fe"}, + {file = "PyMuPDF-1.23.26-cp310-none-manylinux2014_aarch64.whl", hash = "sha256:734ee380b3abd038602be79114194a3cb74ac102b7c943bcb333104575922c50"}, {file = "PyMuPDF-1.23.26-cp310-none-manylinux2014_x86_64.whl", hash = "sha256:b22f8d854f8196ad5b20308c1cebad3d5189ed9f0988acbafa043947ea7e6c55"}, {file = "PyMuPDF-1.23.26-cp310-none-win32.whl", hash = "sha256:cc0f794e3466bc96b5bf79d42fbc1551428751e3fef38ebc10ac70396b676144"}, {file = "PyMuPDF-1.23.26-cp310-none-win_amd64.whl", hash = "sha256:2eb701247d8e685a24e45899d1175f01a3ce5fc792a4431c91fbb68633b29298"}, {file = "PyMuPDF-1.23.26-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:e2804a64bb57da414781e312fb0561f6be67658ad57ed4a73dce008b23fc70a6"}, {file = "PyMuPDF-1.23.26-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:97b40bb22e3056874634617a90e0ed24a5172cf71791b9e25d1d91c6743bc567"}, + {file = "PyMuPDF-1.23.26-cp311-none-manylinux2014_aarch64.whl", hash = "sha256:fab8833559bc47ab26ce736f915b8fc1dd37c108049b90396f7cd5e1004d7593"}, {file = "PyMuPDF-1.23.26-cp311-none-manylinux2014_x86_64.whl", hash = "sha256:f25aafd3e7fb9d7761a22acf2b67d704f04cc36d4dc33a3773f0eb3f4ec3606f"}, {file = "PyMuPDF-1.23.26-cp311-none-win32.whl", hash = "sha256:05e672ed3e82caca7ef02a88ace30130b1dd392a1190f03b2b58ffe7aa331400"}, {file = "PyMuPDF-1.23.26-cp311-none-win_amd64.whl", hash = "sha256:92b3c4dd4d0491d495f333be2d41f4e1c155a409bc9d04b5ff29655dccbf4655"}, {file = "PyMuPDF-1.23.26-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:a217689ede18cc6991b4e6a78afee8a440b3075d53b9dec4ba5ef7487d4547e9"}, {file = "PyMuPDF-1.23.26-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:42ad2b819b90ce1947e11b90ec5085889df0a2e3aa0207bc97ecacfc6157cabc"}, + {file = "PyMuPDF-1.23.26-cp312-none-manylinux2014_aarch64.whl", hash = "sha256:99607649f89a02bba7d8ebe96e2410664316adc95e9337f7dfeff6a154f93049"}, {file = "PyMuPDF-1.23.26-cp312-none-manylinux2014_x86_64.whl", hash = "sha256:bb42d4b8407b4de7cb58c28f01449f16f32a6daed88afb41108f1aeb3552bdd4"}, {file = "PyMuPDF-1.23.26-cp312-none-win32.whl", hash = "sha256:c40d044411615e6f0baa7d3d933b3032cf97e168c7fa77d1be8a46008c109aee"}, {file = "PyMuPDF-1.23.26-cp312-none-win_amd64.whl", hash = "sha256:3f876533aa7f9a94bcd9a0225ce72571b7808260903fec1d95c120bc842fb52d"}, {file = "PyMuPDF-1.23.26-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:52df831d46beb9ff494f5fba3e5d069af6d81f49abf6b6e799ee01f4f8fa6799"}, {file = "PyMuPDF-1.23.26-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:0bbb0cf6593e53524f3fc26fb5e6ead17c02c64791caec7c4afe61b677dedf80"}, + {file = "PyMuPDF-1.23.26-cp38-none-manylinux2014_aarch64.whl", hash = "sha256:5ef4360f20015673c20cf59b7e19afc97168795188c584254ed3778cde43ce77"}, {file = "PyMuPDF-1.23.26-cp38-none-manylinux2014_x86_64.whl", hash = "sha256:d7cd88842b2e7f4c71eef4d87c98c35646b80b60e6375392d7ce40e519261f59"}, {file = "PyMuPDF-1.23.26-cp38-none-win32.whl", hash = "sha256:6577e2f473625e2d0df5f5a3bf1e4519e94ae749733cc9937994d1b256687bfa"}, {file = "PyMuPDF-1.23.26-cp38-none-win_amd64.whl", hash = "sha256:fbe1a3255b2cd0d769b2da2c4efdd0c0f30d4961a1aac02c0f75cf951b337aa4"}, {file = "PyMuPDF-1.23.26-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:73fce034f2afea886a59ead2d0caedf27e2b2a8558b5da16d0286882e0b1eb82"}, {file = "PyMuPDF-1.23.26-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:b3de8618b7cb5b36db611083840b3bcf09b11a893e2d8262f4e042102c7e65de"}, + {file = "PyMuPDF-1.23.26-cp39-none-manylinux2014_aarch64.whl", hash = "sha256:879e7f5ad35709d8760ab6103c3d5dac8ab8043a856ab3653fd324af7358ee87"}, {file = "PyMuPDF-1.23.26-cp39-none-manylinux2014_x86_64.whl", hash = "sha256:deee96c2fd415ded7b5070d8d5b2c60679aee6ed0e28ac0d2cb998060d835c2c"}, {file = "PyMuPDF-1.23.26-cp39-none-win32.whl", hash = "sha256:9f7f4ef99dd8ac97fb0b852efa3dcbee515798078b6c79a6a13c7b1e7c5d41a4"}, {file = "PyMuPDF-1.23.26-cp39-none-win_amd64.whl", hash = "sha256:ba9a54552c7afb9ec85432c765e2fa9a81413acfaa7d70db7c9b528297749e5b"}, @@ -6569,7 +6543,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -9404,4 +9377,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "9ed4d0b11749d1f98e8fbe2895a94e4bc90975817873e52a70f2bbcee934ce19" +content-hash = "99c6692964eb665e7746911ab0c834a4893215b40d3d4733c0c5fa5904669769" diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index f1178f52eb..a7240614d0 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.1.18" +version = "0.2.0rc1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" @@ -14,7 +14,6 @@ langchain-server = "langchain.server:main" python = ">=3.8.1,<4.0" langchain-core = "^0.1.48" langchain-text-splitters = ">=0.0.1,<0.1" -langchain-community = ">=0.0.37,<0.1" langsmith = "^0.1.17" pydantic = ">=1,<3" SQLAlchemy = ">=1.4,<3" @@ -169,7 +168,6 @@ cassio = "^0.1.0" tiktoken = ">=0.3.2,<0.6.0" anthropic = "^0.3.11" langchain-core = {path = "../core", develop = true} -langchain-community = {path = "../community", develop = true} langchain-text-splitters = {path = "../text-splitters", develop = true} langchainhub = "^0.1.15" @@ -192,7 +190,6 @@ types-pytz = "^2023.3.0.0" types-chardet = "^5.0.4.6" mypy-protobuf = "^3.0.0" langchain-core = {path = "../core", develop = true} -langchain-community = {path = "../community", develop = true} langchain-text-splitters = {path = "../text-splitters", develop = true} [tool.poetry.group.dev] @@ -203,7 +200,6 @@ jupyter = "^1.0.0" playwright = "^1.28.0" setuptools = "^67.6.1" langchain-core = {path = "../core", develop = true} -langchain-community = {path = "../community", develop = true} langchain-text-splitters = {path = "../text-splitters", develop = true} [tool.poetry.extras] @@ -346,6 +342,7 @@ addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused markers = [ "requires: mark tests as requiring a specific library", "scheduled: mark tests to run in scheduled testing", + "community: mark tests that require langchain-community to be installed", "compile: mark placeholder test used to compile integration tests without running them" ] asyncio_mode = "auto" diff --git a/libs/langchain/tests/unit_tests/agents/agent_toolkits/test_imports.py b/libs/langchain/tests/unit_tests/agents/agent_toolkits/test_imports.py index c81ce791ea..378764bc48 100644 --- a/libs/langchain/tests/unit_tests/agents/agent_toolkits/test_imports.py +++ b/libs/langchain/tests/unit_tests/agents/agent_toolkits/test_imports.py @@ -1,5 +1,4 @@ from langchain.agents import agent_toolkits -from tests.unit_tests import assert_all_importable EXPECTED_ALL = [ "AINetworkToolkit", @@ -39,4 +38,3 @@ EXPECTED_ALL = [ def test_imports() -> None: assert sorted(agent_toolkits.__all__) == sorted(EXPECTED_ALL) - assert_all_importable(agent_toolkits) diff --git a/libs/langchain/tests/unit_tests/agents/test_agent.py b/libs/langchain/tests/unit_tests/agents/test_agent.py index 060e64338a..6db94330a1 100644 --- a/libs/langchain/tests/unit_tests/agents/test_agent.py +++ b/libs/langchain/tests/unit_tests/agents/test_agent.py @@ -21,7 +21,7 @@ from langchain_core.messages import ( ) from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables.utils import add -from langchain_core.tools import Tool +from langchain_core.tools import Tool, tool from langchain_core.tracers import RunLog, RunLogPatch from langchain.agents import ( @@ -33,7 +33,6 @@ from langchain.agents import ( initialize_agent, ) from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction -from langchain.tools import tool from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel from tests.unit_tests.stubs import AnyStr diff --git a/libs/langchain/tests/unit_tests/agents/test_imports.py b/libs/langchain/tests/unit_tests/agents/test_imports.py index ad092318dc..163dd022b6 100644 --- a/libs/langchain/tests/unit_tests/agents/test_imports.py +++ b/libs/langchain/tests/unit_tests/agents/test_imports.py @@ -1,5 +1,4 @@ from langchain import agents -from tests.unit_tests import assert_all_importable EXPECTED_ALL = [ "Agent", @@ -49,4 +48,3 @@ EXPECTED_ALL = [ def test_all_imports() -> None: assert set(agents.__all__) == set(EXPECTED_ALL) - assert_all_importable(agents) diff --git a/libs/langchain/tests/unit_tests/agents/test_serialization.py b/libs/langchain/tests/unit_tests/agents/test_serialization.py index ece1dbcc9a..f39cd16833 100644 --- a/libs/langchain/tests/unit_tests/agents/test_serialization.py +++ b/libs/langchain/tests/unit_tests/agents/test_serialization.py @@ -1,12 +1,15 @@ from pathlib import Path from tempfile import TemporaryDirectory +import pytest from langchain_core.language_models import FakeListLLM from langchain_core.tools import Tool from langchain.agents.agent_types import AgentType from langchain.agents.initialize import initialize_agent, load_agent +pytest.importorskip("langchain_community") + def test_mrkl_serialization() -> None: agent = initialize_agent( diff --git a/libs/langchain/tests/unit_tests/callbacks/test_imports.py b/libs/langchain/tests/unit_tests/callbacks/test_imports.py index 3e01ae4953..a7bd7d366d 100644 --- a/libs/langchain/tests/unit_tests/callbacks/test_imports.py +++ b/libs/langchain/tests/unit_tests/callbacks/test_imports.py @@ -1,5 +1,4 @@ from langchain import callbacks -from tests.unit_tests import assert_all_importable EXPECTED_ALL = [ "AimCallbackHandler", @@ -39,4 +38,3 @@ EXPECTED_ALL = [ def test_all_imports() -> None: assert set(callbacks.__all__) == set(EXPECTED_ALL) - assert_all_importable(callbacks) diff --git a/libs/langchain/tests/unit_tests/chains/test_imports.py b/libs/langchain/tests/unit_tests/chains/test_imports.py index 4da79031fe..797b81b440 100644 --- a/libs/langchain/tests/unit_tests/chains/test_imports.py +++ b/libs/langchain/tests/unit_tests/chains/test_imports.py @@ -1,5 +1,4 @@ from langchain import chains -from tests.unit_tests import assert_all_importable EXPECTED_ALL = [ "APIChain", @@ -68,4 +67,3 @@ EXPECTED_ALL = [ def test_all_imports() -> None: assert set(chains.__all__) == set(EXPECTED_ALL) - assert_all_importable(chains) diff --git a/libs/langchain/tests/unit_tests/chains/test_neptune_cypher_qa.py b/libs/langchain/tests/unit_tests/chains/test_neptune_cypher_qa.py deleted file mode 100644 index 5685e2c7d9..0000000000 --- a/libs/langchain/tests/unit_tests/chains/test_neptune_cypher_qa.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_import() -> None: - from langchain.chains import NeptuneOpenCypherQAChain # noqa: F401 diff --git a/libs/langchain/tests/unit_tests/chains/test_ontotext_graphdb_qa.py b/libs/langchain/tests/unit_tests/chains/test_ontotext_graphdb_qa.py deleted file mode 100644 index 46917abdab..0000000000 --- a/libs/langchain/tests/unit_tests/chains/test_ontotext_graphdb_qa.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_import() -> None: - from langchain.chains import OntotextGraphDBQAChain # noqa: F401 diff --git a/libs/langchain/tests/unit_tests/chat_models/test_imports.py b/libs/langchain/tests/unit_tests/chat_models/test_imports.py index e27df46d55..9fc196246d 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_imports.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_imports.py @@ -1,5 +1,4 @@ from langchain import chat_models -from tests.unit_tests import assert_all_importable EXPECTED_ALL = [ "ChatOpenAI", @@ -37,4 +36,3 @@ EXPECTED_ALL = [ def test_all_imports() -> None: assert set(chat_models.__all__) == set(EXPECTED_ALL) - assert_all_importable(chat_models) diff --git a/libs/langchain/tests/unit_tests/conftest.py b/libs/langchain/tests/unit_tests/conftest.py index f1746902fc..d7de32ebde 100644 --- a/libs/langchain/tests/unit_tests/conftest.py +++ b/libs/langchain/tests/unit_tests/conftest.py @@ -19,6 +19,14 @@ def pytest_addoption(parser: Parser) -> None: help="Only run core tests. Never runs any extended tests.", ) + parser.addoption( + "--community", + action="store_true", + dest="community", + default=False, + help="enable running unite tests that require community", + ) + def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None: """Add implementations for handling custom markers. @@ -43,6 +51,12 @@ def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> only_extended = config.getoption("--only-extended") or False only_core = config.getoption("--only-core") or False + if not config.getoption("--community"): + skip_community = pytest.mark.skip(reason="need --community option to run") + for item in items: + if "community" in item.keywords: + item.add_marker(skip_community) + if only_extended and only_core: raise ValueError("Cannot specify both `--only-extended` and `--only-core`.") diff --git a/libs/langchain/tests/unit_tests/docstore/test_imports.py b/libs/langchain/tests/unit_tests/docstore/test_imports.py index 763e6d58ff..88fc92b464 100644 --- a/libs/langchain/tests/unit_tests/docstore/test_imports.py +++ b/libs/langchain/tests/unit_tests/docstore/test_imports.py @@ -1,9 +1,7 @@ from langchain import docstore -from tests.unit_tests import assert_all_importable EXPECTED_ALL = ["DocstoreFn", "InMemoryDocstore", "Wikipedia"] def test_all_imports() -> None: assert set(docstore.__all__) == set(EXPECTED_ALL) - assert_all_importable(docstore) diff --git a/libs/langchain/tests/unit_tests/document_loaders/test_base.py b/libs/langchain/tests/unit_tests/document_loaders/test_base.py index 682978eaba..a293a51d55 100644 --- a/libs/langchain/tests/unit_tests/document_loaders/test_base.py +++ b/libs/langchain/tests/unit_tests/document_loaders/test_base.py @@ -1,8 +1,7 @@ """Test Base Schema of documents.""" from typing import Iterator -from langchain_community.document_loaders.base import BaseBlobParser -from langchain_community.document_loaders.blob_loaders import Blob +from langchain_core.document_loaders import BaseBlobParser, Blob from langchain_core.documents import Document diff --git a/libs/langchain/tests/unit_tests/document_loaders/test_imports.py b/libs/langchain/tests/unit_tests/document_loaders/test_imports.py index 86d5b115da..3b5736e4ee 100644 --- a/libs/langchain/tests/unit_tests/document_loaders/test_imports.py +++ b/libs/langchain/tests/unit_tests/document_loaders/test_imports.py @@ -1,5 +1,4 @@ from langchain import document_loaders -from tests.unit_tests import assert_all_importable EXPECTED_ALL = [ "AcreomLoader", @@ -177,4 +176,3 @@ EXPECTED_ALL = [ def test_all_imports() -> None: assert set(document_loaders.__all__) == set(EXPECTED_ALL) - assert_all_importable(document_loaders) diff --git a/libs/langchain/tests/unit_tests/document_transformers/test_imports.py b/libs/langchain/tests/unit_tests/document_transformers/test_imports.py index eadac71fde..557bc00caf 100644 --- a/libs/langchain/tests/unit_tests/document_transformers/test_imports.py +++ b/libs/langchain/tests/unit_tests/document_transformers/test_imports.py @@ -1,5 +1,4 @@ from langchain import document_transformers -from tests.unit_tests import assert_all_importable EXPECTED_ALL = [ "BeautifulSoupTransformer", @@ -19,4 +18,3 @@ EXPECTED_ALL = [ def test_all_imports() -> None: assert set(document_transformers.__all__) == set(EXPECTED_ALL) - assert_all_importable(document_transformers) diff --git a/libs/langchain/tests/unit_tests/embeddings/test_imports.py b/libs/langchain/tests/unit_tests/embeddings/test_imports.py index 0a5c67c704..c6d7a8207d 100644 --- a/libs/langchain/tests/unit_tests/embeddings/test_imports.py +++ b/libs/langchain/tests/unit_tests/embeddings/test_imports.py @@ -1,5 +1,4 @@ from langchain import embeddings -from tests.unit_tests import assert_all_importable EXPECTED_ALL = [ "OpenAIEmbeddings", @@ -61,4 +60,3 @@ EXPECTED_ALL = [ def test_all_imports() -> None: assert set(embeddings.__all__) == set(EXPECTED_ALL) - assert_all_importable(embeddings) diff --git a/libs/langchain/tests/unit_tests/evaluation/test_loading.py b/libs/langchain/tests/unit_tests/evaluation/test_loading.py index e6ef4066fb..9748bbaaa9 100644 --- a/libs/langchain/tests/unit_tests/evaluation/test_loading.py +++ b/libs/langchain/tests/unit_tests/evaluation/test_loading.py @@ -26,6 +26,7 @@ def test_load_evaluators(evaluator_type: EvaluatorType) -> None: ) +@pytest.mark.community @pytest.mark.parametrize( "evaluator_types", [ diff --git a/libs/langchain/tests/unit_tests/graphs/test_imports.py b/libs/langchain/tests/unit_tests/graphs/test_imports.py index 5287c42285..7c59309f45 100644 --- a/libs/langchain/tests/unit_tests/graphs/test_imports.py +++ b/libs/langchain/tests/unit_tests/graphs/test_imports.py @@ -1,5 +1,4 @@ from langchain import graphs -from tests.unit_tests import assert_all_importable EXPECTED_ALL = [ "MemgraphGraph", @@ -17,4 +16,3 @@ EXPECTED_ALL = [ def test_all_imports() -> None: assert set(graphs.__all__) == set(EXPECTED_ALL) - assert_all_importable(graphs) diff --git a/libs/langchain/tests/unit_tests/indexes/test_indexing.py b/libs/langchain/tests/unit_tests/indexes/test_indexing.py index 5826687f28..454b1b126d 100644 --- a/libs/langchain/tests/unit_tests/indexes/test_indexing.py +++ b/libs/langchain/tests/unit_tests/indexes/test_indexing.py @@ -14,7 +14,7 @@ from unittest.mock import patch import pytest import pytest_asyncio -from langchain_community.document_loaders.base import BaseLoader +from langchain_core.document_loaders import BaseLoader from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VST, VectorStore diff --git a/libs/langchain/tests/unit_tests/llms/test_base.py b/libs/langchain/tests/unit_tests/llms/test_base.py index 1b19b88ee7..b58b1a116a 100644 --- a/libs/langchain/tests/unit_tests/llms/test_base.py +++ b/libs/langchain/tests/unit_tests/llms/test_base.py @@ -1,4 +1,7 @@ """Test base LLM functionality.""" +import importlib + +import pytest from sqlalchemy import Column, Integer, Sequence, String, create_engine try: @@ -9,7 +12,6 @@ except ImportError: from langchain_core.caches import InMemoryCache from langchain_core.outputs import Generation, LLMResult -from langchain.cache import SQLAlchemyCache from langchain.globals import get_llm_cache, set_llm_cache from langchain.llms.base import __all__ from tests.unit_tests.llms.fake_llm import FakeLLM @@ -50,6 +52,10 @@ def test_caching() -> None: assert output == expected_output +@pytest.mark.skipif( + importlib.util.find_spec("langchain_community") is None, + reason="langchain_community not installed", +) def test_custom_caching() -> None: """Test custom_caching behavior.""" Base = declarative_base() @@ -65,6 +71,9 @@ def test_custom_caching() -> None: response = Column(String) engine = create_engine("sqlite://") + + from langchain_community.cache import SQLAlchemyCache + set_llm_cache(SQLAlchemyCache(engine, FulltextLLMCache)) llm = FakeLLM() params = llm.dict() diff --git a/libs/langchain/tests/unit_tests/llms/test_imports.py b/libs/langchain/tests/unit_tests/llms/test_imports.py index 7585447b7f..6e57e22e66 100644 --- a/libs/langchain/tests/unit_tests/llms/test_imports.py +++ b/libs/langchain/tests/unit_tests/llms/test_imports.py @@ -1,5 +1,7 @@ +import pytest +from langchain_core.language_models import BaseLLM + from langchain import llms -from langchain.llms.base import BaseLLM EXPECT_ALL = [ "AI21", @@ -88,6 +90,11 @@ EXPECT_ALL = [ def test_all_imports() -> None: """Simple test to make sure all things can be imported.""" + assert set(llms.__all__) == set(EXPECT_ALL) + + +@pytest.mark.community +def test_all_subclasses() -> None: + """Simple test to make sure all things are subclasses of BaseLLM.""" for cls in llms.__all__: assert issubclass(getattr(llms, cls), BaseLLM) - assert set(llms.__all__) == set(EXPECT_ALL) diff --git a/libs/langchain/tests/unit_tests/load/test_dump.py b/libs/langchain/tests/unit_tests/load/test_dump.py index f82aef4635..e4d2c2fb55 100644 --- a/libs/langchain/tests/unit_tests/load/test_dump.py +++ b/libs/langchain/tests/unit_tests/load/test_dump.py @@ -6,8 +6,6 @@ from typing import Any, Dict, List from unittest.mock import patch import pytest -from langchain_community.chat_models.openai import ChatOpenAI -from langchain_community.llms.openai import OpenAI from langchain_core.load.dump import dumps from langchain_core.load.serializable import Serializable from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate @@ -76,8 +74,11 @@ def test_typeerror() -> None: ) +@pytest.mark.community @pytest.mark.requires("openai") def test_serialize_openai_llm(snapshot: Any) -> None: + from langchain_community.llms.openai import OpenAI + with patch.dict(os.environ, {"LANGCHAIN_API_KEY": "test-api-key"}): llm = OpenAI( # type: ignore[call-arg] model="davinci", @@ -90,16 +91,22 @@ def test_serialize_openai_llm(snapshot: Any) -> None: assert dumps(llm, pretty=True) == snapshot +@pytest.mark.community @pytest.mark.requires("openai") def test_serialize_llmchain(snapshot: Any) -> None: + from langchain_community.llms.openai import OpenAI + llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg] prompt = PromptTemplate.from_template("hello {name}!") chain = LLMChain(llm=llm, prompt=prompt) assert dumps(chain, pretty=True) == snapshot +@pytest.mark.community @pytest.mark.requires("openai") def test_serialize_llmchain_env() -> None: + from langchain_community.llms.openai import OpenAI + llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg] prompt = PromptTemplate.from_template("hello {name}!") chain = LLMChain(llm=llm, prompt=prompt) @@ -120,8 +127,11 @@ def test_serialize_llmchain_env() -> None: del os.environ["OPENAI_API_KEY"] +@pytest.mark.community @pytest.mark.requires("openai") def test_serialize_llmchain_chat(snapshot: Any) -> None: + from langchain_community.chat_models.openai import ChatOpenAI + llm = ChatOpenAI(model="davinci", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg] prompt = ChatPromptTemplate.from_messages( [HumanMessagePromptTemplate.from_template("hello {name}!")] @@ -147,8 +157,11 @@ def test_serialize_llmchain_chat(snapshot: Any) -> None: del os.environ["OPENAI_API_KEY"] +@pytest.mark.community @pytest.mark.requires("openai") def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None: + from langchain_community.llms.openai import OpenAI + llm = OpenAI( # type: ignore[call-arg] model="davinci", temperature=0.5, diff --git a/libs/langchain/tests/unit_tests/load/test_load.py b/libs/langchain/tests/unit_tests/load/test_load.py index b3614b05ed..beb303635e 100644 --- a/libs/langchain/tests/unit_tests/load/test_load.py +++ b/libs/langchain/tests/unit_tests/load/test_load.py @@ -1,13 +1,19 @@ """Test for Serializable base class""" import pytest -from langchain_community.llms.openai import OpenAI as CommunityOpenAI from langchain_core.load.dump import dumpd, dumps from langchain_core.load.load import load, loads from langchain_core.prompts.prompt import PromptTemplate from langchain.chains.llm import LLMChain +pytest.importorskip( + "langchain_community", +) + + +from langchain_community.llms.openai import OpenAI as CommunityOpenAI # noqa: E402 + class NotSerializable: pass diff --git a/libs/langchain/tests/unit_tests/load/test_serializable.py b/libs/langchain/tests/unit_tests/load/test_serializable.py index 613e3f01d8..d4730f30d6 100644 --- a/libs/langchain/tests/unit_tests/load/test_serializable.py +++ b/libs/langchain/tests/unit_tests/load/test_serializable.py @@ -3,6 +3,7 @@ import inspect import pkgutil from types import ModuleType +import pytest from langchain_core.load.mapping import SERIALIZABLE_MAPPING @@ -54,6 +55,7 @@ def import_all_modules(package_name: str) -> dict: return classes +@pytest.mark.community def test_import_all_modules() -> None: """Test import all modules works as expected""" all_modules = import_all_modules("langchain") @@ -77,6 +79,7 @@ def test_import_all_modules() -> None: ) +@pytest.mark.community def test_serializable_mapping() -> None: to_skip = { # This should have had a different namespace, as it was never diff --git a/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_imports.py b/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_imports.py index 270014b6f2..bb2bcd1915 100644 --- a/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_imports.py +++ b/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_imports.py @@ -1,5 +1,4 @@ from langchain.memory import chat_message_histories -from tests.unit_tests import assert_all_importable EXPECTED_ALL = [ "AstraDBChatMessageHistory", @@ -27,4 +26,3 @@ EXPECTED_ALL = [ def test_imports() -> None: assert sorted(chat_message_histories.__all__) == sorted(EXPECTED_ALL) - assert_all_importable(chat_message_histories) diff --git a/libs/langchain/tests/unit_tests/memory/test_imports.py b/libs/langchain/tests/unit_tests/memory/test_imports.py index e4d819ab89..a42684a46e 100644 --- a/libs/langchain/tests/unit_tests/memory/test_imports.py +++ b/libs/langchain/tests/unit_tests/memory/test_imports.py @@ -1,5 +1,4 @@ from langchain import memory -from tests.unit_tests import assert_all_importable EXPECTED_ALL = [ "AstraDBChatMessageHistory", @@ -42,4 +41,3 @@ EXPECTED_ALL = [ def test_all_imports() -> None: assert set(memory.__all__) == set(EXPECTED_ALL) - assert_all_importable(memory) diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_imports.py b/libs/langchain/tests/unit_tests/output_parsers/test_imports.py index a8d1e50f2c..7e3fc89e86 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_imports.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_imports.py @@ -1,5 +1,4 @@ from langchain import output_parsers -from tests.unit_tests import assert_all_importable EXPECTED_ALL = [ "BooleanOutputParser", @@ -30,4 +29,3 @@ EXPECTED_ALL = [ def test_all_imports() -> None: assert set(output_parsers.__all__) == set(EXPECTED_ALL) - assert_all_importable(output_parsers) diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_pandas_dataframe_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_pandas_dataframe_parser.py index 4d80922064..1a61ba2c49 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_pandas_dataframe_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_pandas_dataframe_parser.py @@ -1,8 +1,8 @@ """Test PandasDataframeParser""" import pandas as pd +from langchain_core.exceptions import OutputParserException from langchain.output_parsers.pandas_dataframe import PandasDataFrameOutputParser -from langchain.schema import OutputParserException df = pd.DataFrame( {"chicken": [1, 2, 3, 4], "veggies": [5, 4, 3, 2], "steak": [9, 8, 7, 6]} diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_pydantic_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_pydantic_parser.py index 864478f73f..ed23b237cd 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_pydantic_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_pydantic_parser.py @@ -4,10 +4,9 @@ from enum import Enum from typing import Optional from langchain_core.exceptions import OutputParserException +from langchain_core.output_parsers import PydanticOutputParser from langchain_core.pydantic_v1 import BaseModel, Field -from langchain.output_parsers.pydantic import PydanticOutputParser - class Actions(Enum): SEARCH = "Search" diff --git a/libs/langchain/tests/unit_tests/prompts/test_imports.py b/libs/langchain/tests/unit_tests/prompts/test_imports.py index b66e1a4758..e32722ac2b 100644 --- a/libs/langchain/tests/unit_tests/prompts/test_imports.py +++ b/libs/langchain/tests/unit_tests/prompts/test_imports.py @@ -1,5 +1,4 @@ from langchain import prompts -from tests.unit_tests import assert_all_importable EXPECTED_ALL = [ "AIMessagePromptTemplate", @@ -27,4 +26,3 @@ EXPECTED_ALL = [ def test_all_imports() -> None: assert set(prompts.__all__) == set(EXPECTED_ALL) - assert_all_importable(prompts) diff --git a/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py b/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py index 0eb49e612a..fe46768c10 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py @@ -1,11 +1,11 @@ import pytest from langchain_core.documents import Document +from langchain_core.embeddings import FakeEmbeddings -from langchain.embeddings import FakeEmbeddings -from langchain.retrievers import KNNRetriever, TFIDFRetriever -from langchain.retrievers.bm25 import BM25Retriever from langchain.retrievers.ensemble import EnsembleRetriever +pytest.importorskip("langchain_community") + @pytest.mark.requires("rank_bm25") def test_ensemble_retriever_get_relevant_docs() -> None: @@ -15,6 +15,8 @@ def test_ensemble_retriever_get_relevant_docs() -> None: "Apples and oranges are fruits", ] + from langchain_community.retrievers import BM25Retriever + dummy_retriever = BM25Retriever.from_texts(doc_list) dummy_retriever.k = 1 @@ -30,6 +32,8 @@ def test_weighted_reciprocal_rank() -> None: doc1 = Document(page_content="1") doc2 = Document(page_content="2") + from langchain_community.retrievers import BM25Retriever + dummy_retriever = BM25Retriever.from_texts(["1", "2"]) ensemble_retriever = EnsembleRetriever( retrievers=[dummy_retriever, dummy_retriever], weights=[0.4, 0.5], c=0 @@ -62,6 +66,12 @@ def test_ensemble_retriever_get_relevant_docs_with_multiple_retrievers() -> None "Avocados and strawberries are fruits", ] + from langchain_community.retrievers import ( + BM25Retriever, + KNNRetriever, + TFIDFRetriever, + ) + dummy_retriever = BM25Retriever.from_texts(doc_list_a) dummy_retriever.k = 1 tfidf_retriever = TFIDFRetriever.from_texts(texts=doc_list_b) diff --git a/libs/langchain/tests/unit_tests/retrievers/test_imports.py b/libs/langchain/tests/unit_tests/retrievers/test_imports.py index 935f62bfa2..e6a40058f4 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_imports.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_imports.py @@ -1,5 +1,4 @@ from langchain import retrievers -from tests.unit_tests import assert_all_importable EXPECTED_ALL = [ "AmazonKendraRetriever", @@ -54,4 +53,3 @@ EXPECTED_ALL = [ def test_imports() -> None: assert sorted(retrievers.__all__) == sorted(EXPECTED_ALL) - assert_all_importable(retrievers) diff --git a/libs/langchain/tests/unit_tests/smith/test_imports.py b/libs/langchain/tests/unit_tests/smith/test_imports.py index c8300d8a69..aee16fb608 100644 --- a/libs/langchain/tests/unit_tests/smith/test_imports.py +++ b/libs/langchain/tests/unit_tests/smith/test_imports.py @@ -1,5 +1,4 @@ from langchain import smith -from tests.unit_tests import assert_all_importable EXPECTED_ALL = [ "arun_on_dataset", @@ -10,4 +9,3 @@ EXPECTED_ALL = [ def test_all_imports() -> None: assert set(smith.__all__) == set(EXPECTED_ALL) - assert_all_importable(smith) diff --git a/libs/langchain/tests/unit_tests/storage/test_imports.py b/libs/langchain/tests/unit_tests/storage/test_imports.py index 33f74105f9..83510a8ff5 100644 --- a/libs/langchain/tests/unit_tests/storage/test_imports.py +++ b/libs/langchain/tests/unit_tests/storage/test_imports.py @@ -1,5 +1,4 @@ from langchain import storage -from tests.unit_tests import assert_all_importable EXPECTED_ALL = [ "EncoderBackedStore", @@ -17,4 +16,3 @@ EXPECTED_ALL = [ def test_all_imports() -> None: assert set(storage.__all__) == set(EXPECTED_ALL) - assert_all_importable(storage) diff --git a/libs/langchain/tests/unit_tests/test_dependencies.py b/libs/langchain/tests/unit_tests/test_dependencies.py index e3fb8bac5f..c388778a0f 100644 --- a/libs/langchain/tests/unit_tests/test_dependencies.py +++ b/libs/langchain/tests/unit_tests/test_dependencies.py @@ -50,7 +50,6 @@ def test_required_dependencies(poetry_conf: Mapping[str, Any]) -> None: "python", "requests", "tenacity", - "langchain-community", ] ) @@ -94,6 +93,7 @@ def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None: ) +@pytest.mark.community def test_imports() -> None: """Test that you can import all top level things okay.""" from langchain_community.callbacks import OpenAICallbackHandler # noqa: F401 @@ -101,6 +101,8 @@ def test_imports() -> None: from langchain_community.document_loaders import BSHTMLLoader # noqa: F401 from langchain_community.embeddings import OpenAIEmbeddings # noqa: F401 from langchain_community.llms import OpenAI # noqa: F401 + from langchain_community.retrievers import VespaRetriever # noqa: F401 + from langchain_community.tools import DuckDuckGoSearchResults # noqa: F401 from langchain_community.utilities import ( SearchApiAPIWrapper, # noqa: F401 SerpAPIWrapper, # noqa: F401 @@ -110,5 +112,3 @@ def test_imports() -> None: from langchain.agents import OpenAIFunctionsAgent # noqa: F401 from langchain.chains import LLMChain # noqa: F401 - from langchain.retrievers import VespaRetriever # noqa: F401 - from langchain.tools import DuckDuckGoSearchResults # noqa: F401 diff --git a/libs/langchain/tests/unit_tests/test_imports.py b/libs/langchain/tests/unit_tests/test_imports.py index 62e17021d6..d62e0a1c1e 100644 --- a/libs/langchain/tests/unit_tests/test_imports.py +++ b/libs/langchain/tests/unit_tests/test_imports.py @@ -1,10 +1,13 @@ import importlib from pathlib import Path +import pytest + # Attempt to recursively import all modules in langchain PKG_ROOT = Path(__file__).parent.parent.parent +@pytest.mark.community def test_import_all() -> None: """Generate the public API for this package.""" library_code = PKG_ROOT / "langchain" @@ -25,3 +28,28 @@ def test_import_all() -> None: # Attempt to import the name from the module obj = getattr(mod, name) assert obj is not None + + +def test_import_all_using_dir() -> None: + """Generate the public API for this package.""" + library_code = PKG_ROOT / "langchain" + for path in library_code.rglob("*.py"): + # Calculate the relative path to the module + module_name = ( + path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".") + ) + if module_name.endswith("__init__"): + # Without init + module_name = module_name.rsplit(".", 1)[0] + + try: + mod = importlib.import_module(module_name) + except ModuleNotFoundError as e: + raise ModuleNotFoundError(f"Could not import {module_name}") from e + all = dir(mod) + + for name in all: + if name.strip().startswith("_"): + continue + # Attempt to import the name from the module + getattr(mod, name) diff --git a/libs/langchain/tests/unit_tests/tools/test_imports.py b/libs/langchain/tests/unit_tests/tools/test_imports.py index 79eea82a5e..708241806e 100644 --- a/libs/langchain/tests/unit_tests/tools/test_imports.py +++ b/libs/langchain/tests/unit_tests/tools/test_imports.py @@ -1,5 +1,4 @@ from langchain import tools -from tests.unit_tests import assert_all_importable EXPECTED_ALL = [ "AINAppOps", @@ -126,4 +125,3 @@ EXPECTED_ALL = [ def test_all_imports() -> None: assert set(tools.__all__) == set(EXPECTED_ALL) - assert_all_importable(tools) diff --git a/libs/langchain/tests/unit_tests/utilities/test_imports.py b/libs/langchain/tests/unit_tests/utilities/test_imports.py index dcfb2be317..895988ed09 100644 --- a/libs/langchain/tests/unit_tests/utilities/test_imports.py +++ b/libs/langchain/tests/unit_tests/utilities/test_imports.py @@ -1,5 +1,4 @@ from langchain import utilities -from tests.unit_tests import assert_all_importable EXPECTED_ALL = [ "AlphaVantageAPIWrapper", @@ -53,4 +52,3 @@ EXPECTED_ALL = [ def test_all_imports() -> None: assert set(utilities.__all__) == set(EXPECTED_ALL) - assert_all_importable(utilities) diff --git a/libs/langchain/tests/unit_tests/utils/test_imports.py b/libs/langchain/tests/unit_tests/utils/test_imports.py index fdf4a95708..c9d3453bea 100644 --- a/libs/langchain/tests/unit_tests/utils/test_imports.py +++ b/libs/langchain/tests/unit_tests/utils/test_imports.py @@ -1,5 +1,4 @@ from langchain import utils -from tests.unit_tests import assert_all_importable EXPECTED_ALL = [ "StrictFormatter", @@ -27,4 +26,3 @@ EXPECTED_ALL = [ def test_all_imports() -> None: assert set(utils.__all__) == set(EXPECTED_ALL) - assert_all_importable(utils) diff --git a/libs/langchain/tests/unit_tests/utils/test_openai_functions.py b/libs/langchain/tests/unit_tests/utils/test_openai_functions.py index 097a41deb0..34a0b8126f 100644 --- a/libs/langchain/tests/unit_tests/utils/test_openai_functions.py +++ b/libs/langchain/tests/unit_tests/utils/test_openai_functions.py @@ -1,6 +1,5 @@ from langchain_core.pydantic_v1 import BaseModel, Field - -from langchain.utils.openai_functions import convert_pydantic_to_openai_function +from langchain_core.utils.function_calling import convert_pydantic_to_openai_function def test_convert_pydantic_to_openai_function() -> None: diff --git a/libs/langchain/tests/unit_tests/vectorstores/test_imports.py b/libs/langchain/tests/unit_tests/vectorstores/test_imports.py index f8dd5dc977..e911173411 100644 --- a/libs/langchain/tests/unit_tests/vectorstores/test_imports.py +++ b/libs/langchain/tests/unit_tests/vectorstores/test_imports.py @@ -1,8 +1,10 @@ +import pytest from langchain_core.vectorstores import VectorStore from langchain import vectorstores +@pytest.mark.community def test_all_imports() -> None: """Simple test to make sure all things can be imported.""" for cls in vectorstores.__all__: