mirror of https://github.com/hwchase17/langchain
azure-dynamic-sessions: add Python REPL tool (#21264)
Adds a Python REPL that executes code in a code interpreter session using Azure Container Apps dynamic sessions. --------- Co-authored-by: Erick Friis <erick@langchain.dev>pull/21507/head langchain-azure-dynamic-sessions==0.1.0rc0
parent
02701c277f
commit
c735849e76
@ -0,0 +1 @@
|
||||
__pycache__
|
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
@ -0,0 +1,59 @@
|
||||
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
|
||||
test:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
MYPY_CACHE=.mypy_cache
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/azure --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
lint_package: PYTHON_FILES=langchain_azure_dynamic_sessions
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
poetry run ruff .
|
||||
poetry run ruff format $(PYTHON_FILES) --diff
|
||||
poetry run ruff --select I $(PYTHON_FILES)
|
||||
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
poetry run ruff format $(PYTHON_FILES)
|
||||
poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
check_imports: $(shell find langchain_azure_dynamic_sessions -name '*.py')
|
||||
poetry run python ./scripts/check_imports.py $^
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'check_imports - check imports'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
@ -0,0 +1,36 @@
|
||||
# langchain-azure-dynamic-sessions
|
||||
|
||||
This package contains the LangChain integration for Azure Container Apps dynamic sessions. You can use it to add a secure and scalable code interpreter to your agents.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install -U langchain-azure-dynamic-sessions
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
You first need to create an Azure Container Apps session pool and obtain its management endpoint. Then you can use the `SessionsPythonREPLTool` tool to give your agent the ability to execute Python code.
|
||||
|
||||
```python
|
||||
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
|
||||
|
||||
|
||||
# get the management endpoint from the session pool in the Azure portal
|
||||
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
|
||||
|
||||
prompt = hub.pull("hwchase17/react")
|
||||
tools=[tool]
|
||||
react_agent = create_react_agent(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
react_agent_executor = AgentExecutor(agent=react_agent, tools=tools, verbose=True, handle_parsing_errors=True)
|
||||
|
||||
react_agent_executor.invoke({"input": "What is the current time in Vancouver, Canada?"})
|
||||
```
|
||||
|
||||
By default, the tool uses `DefaultAzureCredential` to authenticate with Azure. If you're using a user-assigned managed identity, you must set the `AZURE_CLIENT_ID` environment variable to the ID of the managed identity.
|
||||
|
@ -0,0 +1,169 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Azure Container Apps dynamic sessions\n",
|
||||
"\n",
|
||||
"Azure Container Apps dynamic sessions provides a secure and scalable way to run a Python code interpreter in Hyper-V isolated sandboxes. This allows your agents to run potentially untrusted code in a secure environment. The code interpreter environment includes many popular Python packages, such as NumPy, pandas, and scikit-learn.\n",
|
||||
"\n",
|
||||
"## Pre-requisites\n",
|
||||
"\n",
|
||||
"By default, the `SessionsPythonREPLTool` tool uses `DefaultAzureCredential` to authenticate with Azure. Locally, it'll use your credentials from the Azure CLI or VS Code. Install the Azure CLI and log in with `az login` to authenticate.\n",
|
||||
"\n",
|
||||
"## Using the tool\n",
|
||||
"\n",
|
||||
"Set variables:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import dotenv\n",
|
||||
"dotenv.load_dotenv()\n",
|
||||
"\n",
|
||||
"POOL_MANAGEMENT_ENDPOINT = os.getenv(\"POOL_MANAGEMENT_ENDPOINT\")\n",
|
||||
"AZURE_OPENAI_ENDPOINT = os.getenv(\"AZURE_OPENAI_ENDPOINT\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'{\\n \"result\": 42,\\n \"stdout\": \"\",\\n \"stderr\": \"\"\\n}'"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_azure_dynamic_sessions import SessionsPythonREPLTool\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)\n",
|
||||
"tool.run(\"6 * 7\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Full agent example"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mI need to calculate the compound interest on the initial amount over 6 years.\n",
|
||||
"Action: Python_REPL\n",
|
||||
"Action Input: \n",
|
||||
"```python\n",
|
||||
"initial_amount = 500\n",
|
||||
"interest_rate = 0.05\n",
|
||||
"time_period = 6\n",
|
||||
"final_amount = initial_amount * (1 + interest_rate)**time_period\n",
|
||||
"final_amount\n",
|
||||
"```\u001b[0m\u001b[36;1m\u001b[1;3m{\n",
|
||||
" \"result\": 670.0478203125002,\n",
|
||||
" \"stdout\": \"\",\n",
|
||||
" \"stderr\": \"\"\n",
|
||||
"}\u001b[0m\u001b[32;1m\u001b[1;3mThe final amount after 6 years will be $670.05\n",
|
||||
"Final Answer: $670.05\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'input': 'If I put $500 in a bank account with a 5% interest rate, how much money will I have in the account after 6 years?',\n",
|
||||
" 'output': '$670.05'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from azure.identity import DefaultAzureCredential\n",
|
||||
"from langchain_azure_dynamic_sessions import SessionsPythonREPLTool\n",
|
||||
"from langchain_openai import AzureChatOpenAI\n",
|
||||
"from langchain import agents, hub\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"credential = DefaultAzureCredential()\n",
|
||||
"os.environ[\"OPENAI_API_TYPE\"] = \"azure_ad\"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = credential.get_token(\"https://cognitiveservices.azure.com/.default\").token\n",
|
||||
"os.environ[\"AZURE_OPENAI_ENDPOINT\"] = AZURE_OPENAI_ENDPOINT\n",
|
||||
"\n",
|
||||
"llm = AzureChatOpenAI(\n",
|
||||
" azure_deployment=\"gpt-35-turbo\",\n",
|
||||
" openai_api_version=\"2023-09-15-preview\",\n",
|
||||
" streaming=True,\n",
|
||||
" temperature=0,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"repl = SessionsPythonREPLTool(\n",
|
||||
" pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"tools = [repl]\n",
|
||||
"react_agent = agents.create_react_agent(\n",
|
||||
" llm=llm,\n",
|
||||
" tools=tools,\n",
|
||||
" prompt=hub.pull(\"hwchase17/react\"),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"react_agent_executor = agents.AgentExecutor(agent=react_agent, tools=tools, verbose=True, handle_parsing_errors=True)\n",
|
||||
"\n",
|
||||
"react_agent_executor.invoke({\"input\": \"If I put $500 in a bank account with a 5% interest rate, how much money will I have in the account after 6 years?\"})"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
from langchain_azure_dynamic_sessions.tools.sessions import SessionsPythonREPLTool
|
||||
|
||||
__all__ = [
|
||||
"SessionsPythonREPLTool",
|
||||
]
|
@ -0,0 +1,5 @@
|
||||
from langchain_azure_dynamic_sessions.tools.sessions import SessionsPythonREPLTool
|
||||
|
||||
__all__ = [
|
||||
"SessionsPythonREPLTool",
|
||||
]
|
@ -0,0 +1,273 @@
|
||||
import importlib.metadata
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import urllib
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from io import BytesIO
|
||||
from typing import Any, BinaryIO, Callable, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
from azure.core.credentials import AccessToken
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
try:
|
||||
_package_version = importlib.metadata.version("langchain-azure-dynamic-sessions")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
_package_version = "0.0.0"
|
||||
USER_AGENT = f"langchain-azure-dynamic-sessions/{_package_version} (Language=Python)"
|
||||
|
||||
|
||||
def _access_token_provider_factory() -> Callable[[], Optional[str]]:
|
||||
"""Factory function for creating an access token provider function.
|
||||
|
||||
Returns:
|
||||
Callable[[], Optional[str]]: The access token provider function
|
||||
"""
|
||||
|
||||
access_token: Optional[AccessToken] = None
|
||||
|
||||
def access_token_provider() -> Optional[str]:
|
||||
nonlocal access_token
|
||||
if access_token is None or datetime.fromtimestamp(
|
||||
access_token.expires_on, timezone.utc
|
||||
) < datetime.now(timezone.utc) + timedelta(minutes=5):
|
||||
credential = DefaultAzureCredential()
|
||||
access_token = credential.get_token("https://dynamicsessions.io/.default")
|
||||
return access_token.token
|
||||
|
||||
return access_token_provider
|
||||
|
||||
|
||||
def _sanitize_input(query: str) -> str:
|
||||
"""Sanitize input to the python REPL.
|
||||
Remove whitespace, backtick & python (if llm mistakes python console as terminal)
|
||||
|
||||
Args:
|
||||
query: The query to sanitize
|
||||
|
||||
Returns:
|
||||
str: The sanitized query
|
||||
"""
|
||||
|
||||
# Removes `, whitespace & python from start
|
||||
query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query)
|
||||
# Removes whitespace & ` from end
|
||||
query = re.sub(r"(\s|`)*$", "", query)
|
||||
return query
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemoteFileMetadata:
|
||||
"""Metadata for a file in the session."""
|
||||
|
||||
filename: str
|
||||
"""The filename relative to `/mnt/data`."""
|
||||
|
||||
size_in_bytes: int
|
||||
"""The size of the file in bytes."""
|
||||
|
||||
@property
|
||||
def full_path(self) -> str:
|
||||
"""Get the full path of the file."""
|
||||
return f"/mnt/data/{self.filename}"
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: dict) -> "RemoteFileMetadata":
|
||||
"""Create a RemoteFileMetadata object from a dictionary."""
|
||||
properties = data.get("properties", {})
|
||||
return RemoteFileMetadata(
|
||||
filename=properties.get("filename"),
|
||||
size_in_bytes=properties.get("size"),
|
||||
)
|
||||
|
||||
|
||||
class SessionsPythonREPLTool(BaseTool):
|
||||
"""A tool for running Python code in an Azure Container Apps dynamic sessions
|
||||
code interpreter.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
|
||||
tool = SessionsPythonREPLTool(pool_management_endpoint="...")
|
||||
result = tool.run("6 * 7")
|
||||
"""
|
||||
|
||||
name: str = "Python_REPL"
|
||||
description: str = (
|
||||
"A Python shell. Use this to execute python commands "
|
||||
"when you need to perform calculations or computations. "
|
||||
"Input should be a valid python command. "
|
||||
"Returns a JSON object with the result, stdout, and stderr. "
|
||||
)
|
||||
|
||||
sanitize_input: bool = True
|
||||
"""Whether to sanitize input to the python REPL."""
|
||||
|
||||
pool_management_endpoint: str
|
||||
"""The management endpoint of the session pool. Should end with a '/'."""
|
||||
|
||||
access_token_provider: Callable[
|
||||
[], Optional[str]
|
||||
] = _access_token_provider_factory()
|
||||
"""A function that returns the access token to use for the session pool."""
|
||||
|
||||
session_id: str = str(uuid4())
|
||||
"""The session ID to use for the code interpreter. Defaults to a random UUID."""
|
||||
|
||||
def _build_url(self, path: str) -> str:
|
||||
pool_management_endpoint = self.pool_management_endpoint
|
||||
if not pool_management_endpoint:
|
||||
raise ValueError("pool_management_endpoint is not set")
|
||||
if not pool_management_endpoint.endswith("/"):
|
||||
pool_management_endpoint += "/"
|
||||
encoded_session_id = urllib.parse.quote(self.session_id)
|
||||
query = f"identifier={encoded_session_id}&api-version=2024-02-02-preview"
|
||||
query_separator = "&" if "?" in pool_management_endpoint else "?"
|
||||
full_url = pool_management_endpoint + path + query_separator + query
|
||||
return full_url
|
||||
|
||||
def execute(self, python_code: str) -> Any:
|
||||
"""Execute Python code in the session."""
|
||||
|
||||
if self.sanitize_input:
|
||||
python_code = _sanitize_input(python_code)
|
||||
|
||||
access_token = self.access_token_provider()
|
||||
api_url = self._build_url("code/execute")
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": USER_AGENT,
|
||||
}
|
||||
body = {
|
||||
"properties": {
|
||||
"codeInputType": "inline",
|
||||
"executionType": "synchronous",
|
||||
"code": python_code,
|
||||
}
|
||||
}
|
||||
|
||||
response = requests.post(api_url, headers=headers, json=body)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
properties = response_json.get("properties", {})
|
||||
return properties
|
||||
|
||||
def _run(self, python_code: str) -> Any:
|
||||
response = self.execute(python_code)
|
||||
|
||||
# if the result is an image, remove the base64 data
|
||||
result = response.get("result")
|
||||
if isinstance(result, dict):
|
||||
if result.get("type") == "image" and "base64_data" in result:
|
||||
result.pop("base64_data")
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"result": result,
|
||||
"stdout": response.get("stdout"),
|
||||
"stderr": response.get("stderr"),
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
*,
|
||||
data: Optional[BinaryIO] = None,
|
||||
remote_file_path: Optional[str] = None,
|
||||
local_file_path: Optional[str] = None,
|
||||
) -> RemoteFileMetadata:
|
||||
"""Upload a file to the session.
|
||||
|
||||
Args:
|
||||
data: The data to upload.
|
||||
remote_file_path: The path to upload the file to, relative to
|
||||
`/mnt/data`. If local_file_path is provided, this is defaulted
|
||||
to its filename.
|
||||
local_file_path: The path to the local file to upload.
|
||||
|
||||
Returns:
|
||||
RemoteFileMetadata: The metadata for the uploaded file
|
||||
"""
|
||||
|
||||
if data and local_file_path:
|
||||
raise ValueError("data and local_file_path cannot be provided together")
|
||||
|
||||
if data:
|
||||
file_data = data
|
||||
elif local_file_path:
|
||||
if not remote_file_path:
|
||||
remote_file_path = os.path.basename(local_file_path)
|
||||
file_data = open(local_file_path, "rb")
|
||||
|
||||
access_token = self.access_token_provider()
|
||||
api_url = self._build_url("files/upload")
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"User-Agent": USER_AGENT,
|
||||
}
|
||||
files = [("file", (remote_file_path, file_data, "application/octet-stream"))]
|
||||
|
||||
response = requests.request(
|
||||
"POST", api_url, headers=headers, data={}, files=files
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
return RemoteFileMetadata.from_dict(response_json["value"][0])
|
||||
|
||||
def download_file(
|
||||
self, *, remote_file_path: str, local_file_path: Optional[str] = None
|
||||
) -> BinaryIO:
|
||||
"""Download a file from the session.
|
||||
|
||||
Args:
|
||||
remote_file_path: The path to download the file from,
|
||||
relative to `/mnt/data`.
|
||||
local_file_path: The path to save the downloaded file to.
|
||||
If not provided, the file is returned as a BufferedReader.
|
||||
|
||||
Returns:
|
||||
BinaryIO: The data of the downloaded file.
|
||||
"""
|
||||
access_token = self.access_token_provider()
|
||||
encoded_remote_file_path = urllib.parse.quote(remote_file_path)
|
||||
api_url = self._build_url(f"files/content/{encoded_remote_file_path}")
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"User-Agent": USER_AGENT,
|
||||
}
|
||||
|
||||
response = requests.get(api_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
if local_file_path:
|
||||
with open(local_file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
return BytesIO(response.content)
|
||||
|
||||
def list_files(self) -> List[RemoteFileMetadata]:
|
||||
"""List the files in the session.
|
||||
|
||||
Returns:
|
||||
list[RemoteFileMetadata]: The metadata for the files in the session
|
||||
"""
|
||||
access_token = self.access_token_provider()
|
||||
api_url = self._build_url("files")
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"User-Agent": USER_AGENT,
|
||||
}
|
||||
|
||||
response = requests.get(api_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
return [RemoteFileMetadata.from_dict(entry) for entry in response_json["value"]]
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,104 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-azure-dynamic-sessions"
|
||||
version = "0.1.0rc0"
|
||||
description = "An integration package connecting Azure Container Apps dynamic sessions and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/langchain-ai/langchain"
|
||||
license = "MIT"
|
||||
|
||||
[tool.poetry.urls]
|
||||
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/azure-dynamic-sessions"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = "^0.1.52"
|
||||
azure-identity = "^1.16.0"
|
||||
requests = "^2.31.0"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
freezegun = "^1.2.2"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
python-dotenv = "^1.0.1"
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
python-dotenv = "^1.0.1"
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.1.5"
|
||||
python-dotenv = "^1.0.1"
|
||||
pytest = "^7.3.0"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^0.991"
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
types-requests = "^2.31.0.20240406"
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
ipykernel = "^6.29.4"
|
||||
langchain-openai = {path = "../openai", develop = true}
|
||||
langchainhub = "^0.1.15"
|
||||
|
||||
[tool.ruff]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [
|
||||
"tests/*",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# --strict-markers will raise errors on unknown marks.
|
||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||
#
|
||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||
# --strict-config any warnings encountered while parsing the `pytest`
|
||||
# section of the configuration file raise errors.
|
||||
#
|
||||
# https://github.com/tophat/syrupy
|
||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library",
|
||||
"asyncio: mark tests as requiring asyncio",
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
@ -0,0 +1,17 @@
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = sys.argv[1:]
|
||||
has_failure = False
|
||||
for file in files:
|
||||
try:
|
||||
SourceFileLoader("x", file).load_module()
|
||||
except Exception:
|
||||
has_faillure = True
|
||||
print(file)
|
||||
traceback.print_exc()
|
||||
print()
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||
# in tracked files within a Git repository.
|
||||
#
|
||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||
|
||||
# Check if a path argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Usage: $0 /path/to/repository"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
repository_path="$1"
|
||||
|
||||
# Search for lines matching the pattern within the specified repository
|
||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||
|
||||
# Check if any matching lines were found
|
||||
if [ -n "$result" ]; then
|
||||
echo "ERROR: The following lines need to be updated:"
|
||||
echo "$result"
|
||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||
echo "For example, replace 'from pydantic import BaseModel'"
|
||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||
exit 1
|
||||
fi
|
@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Initialize a variable to keep track of errors
|
||||
errors=0
|
||||
|
||||
# make sure not importing from langchain or langchain_experimental
|
||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
@ -0,0 +1 @@
|
||||
test file content
|
@ -0,0 +1,7 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
@ -0,0 +1,68 @@
|
||||
import json
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
import dotenv
|
||||
|
||||
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
POOL_MANAGEMENT_ENDPOINT = os.getenv("AZURE_DYNAMIC_SESSIONS_POOL_MANAGEMENT_ENDPOINT")
|
||||
TEST_DATA_PATH = os.path.join(os.path.dirname(__file__), "data", "testdata.txt")
|
||||
TEST_DATA_CONTENT = open(TEST_DATA_PATH, "rb").read()
|
||||
|
||||
|
||||
def test_end_to_end() -> None:
|
||||
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
|
||||
result = tool.run("print('hello world')\n1 + 1")
|
||||
assert json.loads(result) == {
|
||||
"result": 2,
|
||||
"stdout": "hello world\n",
|
||||
"stderr": "",
|
||||
}
|
||||
|
||||
# upload file content
|
||||
uploaded_file1_metadata = tool.upload_file(
|
||||
remote_file_path="test1.txt", data=BytesIO(b"hello world!!!!!")
|
||||
)
|
||||
assert uploaded_file1_metadata.filename == "test1.txt"
|
||||
assert uploaded_file1_metadata.size_in_bytes == 16
|
||||
assert uploaded_file1_metadata.full_path == "/mnt/data/test1.txt"
|
||||
downloaded_file1 = tool.download_file(remote_file_path="test1.txt")
|
||||
assert downloaded_file1.read() == b"hello world!!!!!"
|
||||
|
||||
# upload file from buffer
|
||||
with open(TEST_DATA_PATH, "rb") as f:
|
||||
uploaded_file2_metadata = tool.upload_file(remote_file_path="test2.txt", data=f)
|
||||
assert uploaded_file2_metadata.filename == "test2.txt"
|
||||
downloaded_file2 = tool.download_file(remote_file_path="test2.txt")
|
||||
assert downloaded_file2.read() == TEST_DATA_CONTENT
|
||||
|
||||
# upload file from disk, specifying remote file path
|
||||
uploaded_file3_metadata = tool.upload_file(
|
||||
remote_file_path="test3.txt", local_file_path=TEST_DATA_PATH
|
||||
)
|
||||
assert uploaded_file3_metadata.filename == "test3.txt"
|
||||
downloaded_file3 = tool.download_file(remote_file_path="test3.txt")
|
||||
assert downloaded_file3.read() == TEST_DATA_CONTENT
|
||||
|
||||
# upload file from disk, without specifying remote file path
|
||||
uploaded_file4_metadata = tool.upload_file(local_file_path=TEST_DATA_PATH)
|
||||
assert uploaded_file4_metadata.filename == os.path.basename(TEST_DATA_PATH)
|
||||
downloaded_file4 = tool.download_file(
|
||||
remote_file_path=uploaded_file4_metadata.filename
|
||||
)
|
||||
assert downloaded_file4.read() == TEST_DATA_CONTENT
|
||||
|
||||
# list files
|
||||
remote_files_metadata = tool.list_files()
|
||||
assert len(remote_files_metadata) == 4
|
||||
remote_file_paths = [metadata.filename for metadata in remote_files_metadata]
|
||||
expected_filenames = [
|
||||
"test1.txt",
|
||||
"test2.txt",
|
||||
"test3.txt",
|
||||
os.path.basename(TEST_DATA_PATH),
|
||||
]
|
||||
assert set(remote_file_paths) == set(expected_filenames)
|
@ -0,0 +1,9 @@
|
||||
from langchain_azure_dynamic_sessions import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"SessionsPythonREPLTool",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
@ -0,0 +1,208 @@
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from unittest import mock
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from azure.core.credentials import AccessToken
|
||||
|
||||
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
|
||||
from langchain_azure_dynamic_sessions.tools.sessions import (
|
||||
_access_token_provider_factory,
|
||||
)
|
||||
|
||||
POOL_MANAGEMENT_ENDPOINT = "https://westus2.dynamicsessions.io/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/sessions-rg/sessionPools/my-pool"
|
||||
|
||||
|
||||
def test_default_access_token_provider_returns_token() -> None:
|
||||
access_token_provider = _access_token_provider_factory()
|
||||
with mock.patch(
|
||||
"azure.identity.DefaultAzureCredential.get_token"
|
||||
) as mock_get_token:
|
||||
mock_get_token.return_value = AccessToken("token_value", 0)
|
||||
access_token = access_token_provider()
|
||||
assert access_token == "token_value"
|
||||
|
||||
|
||||
def test_default_access_token_provider_returns_cached_token() -> None:
|
||||
access_token_provider = _access_token_provider_factory()
|
||||
with mock.patch(
|
||||
"azure.identity.DefaultAzureCredential.get_token"
|
||||
) as mock_get_token:
|
||||
mock_get_token.return_value = AccessToken(
|
||||
"token_value", int(time.time() + 1000)
|
||||
)
|
||||
access_token = access_token_provider()
|
||||
assert access_token == "token_value"
|
||||
assert mock_get_token.call_count == 1
|
||||
|
||||
mock_get_token.return_value = AccessToken(
|
||||
"new_token_value", int(time.time() + 1000)
|
||||
)
|
||||
access_token = access_token_provider()
|
||||
assert access_token == "token_value"
|
||||
assert mock_get_token.call_count == 1
|
||||
|
||||
|
||||
def test_default_access_token_provider_refreshes_expiring_token() -> None:
|
||||
access_token_provider = _access_token_provider_factory()
|
||||
with mock.patch(
|
||||
"azure.identity.DefaultAzureCredential.get_token"
|
||||
) as mock_get_token:
|
||||
mock_get_token.return_value = AccessToken("token_value", int(time.time() - 1))
|
||||
access_token = access_token_provider()
|
||||
assert access_token == "token_value"
|
||||
assert mock_get_token.call_count == 1
|
||||
|
||||
mock_get_token.return_value = AccessToken(
|
||||
"new_token_value", int(time.time() + 1000)
|
||||
)
|
||||
access_token = access_token_provider()
|
||||
assert access_token == "new_token_value"
|
||||
assert mock_get_token.call_count == 2
|
||||
|
||||
|
||||
@mock.patch("requests.post")
|
||||
@mock.patch("azure.identity.DefaultAzureCredential.get_token")
|
||||
def test_code_execution_calls_api(
|
||||
mock_get_token: mock.MagicMock, mock_post: mock.MagicMock
|
||||
) -> None:
|
||||
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
|
||||
mock_post.return_value.json.return_value = {
|
||||
"$id": "1",
|
||||
"properties": {
|
||||
"$id": "2",
|
||||
"status": "Success",
|
||||
"stdout": "hello world\n",
|
||||
"stderr": "",
|
||||
"result": "",
|
||||
"executionTimeInMilliseconds": 33,
|
||||
},
|
||||
}
|
||||
mock_get_token.return_value = AccessToken("token_value", int(time.time() + 1000))
|
||||
|
||||
result = tool.run("print('hello world')")
|
||||
|
||||
assert json.loads(result) == {
|
||||
"result": "",
|
||||
"stdout": "hello world\n",
|
||||
"stderr": "",
|
||||
}
|
||||
|
||||
api_url = f"{POOL_MANAGEMENT_ENDPOINT}/code/execute"
|
||||
headers = {
|
||||
"Authorization": "Bearer token_value",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": mock.ANY,
|
||||
}
|
||||
body = {
|
||||
"properties": {
|
||||
"codeInputType": "inline",
|
||||
"executionType": "synchronous",
|
||||
"code": "print('hello world')",
|
||||
}
|
||||
}
|
||||
mock_post.assert_called_once_with(mock.ANY, headers=headers, json=body)
|
||||
|
||||
called_headers = mock_post.call_args.kwargs["headers"]
|
||||
assert re.match(
|
||||
r"^langchain-azure-dynamic-sessions/\d+\.\d+\.\d+.* \(Language=Python\)",
|
||||
called_headers["User-Agent"],
|
||||
)
|
||||
|
||||
called_api_url = mock_post.call_args.args[0]
|
||||
assert called_api_url.startswith(api_url)
|
||||
|
||||
|
||||
@mock.patch("requests.post")
|
||||
@mock.patch("azure.identity.DefaultAzureCredential.get_token")
|
||||
def test_uses_specified_session_id(
|
||||
mock_get_token: mock.MagicMock, mock_post: mock.MagicMock
|
||||
) -> None:
|
||||
tool = SessionsPythonREPLTool(
|
||||
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT,
|
||||
session_id="00000000-0000-0000-0000-000000000003",
|
||||
)
|
||||
mock_post.return_value.json.return_value = {
|
||||
"$id": "1",
|
||||
"properties": {
|
||||
"$id": "2",
|
||||
"status": "Success",
|
||||
"stdout": "",
|
||||
"stderr": "",
|
||||
"result": "2",
|
||||
"executionTimeInMilliseconds": 33,
|
||||
},
|
||||
}
|
||||
mock_get_token.return_value = AccessToken("token_value", int(time.time() + 1000))
|
||||
tool.run("1 + 1")
|
||||
call_url = mock_post.call_args.args[0]
|
||||
parsed_url = urlparse(call_url)
|
||||
call_identifier = parse_qs(parsed_url.query)["identifier"][0]
|
||||
assert call_identifier == "00000000-0000-0000-0000-000000000003"
|
||||
|
||||
|
||||
def test_sanitizes_input() -> None:
|
||||
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
|
||||
with mock.patch("requests.post") as mock_post:
|
||||
mock_post.return_value.json.return_value = {
|
||||
"$id": "1",
|
||||
"properties": {
|
||||
"$id": "2",
|
||||
"status": "Success",
|
||||
"stdout": "",
|
||||
"stderr": "",
|
||||
"result": "",
|
||||
"executionTimeInMilliseconds": 33,
|
||||
},
|
||||
}
|
||||
tool.run("```python\nprint('hello world')\n```")
|
||||
body = mock_post.call_args.kwargs["json"]
|
||||
assert body["properties"]["code"] == "print('hello world')"
|
||||
|
||||
|
||||
def test_does_not_sanitize_input() -> None:
|
||||
tool = SessionsPythonREPLTool(
|
||||
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT, sanitize_input=False
|
||||
)
|
||||
with mock.patch("requests.post") as mock_post:
|
||||
mock_post.return_value.json.return_value = {
|
||||
"$id": "1",
|
||||
"properties": {
|
||||
"$id": "2",
|
||||
"status": "Success",
|
||||
"stdout": "",
|
||||
"stderr": "",
|
||||
"result": "",
|
||||
"executionTimeInMilliseconds": 33,
|
||||
},
|
||||
}
|
||||
tool.run("```python\nprint('hello world')\n```")
|
||||
body = mock_post.call_args.kwargs["json"]
|
||||
assert body["properties"]["code"] == "```python\nprint('hello world')\n```"
|
||||
|
||||
|
||||
def test_uses_custom_access_token_provider() -> None:
|
||||
def custom_access_token_provider() -> str:
|
||||
return "custom_token"
|
||||
|
||||
tool = SessionsPythonREPLTool(
|
||||
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT,
|
||||
access_token_provider=custom_access_token_provider,
|
||||
)
|
||||
|
||||
with mock.patch("requests.post") as mock_post:
|
||||
mock_post.return_value.json.return_value = {
|
||||
"$id": "1",
|
||||
"properties": {
|
||||
"$id": "2",
|
||||
"status": "Success",
|
||||
"stdout": "",
|
||||
"stderr": "",
|
||||
"result": "",
|
||||
"executionTimeInMilliseconds": 33,
|
||||
},
|
||||
}
|
||||
tool.run("print('hello world')")
|
||||
headers = mock_post.call_args.kwargs["headers"]
|
||||
assert headers["Authorization"] == "Bearer custom_token"
|
Loading…
Reference in New Issue