mirror of https://github.com/hwchase17/langchain
community: Add document manager and mongo document manager (#17320)
- **Description:** - Add DocumentManager class, which is a nosql record manager. - In order to use index and aindex in libs/langchain/langchain/indexes/_api.py, DocumentManager inherits RecordManager. - Also I added the MongoDB implementation of Document Manager too. - **Dependencies:** pymongo, motor <!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** Add DocumentManager class, which is a no sql record manager. To use index method and aindex method in indexes._api.py, Document Manager inherits RecordManager.Add the MongoDB implementation of Document Manager. - **Dependencies:** pymongo, motor Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. --> --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>pull/18016/head^2
parent
3f6bf852ea
commit
7fc903464a
@ -0,0 +1,231 @@
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from langchain_community.indexes.base import RecordManager
|
||||
|
||||
IMPORT_PYMONGO_ERROR = (
|
||||
"Could not import MongoClient. Please install it with `pip install pymongo`."
|
||||
)
|
||||
IMPORT_MOTOR_ASYNCIO_ERROR = (
|
||||
"Could not import AsyncIOMotorClient. Please install it with `pip install motor`."
|
||||
)
|
||||
|
||||
|
||||
def _import_pymongo() -> Any:
|
||||
"""Import PyMongo if available, otherwise raise error."""
|
||||
try:
|
||||
from pymongo import MongoClient
|
||||
except ImportError:
|
||||
raise ImportError(IMPORT_PYMONGO_ERROR)
|
||||
return MongoClient
|
||||
|
||||
|
||||
def _get_pymongo_client(mongodb_url: str, **kwargs: Any) -> Any:
|
||||
"""Get MongoClient for sync operations from the mongodb_url,
|
||||
otherwise raise error."""
|
||||
try:
|
||||
pymongo = _import_pymongo()
|
||||
client = pymongo(mongodb_url, **kwargs)
|
||||
except ValueError as e:
|
||||
raise ImportError(
|
||||
f"MongoClient string provided is not in proper format. " f"Got error: {e} "
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def _import_motor_asyncio() -> Any:
|
||||
"""Import Motor if available, otherwise raise error."""
|
||||
try:
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
except ImportError:
|
||||
raise ImportError(IMPORT_MOTOR_ASYNCIO_ERROR)
|
||||
return AsyncIOMotorClient
|
||||
|
||||
|
||||
def _get_motor_client(mongodb_url: str, **kwargs: Any) -> Any:
|
||||
"""Get AsyncIOMotorClient for async operations from the mongodb_url,
|
||||
otherwise raise error."""
|
||||
try:
|
||||
motor = _import_motor_asyncio()
|
||||
client = motor(mongodb_url, **kwargs)
|
||||
except ValueError as e:
|
||||
raise ImportError(
|
||||
f"AsyncIOMotorClient string provided is not in proper format. "
|
||||
f"Got error: {e} "
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
class MongoDocumentManager(RecordManager):
|
||||
"""A MongoDB based implementation of the document manager."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
namespace: str,
|
||||
*,
|
||||
mongodb_url: str,
|
||||
db_name: str,
|
||||
collection_name: str = "documentMetadata",
|
||||
) -> None:
|
||||
"""Initialize the MongoDocumentManager.
|
||||
|
||||
Args:
|
||||
namespace: The namespace associated with this document manager.
|
||||
db_name: The name of the database to use.
|
||||
collection_name: The name of the collection to use.
|
||||
Default is 'documentMetadata'.
|
||||
"""
|
||||
super().__init__(namespace=namespace)
|
||||
self.sync_client = _get_pymongo_client(mongodb_url)
|
||||
self.sync_db = self.sync_client[db_name]
|
||||
self.sync_collection = self.sync_db[collection_name]
|
||||
self.async_client = _get_motor_client(mongodb_url)
|
||||
self.async_db = self.async_client[db_name]
|
||||
self.async_collection = self.async_db[collection_name]
|
||||
|
||||
def create_schema(self) -> None:
|
||||
"""Create the database schema for the document manager."""
|
||||
pass
|
||||
|
||||
async def acreate_schema(self) -> None:
|
||||
"""Create the database schema for the document manager."""
|
||||
pass
|
||||
|
||||
def update(
|
||||
self,
|
||||
keys: Sequence[str],
|
||||
*,
|
||||
group_ids: Optional[Sequence[Optional[str]]] = None,
|
||||
time_at_least: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Upsert documents into the MongoDB collection."""
|
||||
if group_ids is None:
|
||||
group_ids = [None] * len(keys)
|
||||
|
||||
if len(keys) != len(group_ids):
|
||||
raise ValueError("Number of keys does not match number of group_ids")
|
||||
|
||||
for key, group_id in zip(keys, group_ids):
|
||||
self.sync_collection.find_one_and_update(
|
||||
{"namespace": self.namespace, "key": key},
|
||||
{"$set": {"group_id": group_id, "updated_at": self.get_time()}},
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
async def aupdate(
|
||||
self,
|
||||
keys: Sequence[str],
|
||||
*,
|
||||
group_ids: Optional[Sequence[Optional[str]]] = None,
|
||||
time_at_least: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Asynchronously upsert documents into the MongoDB collection."""
|
||||
if group_ids is None:
|
||||
group_ids = [None] * len(keys)
|
||||
|
||||
if len(keys) != len(group_ids):
|
||||
raise ValueError("Number of keys does not match number of group_ids")
|
||||
|
||||
update_time = await self.aget_time()
|
||||
if time_at_least and update_time < time_at_least:
|
||||
raise ValueError("Server time is behind the expected time_at_least")
|
||||
|
||||
for key, group_id in zip(keys, group_ids):
|
||||
await self.async_collection.find_one_and_update(
|
||||
{"namespace": self.namespace, "key": key},
|
||||
{"$set": {"group_id": group_id, "updated_at": update_time}},
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
def get_time(self) -> float:
|
||||
"""Get the current server time as a timestamp."""
|
||||
server_info = self.sync_db.command("hostInfo")
|
||||
local_time = server_info["system"]["currentTime"]
|
||||
timestamp = local_time.timestamp()
|
||||
return timestamp
|
||||
|
||||
async def aget_time(self) -> float:
|
||||
"""Asynchronously get the current server time as a timestamp."""
|
||||
host_info = await self.async_collection.database.command("hostInfo")
|
||||
local_time = host_info["system"]["currentTime"]
|
||||
return local_time.timestamp()
|
||||
|
||||
def exists(self, keys: Sequence[str]) -> List[bool]:
|
||||
"""Check if the given keys exist in the MongoDB collection."""
|
||||
existing_keys = {
|
||||
doc["key"]
|
||||
for doc in self.sync_collection.find(
|
||||
{"namespace": self.namespace, "key": {"$in": keys}}, {"key": 1}
|
||||
)
|
||||
}
|
||||
return [key in existing_keys for key in keys]
|
||||
|
||||
async def aexists(self, keys: Sequence[str]) -> List[bool]:
|
||||
"""Asynchronously check if the given keys exist in the MongoDB collection."""
|
||||
cursor = self.async_collection.find(
|
||||
{"namespace": self.namespace, "key": {"$in": keys}}, {"key": 1}
|
||||
)
|
||||
existing_keys = {doc["key"] async for doc in cursor}
|
||||
return [key in existing_keys for key in keys]
|
||||
|
||||
def list_keys(
|
||||
self,
|
||||
*,
|
||||
before: Optional[float] = None,
|
||||
after: Optional[float] = None,
|
||||
group_ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""List documents in the MongoDB collection based on the provided date range."""
|
||||
query: Dict[str, Any] = {"namespace": self.namespace}
|
||||
if before:
|
||||
query["updated_at"] = {"$lt": before}
|
||||
if after:
|
||||
query["updated_at"] = {"$gt": after}
|
||||
if group_ids:
|
||||
query["group_id"] = {"$in": group_ids}
|
||||
|
||||
cursor = (
|
||||
self.sync_collection.find(query, {"key": 1}).limit(limit)
|
||||
if limit
|
||||
else self.sync_collection.find(query, {"key": 1})
|
||||
)
|
||||
return [doc["key"] for doc in cursor]
|
||||
|
||||
async def alist_keys(
|
||||
self,
|
||||
*,
|
||||
before: Optional[float] = None,
|
||||
after: Optional[float] = None,
|
||||
group_ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Asynchronously list documents in the MongoDB collection
|
||||
based on the provided date range.
|
||||
"""
|
||||
query: Dict[str, Any] = {"namespace": self.namespace}
|
||||
if before:
|
||||
query["updated_at"] = {"$lt": before}
|
||||
if after:
|
||||
query["updated_at"] = {"$gt": after}
|
||||
if group_ids:
|
||||
query["group_id"] = {"$in": group_ids}
|
||||
|
||||
cursor = (
|
||||
self.async_collection.find(query, {"key": 1}).limit(limit)
|
||||
if limit
|
||||
else self.async_collection.find(query, {"key": 1})
|
||||
)
|
||||
return [doc["key"] async for doc in cursor]
|
||||
|
||||
def delete_keys(self, keys: Sequence[str]) -> None:
|
||||
"""Delete documents from the MongoDB collection."""
|
||||
self.sync_collection.delete_many(
|
||||
{"namespace": self.namespace, "key": {"$in": keys}}
|
||||
)
|
||||
|
||||
async def adelete_keys(self, keys: Sequence[str]) -> None:
|
||||
"""Asynchronously delete documents from the MongoDB collection."""
|
||||
await self.async_collection.delete_many(
|
||||
{"namespace": self.namespace, "key": {"$in": keys}}
|
||||
)
|
@ -0,0 +1,268 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from langchain_community.indexes._document_manager import MongoDocumentManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.mark.requires("pymongo")
|
||||
def manager() -> MongoDocumentManager:
|
||||
"""Initialize the test MongoDB and yield the DocumentManager instance."""
|
||||
document_manager = MongoDocumentManager(
|
||||
namespace="kittens",
|
||||
mongodb_url="mongodb://langchain:langchain@localhost:6022/",
|
||||
db_name="test_db",
|
||||
collection_name="test_collection",
|
||||
)
|
||||
return document_manager
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@pytest.mark.requires("motor")
|
||||
async def amanager() -> MongoDocumentManager:
|
||||
"""Initialize the test MongoDB and yield the DocumentManager instance."""
|
||||
document_manager = MongoDocumentManager(
|
||||
namespace="kittens",
|
||||
mongodb_url="mongodb://langchain:langchain@localhost:6022/",
|
||||
db_name="test_db",
|
||||
collection_name="test_collection",
|
||||
)
|
||||
return document_manager
|
||||
|
||||
|
||||
@pytest.mark.requires("pymongo")
|
||||
def test_update(manager: MongoDocumentManager) -> None:
|
||||
"""Test updating records in the MongoDB."""
|
||||
read_keys = manager.list_keys()
|
||||
updated_keys = ["update_key1", "update_key2", "update_key3"]
|
||||
manager.update(updated_keys)
|
||||
all_keys = manager.list_keys()
|
||||
assert sorted(all_keys) == sorted(read_keys + updated_keys)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.requires("motor")
|
||||
async def test_aupdate(amanager: MongoDocumentManager) -> None:
|
||||
"""Test updating records in the MongoDB."""
|
||||
read_keys = await amanager.alist_keys()
|
||||
aupdated_keys = ["aupdate_key1", "aupdate_key2", "aupdate_key3"]
|
||||
await amanager.aupdate(aupdated_keys)
|
||||
all_keys = await amanager.alist_keys()
|
||||
assert sorted(all_keys) == sorted(read_keys + aupdated_keys)
|
||||
|
||||
|
||||
@pytest.mark.requires("pymongo")
|
||||
def test_update_timestamp(manager: MongoDocumentManager) -> None:
|
||||
"""Test updating records with timestamps in MongoDB."""
|
||||
with patch.object(
|
||||
manager, "get_time", return_value=datetime(2024, 2, 23).timestamp()
|
||||
):
|
||||
manager.update(["key1"])
|
||||
records = list(
|
||||
manager.sync_collection.find({"namespace": manager.namespace, "key": "key1"})
|
||||
)
|
||||
|
||||
assert [
|
||||
{
|
||||
"key": record["key"],
|
||||
"namespace": record["namespace"],
|
||||
"updated_at": record["updated_at"],
|
||||
"group_id": record.get("group_id"),
|
||||
}
|
||||
for record in records
|
||||
] == [
|
||||
{
|
||||
"group_id": None,
|
||||
"key": "key1",
|
||||
"namespace": "kittens",
|
||||
"updated_at": datetime(2024, 2, 23).timestamp(),
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.requires("motor")
|
||||
async def test_aupdate_timestamp(amanager: MongoDocumentManager) -> None:
|
||||
"""Test asynchronously updating records with timestamps in MongoDB."""
|
||||
with patch.object(
|
||||
amanager, "aget_time", return_value=datetime(2024, 2, 23).timestamp()
|
||||
):
|
||||
await amanager.aupdate(["key1"])
|
||||
|
||||
records = [
|
||||
doc
|
||||
async for doc in amanager.async_collection.find(
|
||||
{"namespace": amanager.namespace, "key": "key1"}
|
||||
)
|
||||
]
|
||||
|
||||
assert [
|
||||
{
|
||||
"key": record["key"],
|
||||
"namespace": record["namespace"],
|
||||
"updated_at": record["updated_at"],
|
||||
"group_id": record.get("group_id"),
|
||||
}
|
||||
for record in records
|
||||
] == [
|
||||
{
|
||||
"group_id": None,
|
||||
"key": "key1",
|
||||
"namespace": "kittens",
|
||||
"updated_at": datetime(2024, 2, 23).timestamp(),
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.requires("pymongo")
|
||||
def test_exists(manager: MongoDocumentManager) -> None:
|
||||
"""Test checking if keys exist in MongoDB."""
|
||||
keys = ["key1", "key2", "key3"]
|
||||
manager.update(keys)
|
||||
exists = manager.exists(keys)
|
||||
assert len(exists) == len(keys)
|
||||
assert all(exists)
|
||||
|
||||
exists = manager.exists(["key1", "key4"])
|
||||
assert len(exists) == 2
|
||||
assert exists == [True, False]
|
||||
|
||||
|
||||
@pytest.mark.requires("motor")
|
||||
async def test_aexists(amanager: MongoDocumentManager) -> None:
|
||||
"""Test asynchronously checking if keys exist in MongoDB."""
|
||||
keys = ["key1", "key2", "key3"]
|
||||
await amanager.aupdate(keys)
|
||||
exists = await amanager.aexists(keys)
|
||||
assert len(exists) == len(keys)
|
||||
assert all(exists)
|
||||
|
||||
exists = await amanager.aexists(["key1", "key4"])
|
||||
assert len(exists) == 2
|
||||
assert exists == [True, False]
|
||||
|
||||
|
||||
@pytest.mark.requires("pymongo")
|
||||
def test_list_keys(manager: MongoDocumentManager) -> None:
|
||||
"""Test listing keys in MongoDB."""
|
||||
manager.delete_keys(manager.list_keys())
|
||||
with patch.object(
|
||||
manager, "get_time", return_value=datetime(2021, 1, 1).timestamp()
|
||||
):
|
||||
manager.update(["key1"])
|
||||
with patch.object(
|
||||
manager, "get_time", return_value=datetime(2022, 1, 1).timestamp()
|
||||
):
|
||||
manager.update(["key2"])
|
||||
with patch.object(
|
||||
manager, "get_time", return_value=datetime(2023, 1, 1).timestamp()
|
||||
):
|
||||
manager.update(["key3"])
|
||||
with patch.object(
|
||||
manager, "get_time", return_value=datetime(2024, 1, 1).timestamp()
|
||||
):
|
||||
manager.update(["key4"], group_ids=["group1"])
|
||||
assert sorted(manager.list_keys()) == sorted(["key1", "key2", "key3", "key4"])
|
||||
assert sorted(manager.list_keys(after=datetime(2022, 2, 1).timestamp())) == sorted(
|
||||
["key3", "key4"]
|
||||
)
|
||||
assert sorted(manager.list_keys(group_ids=["group1", "group2"])) == sorted(["key4"])
|
||||
|
||||
|
||||
@pytest.mark.requires("motor")
|
||||
async def test_alist_keys(amanager: MongoDocumentManager) -> None:
|
||||
"""Test asynchronously listing keys in MongoDB."""
|
||||
await amanager.adelete_keys(await amanager.alist_keys())
|
||||
with patch.object(
|
||||
amanager, "aget_time", return_value=datetime(2021, 1, 1).timestamp()
|
||||
):
|
||||
await amanager.aupdate(["key1"])
|
||||
with patch.object(
|
||||
amanager, "aget_time", return_value=datetime(2022, 1, 1).timestamp()
|
||||
):
|
||||
await amanager.aupdate(["key2"])
|
||||
with patch.object(
|
||||
amanager, "aget_time", return_value=datetime(2023, 1, 1).timestamp()
|
||||
):
|
||||
await amanager.aupdate(["key3"])
|
||||
with patch.object(
|
||||
amanager, "aget_time", return_value=datetime(2024, 1, 1).timestamp()
|
||||
):
|
||||
await amanager.aupdate(["key4"], group_ids=["group1"])
|
||||
assert sorted(await amanager.alist_keys()) == sorted(
|
||||
["key1", "key2", "key3", "key4"]
|
||||
)
|
||||
assert sorted(
|
||||
await amanager.alist_keys(after=datetime(2022, 2, 1).timestamp())
|
||||
) == sorted(["key3", "key4"])
|
||||
assert sorted(await amanager.alist_keys(group_ids=["group1", "group2"])) == sorted(
|
||||
["key4"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("pymongo")
|
||||
def test_namespace_is_used(manager: MongoDocumentManager) -> None:
|
||||
"""Verify that namespace is taken into account for all operations in MongoDB."""
|
||||
manager.delete_keys(manager.list_keys())
|
||||
manager.update(["key1", "key2"], group_ids=["group1", "group2"])
|
||||
manager.sync_collection.insert_many(
|
||||
[
|
||||
{"key": "key1", "namespace": "puppies", "group_id": None},
|
||||
{"key": "key3", "namespace": "puppies", "group_id": None},
|
||||
]
|
||||
)
|
||||
assert sorted(manager.list_keys()) == sorted(["key1", "key2"])
|
||||
manager.delete_keys(["key1"])
|
||||
assert sorted(manager.list_keys()) == sorted(["key2"])
|
||||
manager.update(["key3"], group_ids=["group3"])
|
||||
assert (
|
||||
manager.sync_collection.find_one({"key": "key3", "namespace": "kittens"})[
|
||||
"group_id"
|
||||
]
|
||||
== "group3"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("motor")
|
||||
async def test_anamespace_is_used(amanager: MongoDocumentManager) -> None:
|
||||
"""
|
||||
Verify that namespace is taken into account for all operations
|
||||
in MongoDB asynchronously.
|
||||
"""
|
||||
await amanager.adelete_keys(await amanager.alist_keys())
|
||||
await amanager.aupdate(["key1", "key2"], group_ids=["group1", "group2"])
|
||||
await amanager.async_collection.insert_many(
|
||||
[
|
||||
{"key": "key1", "namespace": "puppies", "group_id": None},
|
||||
{"key": "key3", "namespace": "puppies", "group_id": None},
|
||||
]
|
||||
)
|
||||
assert sorted(await amanager.alist_keys()) == sorted(["key1", "key2"])
|
||||
await amanager.adelete_keys(["key1"])
|
||||
assert sorted(await amanager.alist_keys()) == sorted(["key2"])
|
||||
await amanager.aupdate(["key3"], group_ids=["group3"])
|
||||
assert (
|
||||
await amanager.async_collection.find_one(
|
||||
{"key": "key3", "namespace": "kittens"}
|
||||
)
|
||||
)["group_id"] == "group3"
|
||||
|
||||
|
||||
@pytest.mark.requires("pymongo")
|
||||
def test_delete_keys(manager: MongoDocumentManager) -> None:
|
||||
"""Test deleting keys from MongoDB."""
|
||||
manager.update(["key1", "key2", "key3"])
|
||||
manager.delete_keys(["key1", "key2"])
|
||||
remaining_keys = manager.list_keys()
|
||||
assert sorted(remaining_keys) == sorted(["key3"])
|
||||
|
||||
|
||||
@pytest.mark.requires("motor")
|
||||
async def test_adelete_keys(amanager: MongoDocumentManager) -> None:
|
||||
"""Test asynchronously deleting keys from MongoDB."""
|
||||
await amanager.aupdate(["key1", "key2", "key3"])
|
||||
await amanager.adelete_keys(["key1", "key2"])
|
||||
remaining_keys = await amanager.alist_keys()
|
||||
assert sorted(remaining_keys) == sorted(["key3"])
|
Loading…
Reference in New Issue