core[minor]: Add v2 implementation of astream events (#21638)

This PR introduces a v2 implementation of astream events that removes
intermediate abstractions and fixes some issues with v1 implementation.

The v2 implementation significantly reduces relevant code that's
associated with the astream events implementation together with
overhead.

After this PR, the astream events implementation:

- Uses an async callback handler
- No longer relies on BaseTracer
- No longer relies on json patch

As a result of this re-write, a number of issues were discovered with
the existing implementation.

## Changes in V2 vs. V1

### on_chat_model_end `output`

The outputs associated with `on_chat_model_end` changed depending on
whether it was within a chain or not.

As a root level runnable the output was: 

```python
"data": {"output": AIMessageChunk(content="hello world!", id='some id')}
```

As part of a chain the output was:

```
            "data": {
                "output": {
                    "generations": [
                        [
                            {
                                "generation_info": None,
                                "message": AIMessageChunk(
                                    content="hello world!", id=AnyStr()
                                ),
                                "text": "hello world!",
                                "type": "ChatGenerationChunk",
                            }
                        ]
                    ],
                    "llm_output": None,
                }
            },
```

After this PR, we will always use the simpler representation:

```python
"data": {"output": AIMessageChunk(content="hello world!", id='some id')}
```

**NOTE** Non chat models (i.e., regular LLMs) are still associated with
the more verbose format.

### Remove some `_stream` events

`on_retriever_stream` and `on_tool_stream` events were removed -- these
were not real events, but created as an artifact of implementing on top
of astream_log.

The same information is already available in the `x_on_end` events.

### Propagating Names

Names of runnables have been updated to be more consistent

```python
  model = GenericFakeChatModel(messages=infinite_cycle).configurable_fields(
        messages=ConfigurableField(
            id="messages",
            name="Messages",
            description="Messages return by the LLM",
        )
    )
```

Before:
```python
"name": "RunnableConfigurableFields",
```

After:
```python
"name": "GenericFakeChatModel",
```

### on_retriever_end

on_retriever_end will always return `output` which is a list of
documents (rather than a dict containing a key called "documents")

### Retry events

Removed the `on_retry` callback handler. It was incorrectly showing that
the failed function being retried has invoked `on_chain_end`


https://github.com/langchain-ai/langchain/pull/21638/files#diff-e512e3f84daf23029ebcceb11460f1c82056314653673e450a5831147d8cb84dL1394
pull/21600/head
Eugene Yurtsev 3 weeks ago committed by GitHub
parent 54e003268e
commit 5c2cfabec6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -52,7 +52,7 @@ from langchain_core.outputs import (
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.runnables.config import ensure_config, run_in_executor
from langchain_core.tracers.log_stream import LogStreamCallbackHandler
from langchain_core.tracers._streaming import _StreamingCallbackHandler
if TYPE_CHECKING:
from langchain_core.pydantic_v1 import BaseModel
@ -608,7 +608,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
(
True
for h in run_manager.handlers
if isinstance(h, LogStreamCallbackHandler)
if isinstance(h, _StreamingCallbackHandler)
),
False,
)
@ -691,7 +691,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
(
True
for h in run_manager.handlers
if isinstance(h, LogStreamCallbackHandler)
if isinstance(h, _StreamingCallbackHandler)
),
False,
)

