[Core] Check is async callable (#21714)

To permit proper coercion of objects like the following:


```python
class MyAsyncCallable:
    async def __call__(self, foo):
        return await ...

class MyAsyncGenerator:
    async def __call__(self, foo):
        await ...
        yield 
```
pull/18735/merge
William FH 3 weeks ago committed by GitHub
parent 7128c2d8ad
commit ca768c8353
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -76,6 +76,8 @@ from langchain_core.runnables.utils import (
get_lambda_source,
get_unique_config_specs,
indent_lines_after_first,
is_async_callable,
is_async_generator,
)
from langchain_core.utils.aiter import atee, py_anext
from langchain_core.utils.iter import safetee
@ -3300,7 +3302,7 @@ class RunnableGenerator(Runnable[Input, Output]):
self._atransform = atransform
func_for_name: Callable = atransform
if inspect.isasyncgenfunction(transform):
if is_async_generator(transform):
self._atransform = transform # type: ignore[assignment]
func_for_name = transform
elif inspect.isgeneratorfunction(transform):
@ -3513,7 +3515,7 @@ class RunnableLambda(Runnable[Input, Output]):
self.afunc = afunc
func_for_name: Callable = afunc
if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func):
if is_async_callable(func) or is_async_generator(func):
if afunc is not None:
raise TypeError(
"Func was provided as a coroutine function, but afunc was "
@ -3774,7 +3776,7 @@ class RunnableLambda(Runnable[Input, Output]):
afunc = f
if inspect.isasyncgenfunction(afunc):
if is_async_generator(afunc):
output: Optional[Output] = None
async for chunk in cast(
AsyncIterator[Output],
@ -3992,7 +3994,7 @@ class RunnableLambda(Runnable[Input, Output]):
afunc = f
if inspect.isasyncgenfunction(afunc):
if is_async_generator(afunc):
output: Optional[Output] = None
async for chunk in cast(
AsyncIterator[Output],
@ -4034,7 +4036,7 @@ class RunnableLambda(Runnable[Input, Output]):
),
):
yield chunk
elif not inspect.isasyncgenfunction(afunc):
elif not is_async_generator(afunc):
# Otherwise, just yield it
yield cast(Output, output)
@ -4836,7 +4838,7 @@ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
"""
if isinstance(thing, Runnable):
return thing
elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing):
elif is_async_generator(thing) or inspect.isgeneratorfunction(thing):
return RunnableGenerator(thing)
elif callable(thing):
return RunnableLambda(cast(Callable[[Input], Output], thing))

@ -1,4 +1,5 @@
"""Utility code for runnables."""
from __future__ import annotations
import ast
@ -11,6 +12,8 @@ from itertools import groupby
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Dict,
@ -27,6 +30,8 @@ from typing import (
Union,
)
from typing_extensions import TypeGuard
from langchain_core.pydantic_v1 import BaseConfig, BaseModel
from langchain_core.pydantic_v1 import create_model as _create_model_base
from langchain_core.runnables.schema import StreamEvent
@ -533,3 +538,25 @@ def _create_model_cached(
return _create_model_base(
__model_name, __config__=_SchemaConfig, **field_definitions
)
def is_async_generator(
func: Any,
) -> TypeGuard[Callable[..., AsyncIterator]]:
"""Check if a function is an async generator."""
return (
inspect.isasyncgenfunction(func)
or hasattr(func, "__call__")
and inspect.isasyncgenfunction(func.__call__)
)
def is_async_callable(
func: Any,
) -> TypeGuard[Callable[..., Awaitable]]:
"""Check if a function is async."""
return (
asyncio.iscoroutinefunction(func)
or hasattr(func, "__call__")
and asyncio.iscoroutinefunction(func.__call__)
)

@ -4883,6 +4883,23 @@ async def test_runnable_gen() -> None:
assert [p async for p in arunnable.astream(None)] == [1, 2, 3]
assert await arunnable.abatch([None, None]) == [6, 6]
class AsyncGen:
async def __call__(self, input: AsyncIterator[Any]) -> AsyncIterator[int]:
yield 1
yield 2
yield 3
arunnablecallable = RunnableGenerator(AsyncGen())
assert await arunnablecallable.ainvoke(None) == 6
assert [p async for p in arunnablecallable.astream(None)] == [1, 2, 3]
assert await arunnablecallable.abatch([None, None]) == [6, 6]
with pytest.raises(NotImplementedError):
arunnablecallable.invoke(None)
with pytest.raises(NotImplementedError):
arunnablecallable.stream(None)
with pytest.raises(NotImplementedError):
arunnablecallable.batch([None, None])
async def test_runnable_gen_context_config() -> None:
"""Test that a generator can call other runnables with config

Loading…
Cancel
Save