Pass through Run ID Explicitly (#21469)

pull/21420/head^2
William FH 4 weeks ago committed by GitHub
parent 83eecd54fe
commit b28be5d407
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,4 +1,5 @@
"""Chain that takes in an input and produces an action and action input."""
from __future__ import annotations
import asyncio
@ -1565,6 +1566,7 @@ class AgentExecutor(Chain):
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.get("run_id"),
yield_actions=True,
**kwargs,
)
@ -1586,6 +1588,7 @@ class AgentExecutor(Chain):
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.get("run_id"),
yield_actions=True,
**kwargs,
)

@ -14,6 +14,7 @@ from typing import (
Tuple,
Union,
)
from uuid import UUID
from langchain_core.agents import (
AgentAction,
@ -54,6 +55,7 @@ class AgentExecutorIterator:
tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[UUID] = None,
include_run_info: bool = False,
yield_actions: bool = False,
):
@ -67,6 +69,7 @@ class AgentExecutorIterator:
self.tags = tags
self.metadata = metadata
self.run_name = run_name
self.run_id = run_id
self.include_run_info = include_run_info
self.yield_actions = yield_actions
self.reset()
@ -76,6 +79,7 @@ class AgentExecutorIterator:
tags: Optional[list[str]]
metadata: Optional[Dict[str, Any]]
run_name: Optional[str]
run_id: Optional[UUID]
include_run_info: bool
yield_actions: bool
@ -162,6 +166,7 @@ class AgentExecutorIterator:
run_manager = callback_manager.on_chain_start(
dumpd(self.agent_executor),
self.inputs,
self.run_id,
name=self.run_name,
)
try:
@ -227,6 +232,7 @@ class AgentExecutorIterator:
run_manager = await callback_manager.on_chain_start(
dumpd(self.agent_executor),
self.inputs,
self.run_id,
name=self.run_name,
)
try:

@ -1,4 +1,5 @@
"""Base interface that all chains should implement."""
import inspect
import json
import logging
@ -127,6 +128,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
tags = config.get("tags")
metadata = config.get("metadata")
run_name = config.get("run_name") or self.get_name()
run_id = config.get("run_id")
include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False)
@ -145,6 +147,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
run_id,
name=run_name,
)
try:
@ -178,6 +181,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
tags = config.get("tags")
metadata = config.get("metadata")
run_name = config.get("run_name") or self.get_name()
run_id = config.get("run_id")
include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False)
@ -195,6 +199,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
run_manager = await callback_manager.on_chain_start(
dumpd(self),
inputs,
run_id,
name=run_name,
)
try:

@ -3,6 +3,7 @@ from uuid import UUID
import pytest
from langchain_core.language_models import FakeListLLM
from langchain_core.tools import Tool
from langchain_core.tracers.context import collect_runs
from langchain.agents import (
AgentExecutor,
@ -251,6 +252,28 @@ def test_agent_iterator_properties_and_setters() -> None:
assert isinstance(agent_iter.agent_executor, AgentExecutor)
def test_agent_iterator_manual_run_id() -> None:
"""Test react chain iterator with manually specified run_id."""
agent = _get_agent()
run_id = UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")
with collect_runs() as cb:
agent_iter = agent.stream("when was langchain made", {"run_id": run_id})
list(agent_iter)
run = cb.traced_runs[0]
assert run.id == run_id
async def test_manually_specify_rid_async() -> None:
agent = _get_agent()
run_id = UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")
with collect_runs() as cb:
res = agent.astream("bar", {"run_id": run_id})
async for _ in res:
pass
run = cb.traced_runs[0]
assert run.id == run_id
def test_agent_iterator_reset() -> None:
"""Test reset functionality of AgentExecutorIterator."""
agent = _get_agent()

@ -1,9 +1,12 @@
"""Test logic on base chain class."""
import uuid
from typing import Any, Dict, List, Optional
import pytest
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.memory import BaseMemory
from langchain_core.tracers.context import collect_runs
from langchain.chains.base import Chain
from langchain.schema import RUN_KEY
@ -180,6 +183,37 @@ def test_run_with_callback_and_input_error() -> None:
assert handler.errors == 1
def test_manually_specify_rid() -> None:
chain = FakeChain()
run_id = uuid.uuid4()
with collect_runs() as cb:
chain.invoke({"foo": "bar"}, {"run_id": run_id})
run = cb.traced_runs[0]
assert run.id == run_id
run_id2 = uuid.uuid4()
with collect_runs() as cb:
list(chain.stream({"foo": "bar"}, {"run_id": run_id2}))
run = cb.traced_runs[0]
assert run.id == run_id2
async def test_manually_specify_rid_async() -> None:
chain = FakeChain()
run_id = uuid.uuid4()
with collect_runs() as cb:
await chain.ainvoke({"foo": "bar"}, {"run_id": run_id})
run = cb.traced_runs[0]
assert run.id == run_id
run_id2 = uuid.uuid4()
with collect_runs() as cb:
res = chain.astream({"foo": "bar"}, {"run_id": run_id2})
async for _ in res:
pass
run = cb.traced_runs[0]
assert run.id == run_id2
def test_run_with_callback_and_output_error() -> None:
"""Test callback manager catches run validation output error."""
handler = FakeCallbackHandler()

Loading…
Cancel
Save