@ -58,7 +58,7 @@ from langchain_core.runnables.config import (
var_child_runnable_config,
)
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.schema import EventData, StreamEvent
from langchain_core.runnables.schema import StreamEvent
from langchain_core.runnables.utils import (
AddableDict,
AnyConfigurableField,
@ -90,7 +90,6 @@ if TYPE_CHECKING:
RunnableWithFallbacks as RunnableWithFallbacksT,
)
from langchain_core.tracers.log_stream import (
LogEntry,
RunLog,
RunLogPatch,
)
@ -927,7 +926,7 @@ class Runnable(Generic[Input, Output], ABC):
input: Any,
config: Optional[RunnableConfig] = None,
*,
version: Literal["v1"],
version: Literal["v1", "v2"],
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
@ -962,6 +961,8 @@ class Runnable(Generic[Input, Output], ABC):
chains. Metadata fields have been omitted from the table for brevity.
Chain definitions have been included after the table.
**ATTENTION** This reference table is for the V2 version of the schema.
+----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
| event | name | chunk | input | output |
+======================+==================+=================================+===============================================+=================================================+
@ -969,7 +970,7 @@ class Runnable(Generic[Input, Output], ABC):
+----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
| on_chat_model_stream | [model name] | AIMessageChunk(content="hello") | | |
+----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
| on_chat_model_end | [model name] | | {"messages": [[SystemMessage, HumanMessage]]} | {"generations": [...], "llm_output": None, ...} |
| on_chat_model_end | [model name] | | {"messages": [[SystemMessage, HumanMessage]]} | AIMessageChunk(content="hello world") |
+----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
| on_llm_start | [model name] | | {'input': 'hello'} | |
+----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
@ -985,15 +986,11 @@ class Runnable(Generic[Input, Output], ABC):
+----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
| on_tool_start | some_tool | | {"x": 1, "y": "2"} | |
+----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
| on_tool_stream | some_tool | {"x": 1, "y": "2"} | | |
+----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
| on_tool_end | some_tool | | | {"x": 1, "y": "2"} |
+----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
| on_retriever_start | [retriever name] | | {"query": "hello"} | |
+----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
| on_retriever_chunk | [retriever name] | {documents: [...]} | | |
+----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
| on_retriever_end | [retriever name] | | {"query": "hello"} | {documents: [...]} |
| on_retriever_end | [retriever name] | | {"query": "hello"} | [Document(...), ..] |
+----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
| on_prompt_start | [template_name] | | {"question": "hello"} | |
+----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
@ -1042,7 +1039,7 @@ class Runnable(Generic[Input, Output], ABC):
chain = RunnableLambda(func=reverse)
events = [
event async for event in chain.astream_events("hello", version="v1")
event async for event in chain.astream_events("hello", version="v2")
]
# will produce the following events (run_id has been omitted for brevity):
@ -1073,8 +1070,10 @@ class Runnable(Generic[Input, Output], ABC):
Args:
input: The input to the runnable.
config: The config to use for the runnable.
version: The version of the schema to use.
Currently only version 1 is available.
version: The version of the schema to use either `v2` or `v1`.
Users should use `v2`.
`v1` is for backwards compatibility and will be deprecated
in 0.4.0.
No default will be assigned until the API is stabilized.
include_names: Only include events from runnables with matching names.
include_types: Only include events from runnables with matching types.
@ -1089,180 +1088,45 @@ class Runnable(Generic[Input, Output], ABC):
Returns:
An async stream of StreamEvents.
""" # noqa: E501
if version != "v1":
raise NotImplementedError(
'Only version "v1" of the schema is currently supported.'
)
from langchain_core.runnables.utils import (
_RootEventFilter,
)
from langchain_core.tracers.log_stream import (
LogStreamCallbackHandler,
RunLog,
_astream_log_implementation,
)
stream = LogStreamCallbackHandler(
auto_close=False,
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_names=exclude_names,
exclude_types=exclude_types,
exclude_tags=exclude_tags,
_schema_format="streaming_events",
)
run_log = RunLog(state=None) # type: ignore[arg-type]
encountered_start_event = False
_root_event_filter = _RootEventFilter(
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_names=exclude_names,
exclude_types=exclude_types,
exclude_tags=exclude_tags,
from langchain_core.tracers.event_stream import (
_astream_events_implementation_v1,
_astream_events_implementation_v2,
)
config = ensure_config(config)
root_tags = config.get("tags", [])
root_metadata = config.get("metadata", {})
root_name = config.get("run_name", self.get_name())
# Ignoring mypy complaint about too many different union combinations
# This arises because many of the argument types are unions
async for log in _astream_log_implementation( # type: ignore[misc]
self,
input,
config=config,
stream=stream,
diff=True,
with_streamed_output_list=True,
**kwargs,
):
run_log = run_log + log
if not encountered_start_event:
# Yield the start event for the root runnable.
encountered_start_event = True
state = run_log.state.copy()
event = StreamEvent(
event=f"on_{state['type']}_start",
run_id=state["id"],
name=root_name,
tags=root_tags,
metadata=root_metadata,
data={
"input": input,
},
)
if _root_event_filter.include_event(event, state["type"]):
yield event
paths = {
op["path"].split("/")[2]
for op in log.ops
if op["path"].startswith("/logs/")
}
# Elements in a set should be iterated in the same order
# as they were inserted in modern python versions.
for path in paths:
data: EventData = {}
log_entry: LogEntry = run_log.state["logs"][path]
if log_entry["end_time"] is None:
if log_entry["streamed_output"]:
event_type = "stream"
else:
event_type = "start"
else:
event_type = "end"
if event_type == "start":
# Include the inputs with the start event if they are available.
# Usually they will NOT be available for components that operate
# on streams, since those components stream the input and
# don't know its final value until the end of the stream.
inputs = log_entry["inputs"]
if inputs is not None:
data["input"] = inputs
pass
if event_type == "end":
inputs = log_entry["inputs"]
if inputs is not None:
data["input"] = inputs
# None is a VALID output for an end event
data["output"] = log_entry["final_output"]
if event_type == "stream":
num_chunks = len(log_entry["streamed_output"])
if num_chunks != 1:
raise AssertionError(
f"Expected exactly one chunk of streamed output, "
f"got {num_chunks} instead. This is impossible. "
f"Encountered in: {log_entry['name']}"
)
data = {"chunk": log_entry["streamed_output"][0]}
# Clean up the stream, we don't need it anymore.
# And this avoids duplicates as well!
log_entry["streamed_output"] = []
yield StreamEvent(
event=f"on_{log_entry['type']}_{event_type}",
name=log_entry["name"],
run_id=log_entry["id"],
tags=log_entry["tags"],
metadata=log_entry["metadata"],
data=data,
)
# Finally, we take care of the streaming output from the root chain
# if there is any.
state = run_log.state
if state["streamed_output"]:
num_chunks = len(state["streamed_output"])
if num_chunks != 1:
raise AssertionError(
f"Expected exactly one chunk of streamed output, "
f"got {num_chunks} instead. This is impossible. "
f"Encountered in: {state['name']}"
)
if version == "v2":
event_stream = _astream_events_implementation_v2(
self,
input,
config=config,
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_names=exclude_names,
exclude_types=exclude_types,
exclude_tags=exclude_tags,
**kwargs,
)
elif version == "v1":
# First implementation, built on top of astream_log API
# This implementation will be deprecated as of 0.2.0
event_stream = _astream_events_implementation_v1(
self,
input,
config=config,
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_names=exclude_names,
exclude_types=exclude_types,
exclude_tags=exclude_tags,
**kwargs,
)
else:
raise NotImplementedError(
'Only versions "v1" and "v2" of the schema is currently supported.'
)
data = {"chunk": state["streamed_output"][0]}
# Clean up the stream, we don't need it anymore.
state["streamed_output"] = []
event = StreamEvent(
event=f"on_{state['type']}_stream",
run_id=state["id"],
tags=root_tags,
metadata=root_metadata,
name=root_name,
data=data,
)
if _root_event_filter.include_event(event, state["type"]):
yield event
state = run_log.state
# Finally yield the end event for the root runnable.
event = StreamEvent(
event=f"on_{state['type']}_end",
name=root_name,
run_id=state["id"],
tags=root_tags,
metadata=root_metadata,
data={
"output": state["final_output"],
},
)
if _root_event_filter.include_event(event, state["type"]):
async for event in event_stream:
yield event
def transform(
@ -1936,7 +1800,8 @@ class Runnable(Generic[Input, Output], ABC):
"""Helper method to transform an Async Iterator of Input values into an Async
Iterator of Output values, with callbacks.
Use this to implement `astream()` or `atransform()` in Runnable subclasses."""
from langchain_core.tracers.log_stream import LogStreamCallbackHandler
# Mixin that is used by both astream log and astream events implementation
from langchain_core.tracers._streaming import _StreamingCallbackHandler
# tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = atee(input, 2)
@ -1964,16 +1829,18 @@ class Runnable(Generic[Input, Output], ABC):
context = copy_context()
context.run(var_child_runnable_config.set, child_config)
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
if stream_log := next(
if stream_handler := next(
(
h
cast(_StreamingCallbackHandler, h)
for h in run_manager.handlers
if isinstance(h, LogStreamCallbackHandler)
# instance check OK here, it's a mixin
if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc]
),
None,
):
# populates streamed_output in astream_log() output if needed
iterator = stream_log.tap_output_aiter(run_manager.run_id, iterator)
iterator = stream_handler.tap_output_aiter(run_manager.run_id, iterator)
try:
while True:
if accepts_context(asyncio.create_task):

@ -0,0 +1,28 @@
"""Internal tracers used for stream_log and astream events implementations."""
import abc
from typing import AsyncIterator, TypeVar
from uuid import UUID
T = TypeVar("T")
class _StreamingCallbackHandler(abc.ABC):
"""For internal use.
This is a common mixin that the callback handlers
for both astream events and astream log inherit from.
The `tap_output_aiter` method is invoked in some contexts
to produce callbacks for intermediate results.
"""
@abc.abstractmethod
def tap_output_aiter(
self, run_id: UUID, output: AsyncIterator[T]
) -> AsyncIterator[T]:
"""Used for internal astream_log and astream events implementations."""
__all__ = [
"_StreamingCallbackHandler",
]

@ -0,0 +1,790 @@
"""Internal tracer to power the event stream API."""
from __future__ import annotations
import asyncio
import logging
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
List,
Optional,
Sequence,
TypeVar,
Union,
cast,
)
from uuid import UUID
from typing_extensions import NotRequired, TypedDict
from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
from langchain_core.outputs import (
ChatGenerationChunk,
GenerationChunk,
LLMResult,
)
from langchain_core.runnables.schema import EventData, StreamEvent
from langchain_core.runnables.utils import (
Input,
Output,
_RootEventFilter,
)
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from langchain_core.tracers.log_stream import LogEntry
from langchain_core.tracers.memory_stream import _MemoryStream
if TYPE_CHECKING:
from langchain_core.documents import Document
from langchain_core.runnables import Runnable, RunnableConfig
logger = logging.getLogger(__name__)
class RunInfo(TypedDict):
"""Information about a run."""
name: str
tags: List[str]
metadata: Dict[str, Any]
run_type: str
inputs: NotRequired[Any]
def _assign_name(name: Optional[str], serialized: Dict[str, Any]) -> str:
"""Assign a name to a run."""
if name is not None:
return name
if "name" in serialized:
return serialized["name"]
elif "id" in serialized:
return serialized["id"][-1]
return "Unnamed"
T = TypeVar("T")
class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHandler):
"""An implementation of an async callback handler for astream events."""
def __init__(
self,
*args: Any,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> None:
"""Initialize the tracer."""
super().__init__(*args, **kwargs)
# Map of run ID to run info.
self.run_map: Dict[UUID, RunInfo] = {}
# Filter which events will be sent over the queue.
self.root_event_filter = _RootEventFilter(
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_names=exclude_names,
exclude_types=exclude_types,
exclude_tags=exclude_tags,
)
loop = asyncio.get_event_loop()
memory_stream = _MemoryStream[StreamEvent](loop)
self.send_stream = memory_stream.get_send_stream()
self.receive_stream = memory_stream.get_receive_stream()
async def _send(self, event: StreamEvent, event_type: str) -> None:
"""Send an event to the stream."""
if self.root_event_filter.include_event(event, event_type):
await self.send_stream.send(event)
def __aiter__(self) -> AsyncIterator[Any]:
"""Iterate over the receive stream."""
return self.receive_stream.__aiter__()
async def tap_output_aiter(
self, run_id: UUID, output: AsyncIterator[T]
) -> AsyncIterator[T]:
"""Tap the output aiter."""
async for chunk in output:
run_info = self.run_map.get(run_id)
if run_info is None:
raise AssertionError(f"Run ID {run_id} not found in run map.")
await self._send(
{
"event": f"on_{run_info['run_type']}_stream",
"data": {"chunk": chunk},
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
},
run_info["run_type"],
)
yield chunk
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Start a trace for an LLM run."""
name_ = _assign_name(name, serialized)
run_type = "chat_model"
self.run_map[run_id] = {
"tags": tags or [],
"metadata": metadata or {},
"name": name_,
"run_type": run_type,
"inputs": {"messages": messages},
}
await self._send(
{
"event": "on_chat_model_start",
"data": {
"input": {"messages": messages},
},
"name": name_,
"tags": tags or [],
"run_id": str(run_id),
"metadata": metadata or {},
},
run_type,
)
async def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Start a trace for an LLM run."""
name_ = _assign_name(name, serialized)
run_type = "llm"
self.run_map[run_id] = {
"tags": tags or [],
"metadata": metadata or {},
"name": name_,
"run_type": run_type,
"inputs": {"prompts": prompts},
}
await self._send(
{
"event": "on_llm_start",
"data": {
"input": {
"prompts": prompts,
}
},
"name": name_,
"tags": tags or [],
"run_id": str(run_id),
"metadata": metadata or {},
},
run_type,
)
async def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
run_info = self.run_map.get(run_id)
chunk_: Union[GenerationChunk, BaseMessageChunk]
if run_info is None:
raise AssertionError(f"Run ID {run_id} not found in run map.")
if run_info["run_type"] == "chat_model":
event = "on_chat_model_stream"
if chunk is None:
chunk_ = AIMessageChunk(content=token)
else:
chunk_ = cast(ChatGenerationChunk, chunk).message
elif run_info["run_type"] == "llm":
event = "on_llm_stream"
if chunk is None:
chunk_ = GenerationChunk(text=token)
else:
chunk_ = cast(GenerationChunk, chunk)
else:
raise ValueError(f"Unexpected run type: {run_info['run_type']}")
await self._send(
{
"event": event,
"data": {
"chunk": chunk_,
},
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
},
run_info["run_type"],
)
async def on_llm_end(
self, response: LLMResult, *, run_id: UUID, **kwargs: Any
) -> None:
"""End a trace for an LLM run."""
run_info = self.run_map.pop(run_id)
inputs_ = run_info["inputs"]
generations: Union[List[List[GenerationChunk]], List[List[ChatGenerationChunk]]]
output: Union[dict, BaseMessage] = {}
if run_info["run_type"] == "chat_model":
generations = cast(List[List[ChatGenerationChunk]], response.generations)
for gen in generations:
if output != {}:
break
for chunk in gen:
output = chunk.message
break
event = "on_chat_model_end"
elif run_info["run_type"] == "llm":
generations = cast(List[List[GenerationChunk]], response.generations)
output = {
"generations": [
[
{
"text": chunk.text,
"generation_info": chunk.generation_info,
"type": chunk.type,
}
for chunk in gen
]
for gen in generations
],
"llm_output": response.llm_output,
}
event = "on_llm_end"
else:
raise ValueError(f"Unexpected run type: {run_info['run_type']}")
await self._send(
{
"event": event,
"data": {"output": output, "input": inputs_},
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
},
run_info["run_type"],
)
async def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
run_type: Optional[str] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Start a trace for a chain run."""
name_ = _assign_name(name, serialized)
run_type_ = run_type or "chain"
run_info: RunInfo = {
"tags": tags or [],
"metadata": metadata or {},
"name": name_,
"run_type": run_type_,
}
data: EventData = {}
# Work-around Runnable core code not sending input in some
# cases.
if inputs != {"input": ""}:
data["input"] = inputs
run_info["inputs"] = inputs
self.run_map[run_id] = run_info
await self._send(
{
"event": f"on_{run_type_}_start",
"data": data,
"name": name_,
"tags": tags or [],
"run_id": str(run_id),
"metadata": metadata or {},
},
run_type_,
)
async def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
inputs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""End a trace for a chain run."""
run_info = self.run_map.pop(run_id)
run_type = run_info["run_type"]
event = f"on_{run_type}_end"
inputs = inputs or run_info.get("inputs") or {}
data: EventData = {
"output": outputs,
"input": inputs,
}
await self._send(
{
"event": event,
"data": data,
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
},
run_type,
)
async def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
inputs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Start a trace for a tool run."""
name_ = _assign_name(name, serialized)
self.run_map[run_id] = {
"tags": tags or [],
"metadata": metadata or {},
"name": name_,
"run_type": "tool",
"inputs": inputs,
}
await self._send(
{
"event": "on_tool_start",
"data": {
"input": inputs or {},
},
"name": name_,
"tags": tags or [],
"run_id": str(run_id),
"metadata": metadata or {},
},
"tool",
)
async def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
"""End a trace for a tool run."""
run_info = self.run_map.pop(run_id)
if "inputs" not in run_info:
raise AssertionError(
f"Run ID {run_id} is a tool call and is expected to have "
f"inputs associated with it."
)
inputs = run_info["inputs"]
await self._send(
{
"event": "on_tool_end",
"data": {
"output": output,
"input": inputs,
},
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
},
"tool",
)
async def on_retriever_start(
self,
serialized: Dict[str, Any],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Run when Retriever starts running."""
name_ = _assign_name(name, serialized)
run_type = "retriever"
self.run_map[run_id] = {
"tags": tags or [],
"metadata": metadata or {},
"name": name_,
"run_type": run_type,
"inputs": {"query": query},
}
await self._send(
{
"event": "on_retriever_start",
"data": {
"input": {
"query": query,
}
},
"name": name_,
"tags": tags or [],
"run_id": str(run_id),
"metadata": metadata or {},
},
run_type,
)
async def on_retriever_end(
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
) -> None:
"""Run when Retriever ends running."""
run_info = self.run_map.pop(run_id)
await self._send(
{
"event": "on_retriever_end",
"data": {
"output": documents,
"input": run_info["inputs"],
},
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
},
run_info["run_type"],
)
def __deepcopy__(self, memo: dict) -> _AstreamEventsCallbackHandler:
"""Deepcopy the tracer."""
return self
def __copy__(self) -> _AstreamEventsCallbackHandler:
"""Copy the tracer."""
return self
async def _astream_events_implementation_v1(
runnable: Runnable[Input, Output],
input: Any,
config: Optional[RunnableConfig] = None,
*,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AsyncIterator[StreamEvent]:
from langchain_core.runnables import ensure_config
from langchain_core.runnables.utils import _RootEventFilter
from langchain_core.tracers.log_stream import (
LogStreamCallbackHandler,
RunLog,
_astream_log_implementation,
)
stream = LogStreamCallbackHandler(
auto_close=False,
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_names=exclude_names,
exclude_types=exclude_types,
exclude_tags=exclude_tags,
_schema_format="streaming_events",
)
run_log = RunLog(state=None) # type: ignore[arg-type]
encountered_start_event = False
_root_event_filter = _RootEventFilter(
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_names=exclude_names,
exclude_types=exclude_types,
exclude_tags=exclude_tags,
)
config = ensure_config(config)
root_tags = config.get("tags", [])
root_metadata = config.get("metadata", {})
root_name = config.get("run_name", runnable.get_name())
# Ignoring mypy complaint about too many different union combinations
# This arises because many of the argument types are unions
async for log in _astream_log_implementation( # type: ignore[misc]
runnable,
input,
config=config,
stream=stream,
diff=True,
with_streamed_output_list=True,
**kwargs,
):
run_log = run_log + log
if not encountered_start_event:
# Yield the start event for the root runnable.
encountered_start_event = True
state = run_log.state.copy()
event = StreamEvent(
event=f"on_{state['type']}_start",
run_id=state["id"],
name=root_name,
tags=root_tags,
metadata=root_metadata,
data={
"input": input,
},
)
if _root_event_filter.include_event(event, state["type"]):
yield event
paths = {
op["path"].split("/")[2]
for op in log.ops
if op["path"].startswith("/logs/")
}
# Elements in a set should be iterated in the same order
# as they were inserted in modern python versions.
for path in paths:
data: EventData = {}
log_entry: LogEntry = run_log.state["logs"][path]
if log_entry["end_time"] is None:
if log_entry["streamed_output"]:
event_type = "stream"
else:
event_type = "start"
else:
event_type = "end"
if event_type == "start":
# Include the inputs with the start event if they are available.
# Usually they will NOT be available for components that operate
# on streams, since those components stream the input and
# don't know its final value until the end of the stream.
inputs = log_entry["inputs"]
if inputs is not None:
data["input"] = inputs
pass
if event_type == "end":
inputs = log_entry["inputs"]
if inputs is not None:
data["input"] = inputs
# None is a VALID output for an end event
data["output"] = log_entry["final_output"]
if event_type == "stream":
num_chunks = len(log_entry["streamed_output"])
if num_chunks != 1:
raise AssertionError(
f"Expected exactly one chunk of streamed output, "
f"got {num_chunks} instead. This is impossible. "
f"Encountered in: {log_entry['name']}"
)
data = {"chunk": log_entry["streamed_output"][0]}
# Clean up the stream, we don't need it anymore.
# And this avoids duplicates as well!
log_entry["streamed_output"] = []
yield StreamEvent(
event=f"on_{log_entry['type']}_{event_type}",
name=log_entry["name"],
run_id=log_entry["id"],
tags=log_entry["tags"],
metadata=log_entry["metadata"],
data=data,
)
# Finally, we take care of the streaming output from the root chain
# if there is any.
state = run_log.state
if state["streamed_output"]:
num_chunks = len(state["streamed_output"])
if num_chunks != 1:
raise AssertionError(
f"Expected exactly one chunk of streamed output, "
f"got {num_chunks} instead. This is impossible. "
f"Encountered in: {state['name']}"
)
data = {"chunk": state["streamed_output"][0]}
# Clean up the stream, we don't need it anymore.
state["streamed_output"] = []
event = StreamEvent(
event=f"on_{state['type']}_stream",
run_id=state["id"],
tags=root_tags,
metadata=root_metadata,
name=root_name,
data=data,
)
if _root_event_filter.include_event(event, state["type"]):
yield event
state = run_log.state
# Finally yield the end event for the root runnable.
event = StreamEvent(
event=f"on_{state['type']}_end",
name=root_name,
run_id=state["id"],
tags=root_tags,
metadata=root_metadata,
data={
"output": state["final_output"],
},
)
if _root_event_filter.include_event(event, state["type"]):
yield event
async def _astream_events_implementation_v2(
runnable: Runnable[Input, Output],
input: Any,
config: Optional[RunnableConfig] = None,
*,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AsyncIterator[StreamEvent]:
"""Implementation of the astream events API for V2 runnables."""
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_core.runnables import ensure_config
event_streamer = _AstreamEventsCallbackHandler(
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_names=exclude_names,
exclude_types=exclude_types,
exclude_tags=exclude_tags,
)
# Assign the stream handler to the config
config = ensure_config(config)
callbacks = config.get("callbacks")
if callbacks is None:
config["callbacks"] = [event_streamer]
elif isinstance(callbacks, list):
config["callbacks"] = callbacks + [event_streamer]
elif isinstance(callbacks, BaseCallbackManager):
callbacks = callbacks.copy()
callbacks.add_handler(event_streamer, inherit=True)
config["callbacks"] = callbacks
else:
raise ValueError(
f"Unexpected type for callbacks: {callbacks}."
"Expected None, list or AsyncCallbackManager."
)
# Call the runnable in streaming mode,
# add each chunk to the output stream
async def consume_astream() -> None:
try:
async for _ in runnable.astream(input, config, **kwargs):
# All the content will be picked up
pass
finally:
await event_streamer.send_stream.aclose()
# Start the runnable in a task, so we can start consuming output
task = asyncio.create_task(consume_astream())
first_event_sent = False
first_event_run_id = None
try:
async for event in event_streamer:
if not first_event_sent:
first_event_sent = True
# This is a work-around an issue where the inputs into the
# chain are not available until the entire input is consumed.
# As a temporary solution, we'll modify the input to be the input
# that was passed into the chain.
event["data"]["input"] = input
first_event_run_id = event["run_id"]
yield event
continue
if event["run_id"] == first_event_run_id and event["event"].endswith(
"_end"
):
# If it's the end event corresponding to the root runnable
# we dont include the input in the event since it's guaranteed
# to be included in the first event.
if "input" in event["data"]:
del event["data"]["input"]
yield event
finally:
# Wait for the runnable to finish, if not cancelled (eg. by break)
try:
await task
except asyncio.CancelledError:
pass

@ -26,6 +26,7 @@ from langchain_core.load.load import load
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
from langchain_core.runnables.utils import Input, Output
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.memory_stream import _MemoryStream
from langchain_core.tracers.schemas import Run
@ -157,7 +158,7 @@ class RunLog(RunLogPatch):
T = TypeVar("T")
class LogStreamCallbackHandler(BaseTracer):
class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
"""Tracer that streams run logs to a stream."""
def __init__(

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
[[package]]
name = "annotated-types"
@ -1198,7 +1198,7 @@ files = [
[[package]]
name = "langchain-text-splitters"
version = "0.0.1"
version = "0.0.2"
description = "LangChain text splitting utilities"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -2142,7 +2142,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},

Loading…
Cancel
Save