Agustí Dosaiguas 2 weeks ago committed by GitHub
commit 5642ad60db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,141 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Persistence with DynamoDB\n",
"\n",
"You have your langraph persistence setup as described in the [notebook](./persistence.ipynb). Now you want to use DynamoDB as your persistence layer. This notebook will guide you through the process of setting up DynamoDB as your persistence layer."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"First we need to install the packages required"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
]
}
],
"source": [
"!pip install --quiet -U langchain boto3"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Persistence\n",
"\n",
"To add in persistence with DynamoDB, we pass in a checkpoint when compiling the graph"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import boto3\n",
"from langchain_community.checkpoints.dynamodb import DynamoDbSaver\n",
"\n",
"boto3_session = boto3.Session(region_name=\"us-west-2\")\n",
"\n",
"checkpointDynamoDbParams = {\n",
" \"table_name\": \"langgraph-checkpoints\",\n",
" \"boto3_session\": boto3_session,\n",
"}\n",
"\n",
"memory = DynamoDbSaver.from_params(**checkpointDynamoDbParams)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from langgraph.graph import END, MessageGraph\n",
"\n",
"\n",
"# Define the function that calls the model\n",
"def call_model(messages):\n",
" return \"model output\"\n",
"\n",
"\n",
"# Define the function to execute tools\n",
"def call_tool(messages):\n",
" return \"tool output\"\n",
"\n",
"\n",
"# Define a new graph\n",
"workflow = MessageGraph()\n",
"\n",
"# add FAKE nodes and edges to the graph...\n",
"workflow.add_node(\"agent\", call_model)\n",
"workflow.add_node(\"action\", call_tool)\n",
"workflow.set_entry_point(\"agent\")\n",
"workflow.add_conditional_edges(\n",
" \"agent\",\n",
" False,\n",
" {\n",
" \"continue\": \"action\",\n",
" \"end\": END,\n",
" },\n",
")\n",
"workflow.add_edge(\"action\", \"agent\")\n",
"\n",
"# Finally, we compile it!\n",
"# This compiles it into a LangChain Runnable,\n",
"# meaning you can use it as you would any other runnable\n",
"app = workflow.compile(checkpointer=memory)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Interacting with the agent\n",
"\n",
"Ok, yes, this graph is not going to work, but this is just to show you how to setup the DynamoDB checkpoint!"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -0,0 +1,3 @@
from langchain_community.checkpoints.dynamodb import DynamoDbSaver
__all__ = ["DynamoDbSaver"]

@ -0,0 +1,179 @@
import logging
import pickle
from typing import Any, Optional
from boto3.session import Session
from langchain_core.pydantic_v1 import Field
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.utils import ConfigurableFieldSpec
from langgraph.checkpoint.base import BaseCheckpointSaver, Checkpoint
logger = logging.getLogger(__name__)
class DynamoDbSaver(BaseCheckpointSaver):
client: Any
table: Any
table_name: str
primary_key_name: str
checkpoint_field_name: str
is_setup: bool = Field(False, init=False, repr=False)
"""Checkpoint saver that stores history in AWS DynamoDB.
This class creates a DynamoDB table with name `table_name` if it does not exist
Args:
table_name: name of the DynamoDB table
endpoint_url: URL of the AWS endpoint to connect to. This argument
is optional and useful for test purposes, like using Localstack.
If you plan to use AWS cloud service, you normally don't have to
worry about setting the endpoint_url.
primary_key_name: name of the primary key of the DynamoDB table. This argument
is optional, defaulting to "SessionId".
checkpoint_field_name: name of the field in the DynamoDB table where to store
the checkpoint data. This argument is optional, defaulting to "Checkpoint".
boto3_session: an optional boto3 session to use for the DynamoDB client
kms_key_id: an optional AWS KMS Key ID, AWS KMS Key ARN, or AWS KMS Alias for
client-side encryption
"""
@classmethod
def from_params(
cls,
table_name: str,
endpoint_url: Optional[str] = None,
primary_key_name: str = "ThreadId",
checkpoint_field_name: str = "Checkpoint",
boto3_session: Optional[Session] = None,
kms_key_id: Optional[str] = None,
):
if boto3_session:
client = boto3_session.resource("dynamodb", endpoint_url=endpoint_url)
else:
try:
import boto3
except ImportError as e:
raise ImportError(
"Unable to import boto3, please install with `pip install boto3`."
) from e
if endpoint_url:
client = boto3.resource("dynamodb", endpoint_url=endpoint_url)
else:
client = boto3.resource("dynamodb")
table = client.Table(table_name)
if kms_key_id:
try:
from dynamodb_encryption_sdk.encrypted.table import EncryptedTable
from dynamodb_encryption_sdk.identifiers import CryptoAction
from dynamodb_encryption_sdk.material_providers.aws_kms import (
AwsKmsCryptographicMaterialsProvider,
)
from dynamodb_encryption_sdk.structures import AttributeActions
except ImportError as e:
raise ImportError(
"Unable to import dynamodb_encryption_sdk, please install with "
"`pip install dynamodb-encryption-sdk`."
) from e
actions = AttributeActions(
default_action=CryptoAction.DO_NOTHING,
attribute_actions={"History": CryptoAction.ENCRYPT_AND_SIGN},
)
aws_kms_cmp = AwsKmsCryptographicMaterialsProvider(key_id=kms_key_id)
table = EncryptedTable(
table=table,
materials_provider=aws_kms_cmp,
attribute_actions=actions,
auto_refresh_table_indexes=False,
)
return DynamoDbSaver(
client=client,
table=table,
table_name=table_name,
primary_key_name=primary_key_name,
checkpoint_field_name=checkpoint_field_name,
)
@property
def config_specs(self) -> list[ConfigurableFieldSpec]:
return [
ConfigurableFieldSpec(
id="thread_id",
annotation=str,
name="Thread ID",
description=None,
default="",
is_shared=True,
),
]
def setup(self) -> None:
if self.is_setup:
return
# Define table schema
table_name = self.table_name
# Check if the table already exists
existing_tables = self.client.meta.client.list_tables()["TableNames"]
if table_name in existing_tables:
self.is_setup = True
return
# Create the DynamoDB table with a Global Secondary Index
try:
table = self.client.create_table(
TableName=table_name,
KeySchema=[
{
"AttributeName": self.primary_key_name,
"KeyType": "HASH", # Partition key
}
],
AttributeDefinitions=[
{
"AttributeName": self.primary_key_name,
"AttributeType": "S", # 'S' for string type
}
],
ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5},
)
# Wait for the table to be created
table.meta.client.get_waiter("table_exists").wait(TableName=table_name)
logger.info(f"Table {table_name} created successfully.")
except Exception as e:
logger.error(f"Error creating table: {e}")
self.is_setup = True
def get(self, config: RunnableConfig) -> Optional[Checkpoint]:
self.setup()
response = self.table.get_item(
Key={self.primary_key_name: config["configurable"]["thread_id"]}
)
item = response.get("Item", None)
if item and self.checkpoint_field_name in item:
checkpoint = pickle.loads(item[self.checkpoint_field_name].value)
return checkpoint
else:
return None
def put(self, config: RunnableConfig, checkpoint: Checkpoint) -> None:
self.setup()
self.table.put_item(
Item={
self.primary_key_name: config["configurable"]["thread_id"],
self.checkpoint_field_name: pickle.dumps(checkpoint),
}
)
async def aget(self, config: RunnableConfig) -> Optional[Checkpoint]:
raise NotImplementedError
async def aput(self, config: RunnableConfig, checkpoint: Checkpoint) -> None:
raise NotImplementedError

