Remove RunnableEach

pull/19042/head
Nuno Campos 3 months ago
parent 2b7c3c548d
commit 6f59366799

@ -439,18 +439,6 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
"base",
"RunnableSequence",
),
("langchain", "schema", "runnable", "RunnableEach"): (
"langchain_core",
"runnables",
"base",
"RunnableEach",
),
("langchain", "schema", "runnable", "RunnableEachBase"): (
"langchain_core",
"runnables",
"base",
"RunnableEachBase",
),
("langchain", "schema", "runnable", "RunnableConfigurableAlternatives"): (
"langchain_core",
"runnables",
@ -831,18 +819,6 @@ OLD_CORE_NAMESPACES_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
"base",
"RunnableSequence",
),
("langchain_core", "runnables", "base", "RunnableEach"): (
"langchain_core",
"runnables",
"base",
"RunnableEach",
),
("langchain_core", "runnables", "base", "RunnableEachBase"): (
"langchain_core",
"runnables",
"base",
"RunnableEachBase",
),
(
"langchain_core",
"runnables",

@ -1339,7 +1339,24 @@ class Runnable(Generic[Input, Output], ABC):
Return a new Runnable that maps a list of inputs to a list of outputs,
by calling invoke() with each input.
"""
return RunnableEach(bound=self)
return RunnableLambda(
self.batch,
self.abatch,
input_schema=create_model(
self.get_name("Input"),
__root__=(
List[self.get_input_schema()], # type: ignore
None,
),
),
output_schema=create_model(
self.get_name("Input"),
__root__=(
List[self.get_output_schema()], # type: ignore
None,
),
),
)
def with_fallbacks(
self,
@ -3341,6 +3358,8 @@ class RunnableLambda(Runnable[Input, Output]):
]
] = None,
name: Optional[str] = None,
input_schema: Optional[Type[BaseModel]] = None,
output_schema: Optional[Type[BaseModel]] = None,
) -> None:
"""Create a RunnableLambda from a callable, and async callable or both.
@ -3381,6 +3400,9 @@ class RunnableLambda(Runnable[Input, Output]):
except AttributeError:
pass
self._input_schema = input_schema
self._output_schema = output_schema
@property
def InputType(self) -> Any:
"""The type of the input to this runnable."""
@ -3399,6 +3421,9 @@ class RunnableLambda(Runnable[Input, Output]):
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
"""The pydantic schema for the input to this runnable."""
if self._input_schema:
return self._input_schema
func = getattr(self, "func", None) or getattr(self, "afunc")
if isinstance(func, itemgetter):
@ -3449,6 +3474,13 @@ class RunnableLambda(Runnable[Input, Output]):
except ValueError:
return Any
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
if self._output_schema:
return self._output_schema
return super().get_output_schema(config)
@property
def deps(self) -> List[Runnable]:
"""The dependencies of this runnable."""
@ -3891,184 +3923,6 @@ class RunnableLambda(Runnable[Input, Output]):
yield chunk
class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
"""Runnable that delegates calls to another Runnable
with each element of the input sequence.
Use only if creating a new RunnableEach subclass with different __init__ args.
See documentation for RunnableEach for more details.
"""
bound: Runnable[Input, Output]
class Config:
arbitrary_types_allowed = True
@property
def InputType(self) -> Any:
return List[self.bound.InputType] # type: ignore[name-defined]
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return create_model(
self.get_name("Input"),
__root__=(
List[self.bound.get_input_schema(config)], # type: ignore
None,
),
)
@property
def OutputType(self) -> Type[List[Output]]:
return List[self.bound.OutputType] # type: ignore[name-defined]
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
schema = self.bound.get_output_schema(config)
return create_model(
self.get_name("Output"),
__root__=(
List[schema], # type: ignore
None,
),
)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return self.bound.config_specs
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
return self.bound.get_graph(config)
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def _invoke(
self,
inputs: List[Input],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> List[Output]:
return self.bound.batch(
inputs, patch_config(config, callbacks=run_manager.get_child()), **kwargs
)
def invoke(
self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
) -> List[Output]:
return self._call_with_config(self._invoke, input, config, **kwargs)
async def _ainvoke(
self,
inputs: List[Input],
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> List[Output]:
return await self.bound.abatch(
inputs, patch_config(config, callbacks=run_manager.get_child()), **kwargs
)
async def ainvoke(
self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
) -> List[Output]:
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
async def astream_events(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[StreamEvent]:
for _ in range(1):
raise NotImplementedError(
"RunnableEach does not support astream_events yet."
)
yield
class RunnableEach(RunnableEachBase[Input, Output]):
"""Runnable that delegates calls to another Runnable
with each element of the input sequence.
It allows you to call multiple inputs with the bounded Runnable.
RunnableEach makes it easy to run multiple inputs for the runnable.
In the below example, we associate and run three inputs
with a Runnable:
.. code-block:: python
from langchain_core.runnables.base import RunnableEach
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
prompt = ChatPromptTemplate.from_template("Tell me a short joke about
{topic}")
model = ChatOpenAI()
output_parser = StrOutputParser()
runnable = prompt | model | output_parser
runnable_each = RunnableEach(bound=runnable)
output = runnable_each.invoke([{'topic':'Computer Science'},
{'topic':'Art'},
{'topic':'Biology'}])
print(output) # noqa: T201
"""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
) -> str:
name = name or self.name or f"RunnableEach<{self.bound.get_name()}>"
return super().get_name(suffix, name=name)
def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]:
return RunnableEach(bound=self.bound.bind(**kwargs))
def with_config(
self, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> RunnableEach[Input, Output]:
return RunnableEach(bound=self.bound.with_config(config, **kwargs))
def with_listeners(
self,
*,
on_start: Optional[Listener] = None,
on_end: Optional[Listener] = None,
on_error: Optional[Listener] = None,
) -> RunnableEach[Input, Output]:
"""
Bind lifecycle listeners to a Runnable, returning a new Runnable.
on_start: Called before the runnable starts running, with the Run object.
on_end: Called after the runnable finishes running, with the Run object.
on_error: Called if the runnable throws an error, with the Run object.
The Run object contains information about the run, including its id,
type, input, output, error, start_time, end_time, and any tags or metadata
added to the run.
"""
return RunnableEach(
bound=self.bound.with_listeners(
on_start=on_start, on_end=on_end, on_error=on_error
)
)
class RunnableBindingBase(RunnableSerializable[Input, Output]):
"""Runnable that delegates calls to another Runnable with a set of kwargs.

@ -572,7 +572,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
},
"items": {"$ref": "#/definitions/PromptInput"},
"type": "array",
"title": "RunnableEach<PromptTemplate>Input",
"title": "PromptTemplateInput",
}
assert prompt_mapper.output_schema.schema() == snapshot
@ -3856,9 +3856,11 @@ def test_each(snapshot: SnapshotAssertion) -> None:
chain = prompt | first_llm | parser | second_llm.map()
tracer = FakeTracer()
assert dumps(chain, pretty=True) == snapshot
output = chain.invoke({"question": "What up"})
output = chain.invoke({"question": "What up"}, {"callbacks": [tracer]})
assert output == ["this", "is", "a"]
assert len(tracer.runs[0].child_runs) == 4
assert (parser | second_llm.map()).invoke("first item, second item") == [
"test",
@ -3866,6 +3868,30 @@ def test_each(snapshot: SnapshotAssertion) -> None:
]
def test_pipe_to_batch(snapshot: SnapshotAssertion) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
first_llm = FakeStreamingListLLM(responses=["first item, second item, third item"])
parser = FakeSplitIntoListParser()
second_llm = FakeStreamingListLLM(responses=["this", "is", "a", "test"])
chain = prompt | first_llm | parser | second_llm.batch
assert dumps(chain, pretty=True) == snapshot
tracer = FakeTracer()
output = chain.invoke({"question": "What up"}, {"callbacks": [tracer]})
assert output == ["this", "is", "a"]
assert len(tracer.runs[0].child_runs) == 4
assert (parser | second_llm.batch).invoke("first item, second item") == [
"test",
"this",
]
def test_recursive_lambda() -> None:
def _simple_recursion(x: int) -> Union[int, Runnable]:
if x < 10:

@ -2,8 +2,6 @@
from itertools import cycle
from typing import Any, AsyncIterator, Dict, List, Sequence, cast
import pytest
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.documents import Document
@ -1073,20 +1071,6 @@ async def test_with_llm() -> None:
]
async def test_runnable_each() -> None:
"""Test runnable each astream_events."""
async def add_one(x: int) -> int:
return x + 1
add_one_map = RunnableLambda(add_one).map() # type: ignore
assert await add_one_map.ainvoke([1, 2, 3]) == [2, 3, 4]
with pytest.raises(NotImplementedError):
async for _ in add_one_map.astream_events([1, 2, 3], version="v1"):
pass
async def test_events_astream_config() -> None:
"""Test that astream events support accepting config"""
infinite_cycle = cycle([AIMessage(content="hello world!")])

@ -3,8 +3,6 @@ from langchain_core.runnables.base import (
Runnable,
RunnableBinding,
RunnableBindingBase,
RunnableEach,
RunnableEachBase,
RunnableGenerator,
RunnableLambda,
RunnableLike,
@ -29,8 +27,6 @@ __all__ = [
"RunnableParallel",
"RunnableGenerator",
"RunnableLambda",
"RunnableEachBase",
"RunnableEach",
"RunnableBindingBase",
"RunnableBinding",
"RunnableMap",

@ -4,8 +4,6 @@ EXPECTED_ALL = [
"Runnable",
"RunnableBinding",
"RunnableBindingBase",
"RunnableEach",
"RunnableEachBase",
"RunnableGenerator",
"RunnableLambda",
"RunnableMap",

Loading…
Cancel
Save