mirror of https://github.com/hwchase17/langchain
Merge 4a7d077566
into 242eeb537f
commit
5642ad60db
@ -0,0 +1,141 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Persistence with DynamoDB\n",
|
||||
"\n",
|
||||
"You have your langraph persistence setup as described in the [notebook](./persistence.ipynb). Now you want to use DynamoDB as your persistence layer. This notebook will guide you through the process of setting up DynamoDB as your persistence layer."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Setup\n",
|
||||
"\n",
|
||||
"First we need to install the packages required"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!pip install --quiet -U langchain boto3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Persistence\n",
|
||||
"\n",
|
||||
"To add in persistence with DynamoDB, we pass in a checkpoint when compiling the graph"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import boto3\n",
|
||||
"from langchain_community.checkpoints.dynamodb import DynamoDbSaver\n",
|
||||
"\n",
|
||||
"boto3_session = boto3.Session(region_name=\"us-west-2\")\n",
|
||||
"\n",
|
||||
"checkpointDynamoDbParams = {\n",
|
||||
" \"table_name\": \"langgraph-checkpoints\",\n",
|
||||
" \"boto3_session\": boto3_session,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"memory = DynamoDbSaver.from_params(**checkpointDynamoDbParams)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langgraph.graph import END, MessageGraph\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Define the function that calls the model\n",
|
||||
"def call_model(messages):\n",
|
||||
" return \"model output\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Define the function to execute tools\n",
|
||||
"def call_tool(messages):\n",
|
||||
" return \"tool output\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Define a new graph\n",
|
||||
"workflow = MessageGraph()\n",
|
||||
"\n",
|
||||
"# add FAKE nodes and edges to the graph...\n",
|
||||
"workflow.add_node(\"agent\", call_model)\n",
|
||||
"workflow.add_node(\"action\", call_tool)\n",
|
||||
"workflow.set_entry_point(\"agent\")\n",
|
||||
"workflow.add_conditional_edges(\n",
|
||||
" \"agent\",\n",
|
||||
" False,\n",
|
||||
" {\n",
|
||||
" \"continue\": \"action\",\n",
|
||||
" \"end\": END,\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"workflow.add_edge(\"action\", \"agent\")\n",
|
||||
"\n",
|
||||
"# Finally, we compile it!\n",
|
||||
"# This compiles it into a LangChain Runnable,\n",
|
||||
"# meaning you can use it as you would any other runnable\n",
|
||||
"app = workflow.compile(checkpointer=memory)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Interacting with the agent\n",
|
||||
"\n",
|
||||
"Ok, yes, this graph is not going to work, but this is just to show you how to setup the DynamoDB checkpoint!"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,3 @@
|
||||
from langchain_community.checkpoints.dynamodb import DynamoDbSaver
|
||||
|
||||
__all__ = ["DynamoDbSaver"]
|
@ -0,0 +1,179 @@
|
||||
import logging
|
||||
import pickle
|
||||
from typing import Any, Optional
|
||||
|
||||
from boto3.session import Session
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.runnables.utils import ConfigurableFieldSpec
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver, Checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DynamoDbSaver(BaseCheckpointSaver):
|
||||
client: Any
|
||||
table: Any
|
||||
table_name: str
|
||||
primary_key_name: str
|
||||
checkpoint_field_name: str
|
||||
|
||||
is_setup: bool = Field(False, init=False, repr=False)
|
||||
|
||||
"""Checkpoint saver that stores history in AWS DynamoDB.
|
||||
|
||||
This class creates a DynamoDB table with name `table_name` if it does not exist
|
||||
|
||||
Args:
|
||||
table_name: name of the DynamoDB table
|
||||
endpoint_url: URL of the AWS endpoint to connect to. This argument
|
||||
is optional and useful for test purposes, like using Localstack.
|
||||
If you plan to use AWS cloud service, you normally don't have to
|
||||
worry about setting the endpoint_url.
|
||||
primary_key_name: name of the primary key of the DynamoDB table. This argument
|
||||
is optional, defaulting to "SessionId".
|
||||
checkpoint_field_name: name of the field in the DynamoDB table where to store
|
||||
the checkpoint data. This argument is optional, defaulting to "Checkpoint".
|
||||
boto3_session: an optional boto3 session to use for the DynamoDB client
|
||||
kms_key_id: an optional AWS KMS Key ID, AWS KMS Key ARN, or AWS KMS Alias for
|
||||
client-side encryption
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_params(
|
||||
cls,
|
||||
table_name: str,
|
||||
endpoint_url: Optional[str] = None,
|
||||
primary_key_name: str = "ThreadId",
|
||||
checkpoint_field_name: str = "Checkpoint",
|
||||
boto3_session: Optional[Session] = None,
|
||||
kms_key_id: Optional[str] = None,
|
||||
):
|
||||
if boto3_session:
|
||||
client = boto3_session.resource("dynamodb", endpoint_url=endpoint_url)
|
||||
else:
|
||||
try:
|
||||
import boto3
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import boto3, please install with `pip install boto3`."
|
||||
) from e
|
||||
if endpoint_url:
|
||||
client = boto3.resource("dynamodb", endpoint_url=endpoint_url)
|
||||
else:
|
||||
client = boto3.resource("dynamodb")
|
||||
table = client.Table(table_name)
|
||||
|
||||
if kms_key_id:
|
||||
try:
|
||||
from dynamodb_encryption_sdk.encrypted.table import EncryptedTable
|
||||
from dynamodb_encryption_sdk.identifiers import CryptoAction
|
||||
from dynamodb_encryption_sdk.material_providers.aws_kms import (
|
||||
AwsKmsCryptographicMaterialsProvider,
|
||||
)
|
||||
from dynamodb_encryption_sdk.structures import AttributeActions
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import dynamodb_encryption_sdk, please install with "
|
||||
"`pip install dynamodb-encryption-sdk`."
|
||||
) from e
|
||||
|
||||
actions = AttributeActions(
|
||||
default_action=CryptoAction.DO_NOTHING,
|
||||
attribute_actions={"History": CryptoAction.ENCRYPT_AND_SIGN},
|
||||
)
|
||||
aws_kms_cmp = AwsKmsCryptographicMaterialsProvider(key_id=kms_key_id)
|
||||
table = EncryptedTable(
|
||||
table=table,
|
||||
materials_provider=aws_kms_cmp,
|
||||
attribute_actions=actions,
|
||||
auto_refresh_table_indexes=False,
|
||||
)
|
||||
|
||||
return DynamoDbSaver(
|
||||
client=client,
|
||||
table=table,
|
||||
table_name=table_name,
|
||||
primary_key_name=primary_key_name,
|
||||
checkpoint_field_name=checkpoint_field_name,
|
||||
)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
return [
|
||||
ConfigurableFieldSpec(
|
||||
id="thread_id",
|
||||
annotation=str,
|
||||
name="Thread ID",
|
||||
description=None,
|
||||
default="",
|
||||
is_shared=True,
|
||||
),
|
||||
]
|
||||
|
||||
def setup(self) -> None:
|
||||
if self.is_setup:
|
||||
return
|
||||
|
||||
# Define table schema
|
||||
table_name = self.table_name
|
||||
|
||||
# Check if the table already exists
|
||||
existing_tables = self.client.meta.client.list_tables()["TableNames"]
|
||||
if table_name in existing_tables:
|
||||
self.is_setup = True
|
||||
return
|
||||
|
||||
# Create the DynamoDB table with a Global Secondary Index
|
||||
try:
|
||||
table = self.client.create_table(
|
||||
TableName=table_name,
|
||||
KeySchema=[
|
||||
{
|
||||
"AttributeName": self.primary_key_name,
|
||||
"KeyType": "HASH", # Partition key
|
||||
}
|
||||
],
|
||||
AttributeDefinitions=[
|
||||
{
|
||||
"AttributeName": self.primary_key_name,
|
||||
"AttributeType": "S", # 'S' for string type
|
||||
}
|
||||
],
|
||||
ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5},
|
||||
)
|
||||
|
||||
# Wait for the table to be created
|
||||
table.meta.client.get_waiter("table_exists").wait(TableName=table_name)
|
||||
logger.info(f"Table {table_name} created successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating table: {e}")
|
||||
|
||||
self.is_setup = True
|
||||
|
||||
def get(self, config: RunnableConfig) -> Optional[Checkpoint]:
|
||||
self.setup()
|
||||
response = self.table.get_item(
|
||||
Key={self.primary_key_name: config["configurable"]["thread_id"]}
|
||||
)
|
||||
item = response.get("Item", None)
|
||||
if item and self.checkpoint_field_name in item:
|
||||
checkpoint = pickle.loads(item[self.checkpoint_field_name].value)
|
||||
return checkpoint
|
||||
else:
|
||||
return None
|
||||
|
||||
def put(self, config: RunnableConfig, checkpoint: Checkpoint) -> None:
|
||||
self.setup()
|
||||
self.table.put_item(
|
||||
Item={
|
||||
self.primary_key_name: config["configurable"]["thread_id"],
|
||||
self.checkpoint_field_name: pickle.dumps(checkpoint),
|
||||
}
|
||||
)
|
||||
|
||||
async def aget(self, config: RunnableConfig) -> Optional[Checkpoint]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def aput(self, config: RunnableConfig, checkpoint: Checkpoint) -> None:
|
||||
raise NotImplementedError
|
@ -0,0 +1 @@
|
||||
"""Test functionality related to the checkpoints objects."""
|
@ -0,0 +1,129 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from moto import mock_aws
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def aws_credentials():
|
||||
"""Mocked AWS Credentials for moto."""
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = "testing"
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
|
||||
os.environ["AWS_SECURITY_TOKEN"] = "testing"
|
||||
os.environ["AWS_SESSION_TOKEN"] = "testing"
|
||||
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"
|
||||
|
||||
|
||||
@mock_aws
|
||||
def test_dynamodb_init(aws_credentials):
|
||||
import boto3
|
||||
|
||||
from langchain_community.checkpoints.dynamodb import DynamoDbSaver
|
||||
|
||||
table_name = "test_table"
|
||||
primary_key_name = "test_primary_key"
|
||||
checkpoint_field_name = "test_checkpoint_field"
|
||||
endpoint_url = "https://test.endpoint.url"
|
||||
boto3_session = boto3.Session()
|
||||
|
||||
dynamodb = DynamoDbSaver.from_params(
|
||||
table_name=table_name,
|
||||
primary_key_name=primary_key_name,
|
||||
checkpoint_field_name=checkpoint_field_name,
|
||||
endpoint_url=endpoint_url,
|
||||
boto3_session=boto3_session,
|
||||
)
|
||||
|
||||
assert dynamodb.table_name == table_name
|
||||
assert dynamodb.primary_key_name == primary_key_name
|
||||
assert dynamodb.checkpoint_field_name == checkpoint_field_name
|
||||
assert str(dynamodb.client.__class__) == str(
|
||||
boto3_session.resource("dynamodb", endpoint_url=endpoint_url).__class__
|
||||
)
|
||||
assert str(dynamodb.table.__class__) == str(
|
||||
dynamodb.client.Table(table_name).__class__
|
||||
)
|
||||
|
||||
dynamodb = DynamoDbSaver.from_params(
|
||||
table_name=table_name,
|
||||
primary_key_name=primary_key_name,
|
||||
checkpoint_field_name=checkpoint_field_name,
|
||||
endpoint_url=endpoint_url,
|
||||
)
|
||||
|
||||
assert dynamodb.table_name == table_name
|
||||
assert dynamodb.primary_key_name == primary_key_name
|
||||
assert dynamodb.checkpoint_field_name == checkpoint_field_name
|
||||
assert str(dynamodb.client.__class__) == str(
|
||||
boto3.resource("dynamodb", endpoint_url=endpoint_url).__class__
|
||||
)
|
||||
assert str(dynamodb.table.__class__) == str(
|
||||
dynamodb.client.Table(table_name).__class__
|
||||
)
|
||||
|
||||
dynamodb = DynamoDbSaver.from_params(
|
||||
table_name=table_name,
|
||||
primary_key_name=primary_key_name,
|
||||
checkpoint_field_name=checkpoint_field_name,
|
||||
)
|
||||
|
||||
assert dynamodb.table_name == table_name
|
||||
assert dynamodb.primary_key_name == primary_key_name
|
||||
assert dynamodb.checkpoint_field_name == checkpoint_field_name
|
||||
assert str(dynamodb.client.__class__) == str(
|
||||
boto3.resource("dynamodb", endpoint_url=endpoint_url).__class__
|
||||
)
|
||||
assert str(dynamodb.table.__class__) == str(
|
||||
dynamodb.client.Table(table_name).__class__
|
||||
)
|
||||
|
||||
dynamodb = DynamoDbSaver.from_params(
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
assert dynamodb.table_name == table_name
|
||||
assert dynamodb.primary_key_name == "ThreadId"
|
||||
assert dynamodb.checkpoint_field_name == "Checkpoint"
|
||||
assert str(dynamodb.client.__class__) == str(
|
||||
boto3.resource("dynamodb", endpoint_url=endpoint_url).__class__
|
||||
)
|
||||
assert str(dynamodb.table.__class__) == str(
|
||||
dynamodb.client.Table(table_name).__class__
|
||||
)
|
||||
|
||||
|
||||
@mock_aws
|
||||
def test_dynamodb_put_get(aws_credentials):
|
||||
import boto3
|
||||
|
||||
from langchain_community.checkpoints.dynamodb import DynamoDbSaver
|
||||
|
||||
table_name = "test_table"
|
||||
primary_key_name = "test_primary_key"
|
||||
checkpoint_field_name = "test_checkpoint_field"
|
||||
endpoint_url = "https://test.endpoint.url"
|
||||
boto3_session = boto3.Session()
|
||||
|
||||
dynamodb = DynamoDbSaver.from_params(
|
||||
table_name=table_name,
|
||||
primary_key_name=primary_key_name,
|
||||
checkpoint_field_name=checkpoint_field_name,
|
||||
endpoint_url=endpoint_url,
|
||||
boto3_session=boto3_session,
|
||||
)
|
||||
|
||||
patched_client = boto3.resource("dynamodb")
|
||||
dynamodb.client = patched_client
|
||||
dynamodb.table = patched_client.Table(table_name)
|
||||
|
||||
config = {"configurable": {"thread_id": "test_thread_id"}}
|
||||
checkpoint = {
|
||||
"v": 1,
|
||||
"ts": "2021-01-01T00:00:00+00:00",
|
||||
"channel_values": {},
|
||||
"channel_versions": {},
|
||||
"versions_seen": {},
|
||||
}
|
||||
|
||||
dynamodb.put(config, checkpoint)
|
||||
assert dynamodb.get(config) == checkpoint
|
Loading…
Reference in New Issue