core: Tap output of sync iterators for astream_events (#21842)

Thank you for contributing to LangChain!

- [ ] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core,
experimental, etc. is being modified. Use "docs: ..." for purely docs
changes, "templates: ..." for template changes, "infra: ..." for CI
changes.
  - Example: "community: add foobar LLM"


- [ ] **PR message**: ***Delete this entire checklist*** and replace
with
    - **Description:** a description of the change
    - **Issue:** the issue # it fixes, if applicable
    - **Dependencies:** any dependencies required for this change
- **Twitter handle:** if your PR gets announced, and you'd like a
mention, we'll gladly shout you out!


- [ ] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [ ] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, hwchase17.
pull/21852/head
Nuno Campos 2 weeks ago committed by GitHub
parent 9a39f92aba
commit b1e7b40b6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1716,6 +1716,9 @@ class Runnable(Generic[Input, Output], ABC):
"""Helper method to transform an Iterator of Input values into an Iterator of
Output values, with callbacks.
Use this to implement `stream()` or `transform()` in Runnable subclasses."""
# Mixin that is used by both astream log and astream events implementation
from langchain_core.tracers._streaming import _StreamingCallbackHandler
# tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = tee(input, 2)
# Start the input iterator to ensure the input runnable starts before this one
@ -1742,6 +1745,17 @@ class Runnable(Generic[Input, Output], ABC):
context = copy_context()
context.run(var_child_runnable_config.set, child_config)
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
if stream_handler := next(
(
cast(_StreamingCallbackHandler, h)
for h in run_manager.handlers
# instance check OK here, it's a mixin
if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc]
),
None,
):
# populates streamed_output in astream_log() output if needed
iterator = stream_handler.tap_output_iter(run_manager.run_id, iterator)
try:
while True:
chunk: Output = context.run(next, iterator) # type: ignore

@ -1,6 +1,6 @@
"""Internal tracers used for stream_log and astream events implementations."""
import abc
from typing import AsyncIterator, TypeVar
from typing import AsyncIterator, Iterator, TypeVar
from uuid import UUID
T = TypeVar("T")
@ -22,6 +22,10 @@ class _StreamingCallbackHandler(abc.ABC):
) -> AsyncIterator[T]:
"""Used for internal astream_log and astream events implementations."""
@abc.abstractmethod
def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
"""Used for internal astream_log and astream events implementations."""
__all__ = [
"_StreamingCallbackHandler",

@ -9,6 +9,7 @@ from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Sequence,
@ -102,10 +103,10 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
self.send_stream = memory_stream.get_send_stream()
self.receive_stream = memory_stream.get_receive_stream()
async def _send(self, event: StreamEvent, event_type: str) -> None:
def _send(self, event: StreamEvent, event_type: str) -> None:
"""Send an event to the stream."""
if self.root_event_filter.include_event(event, event_type):
await self.send_stream.send(event)
self.send_stream.send_nowait(event)
def __aiter__(self) -> AsyncIterator[Any]:
"""Iterate over the receive stream."""
@ -119,7 +120,26 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
run_info = self.run_map.get(run_id)
if run_info is None:
raise AssertionError(f"Run ID {run_id} not found in run map.")
await self._send(
self._send(
{
"event": f"on_{run_info['run_type']}_stream",
"data": {"chunk": chunk},
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
},
run_info["run_type"],
)
yield chunk
def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
"""Tap the output aiter."""
for chunk in output:
run_info = self.run_map.get(run_id)
if run_info is None:
raise AssertionError(f"Run ID {run_id} not found in run map.")
self._send(
{
"event": f"on_{run_info['run_type']}_stream",
"data": {"chunk": chunk},
@ -155,7 +175,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
"inputs": {"messages": messages},
}
await self._send(
self._send(
{
"event": "on_chat_model_start",
"data": {
@ -192,7 +212,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
"inputs": {"prompts": prompts},
}
await self._send(
self._send(
{
"event": "on_llm_start",
"data": {
@ -241,7 +261,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
else:
raise ValueError(f"Unexpected run type: {run_info['run_type']}")
await self._send(
self._send(
{
"event": event,
"data": {
@ -295,7 +315,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
else:
raise ValueError(f"Unexpected run type: {run_info['run_type']}")
await self._send(
self._send(
{
"event": event,
"data": {"output": output, "input": inputs_},
@ -340,7 +360,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
self.run_map[run_id] = run_info
await self._send(
self._send(
{
"event": f"on_{run_type_}_start",
"data": data,
@ -373,7 +393,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
"input": inputs,
}
await self._send(
self._send(
{
"event": event,
"data": data,
@ -408,7 +428,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
"inputs": inputs,
}
await self._send(
self._send(
{
"event": "on_tool_start",
"data": {
@ -432,7 +452,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
)
inputs = run_info["inputs"]
await self._send(
self._send(
{
"event": "on_tool_end",
"data": {
@ -470,7 +490,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
"inputs": {"query": query},
}
await self._send(
self._send(
{
"event": "on_retriever_start",
"data": {
@ -492,7 +512,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
"""Run when Retriever ends running."""
run_info = self.run_map.pop(run_id)
await self._send(
self._send(
{
"event": "on_retriever_end",
"data": {

@ -8,6 +8,7 @@ from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Literal,
Optional,
@ -252,6 +253,25 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
yield chunk
def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
"""Tap an output async iterator to stream its values to the log."""
for chunk in output:
# root run is handled in .astream_log()
if run_id != self.root_id:
# if we can't find the run silently ignore
# eg. because this run wasn't included in the log
if key := self._key_map_by_run_id.get(run_id):
if not self.send(
{
"op": "add",
"path": f"/logs/{key}/streamed_output/-",
"value": chunk,
}
):
break
yield chunk
def include_run(self, run: Run) -> bool:
if run.id == self.root_id:
return False

@ -1650,27 +1650,22 @@ EXPECTED_EVENTS = [
]
@pytest.mark.xfail(
reason="This test is failing due to missing functionality."
"Need to implement logic in _transform_stream_with_config that mimics the async "
"variant that uses tap_output_iter"
)
async def test_sync_in_async_stream_lambdas() -> None:
"""Test invoking nested runnable lambda."""
def add_one_(x: int) -> int:
def add_one(x: int) -> int:
return x + 1
add_one = RunnableLambda(add_one_)
add_one_ = RunnableLambda(add_one)
async def add_one_proxy_(x: int, config: RunnableConfig) -> int:
streaming = add_one.stream(x, config)
async def add_one_proxy(x: int, config: RunnableConfig) -> int:
streaming = add_one_.stream(x, config)
results = [result for result in streaming]
return results[0]
add_one_proxy = RunnableLambda(add_one_proxy_) # type: ignore
add_one_proxy_ = RunnableLambda(add_one_proxy) # type: ignore
events = await _collect_events(add_one_proxy.astream_events(1, version="v2"))
events = await _collect_events(add_one_proxy_.astream_events(1, version="v2"))
assert events == EXPECTED_EVENTS
@ -1694,11 +1689,6 @@ async def test_async_in_async_stream_lambdas() -> None:
assert events == EXPECTED_EVENTS
@pytest.mark.xfail(
reason="This test is failing due to missing functionality."
"Need to implement logic in _transform_stream_with_config that mimics the async "
"variant that uses tap_output_iter"
)
async def test_sync_in_sync_lambdas() -> None:
"""Test invoking nested runnable lambda."""

Loading…
Cancel
Save