azure-dynamic-sessions: add Python REPL tool (#21264)

Adds a Python REPL that executes code in a code interpreter session
using Azure Container Apps dynamic sessions.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
pull/21507/head langchain-azure-dynamic-sessions==0.1.0rc0
Anthony Chu 3 weeks ago committed by GitHub
parent 02701c277f
commit c735849e76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 LangChain, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

@ -0,0 +1,59 @@
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
test:
poetry run pytest $(TEST_FILE)
tests:
poetry run pytest $(TEST_FILE)
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/azure --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_azure_dynamic_sessions
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
poetry run ruff .
poetry run ruff format $(PYTHON_FILES) --diff
poetry run ruff --select I $(PYTHON_FILES)
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
poetry run ruff format $(PYTHON_FILES)
poetry run ruff --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
check_imports: $(shell find langchain_azure_dynamic_sessions -name '*.py')
poetry run python ./scripts/check_imports.py $^
######################
# HELP
######################
help:
@echo '----'
@echo 'check_imports - check imports'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'

@ -0,0 +1,36 @@
# langchain-azure-dynamic-sessions
This package contains the LangChain integration for Azure Container Apps dynamic sessions. You can use it to add a secure and scalable code interpreter to your agents.
## Installation
```bash
pip install -U langchain-azure-dynamic-sessions
```
## Usage
You first need to create an Azure Container Apps session pool and obtain its management endpoint. Then you can use the `SessionsPythonREPLTool` tool to give your agent the ability to execute Python code.
```python
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
# get the management endpoint from the session pool in the Azure portal
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
prompt = hub.pull("hwchase17/react")
tools=[tool]
react_agent = create_react_agent(
llm=llm,
tools=tools,
prompt=prompt,
)
react_agent_executor = AgentExecutor(agent=react_agent, tools=tools, verbose=True, handle_parsing_errors=True)
react_agent_executor.invoke({"input": "What is the current time in Vancouver, Canada?"})
```
By default, the tool uses `DefaultAzureCredential` to authenticate with Azure. If you're using a user-assigned managed identity, you must set the `AZURE_CLIENT_ID` environment variable to the ID of the managed identity.

@ -0,0 +1,169 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Azure Container Apps dynamic sessions\n",
"\n",
"Azure Container Apps dynamic sessions provides a secure and scalable way to run a Python code interpreter in Hyper-V isolated sandboxes. This allows your agents to run potentially untrusted code in a secure environment. The code interpreter environment includes many popular Python packages, such as NumPy, pandas, and scikit-learn.\n",
"\n",
"## Pre-requisites\n",
"\n",
"By default, the `SessionsPythonREPLTool` tool uses `DefaultAzureCredential` to authenticate with Azure. Locally, it'll use your credentials from the Azure CLI or VS Code. Install the Azure CLI and log in with `az login` to authenticate.\n",
"\n",
"## Using the tool\n",
"\n",
"Set variables:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import dotenv\n",
"dotenv.load_dotenv()\n",
"\n",
"POOL_MANAGEMENT_ENDPOINT = os.getenv(\"POOL_MANAGEMENT_ENDPOINT\")\n",
"AZURE_OPENAI_ENDPOINT = os.getenv(\"AZURE_OPENAI_ENDPOINT\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'{\\n \"result\": 42,\\n \"stdout\": \"\",\\n \"stderr\": \"\"\\n}'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_azure_dynamic_sessions import SessionsPythonREPLTool\n",
"\n",
"\n",
"tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)\n",
"tool.run(\"6 * 7\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Full agent example"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mI need to calculate the compound interest on the initial amount over 6 years.\n",
"Action: Python_REPL\n",
"Action Input: \n",
"```python\n",
"initial_amount = 500\n",
"interest_rate = 0.05\n",
"time_period = 6\n",
"final_amount = initial_amount * (1 + interest_rate)**time_period\n",
"final_amount\n",
"```\u001b[0m\u001b[36;1m\u001b[1;3m{\n",
" \"result\": 670.0478203125002,\n",
" \"stdout\": \"\",\n",
" \"stderr\": \"\"\n",
"}\u001b[0m\u001b[32;1m\u001b[1;3mThe final amount after 6 years will be $670.05\n",
"Final Answer: $670.05\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': 'If I put $500 in a bank account with a 5% interest rate, how much money will I have in the account after 6 years?',\n",
" 'output': '$670.05'}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import os\n",
"from azure.identity import DefaultAzureCredential\n",
"from langchain_azure_dynamic_sessions import SessionsPythonREPLTool\n",
"from langchain_openai import AzureChatOpenAI\n",
"from langchain import agents, hub\n",
"\n",
"\n",
"credential = DefaultAzureCredential()\n",
"os.environ[\"OPENAI_API_TYPE\"] = \"azure_ad\"\n",
"os.environ[\"OPENAI_API_KEY\"] = credential.get_token(\"https://cognitiveservices.azure.com/.default\").token\n",
"os.environ[\"AZURE_OPENAI_ENDPOINT\"] = AZURE_OPENAI_ENDPOINT\n",
"\n",
"llm = AzureChatOpenAI(\n",
" azure_deployment=\"gpt-35-turbo\",\n",
" openai_api_version=\"2023-09-15-preview\",\n",
" streaming=True,\n",
" temperature=0,\n",
")\n",
"\n",
"repl = SessionsPythonREPLTool(\n",
" pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT,\n",
")\n",
"\n",
"tools = [repl]\n",
"react_agent = agents.create_react_agent(\n",
" llm=llm,\n",
" tools=tools,\n",
" prompt=hub.pull(\"hwchase17/react\"),\n",
")\n",
"\n",
"react_agent_executor = agents.AgentExecutor(agent=react_agent, tools=tools, verbose=True, handle_parsing_errors=True)\n",
"\n",
"react_agent_executor.invoke({\"input\": \"If I put $500 in a bank account with a 5% interest rate, how much money will I have in the account after 6 years?\"})"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 1
}