@ -0,0 +1 @@
"""Test functionality related to the checkpoints objects."""

@ -0,0 +1,129 @@
import os
import pytest
from moto import mock_aws
@pytest.fixture(scope="function")
def aws_credentials():
"""Mocked AWS Credentials for moto."""
os.environ["AWS_ACCESS_KEY_ID"] = "testing"
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
os.environ["AWS_SECURITY_TOKEN"] = "testing"
os.environ["AWS_SESSION_TOKEN"] = "testing"
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"
@mock_aws
def test_dynamodb_init(aws_credentials):
import boto3
from langchain_community.checkpoints.dynamodb import DynamoDbSaver
table_name = "test_table"
primary_key_name = "test_primary_key"
checkpoint_field_name = "test_checkpoint_field"
endpoint_url = "https://test.endpoint.url"
boto3_session = boto3.Session()
dynamodb = DynamoDbSaver.from_params(
table_name=table_name,
primary_key_name=primary_key_name,
checkpoint_field_name=checkpoint_field_name,
endpoint_url=endpoint_url,
boto3_session=boto3_session,
)
assert dynamodb.table_name == table_name
assert dynamodb.primary_key_name == primary_key_name
assert dynamodb.checkpoint_field_name == checkpoint_field_name
assert str(dynamodb.client.__class__) == str(
boto3_session.resource("dynamodb", endpoint_url=endpoint_url).__class__
)
assert str(dynamodb.table.__class__) == str(
dynamodb.client.Table(table_name).__class__
)
dynamodb = DynamoDbSaver.from_params(
table_name=table_name,
primary_key_name=primary_key_name,
checkpoint_field_name=checkpoint_field_name,
endpoint_url=endpoint_url,
)
assert dynamodb.table_name == table_name
assert dynamodb.primary_key_name == primary_key_name
assert dynamodb.checkpoint_field_name == checkpoint_field_name
assert str(dynamodb.client.__class__) == str(
boto3.resource("dynamodb", endpoint_url=endpoint_url).__class__
)
assert str(dynamodb.table.__class__) == str(
dynamodb.client.Table(table_name).__class__
)
dynamodb = DynamoDbSaver.from_params(
table_name=table_name,
primary_key_name=primary_key_name,
checkpoint_field_name=checkpoint_field_name,
)
assert dynamodb.table_name == table_name
assert dynamodb.primary_key_name == primary_key_name
assert dynamodb.checkpoint_field_name == checkpoint_field_name
assert str(dynamodb.client.__class__) == str(
boto3.resource("dynamodb", endpoint_url=endpoint_url).__class__
)
assert str(dynamodb.table.__class__) == str(
dynamodb.client.Table(table_name).__class__
)
dynamodb = DynamoDbSaver.from_params(
table_name=table_name,
)
assert dynamodb.table_name == table_name
assert dynamodb.primary_key_name == "ThreadId"
assert dynamodb.checkpoint_field_name == "Checkpoint"
assert str(dynamodb.client.__class__) == str(
boto3.resource("dynamodb", endpoint_url=endpoint_url).__class__
)
assert str(dynamodb.table.__class__) == str(
dynamodb.client.Table(table_name).__class__
)
@mock_aws
def test_dynamodb_put_get(aws_credentials):
import boto3
from langchain_community.checkpoints.dynamodb import DynamoDbSaver
table_name = "test_table"
primary_key_name = "test_primary_key"
checkpoint_field_name = "test_checkpoint_field"
endpoint_url = "https://test.endpoint.url"
boto3_session = boto3.Session()
dynamodb = DynamoDbSaver.from_params(
table_name=table_name,
primary_key_name=primary_key_name,
checkpoint_field_name=checkpoint_field_name,
endpoint_url=endpoint_url,
boto3_session=boto3_session,
)
patched_client = boto3.resource("dynamodb")
dynamodb.client = patched_client
dynamodb.table = patched_client.Table(table_name)
config = {"configurable": {"thread_id": "test_thread_id"}}
checkpoint = {
"v": 1,
"ts": "2021-01-01T00:00:00+00:00",
"channel_values": {},
"channel_versions": {},
"versions_seen": {},
}
dynamodb.put(config, checkpoint)
assert dynamodb.get(config) == checkpoint

@ -1,4 +1,5 @@
"""A unit test meant to catch accidental introduction of non-optional dependencies."""
from pathlib import Path
from typing import Any, Dict, Mapping

Loading…
Cancel
Save