Pass through Run ID Explicitly (#21469)

pull/21420/head^2
William FH 2 weeks ago committed by GitHub
parent 83eecd54fe
commit b28be5d407
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,4 +1,5 @@
"""Chain that takes in an input and produces an action and action input."""
from __future__ import annotations
import asyncio
@ -346,11 +347,11 @@ class RunnableAgent(BaseSingleActionAgent):
input_keys_arg: List[str] = []
return_keys_arg: List[str] = []
stream_runnable: bool = True
"""Whether to stream from the runnable or not.
"""Whether to stream from the runnable or not.
If True then underlying LLM is invoked in a streaming fashion to make it possible
to get access to the individual LLM tokens when using stream_log with the Agent
Executor. If False then LLM is invoked in a non-streaming fashion and
If True then underlying LLM is invoked in a streaming fashion to make it possible
to get access to the individual LLM tokens when using stream_log with the Agent
Executor. If False then LLM is invoked in a non-streaming fashion and
individual LLM tokens will not be available in stream_log.
"""
@ -455,11 +456,11 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
input_keys_arg: List[str] = []
return_keys_arg: List[str] = []
stream_runnable: bool = True
"""Whether to stream from the runnable or not.
If True then underlying LLM is invoked in a streaming fashion to make it possible
to get access to the individual LLM tokens when using stream_log with the Agent
Executor. If False then LLM is invoked in a non-streaming fashion and
"""Whether to stream from the runnable or not.
If True then underlying LLM is invoked in a streaming fashion to make it possible
to get access to the individual LLM tokens when using stream_log with the Agent
Executor. If False then LLM is invoked in a non-streaming fashion and
individual LLM tokens will not be available in stream_log.
"""
@ -926,7 +927,7 @@ class AgentExecutor(Chain):
max_iterations: Optional[int] = 15
"""The maximum number of steps to take before ending the execution
loop.
Setting to 'None' could lead to an infinite loop."""
max_execution_time: Optional[float] = None
"""The maximum amount of wall clock time to spend in the execution
@ -938,7 +939,7 @@ class AgentExecutor(Chain):
`"force"` returns a string saying that it stopped because it met a
time or iteration limit.
`"generate"` calls the agent's LLM Chain one final time to generate
a final answer based on the previous steps.
"""
@ -1565,6 +1566,7 @@ class AgentExecutor(Chain):
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.get("run_id"),
yield_actions=True,
**kwargs,
)
@ -1586,6 +1588,7 @@ class AgentExecutor(Chain):
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.get("run_id"),
yield_actions=True,
**kwargs,
)

@ -14,6 +14,7 @@ from typing import (
Tuple,
Union,
)
from uuid import UUID
from langchain_core.agents import (
AgentAction,
@ -54,6 +55,7 @@ class AgentExecutorIterator:
tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[UUID] = None,
include_run_info: bool = False,
yield_actions: bool = False,
):
@ -67,6 +69,7 @@ class AgentExecutorIterator:
self.tags = tags
self.metadata = metadata
self.run_name = run_name
self.run_id = run_id
self.include_run_info = include_run_info
self.yield_actions = yield_actions
self.reset()
@ -76,6 +79,7 @@ class AgentExecutorIterator:
tags: Optional[list[str]]
metadata: Optional[Dict[str, Any]]
run_name: Optional[str]
run_id: Optional[UUID]
include_run_info: bool
yield_actions: bool
@ -162,6 +166,7 @@ class AgentExecutorIterator:
run_manager = callback_manager.on_chain_start(
dumpd(self.agent_executor),
self.inputs,
self.run_id,
name=self.run_name,
)
try:
@ -227,6 +232,7 @@ class AgentExecutorIterator:
run_manager = await callback_manager.on_chain_start(
dumpd(self.agent_executor),
self.inputs,
self.run_id,
name=self.run_name,
)
try:

@ -1,4 +1,5 @@
"""Base interface that all chains should implement."""
import inspect
import json
import logging
@ -127,6 +128,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
tags = config.get("tags")
metadata = config.get("metadata")
run_name = config.get("run_name") or self.get_name()
run_id = config.get("run_id")
include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False)
@ -145,6 +147,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
run_id,
name=run_name,
)
try:
@ -178,6 +181,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
tags = config.get("tags")
metadata = config.get("metadata")
run_name = config.get("run_name") or self.get_name()
run_id = config.get("run_id")
include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False)
@ -195,6 +199,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
run_manager = await callback_manager.on_chain_start(
dumpd(self),
inputs,
run_id,
name=run_name,
)
try:

@ -3,6 +3,7 @@ from uuid import UUID
import pytest
from langchain_core.language_models import FakeListLLM
from langchain_core.tools import Tool
from langchain_core.tracers.context import collect_runs
from langchain.agents import (
AgentExecutor,
@ -251,6 +252,28 @@ def test_agent_iterator_properties_and_setters() -> None:
assert isinstance(agent_iter.agent_executor, AgentExecutor)
def test_agent_iterator_manual_run_id() -> None:
"""Test react chain iterator with manually specified run_id."""
agent = _get_agent()
run_id = UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")
with collect_runs() as cb:
agent_iter = agent.stream("when was langchain made", {"run_id": run_id})
list(agent_iter)
run = cb.traced_runs[0]
assert run.id == run_id
async def test_manually_specify_rid_async() -> None:
agent = _get_agent()
run_id = UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")
with collect_runs() as cb:
res = agent.astream("bar", {"run_id": run_id})
async for _ in res:
pass
run = cb.traced_runs[0]
assert run.id == run_id
def test_agent_iterator_reset() -> None:
"""Test reset functionality of AgentExecutorIterator."""
agent = _get_agent()

@ -1,9 +1,12 @@
"""Test logic on base chain class."""
import uuid
from typing import Any, Dict, List, Optional
import pytest
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.memory import BaseMemory
from langchain_core.tracers.context import collect_runs
from langchain.chains.base import Chain
from langchain.schema import RUN_KEY
@ -180,6 +183,37 @@ def test_run_with_callback_and_input_error() -> None:
assert handler.errors == 1
def test_manually_specify_rid() -> None:
chain = FakeChain()
run_id = uuid.uuid4()
with collect_runs() as cb:
chain.invoke({"foo": "bar"}, {"run_id": run_id})
run = cb.traced_runs[0]
assert run.id == run_id
run_id2 = uuid.uuid4()
with collect_runs() as cb:
list(chain.stream({"foo": "bar"}, {"run_id": run_id2}))
run = cb.traced_runs[0]
assert run.id == run_id2
async def test_manually_specify_rid_async() -> None:
chain = FakeChain()
run_id = uuid.uuid4()
with collect_runs() as cb:
await chain.ainvoke({"foo": "bar"}, {"run_id": run_id})
run = cb.traced_runs[0]
assert run.id == run_id
run_id2 = uuid.uuid4()
with collect_runs() as cb:
res = chain.astream({"foo": "bar"}, {"run_id": run_id2})
async for _ in res:
pass
run = cb.traced_runs[0]
assert run.id == run_id2
def test_run_with_callback_and_output_error() -> None:
"""Test callback manager catches run validation output error."""
handler = FakeCallbackHandler()

Loading…
Cancel
Save