@ -0,0 +1,5 @@
from langchain_azure_dynamic_sessions.tools.sessions import SessionsPythonREPLTool
__all__ = [
"SessionsPythonREPLTool",
]

@ -0,0 +1,5 @@
from langchain_azure_dynamic_sessions.tools.sessions import SessionsPythonREPLTool
__all__ = [
"SessionsPythonREPLTool",
]

@ -0,0 +1,273 @@
import importlib.metadata
import json
import os
import re
import urllib
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from io import BytesIO
from typing import Any, BinaryIO, Callable, List, Optional
from uuid import uuid4
import requests
from azure.core.credentials import AccessToken
from azure.identity import DefaultAzureCredential
from langchain_core.tools import BaseTool
try:
_package_version = importlib.metadata.version("langchain-azure-dynamic-sessions")
except importlib.metadata.PackageNotFoundError:
_package_version = "0.0.0"
USER_AGENT = f"langchain-azure-dynamic-sessions/{_package_version} (Language=Python)"
def _access_token_provider_factory() -> Callable[[], Optional[str]]:
"""Factory function for creating an access token provider function.
Returns:
Callable[[], Optional[str]]: The access token provider function
"""
access_token: Optional[AccessToken] = None
def access_token_provider() -> Optional[str]:
nonlocal access_token
if access_token is None or datetime.fromtimestamp(
access_token.expires_on, timezone.utc
) < datetime.now(timezone.utc) + timedelta(minutes=5):
credential = DefaultAzureCredential()
access_token = credential.get_token("https://dynamicsessions.io/.default")
return access_token.token
return access_token_provider
def _sanitize_input(query: str) -> str:
"""Sanitize input to the python REPL.
Remove whitespace, backtick & python (if llm mistakes python console as terminal)
Args:
query: The query to sanitize
Returns:
str: The sanitized query
"""
# Removes `, whitespace & python from start
query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query)
# Removes whitespace & ` from end
query = re.sub(r"(\s|`)*$", "", query)
return query
@dataclass
class RemoteFileMetadata:
"""Metadata for a file in the session."""
filename: str
"""The filename relative to `/mnt/data`."""
size_in_bytes: int
"""The size of the file in bytes."""
@property
def full_path(self) -> str:
"""Get the full path of the file."""
return f"/mnt/data/{self.filename}"
@staticmethod
def from_dict(data: dict) -> "RemoteFileMetadata":
"""Create a RemoteFileMetadata object from a dictionary."""
properties = data.get("properties", {})
return RemoteFileMetadata(
filename=properties.get("filename"),
size_in_bytes=properties.get("size"),
)
class SessionsPythonREPLTool(BaseTool):
"""A tool for running Python code in an Azure Container Apps dynamic sessions
code interpreter.
Example:
.. code-block:: python
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
tool = SessionsPythonREPLTool(pool_management_endpoint="...")
result = tool.run("6 * 7")
"""
name: str = "Python_REPL"
description: str = (
"A Python shell. Use this to execute python commands "
"when you need to perform calculations or computations. "
"Input should be a valid python command. "
"Returns a JSON object with the result, stdout, and stderr. "
)
sanitize_input: bool = True
"""Whether to sanitize input to the python REPL."""
pool_management_endpoint: str
"""The management endpoint of the session pool. Should end with a '/'."""
access_token_provider: Callable[
[], Optional[str]
] = _access_token_provider_factory()
"""A function that returns the access token to use for the session pool."""
session_id: str = str(uuid4())
"""The session ID to use for the code interpreter. Defaults to a random UUID."""
def _build_url(self, path: str) -> str:
pool_management_endpoint = self.pool_management_endpoint
if not pool_management_endpoint:
raise ValueError("pool_management_endpoint is not set")
if not pool_management_endpoint.endswith("/"):
pool_management_endpoint += "/"
encoded_session_id = urllib.parse.quote(self.session_id)
query = f"identifier={encoded_session_id}&api-version=2024-02-02-preview"
query_separator = "&" if "?" in pool_management_endpoint else "?"
full_url = pool_management_endpoint + path + query_separator + query
return full_url
def execute(self, python_code: str) -> Any:
"""Execute Python code in the session."""
if self.sanitize_input:
python_code = _sanitize_input(python_code)
access_token = self.access_token_provider()
api_url = self._build_url("code/execute")
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
"User-Agent": USER_AGENT,
}
body = {
"properties": {
"codeInputType": "inline",
"executionType": "synchronous",
"code": python_code,
}
}
response = requests.post(api_url, headers=headers, json=body)
response.raise_for_status()
response_json = response.json()
properties = response_json.get("properties", {})
return properties
def _run(self, python_code: str) -> Any:
response = self.execute(python_code)
# if the result is an image, remove the base64 data
result = response.get("result")
if isinstance(result, dict):
if result.get("type") == "image" and "base64_data" in result:
result.pop("base64_data")
return json.dumps(
{
"result": result,
"stdout": response.get("stdout"),
"stderr": response.get("stderr"),
},
indent=2,
)
def upload_file(
self,
*,
data: Optional[BinaryIO] = None,
remote_file_path: Optional[str] = None,
local_file_path: Optional[str] = None,
) -> RemoteFileMetadata:
"""Upload a file to the session.
Args:
data: The data to upload.
remote_file_path: The path to upload the file to, relative to
`/mnt/data`. If local_file_path is provided, this is defaulted
to its filename.
local_file_path: The path to the local file to upload.
Returns:
RemoteFileMetadata: The metadata for the uploaded file
"""
if data and local_file_path:
raise ValueError("data and local_file_path cannot be provided together")
if data:
file_data = data
elif local_file_path:
if not remote_file_path:
remote_file_path = os.path.basename(local_file_path)
file_data = open(local_file_path, "rb")
access_token = self.access_token_provider()
api_url = self._build_url("files/upload")
headers = {
"Authorization": f"Bearer {access_token}",
"User-Agent": USER_AGENT,
}
files = [("file", (remote_file_path, file_data, "application/octet-stream"))]
response = requests.request(
"POST", api_url, headers=headers, data={}, files=files
)
response.raise_for_status()
response_json = response.json()
return RemoteFileMetadata.from_dict(response_json["value"][0])
def download_file(
self, *, remote_file_path: str, local_file_path: Optional[str] = None
) -> BinaryIO:
"""Download a file from the session.
Args:
remote_file_path: The path to download the file from,
relative to `/mnt/data`.
local_file_path: The path to save the downloaded file to.
If not provided, the file is returned as a BufferedReader.
Returns:
BinaryIO: The data of the downloaded file.
"""
access_token = self.access_token_provider()
encoded_remote_file_path = urllib.parse.quote(remote_file_path)
api_url = self._build_url(f"files/content/{encoded_remote_file_path}")
headers = {
"Authorization": f"Bearer {access_token}",
"User-Agent": USER_AGENT,
}
response = requests.get(api_url, headers=headers)
response.raise_for_status()
if local_file_path:
with open(local_file_path, "wb") as f:
f.write(response.content)
return BytesIO(response.content)
def list_files(self) -> List[RemoteFileMetadata]:
"""List the files in the session.
Returns:
list[RemoteFileMetadata]: The metadata for the files in the session
"""
access_token = self.access_token_provider()
api_url = self._build_url("files")
headers = {
"Authorization": f"Bearer {access_token}",
"User-Agent": USER_AGENT,
}
response = requests.get(api_url, headers=headers)
response.raise_for_status()
response_json = response.json()
return [RemoteFileMetadata.from_dict(entry) for entry in response_json["value"]]

