langchain[minor], community[minor], core[minor]: Async Cache support and AsyncRedisCache (#15817)

* This PR adds async methods to the LLM cache. 
* Adds an implementation using Redis called AsyncRedisCache.
* Adds a docker compose file at the /docker to help spin up docker
* Updates redis tests to use a context manager so flushing always happens by default
pull/17220/head
Dmitry Kankalovich 4 months ago committed by GitHub
parent 19546081c6
commit f92738a6f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,17 @@
# docker-compose to make it easier to spin up integration tests.
# Services should use NON standard ports to avoid collision with
version: "3"
name: langchain-tests
services:
redis:
image: redis/redis-stack-server:latest
# We use non standard ports since
# these instances are used for testing
# and users may already have existing
# redis instances set up locally
# for other projects
ports:
- "6020:6379"
volumes:
- ./redis-volume:/data

@ -27,6 +27,7 @@ import json
import logging
import uuid
import warnings
from abc import ABC
from datetime import timedelta
from functools import lru_cache
from typing import (
@ -351,8 +352,60 @@ class UpstashRedisCache(BaseCache):
self.redis.flushdb(flush_type=asynchronous)
class RedisCache(BaseCache):
"""Cache that uses Redis as a backend."""
class _RedisCacheBase(BaseCache, ABC):
@staticmethod
def _key(prompt: str, llm_string: str) -> str:
"""Compute key from prompt and llm_string"""
return _hash(prompt + llm_string)
@staticmethod
def _ensure_generation_type(return_val: RETURN_VAL_TYPE) -> None:
for gen in return_val:
if not isinstance(gen, Generation):
raise ValueError(
"RedisCache only supports caching of normal LLM generations, "
f"got {type(gen)}"
)
@staticmethod
def _get_generations(
results: dict[str | bytes, str | bytes],
) -> Optional[List[Generation]]:
generations = []
if results:
for _, text in results.items():
try:
generations.append(loads(text))
except Exception:
logger.warning(
"Retrieving a cache value that could not be deserialized "
"properly. This is likely due to the cache being in an "
"older format. Please recreate your cache to avoid this "
"error."
)
# In a previous life we stored the raw text directly
# in the table, so assume it's in that format.
generations.append(Generation(text=text))
return generations if generations else None
@staticmethod
def _configure_pipeline_for_update(
key: str, pipe: Any, return_val: RETURN_VAL_TYPE, ttl: Optional[int] = None
) -> None:
pipe.hset(
key,
mapping={
str(idx): dumps(generation) for idx, generation in enumerate(return_val)
},
)
if ttl is not None:
pipe.expire(key, ttl)
class RedisCache(_RedisCacheBase):
"""
Cache that uses Redis as a backend. Allows to use a sync `redis.Redis` client.
"""
def __init__(self, redis_: Any, *, ttl: Optional[int] = None):
"""
@ -360,12 +413,12 @@ class RedisCache(BaseCache):
This method initializes an object with Redis caching capabilities.
It takes a `redis_` parameter, which should be an instance of a Redis
client class, allowing the object to interact with a Redis
server for caching purposes.
client class (`redis.Redis`), allowing the object
to interact with a Redis server for caching purposes.
Parameters:
redis_ (Any): An instance of a Redis client class
(e.g., redis.Redis) used for caching.
(`redis.Redis`) to be used for caching.
This allows the object to communicate with a
Redis server for caching operations.
ttl (int, optional): Time-to-live (TTL) for cached items in seconds.
@ -377,61 +430,27 @@ class RedisCache(BaseCache):
from redis import Redis
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Could not import `redis` python package. "
"Please install it with `pip install redis`."
)
if not isinstance(redis_, Redis):
raise ValueError("Please pass in Redis object.")
raise ValueError("Please pass a valid `redis.Redis` client.")
self.redis = redis_
self.ttl = ttl
def _key(self, prompt: str, llm_string: str) -> str:
"""Compute key from prompt and llm_string"""
return _hash(prompt + llm_string)
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
generations = []
# Read from a Redis HASH
results = self.redis.hgetall(self._key(prompt, llm_string))
if results:
for _, text in results.items():
try:
generations.append(loads(text))
except Exception:
logger.warning(
"Retrieving a cache value that could not be deserialized "
"properly. This is likely due to the cache being in an "
"older format. Please recreate your cache to avoid this "
"error."
)
# In a previous life we stored the raw text directly
# in the table, so assume it's in that format.
generations.append(Generation(text=text))
return generations if generations else None
return self._get_generations(results) # type: ignore[arg-type]
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
for gen in return_val:
if not isinstance(gen, Generation):
raise ValueError(
"RedisCache only supports caching of normal LLM generations, "
f"got {type(gen)}"
)
# Write to a Redis HASH
self._ensure_generation_type(return_val)
key = self._key(prompt, llm_string)
with self.redis.pipeline() as pipe:
pipe.hset(
key,
mapping={
str(idx): dumps(generation)
for idx, generation in enumerate(return_val)
},
)
if self.ttl is not None:
pipe.expire(key, self.ttl)
self._configure_pipeline_for_update(key, pipe, return_val, self.ttl)
pipe.execute()
def clear(self, **kwargs: Any) -> None:
@ -440,6 +459,89 @@ class RedisCache(BaseCache):
self.redis.flushdb(asynchronous=asynchronous, **kwargs)
class AsyncRedisCache(_RedisCacheBase):
"""
Cache that uses Redis as a backend. Allows to use an
async `redis.asyncio.Redis` client.
"""
def __init__(self, redis_: Any, *, ttl: Optional[int] = None):
"""
Initialize an instance of AsyncRedisCache.
This method initializes an object with Redis caching capabilities.
It takes a `redis_` parameter, which should be an instance of a Redis
client class (`redis.asyncio.Redis`), allowing the object
to interact with a Redis server for caching purposes.
Parameters:
redis_ (Any): An instance of a Redis client class
(`redis.asyncio.Redis`) to be used for caching.
This allows the object to communicate with a
Redis server for caching operations.
ttl (int, optional): Time-to-live (TTL) for cached items in seconds.
If provided, it sets the time duration for how long cached
items will remain valid. If not provided, cached items will not
have an automatic expiration.
"""
try:
from redis.asyncio import Redis
except ImportError:
raise ValueError(
"Could not import `redis.asyncio` python package. "
"Please install it with `pip install redis`."
)
if not isinstance(redis_, Redis):
raise ValueError("Please pass a valid `redis.asyncio.Redis` client.")
self.redis = redis_
self.ttl = ttl
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
raise NotImplementedError(
"This async Redis cache does not implement `lookup()` method. "
"Consider using the async `alookup()` version."
)
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string. Async version."""
results = await self.redis.hgetall(self._key(prompt, llm_string))
return self._get_generations(results) # type: ignore[arg-type]
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
raise NotImplementedError(
"This async Redis cache does not implement `update()` method. "
"Consider using the async `aupdate()` version."
)
async def aupdate(
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None:
"""Update cache based on prompt and llm_string. Async version."""
self._ensure_generation_type(return_val)
key = self._key(prompt, llm_string)
async with self.redis.pipeline() as pipe:
self._configure_pipeline_for_update(key, pipe, return_val, self.ttl)
await pipe.execute() # type: ignore[attr-defined]
def clear(self, **kwargs: Any) -> None:
"""Clear cache. If `asynchronous` is True, flush asynchronously."""
raise NotImplementedError(
"This async Redis cache does not implement `clear()` method. "
"Consider using the async `aclear()` version."
)
async def aclear(self, **kwargs: Any) -> None:
"""
Clear cache. If `asynchronous` is True, flush asynchronously.
Async version.
"""
asynchronous = kwargs.get("asynchronous", False)
await self.redis.flushdb(asynchronous=asynchronous, **kwargs)
class RedisSemanticCache(BaseCache):
"""Cache that uses Redis as a vector-store backend."""

@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
from typing import Any, Optional, Sequence
from langchain_core.outputs import Generation
from langchain_core.runnables import run_in_executor
RETURN_VAL_TYPE = Sequence[Generation]
@ -22,3 +23,17 @@ class BaseCache(ABC):
@abstractmethod
def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments."""
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
return await run_in_executor(None, self.lookup, prompt, llm_string)
async def aupdate(
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None:
"""Update cache based on prompt and llm_string."""
return await run_in_executor(None, self.update, prompt, llm_string, return_val)
async def aclear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments."""
return await run_in_executor(None, self.clear, **kwargs)

@ -622,7 +622,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
else:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = llm_cache.lookup(prompt, llm_string)
cache_val = await llm_cache.alookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
else:
@ -632,7 +632,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
)
else:
result = await self._agenerate(messages, stop=stop, **kwargs)
llm_cache.update(prompt, llm_string, result.generations)
await llm_cache.aupdate(prompt, llm_string, result.generations)
return result
@abstractmethod

@ -139,6 +139,26 @@ def get_prompts(
return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts
async def aget_prompts(
params: Dict[str, Any], prompts: List[str]
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
"""Get prompts that are already cached. Async version."""
llm_string = str(sorted([(k, v) for k, v in params.items()]))
missing_prompts = []
missing_prompt_idxs = []
existing_prompts = {}
llm_cache = get_llm_cache()
for i, prompt in enumerate(prompts):
if llm_cache:
cache_val = await llm_cache.alookup(prompt, llm_string)
if isinstance(cache_val, list):
existing_prompts[i] = cache_val
else:
missing_prompts.append(prompt)
missing_prompt_idxs.append(i)
return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts
def update_cache(
existing_prompts: Dict[int, List],
llm_string: str,
@ -157,6 +177,24 @@ def update_cache(
return llm_output
async def aupdate_cache(
existing_prompts: Dict[int, List],
llm_string: str,
missing_prompt_idxs: List[int],
new_results: LLMResult,
prompts: List[str],
) -> Optional[dict]:
"""Update the cache and get the LLM output. Async version"""
llm_cache = get_llm_cache()
for i, result in enumerate(new_results.generations):
existing_prompts[missing_prompt_idxs[i]] = result
prompt = prompts[missing_prompt_idxs[i]]
if llm_cache:
await llm_cache.aupdate(prompt, llm_string, result)
llm_output = new_results.llm_output
return llm_output
class BaseLLM(BaseLanguageModel[str], ABC):
"""Base LLM abstract interface.
@ -869,7 +907,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
llm_string,
missing_prompt_idxs,
missing_prompts,
) = get_prompts(params, prompts)
) = await aget_prompts(params, prompts)
disregard_cache = self.cache is not None and not self.cache
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
@ -917,7 +955,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
new_results = await self._agenerate_helper(
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)
llm_output = update_cache(
llm_output = await aupdate_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
run_info = (

@ -1,6 +1,7 @@
from langchain_community.cache import (
AstraDBCache,
AstraDBSemanticCache,
AsyncRedisCache,
CassandraCache,
CassandraSemanticCache,
FullLLMCache,
@ -22,6 +23,7 @@ __all__ = [
"SQLAlchemyCache",
"SQLiteCache",
"UpstashRedisCache",
"AsyncRedisCache",
"RedisCache",
"RedisSemanticCache",
"GPTCache",

@ -1,6 +1,7 @@
"""Test Redis cache functionality."""
import uuid
from typing import List, cast
from contextlib import asynccontextmanager, contextmanager
from typing import AsyncGenerator, Generator, List, Optional, cast
import pytest
from langchain_core.embeddings import Embeddings
@ -8,7 +9,7 @@ from langchain_core.load.dump import dumps
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, Generation, LLMResult
from langchain.cache import RedisCache, RedisSemanticCache
from langchain.cache import AsyncRedisCache, RedisCache, RedisSemanticCache
from langchain.globals import get_llm_cache, set_llm_cache
from tests.integration_tests.cache.fake_embeddings import (
ConsistentFakeEmbeddings,
@ -17,65 +18,176 @@ from tests.integration_tests.cache.fake_embeddings import (
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM
REDIS_TEST_URL = "redis://localhost:6379"
# Using a non-standard port to avoid conflicts with potentially local running
# redis instances
# You can spin up a local redis using docker compose
# cd [repository-root]/docker
# docker-compose up redis
REDIS_TEST_URL = "redis://localhost:6020"
def random_string() -> str:
return str(uuid.uuid4())
def test_redis_cache_ttl() -> None:
@contextmanager
def get_sync_redis(*, ttl: Optional[int] = 1) -> Generator[RedisCache, None, None]:
"""Get a sync RedisCache instance."""
import redis
set_llm_cache(RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL), ttl=1))
llm_cache = cast(RedisCache, get_llm_cache())
llm_cache.update("foo", "bar", [Generation(text="fizz")])
key = llm_cache._key("foo", "bar")
assert llm_cache.redis.pttl(key) > 0
cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL), ttl=ttl)
try:
yield cache
finally:
cache.clear()
def test_redis_cache() -> None:
import redis
@asynccontextmanager
async def get_async_redis(
*, ttl: Optional[int] = 1
) -> AsyncGenerator[AsyncRedisCache, None]:
"""Get an async RedisCache instance."""
from redis.asyncio import Redis
set_llm_cache(RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL)))
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo"])
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
llm_cache = cast(RedisCache, get_llm_cache())
llm_cache.redis.flushall()
cache = AsyncRedisCache(redis_=Redis.from_url(REDIS_TEST_URL), ttl=ttl)
try:
yield cache
finally:
await cache.aclear()
def test_redis_cache_ttl() -> None:
from redis import Redis
with get_sync_redis() as llm_cache:
set_llm_cache(llm_cache)
llm_cache.update("foo", "bar", [Generation(text="fizz")])
key = llm_cache._key("foo", "bar")
assert isinstance(llm_cache.redis, Redis)
assert llm_cache.redis.pttl(key) > 0
async def test_async_redis_cache_ttl() -> None:
from redis.asyncio import Redis as AsyncRedis
async with get_async_redis() as redis_cache:
set_llm_cache(redis_cache)
llm_cache = cast(RedisCache, get_llm_cache())
await llm_cache.aupdate("foo", "bar", [Generation(text="fizz")])
key = llm_cache._key("foo", "bar")
assert isinstance(llm_cache.redis, AsyncRedis)
assert await llm_cache.redis.pttl(key) > 0
def test_sync_redis_cache() -> None:
with get_sync_redis() as llm_cache:
set_llm_cache(llm_cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_cache.update("prompt", llm_string, [Generation(text="fizz0")])
output = llm.generate(["prompt"])
expected_output = LLMResult(
generations=[[Generation(text="fizz0")]],
llm_output={},
)
assert output == expected_output
async def test_sync_in_async_redis_cache() -> None:
"""Test the sync RedisCache invoked with async methods"""
with get_sync_redis() as llm_cache:
set_llm_cache(llm_cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
# llm_cache.update("meow", llm_string, [Generation(text="meow")])
await llm_cache.aupdate("prompt", llm_string, [Generation(text="fizz1")])
output = await llm.agenerate(["prompt"])
expected_output = LLMResult(
generations=[[Generation(text="fizz1")]],
llm_output={},
)
assert output == expected_output
async def test_async_redis_cache() -> None:
async with get_async_redis() as redis_cache:
set_llm_cache(redis_cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_cache = cast(RedisCache, get_llm_cache())
await llm_cache.aupdate("prompt", llm_string, [Generation(text="fizz2")])
output = await llm.agenerate(["prompt"])
expected_output = LLMResult(
generations=[[Generation(text="fizz2")]],
llm_output={},
)
assert output == expected_output
async def test_async_in_sync_redis_cache() -> None:
async with get_async_redis() as redis_cache:
set_llm_cache(redis_cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_cache = cast(RedisCache, get_llm_cache())
with pytest.raises(NotImplementedError):
llm_cache.update("foo", llm_string, [Generation(text="fizz")])
def test_redis_cache_chat() -> None:
import redis
with get_sync_redis() as redis_cache:
set_llm_cache(redis_cache)
llm = FakeChatModel()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
llm_cache = cast(RedisCache, get_llm_cache())
llm_cache.update(
dumps(prompt),
llm_string,
[ChatGeneration(message=AIMessage(content="fizz"))],
)
output = llm.generate([prompt])
expected_output = LLMResult(
generations=[[ChatGeneration(message=AIMessage(content="fizz"))]],
llm_output={},
)
assert output == expected_output
set_llm_cache(RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL)))
llm = FakeChatModel()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
get_llm_cache().update(
dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))]
)
output = llm.generate([prompt])
expected_output = LLMResult(
generations=[[ChatGeneration(message=AIMessage(content="fizz"))]],
llm_output={},
)
assert output == expected_output
llm_cache = cast(RedisCache, get_llm_cache())
llm_cache.redis.flushall()
async def test_async_redis_cache_chat() -> None:
async with get_async_redis() as redis_cache:
set_llm_cache(redis_cache)
llm = FakeChatModel()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
llm_cache = cast(RedisCache, get_llm_cache())
await llm_cache.aupdate(
dumps(prompt),
llm_string,
[ChatGeneration(message=AIMessage(content="fizz"))],
)
output = await llm.agenerate([prompt])
expected_output = LLMResult(
generations=[[ChatGeneration(message=AIMessage(content="fizz"))]],
llm_output={},
)
assert output == expected_output
def test_redis_semantic_cache() -> None:
"""Test redis semantic cache functionality."""
set_llm_cache(
RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
@ -85,7 +197,8 @@ def test_redis_semantic_cache() -> None:
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
llm_cache = cast(RedisSemanticCache, get_llm_cache())
llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(
["bar"]
) # foo and bar will have the same embedding produced by FakeEmbeddings
@ -95,13 +208,13 @@ def test_redis_semantic_cache() -> None:
)
assert output == expected_output
# clear the cache
get_llm_cache().clear(llm_string=llm_string)
llm_cache.clear(llm_string=llm_string)
output = llm.generate(
["bar"]
) # foo and bar will have the same embedding produced by FakeEmbeddings
# expect different output now without cached result
assert output != expected_output
get_llm_cache().clear(llm_string=llm_string)
llm_cache.clear(llm_string=llm_string)
def test_redis_semantic_cache_multi() -> None:
@ -114,7 +227,8 @@ def test_redis_semantic_cache_multi() -> None:
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update(
llm_cache = cast(RedisSemanticCache, get_llm_cache())
llm_cache.update(
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
)
output = llm.generate(
@ -126,7 +240,7 @@ def test_redis_semantic_cache_multi() -> None:
)
assert output == expected_output
# clear the cache
get_llm_cache().clear(llm_string=llm_string)
llm_cache.clear(llm_string=llm_string)
def test_redis_semantic_cache_chat() -> None:
@ -140,7 +254,8 @@ def test_redis_semantic_cache_chat() -> None:
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
get_llm_cache().update(
llm_cache = cast(RedisSemanticCache, get_llm_cache())
llm_cache.update(
dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))]
)
output = llm.generate([prompt])
@ -149,7 +264,7 @@ def test_redis_semantic_cache_chat() -> None:
llm_output={},
)
assert output == expected_output
get_llm_cache().clear(llm_string=llm_string)
llm_cache.clear(llm_string=llm_string)
@pytest.mark.parametrize("embedding", [ConsistentFakeEmbeddings()])
@ -192,10 +307,11 @@ def test_redis_semantic_cache_hit(
]
for prompt_i_generations in generations
]
llm_cache = cast(RedisSemanticCache, get_llm_cache())
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
print(prompt_i)
print(llm_generations_i)
get_llm_cache().update(prompt_i, llm_string, llm_generations_i)
llm_cache.update(prompt_i, llm_string, llm_generations_i)
llm.generate(prompts)
assert llm.generate(prompts) == LLMResult(
generations=llm_generations, llm_output={}

@ -1,4 +1,5 @@
"""Test caching for LLMs and ChatModels."""
import sqlite3
from typing import Dict, Generator, List, Union
import pytest
@ -21,7 +22,11 @@ from langchain.globals import get_llm_cache, set_llm_cache
def get_sqlite_cache() -> SQLAlchemyCache:
return SQLAlchemyCache(engine=create_engine("sqlite://"))
return SQLAlchemyCache(
engine=create_engine(
"sqlite://", creator=lambda: sqlite3.connect("file::memory:?cache=shared")
)
)
CACHE_OPTIONS = [
@ -35,33 +40,41 @@ def set_cache_and_teardown(request: FixtureRequest) -> Generator[None, None, Non
# Will be run before each test
cache_instance = request.param
set_llm_cache(cache_instance())
if get_llm_cache():
get_llm_cache().clear()
if llm_cache := get_llm_cache():
llm_cache.clear()
else:
raise ValueError("Cache not set. This should never happen.")
yield
# Will be run after each test
if get_llm_cache():
get_llm_cache().clear()
if llm_cache:
llm_cache.clear()
set_llm_cache(None)
else:
raise ValueError("Cache not set. This should never happen.")
def test_llm_caching() -> None:
async def test_llm_caching() -> None:
prompt = "How are you?"
response = "Test response"
cached_response = "Cached test response"
llm = FakeListLLM(responses=[response])
if get_llm_cache():
get_llm_cache().update(
if llm_cache := get_llm_cache():
# sync test
llm_cache.update(
prompt=prompt,
llm_string=create_llm_string(llm),
return_val=[Generation(text=cached_response)],
)
assert llm(prompt) == cached_response
# async test
await llm_cache.aupdate(
prompt=prompt,
llm_string=create_llm_string(llm),
return_val=[Generation(text=cached_response)],
)
assert await llm.ainvoke(prompt) == cached_response
else:
raise ValueError(
"The cache not set. This should never happen, as the pytest fixture "
@ -90,14 +103,15 @@ def test_old_sqlite_llm_caching() -> None:
assert llm(prompt) == cached_response
def test_chat_model_caching() -> None:
async def test_chat_model_caching() -> None:
prompt: List[BaseMessage] = [HumanMessage(content="How are you?")]
response = "Test response"
cached_response = "Cached test response"
cached_message = AIMessage(content=cached_response)
llm = FakeListChatModel(responses=[response])
if get_llm_cache():
get_llm_cache().update(
if llm_cache := get_llm_cache():
# sync test
llm_cache.update(
prompt=dumps(prompt),
llm_string=llm._get_llm_string(),
return_val=[ChatGeneration(message=cached_message)],
@ -105,6 +119,16 @@ def test_chat_model_caching() -> None:
result = llm(prompt)
assert isinstance(result, AIMessage)
assert result.content == cached_response
# async test
await llm_cache.aupdate(
prompt=dumps(prompt),
llm_string=llm._get_llm_string(),
return_val=[ChatGeneration(message=cached_message)],
)
result = await llm.ainvoke(prompt)
assert isinstance(result, AIMessage)
assert result.content == cached_response
else:
raise ValueError(
"The cache not set. This should never happen, as the pytest fixture "
@ -112,25 +136,38 @@ def test_chat_model_caching() -> None:
)
def test_chat_model_caching_params() -> None:
async def test_chat_model_caching_params() -> None:
prompt: List[BaseMessage] = [HumanMessage(content="How are you?")]
response = "Test response"
cached_response = "Cached test response"
cached_message = AIMessage(content=cached_response)
llm = FakeListChatModel(responses=[response])
if get_llm_cache():
get_llm_cache().update(
if llm_cache := get_llm_cache():
# sync test
llm_cache.update(
prompt=dumps(prompt),
llm_string=llm._get_llm_string(functions=[]),
return_val=[ChatGeneration(message=cached_message)],
)
result = llm(prompt, functions=[])
result_no_params = llm(prompt)
assert isinstance(result, AIMessage)
assert result.content == cached_response
result_no_params = llm(prompt)
assert isinstance(result_no_params, AIMessage)
assert result_no_params.content == response
# async test
await llm_cache.aupdate(
prompt=dumps(prompt),
llm_string=llm._get_llm_string(functions=[]),
return_val=[ChatGeneration(message=cached_message)],
)
result = await llm.ainvoke(prompt, functions=[])
result_no_params = await llm.ainvoke(prompt)
assert isinstance(result, AIMessage)
assert result.content == cached_response
assert isinstance(result_no_params, AIMessage)
assert result_no_params.content == response
else:
raise ValueError(
"The cache not set. This should never happen, as the pytest fixture "
@ -138,19 +175,31 @@ def test_chat_model_caching_params() -> None:
)
def test_llm_cache_clear() -> None:
async def test_llm_cache_clear() -> None:
prompt = "How are you?"
response = "Test response"
expected_response = "Test response"
cached_response = "Cached test response"
llm = FakeListLLM(responses=[response])
if get_llm_cache():
get_llm_cache().update(
llm = FakeListLLM(responses=[expected_response])
if llm_cache := get_llm_cache():
# sync test
llm_cache.update(
prompt=prompt,
llm_string=create_llm_string(llm),
return_val=[Generation(text=cached_response)],
)
llm_cache.clear()
response = llm(prompt)
assert response == expected_response
# async test
await llm_cache.aupdate(
prompt=prompt,
llm_string=create_llm_string(llm),
return_val=[Generation(text=cached_response)],
)
get_llm_cache().clear()
assert llm(prompt) == response
await llm_cache.aclear()
response = await llm.ainvoke(prompt)
assert response == expected_response
else:
raise ValueError(
"The cache not set. This should never happen, as the pytest fixture "

Loading…
Cancel
Save