File diff suppressed because it is too large Load Diff

@ -0,0 +1,104 @@
[tool.poetry]
name = "langchain-azure-dynamic-sessions"
version = "0.1.0rc0"
description = "An integration package connecting Azure Container Apps dynamic sessions and LangChain"
authors = []
readme = "README.md"
repository = "https://github.com/langchain-ai/langchain"
license = "MIT"
[tool.poetry.urls]
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/azure-dynamic-sessions"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.1.52"
azure-identity = "^1.16.0"
requests = "^2.31.0"
[tool.poetry.group.test]
optional = true
[tool.poetry.group.test.dependencies]
pytest = "^7.3.0"
freezegun = "^1.2.2"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
langchain-core = {path = "../../core", develop = true}
python-dotenv = "^1.0.1"
[tool.poetry.group.test_integration]
optional = true
[tool.poetry.group.test_integration.dependencies]
pytest = "^7.3.0"
python-dotenv = "^1.0.1"
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
[tool.poetry.group.lint]
optional = true
[tool.poetry.group.lint.dependencies]
ruff = "^0.1.5"
python-dotenv = "^1.0.1"
pytest = "^7.3.0"
[tool.poetry.group.typing.dependencies]
mypy = "^0.991"
langchain-core = {path = "../../core", develop = true}
types-requests = "^2.31.0.20240406"
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.dev.dependencies]
langchain-core = {path = "../../core", develop = true}
ipykernel = "^6.29.4"
langchain-openai = {path = "../openai", develop = true}
langchainhub = "^0.1.15"
[tool.ruff]
select = [
"E", # pycodestyle
"F", # pyflakes
"I", # isort
]
[tool.mypy]
disallow_untyped_defs = "True"
[tool.coverage.run]
omit = [
"tests/*",
]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
#
# https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [
"requires: mark tests as requiring a specific library",
"asyncio: mark tests as requiring asyncio",
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto"

@ -0,0 +1,17 @@
import sys
import traceback
from importlib.machinery import SourceFileLoader
if __name__ == "__main__":
files = sys.argv[1:]
has_failure = False
for file in files:
try:
SourceFileLoader("x", file).load_module()
except Exception:
has_faillure = True
print(file)
traceback.print_exc()
print()
sys.exit(1 if has_failure else 0)

@ -0,0 +1,27 @@
#!/bin/bash
#
# This script searches for lines starting with "import pydantic" or "from pydantic"
# in tracked files within a Git repository.
#
# Usage: ./scripts/check_pydantic.sh /path/to/repository
# Check if a path argument is provided
if [ $# -ne 1 ]; then
echo "Usage: $0 /path/to/repository"
exit 1
fi
repository_path="$1"
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
# Check if any matching lines were found
if [ -n "$result" ]; then
echo "ERROR: The following lines need to be updated:"
echo "$result"
echo "Please replace the code with an import from langchain_core.pydantic_v1."
echo "For example, replace 'from pydantic import BaseModel'"
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
exit 1
fi

@ -0,0 +1,17 @@
#!/bin/bash
set -eu
# Initialize a variable to keep track of errors
errors=0
# make sure not importing from langchain or langchain_experimental
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
# Decide on an exit status based on the errors
if [ "$errors" -gt 0 ]; then
exit 1
else
exit 0
fi

@ -0,0 +1,7 @@
import pytest
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""
pass

@ -0,0 +1,68 @@
import json
import os
from io import BytesIO
import dotenv
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
dotenv.load_dotenv()
POOL_MANAGEMENT_ENDPOINT = os.getenv("AZURE_DYNAMIC_SESSIONS_POOL_MANAGEMENT_ENDPOINT")
TEST_DATA_PATH = os.path.join(os.path.dirname(__file__), "data", "testdata.txt")
TEST_DATA_CONTENT = open(TEST_DATA_PATH, "rb").read()
def test_end_to_end() -> None:
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
result = tool.run("print('hello world')\n1 + 1")
assert json.loads(result) == {
"result": 2,
"stdout": "hello world\n",
"stderr": "",
}
# upload file content
uploaded_file1_metadata = tool.upload_file(
remote_file_path="test1.txt", data=BytesIO(b"hello world!!!!!")
)
assert uploaded_file1_metadata.filename == "test1.txt"
assert uploaded_file1_metadata.size_in_bytes == 16
assert uploaded_file1_metadata.full_path == "/mnt/data/test1.txt"
downloaded_file1 = tool.download_file(remote_file_path="test1.txt")
assert downloaded_file1.read() == b"hello world!!!!!"
# upload file from buffer
with open(TEST_DATA_PATH, "rb") as f:
uploaded_file2_metadata = tool.upload_file(remote_file_path="test2.txt", data=f)
assert uploaded_file2_metadata.filename == "test2.txt"
downloaded_file2 = tool.download_file(remote_file_path="test2.txt")
assert downloaded_file2.read() == TEST_DATA_CONTENT
# upload file from disk, specifying remote file path
uploaded_file3_metadata = tool.upload_file(
remote_file_path="test3.txt", local_file_path=TEST_DATA_PATH
)
assert uploaded_file3_metadata.filename == "test3.txt"
downloaded_file3 = tool.download_file(remote_file_path="test3.txt")
assert downloaded_file3.read() == TEST_DATA_CONTENT
# upload file from disk, without specifying remote file path
uploaded_file4_metadata = tool.upload_file(local_file_path=TEST_DATA_PATH)
assert uploaded_file4_metadata.filename == os.path.basename(TEST_DATA_PATH)
downloaded_file4 = tool.download_file(
remote_file_path=uploaded_file4_metadata.filename
)
assert downloaded_file4.read() == TEST_DATA_CONTENT
# list files
remote_files_metadata = tool.list_files()
assert len(remote_files_metadata) == 4
remote_file_paths = [metadata.filename for metadata in remote_files_metadata]
expected_filenames = [
"test1.txt",
"test2.txt",
"test3.txt",
os.path.basename(TEST_DATA_PATH),
]
assert set(remote_file_paths) == set(expected_filenames)

@ -0,0 +1,9 @@
from langchain_azure_dynamic_sessions import __all__
EXPECTED_ALL = [
"SessionsPythonREPLTool",
]
def test_all_imports() -> None:
assert sorted(EXPECTED_ALL) == sorted(__all__)

@ -0,0 +1,208 @@
import json
import re
import time
from unittest import mock
from urllib.parse import parse_qs, urlparse
from azure.core.credentials import AccessToken
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
from langchain_azure_dynamic_sessions.tools.sessions import (
_access_token_provider_factory,
)
POOL_MANAGEMENT_ENDPOINT = "https://westus2.dynamicsessions.io/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/sessions-rg/sessionPools/my-pool"
def test_default_access_token_provider_returns_token() -> None:
access_token_provider = _access_token_provider_factory()
with mock.patch(
"azure.identity.DefaultAzureCredential.get_token"
) as mock_get_token:
mock_get_token.return_value = AccessToken("token_value", 0)
access_token = access_token_provider()
assert access_token == "token_value"
def test_default_access_token_provider_returns_cached_token() -> None:
access_token_provider = _access_token_provider_factory()
with mock.patch(
"azure.identity.DefaultAzureCredential.get_token"
) as mock_get_token:
mock_get_token.return_value = AccessToken(
"token_value", int(time.time() + 1000)
)
access_token = access_token_provider()
assert access_token == "token_value"
assert mock_get_token.call_count == 1
mock_get_token.return_value = AccessToken(
"new_token_value", int(time.time() + 1000)
)
access_token = access_token_provider()
assert access_token == "token_value"
assert mock_get_token.call_count == 1
def test_default_access_token_provider_refreshes_expiring_token() -> None:
access_token_provider = _access_token_provider_factory()
with mock.patch(
"azure.identity.DefaultAzureCredential.get_token"
) as mock_get_token:
mock_get_token.return_value = AccessToken("token_value", int(time.time() - 1))
access_token = access_token_provider()
assert access_token == "token_value"
assert mock_get_token.call_count == 1
mock_get_token.return_value = AccessToken(
"new_token_value", int(time.time() + 1000)
)
access_token = access_token_provider()
assert access_token == "new_token_value"
assert mock_get_token.call_count == 2
@mock.patch("requests.post")
@mock.patch("azure.identity.DefaultAzureCredential.get_token")
def test_code_execution_calls_api(
mock_get_token: mock.MagicMock, mock_post: mock.MagicMock
) -> None:
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
mock_post.return_value.json.return_value = {
"$id": "1",
"properties": {
"$id": "2",
"status": "Success",
"stdout": "hello world\n",
"stderr": "",
"result": "",
"executionTimeInMilliseconds": 33,
},
}
mock_get_token.return_value = AccessToken("token_value", int(time.time() + 1000))
result = tool.run("print('hello world')")
assert json.loads(result) == {
"result": "",
"stdout": "hello world\n",
"stderr": "",
}
api_url = f"{POOL_MANAGEMENT_ENDPOINT}/code/execute"
headers = {
"Authorization": "Bearer token_value",
"Content-Type": "application/json",
"User-Agent": mock.ANY,
}
body = {
"properties": {
"codeInputType": "inline",
"executionType": "synchronous",
"code": "print('hello world')",
}
}
mock_post.assert_called_once_with(mock.ANY, headers=headers, json=body)
called_headers = mock_post.call_args.kwargs["headers"]
assert re.match(
r"^langchain-azure-dynamic-sessions/\d+\.\d+\.\d+.* \(Language=Python\)",
called_headers["User-Agent"],
)
called_api_url = mock_post.call_args.args[0]
assert called_api_url.startswith(api_url)
@mock.patch("requests.post")
@mock.patch("azure.identity.DefaultAzureCredential.get_token")
def test_uses_specified_session_id(
mock_get_token: mock.MagicMock, mock_post: mock.MagicMock
) -> None:
tool = SessionsPythonREPLTool(
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT,
session_id="00000000-0000-0000-0000-000000000003",
)
mock_post.return_value.json.return_value = {
"$id": "1",
"properties": {
"$id": "2",
"status": "Success",
"stdout": "",
"stderr": "",
"result": "2",
"executionTimeInMilliseconds": 33,
},
}
mock_get_token.return_value = AccessToken("token_value", int(time.time() + 1000))
tool.run("1 + 1")
call_url = mock_post.call_args.args[0]
parsed_url = urlparse(call_url)
call_identifier = parse_qs(parsed_url.query)["identifier"][0]
assert call_identifier == "00000000-0000-0000-0000-000000000003"
def test_sanitizes_input() -> None:
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
with mock.patch("requests.post") as mock_post:
mock_post.return_value.json.return_value = {
"$id": "1",
"properties": {
"$id": "2",
"status": "Success",
"stdout": "",
"stderr": "",
"result": "",
"executionTimeInMilliseconds": 33,
},
}
tool.run("```python\nprint('hello world')\n```")
body = mock_post.call_args.kwargs["json"]
assert body["properties"]["code"] == "print('hello world')"
def test_does_not_sanitize_input() -> None:
tool = SessionsPythonREPLTool(
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT, sanitize_input=False
)
with mock.patch("requests.post") as mock_post:
mock_post.return_value.json.return_value = {
"$id": "1",
"properties": {
"$id": "2",
"status": "Success",
"stdout": "",
"stderr": "",
"result": "",
"executionTimeInMilliseconds": 33,
},
}
tool.run("```python\nprint('hello world')\n```")
body = mock_post.call_args.kwargs["json"]
assert body["properties"]["code"] == "```python\nprint('hello world')\n```"
def test_uses_custom_access_token_provider() -> None:
def custom_access_token_provider() -> str:
return "custom_token"
tool = SessionsPythonREPLTool(
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT,
access_token_provider=custom_access_token_provider,
)
with mock.patch("requests.post") as mock_post:
mock_post.return_value.json.return_value = {
"$id": "1",
"properties": {
"$id": "2",
"status": "Success",
"stdout": "",
"stderr": "",
"result": "",
"executionTimeInMilliseconds": 33,
},
}
tool.run("print('hello world')")
headers = mock_post.call_args.kwargs["headers"]
assert headers["Authorization"] == "Bearer custom_token"
Loading…
Cancel
Save