mirror of https://github.com/hwchase17/langchain
Merge 3451595066
into 242eeb537f
commit
1553ed8c38
@ -0,0 +1,245 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Upstash Ratelimit Callback\n",
|
||||
"\n",
|
||||
"In this guide, we will go over how to add rate limiting based on number of requests or the number of tokens using `UpstashRatelimitHandler`. This handler uses [ratelimit library of Upstash](https://github.com/upstash/ratelimit-py/), which utilizes [Upstash Redis](https://upstash.com/docs/redis/overall/getstarted).\n",
|
||||
"\n",
|
||||
"Upstash Ratelimit works by sending an HTTP request to Upstash Redis everytime the `limit` method is called. Remaining tokens/requests of the user are checked and updated. Based on the remaining tokens, we can stop the execution of costly operations like invoking an LLM or querying a vector store:\n",
|
||||
"\n",
|
||||
"```py\n",
|
||||
"response = ratelimit.limit()\n",
|
||||
"if response.allowed:\n",
|
||||
" execute_costly_operation()\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"`UpstashRatelimitHandler` allows you to incorporate the ratelimit logic into your chain in a few minutes.\n",
|
||||
"\n",
|
||||
"First, you will need to go to [the Upstash Console](https://console.upstash.com/login) and create a redis database ([see our docs](https://upstash.com/docs/redis/overall/getstarted)). After creating a database, you will need to set the environment variables:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"UPSTASH_REDIS_REST_URL=\"****\"\n",
|
||||
"UPSTASH_REDIS_REST_TOKEN=\"****\"\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Next, you will need to install Upstash Ratelimit and Redis library with:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"pip install upstash-ratelimit upstash-redis\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"You are now ready to add rate limiting to your chain!\n",
|
||||
"\n",
|
||||
"## Ratelimiting Per Request\n",
|
||||
"\n",
|
||||
"Let's imagine that we want to allow our users to invoke our chain 10 times per minute. Achieving this is as simple as:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Error in UpstashRatelimitHandler.on_chain_start callback: UpstashRatelimitError('Request limit reached!')\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Handling ratelimit. <class 'langchain_community.callbacks.upstash_ratelimit_callback.UpstashRatelimitError'>\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# set env variables\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"UPSTASH_REDIS_REST_URL\"] = \"****\"\n",
|
||||
"os.environ[\"UPSTASH_REDIS_REST_TOKEN\"] = \"****\"\n",
|
||||
"\n",
|
||||
"from langchain_community.callbacks import UpstashRatelimitError, UpstashRatelimitHandler\n",
|
||||
"from langchain_core.runnables import RunnableLambda\n",
|
||||
"from upstash_ratelimit import FixedWindow, Ratelimit\n",
|
||||
"from upstash_redis import Redis\n",
|
||||
"\n",
|
||||
"# create ratelimit\n",
|
||||
"ratelimit = Ratelimit(\n",
|
||||
" redis=Redis.from_env(),\n",
|
||||
" # 10 requests per window, where window size is 60 seconds:\n",
|
||||
" limiter=FixedWindow(max_requests=10, window=60),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# create handler\n",
|
||||
"user_id = \"user_id\" # should be a method which gets the user id\n",
|
||||
"handler = UpstashRatelimitHandler(identifier=user_id, request_ratelimit=ratelimit)\n",
|
||||
"\n",
|
||||
"# create mock chain\n",
|
||||
"chain = RunnableLambda(str)\n",
|
||||
"\n",
|
||||
"# invoke chain with handler:\n",
|
||||
"try:\n",
|
||||
" result = chain.invoke(\"Hello world!\", config={\"callbacks\": [handler]})\n",
|
||||
"except UpstashRatelimitError:\n",
|
||||
" print(\"Handling ratelimit.\", UpstashRatelimitError)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note that we pass the handler to the `invoke` method instead of passing the handler when defining the chain.\n",
|
||||
"\n",
|
||||
"For rate limiting algorithms other than `FixedWindow`, see [upstash-ratelimit docs](https://github.com/upstash/ratelimit-py?tab=readme-ov-file#ratelimiting-algorithms).\n",
|
||||
"\n",
|
||||
"Before executing any steps in our pipeline, ratelimit will check whether the user has passed the request limit. If so, `UpstashRatelimitError` is raised.\n",
|
||||
"\n",
|
||||
"## Ratelimiting Per Token\n",
|
||||
"\n",
|
||||
"Another option is to rate limit chain invokations based on:\n",
|
||||
"1. number of tokens in prompt\n",
|
||||
"2. number of tokens in prompt and LLM completion\n",
|
||||
"\n",
|
||||
"This only works if you have an LLM in your chain. Another requirement is that the LLM you are using should return the token usage in it's `LLMOutput`.\n",
|
||||
"\n",
|
||||
"### How it works\n",
|
||||
"\n",
|
||||
"The handler will get the remaining tokens before calling the LLM. If the remaining tokens is more than 0, LLM will be called. Otherwise `UpstashRatelimitError` will be raised.\n",
|
||||
"\n",
|
||||
"After LLM is called, token usage information will be used to subtracted from the remaining tokens of the user. No error is raised at this stage of the chain.\n",
|
||||
"\n",
|
||||
"### Configuration\n",
|
||||
"\n",
|
||||
"For the first configuration, simply initialize the handler like this:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ratelimit = Ratelimit(\n",
|
||||
" redis=Redis.from_env(),\n",
|
||||
" # 1000 tokens per window, where window size is 60 seconds:\n",
|
||||
" limiter=FixedWindow(max_requests=1000, window=60),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"handler = UpstashRatelimitHandler(identifier=user_id, token_ratelimit=ratelimit)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"For the second configuration, here is how to initialize the handler:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ratelimit = Ratelimit(\n",
|
||||
" redis=Redis.from_env(),\n",
|
||||
" # 1000 tokens per window, where window size is 60 seconds:\n",
|
||||
" limiter=FixedWindow(max_requests=1000, window=60),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"handler = UpstashRatelimitHandler(\n",
|
||||
" identifier=user_id,\n",
|
||||
" token_ratelimit=ratelimit,\n",
|
||||
" include_output_tokens=True, # set to True\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can also employ ratelimiting based on requests and tokens at the same time, simply by passing both `request_ratelimit` and `token_ratelimit` parameters.\n",
|
||||
"\n",
|
||||
"Here is an example with a chain utilizing an LLM:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Error in UpstashRatelimitHandler.on_llm_start callback: UpstashRatelimitError('Token limit reached!')\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Handling ratelimit. <class 'langchain_community.callbacks.upstash_ratelimit_callback.UpstashRatelimitError'>\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# set env variables\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"UPSTASH_REDIS_REST_URL\"] = \"****\"\n",
|
||||
"os.environ[\"UPSTASH_REDIS_REST_TOKEN\"] = \"****\"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = \"****\"\n",
|
||||
"\n",
|
||||
"from langchain_community.callbacks import UpstashRatelimitError, UpstashRatelimitHandler\n",
|
||||
"from langchain_core.runnables import RunnableLambda\n",
|
||||
"from langchain_openai import ChatOpenAI\n",
|
||||
"from upstash_ratelimit import FixedWindow, Ratelimit\n",
|
||||
"from upstash_redis import Redis\n",
|
||||
"\n",
|
||||
"# create ratelimit\n",
|
||||
"ratelimit = Ratelimit(\n",
|
||||
" redis=Redis.from_env(),\n",
|
||||
" # 500 tokens per window, where window size is 60 seconds:\n",
|
||||
" limiter=FixedWindow(max_requests=500, window=60),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# create handler\n",
|
||||
"user_id = \"user_id\" # should be a method which gets the user id\n",
|
||||
"handler = UpstashRatelimitHandler(identifier=user_id, token_ratelimit=ratelimit)\n",
|
||||
"\n",
|
||||
"# create mock chain\n",
|
||||
"as_str = RunnableLambda(str)\n",
|
||||
"model = ChatOpenAI()\n",
|
||||
"\n",
|
||||
"chain = as_str | model\n",
|
||||
"\n",
|
||||
"# invoke chain with handler:\n",
|
||||
"try:\n",
|
||||
" result = chain.invoke(\"Hello world!\", config={\"callbacks\": [handler]})\n",
|
||||
"except UpstashRatelimitError:\n",
|
||||
" print(\"Handling ratelimit.\", UpstashRatelimitError)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "lc39",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.9.19"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,199 @@
|
||||
"""Ratelimiting Handler to limit requests or tokens"""
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
from upstash_ratelimit import Ratelimit
|
||||
|
||||
|
||||
class UpstashRatelimitError(Exception):
|
||||
"""
|
||||
Upstash Ratelimit Error
|
||||
|
||||
Raised when the rate limit is reached in `UpstashRatelimitHandler`
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
type: Literal["token", "request"],
|
||||
limit: Optional[int] = None,
|
||||
reset: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
message (str): error message
|
||||
type (str): The kind of the limit which was reached. One of
|
||||
"token" or "request"
|
||||
limit (Optional[int]): The limit which was reached. Passed when type
|
||||
is request
|
||||
reset (Optional[int]): unix timestamp in milliseconds when the limits
|
||||
are reset. Passed when type is request
|
||||
"""
|
||||
# Call the base class constructor with the parameters it needs
|
||||
super().__init__(message)
|
||||
self.type = type
|
||||
self.limit = limit
|
||||
self.reset = reset
|
||||
|
||||
|
||||
class UpstashRatelimitHandler(BaseCallbackHandler):
|
||||
"""
|
||||
Callback to handle rate limiting based on the number of requests
|
||||
or the number of tokens in the input.
|
||||
|
||||
It uses Upstash Ratelimit to track the ratelimit which utilizes
|
||||
Upstash Redis to track the state.
|
||||
|
||||
Should not be passed to the chain when initialising the chain.
|
||||
This is because the handler has a state which should be fresh
|
||||
every time invoke is called. Instead, initialise and pass a handler
|
||||
every time you invoke.
|
||||
"""
|
||||
|
||||
raise_error = True
|
||||
_checked = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
identifier: str,
|
||||
token_ratelimit: Optional[Ratelimit] = None,
|
||||
request_ratelimit: Optional[Ratelimit] = None,
|
||||
include_output_tokens: bool = False,
|
||||
):
|
||||
"""
|
||||
Creates UpstashRatelimitHandler. Must be passed an identifier to
|
||||
ratelimit like a user id or an ip address.
|
||||
|
||||
Additionally, it must be passed at least one of token_ratelimit
|
||||
or request_ratelimit parameters.
|
||||
|
||||
Args:
|
||||
identifier Union[int, str]: the identifier
|
||||
token_ratelimit Optional[Ratelimit]: Ratelimit to limit the
|
||||
number of tokens. Only works with OpenAI models since only
|
||||
these models provide the number of tokens as information
|
||||
in their output.
|
||||
request_ratelimit Optional[Ratelimit]: Ratelimit to limit the
|
||||
number of requests
|
||||
include_output_tokens bool: Whether to count output tokens when
|
||||
rate limiting based on number of tokens. Only used when
|
||||
`token_ratelimit` is passed. False by default.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from upstash_redis import Redis
|
||||
from upstash_ratelimit import Ratelimit, FixedWindow
|
||||
|
||||
redis = Redis.from_env()
|
||||
ratelimit = Ratelimit(
|
||||
redis=redis,
|
||||
# fixed window to allow 10 requests every 10 seconds:
|
||||
limiter=FixedWindow(max_requests=10, window=10),
|
||||
)
|
||||
|
||||
user_id = "foo"
|
||||
handler = UpstashRatelimitHandler(
|
||||
identifier=user_id,
|
||||
request_ratelimit=ratelimit
|
||||
)
|
||||
|
||||
# Initialize a simple runnable to test
|
||||
chain = RunnableLambda(str)
|
||||
|
||||
# pass handler as callback:
|
||||
output = chain.invoke(
|
||||
"input",
|
||||
config={
|
||||
"callbacks": [handler]
|
||||
}
|
||||
)
|
||||
|
||||
"""
|
||||
if not any([token_ratelimit, request_ratelimit]):
|
||||
raise ValueError(
|
||||
"You must pass at least one of input_token_ratelimit or"
|
||||
" request_ratelimit parameters for handler to work."
|
||||
)
|
||||
|
||||
self.identifier = identifier
|
||||
self.token_ratelimit = token_ratelimit
|
||||
self.request_ratelimit = request_ratelimit
|
||||
self.include_output_tokens = include_output_tokens
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> Any:
|
||||
"""
|
||||
Run when chain starts running.
|
||||
|
||||
on_chain_start runs multiple times during a chain execution. To make
|
||||
sure that it's only called once, we keep a bool state `_checked`. If
|
||||
not `self._checked`, we call limit with `request_ratelimit` and raise
|
||||
`UpstashRatelimitError` if the identifier is rate limited.
|
||||
"""
|
||||
if self.request_ratelimit and not self._checked:
|
||||
response = self.request_ratelimit.limit(self.identifier)
|
||||
if not response.allowed:
|
||||
raise UpstashRatelimitError(
|
||||
"Request limit reached!", "request", response.limit, response.reset
|
||||
)
|
||||
self._checked = True
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""
|
||||
Run when LLM starts running
|
||||
"""
|
||||
if self.token_ratelimit:
|
||||
remaining = self.token_ratelimit.get_remaining(self.identifier)
|
||||
if remaining <= 0:
|
||||
raise UpstashRatelimitError("Token limit reached!", "token")
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""
|
||||
Run when LLM ends running
|
||||
|
||||
If the `include_output_tokens` is set to True, number of tokens
|
||||
in LLM completion are counted for rate limiting
|
||||
"""
|
||||
if self.token_ratelimit:
|
||||
try:
|
||||
llm_output = response.llm_output or {}
|
||||
token_usage = llm_output["token_usage"]
|
||||
token_count = (
|
||||
token_usage["total_tokens"]
|
||||
if self.include_output_tokens
|
||||
else token_usage["prompt_tokens"]
|
||||
)
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"LLM response doesn't include"
|
||||
" `token_usage: {total_tokens: int, prompt_tokens: int}`"
|
||||
" field. To use UpstashRatelimitHandler with token_ratelimit,"
|
||||
" either use a model which returns token_usage (like "
|
||||
" OpenAI models) or rate limit only with request_ratelimit."
|
||||
)
|
||||
|
||||
# call limit to add the completion tokens to rate limit
|
||||
# but don't raise exception since we already generated
|
||||
# the tokens and would rather continue execution.
|
||||
self.token_ratelimit.limit(self.identifier, rate=token_count)
|
||||
|
||||
def reset(self, identifier: Optional[str] = None) -> "UpstashRatelimitHandler":
|
||||
"""
|
||||
Creates a new UpstashRatelimitHandler object with the same
|
||||
ratelimit configurations but with a new identifier if it's
|
||||
provided.
|
||||
|
||||
Also resets the state of the handler.
|
||||
"""
|
||||
return UpstashRatelimitHandler(
|
||||
identifier=identifier or self.identifier,
|
||||
token_ratelimit=self.token_ratelimit,
|
||||
request_ratelimit=self.request_ratelimit,
|
||||
include_output_tokens=self.include_output_tokens,
|
||||
)
|
@ -0,0 +1,210 @@
|
||||
from typing import Any
|
||||
from unittest.mock import create_autospec
|
||||
|
||||
import pytest
|
||||
from langchain_core.outputs import LLMResult
|
||||
from upstash_ratelimit import Ratelimit, Response
|
||||
|
||||
from langchain_community.callbacks import UpstashRatelimitError, UpstashRatelimitHandler
|
||||
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture
|
||||
def request_ratelimit() -> Ratelimit:
|
||||
ratelimit = create_autospec(Ratelimit)
|
||||
response = Response(allowed=True, limit=10, remaining=10, reset=10000)
|
||||
ratelimit.limit.return_value = response
|
||||
return ratelimit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def token_ratelimit() -> Ratelimit:
|
||||
ratelimit = create_autospec(Ratelimit)
|
||||
response = Response(allowed=True, limit=1000, remaining=1000, reset=10000)
|
||||
ratelimit.limit.return_value = response
|
||||
ratelimit.get_remaining.return_value = 1000
|
||||
return ratelimit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def handler_with_both_limits(
|
||||
request_ratelimit: Ratelimit, token_ratelimit: Ratelimit
|
||||
) -> UpstashRatelimitHandler:
|
||||
return UpstashRatelimitHandler(
|
||||
identifier="user123",
|
||||
token_ratelimit=token_ratelimit,
|
||||
request_ratelimit=request_ratelimit,
|
||||
include_output_tokens=False,
|
||||
)
|
||||
|
||||
|
||||
# Tests
|
||||
def test_init_no_limits() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
UpstashRatelimitHandler(identifier="user123")
|
||||
|
||||
|
||||
def test_init_request_limit_only(request_ratelimit: Ratelimit) -> None:
|
||||
handler = UpstashRatelimitHandler(
|
||||
identifier="user123", request_ratelimit=request_ratelimit
|
||||
)
|
||||
assert handler.request_ratelimit is not None
|
||||
assert handler.token_ratelimit is None
|
||||
|
||||
|
||||
def test_init_token_limit_only(token_ratelimit: Ratelimit) -> None:
|
||||
handler = UpstashRatelimitHandler(
|
||||
identifier="user123", token_ratelimit=token_ratelimit
|
||||
)
|
||||
assert handler.token_ratelimit is not None
|
||||
assert handler.request_ratelimit is None
|
||||
|
||||
|
||||
def test_on_chain_start_request_limit(handler_with_both_limits: Any) -> None:
|
||||
handler_with_both_limits.on_chain_start(serialized={}, inputs={})
|
||||
handler_with_both_limits.request_ratelimit.limit.assert_called_once_with("user123")
|
||||
handler_with_both_limits.token_ratelimit.limit.assert_not_called()
|
||||
|
||||
|
||||
def test_on_chain_start_request_limit_reached(request_ratelimit: Any) -> None:
|
||||
request_ratelimit.limit.return_value = Response(
|
||||
allowed=False, limit=10, remaining=0, reset=10000
|
||||
)
|
||||
handler = UpstashRatelimitHandler(
|
||||
identifier="user123", token_ratelimit=None, request_ratelimit=request_ratelimit
|
||||
)
|
||||
with pytest.raises(UpstashRatelimitError):
|
||||
handler.on_chain_start(serialized={}, inputs={})
|
||||
|
||||
|
||||
def test_on_llm_start_token_limit_reached(token_ratelimit: Any) -> None:
|
||||
token_ratelimit.get_remaining.return_value = 0
|
||||
handler = UpstashRatelimitHandler(
|
||||
identifier="user123", token_ratelimit=token_ratelimit, request_ratelimit=None
|
||||
)
|
||||
with pytest.raises(UpstashRatelimitError):
|
||||
handler.on_llm_start(serialized={}, prompts=["test"])
|
||||
|
||||
|
||||
def test_on_llm_start_token_limit_reached_negative(token_ratelimit: Any) -> None:
|
||||
token_ratelimit.get_remaining.return_value = -10
|
||||
handler = UpstashRatelimitHandler(
|
||||
identifier="user123", token_ratelimit=token_ratelimit, request_ratelimit=None
|
||||
)
|
||||
with pytest.raises(UpstashRatelimitError):
|
||||
handler.on_llm_start(serialized={}, prompts=["test"])
|
||||
|
||||
|
||||
def test_on_llm_end_with_token_limit(handler_with_both_limits: Any) -> None:
|
||||
response = LLMResult(
|
||||
generations=[],
|
||||
llm_output={
|
||||
"token_usage": {
|
||||
"prompt_tokens": 2,
|
||||
"completion_tokens": 3,
|
||||
"total_tokens": 5,
|
||||
}
|
||||
},
|
||||
)
|
||||
handler_with_both_limits.on_llm_end(response)
|
||||
handler_with_both_limits.token_ratelimit.limit.assert_called_once_with("user123", 2)
|
||||
|
||||
|
||||
def test_on_llm_end_with_token_limit_include_output_tokens(
|
||||
token_ratelimit: Any,
|
||||
) -> None:
|
||||
handler = UpstashRatelimitHandler(
|
||||
identifier="user123",
|
||||
token_ratelimit=token_ratelimit,
|
||||
request_ratelimit=None,
|
||||
include_output_tokens=True,
|
||||
)
|
||||
response = LLMResult(
|
||||
generations=[],
|
||||
llm_output={
|
||||
"token_usage": {
|
||||
"prompt_tokens": 2,
|
||||
"completion_tokens": 3,
|
||||
"total_tokens": 5,
|
||||
}
|
||||
},
|
||||
)
|
||||
handler.on_llm_end(response)
|
||||
token_ratelimit.limit.assert_called_once_with("user123", 5)
|
||||
|
||||
|
||||
def test_on_llm_end_without_token_usage(handler_with_both_limits: Any) -> None:
|
||||
response = LLMResult(generations=[], llm_output={})
|
||||
with pytest.raises(ValueError):
|
||||
handler_with_both_limits.on_llm_end(response)
|
||||
|
||||
|
||||
def test_reset_handler(handler_with_both_limits: Any) -> None:
|
||||
new_handler = handler_with_both_limits.reset(identifier="user456")
|
||||
assert new_handler.identifier == "user456"
|
||||
assert not new_handler._checked
|
||||
|
||||
|
||||
def test_reset_handler_no_new_identifier(handler_with_both_limits: Any) -> None:
|
||||
new_handler = handler_with_both_limits.reset()
|
||||
assert new_handler.identifier == "user123"
|
||||
assert not new_handler._checked
|
||||
|
||||
|
||||
def test_on_chain_start_called_once(handler_with_both_limits: Any) -> None:
|
||||
handler_with_both_limits.on_chain_start(serialized={}, inputs={})
|
||||
handler_with_both_limits.on_chain_start(serialized={}, inputs={})
|
||||
assert handler_with_both_limits.request_ratelimit.limit.call_count == 1
|
||||
|
||||
|
||||
def test_on_chain_start_reset_checked(handler_with_both_limits: Any) -> None:
|
||||
handler_with_both_limits.on_chain_start(serialized={}, inputs={})
|
||||
new_handler = handler_with_both_limits.reset(identifier="user456")
|
||||
new_handler.on_chain_start(serialized={}, inputs={})
|
||||
|
||||
# becomes two because the mock object is kept in reset
|
||||
assert new_handler.request_ratelimit.limit.call_count == 2
|
||||
|
||||
|
||||
def test_on_llm_start_no_token_limit(request_ratelimit: Any) -> None:
|
||||
handler = UpstashRatelimitHandler(
|
||||
identifier="user123", token_ratelimit=None, request_ratelimit=request_ratelimit
|
||||
)
|
||||
handler.on_llm_start(serialized={}, prompts=["test"])
|
||||
assert request_ratelimit.limit.call_count == 0
|
||||
|
||||
|
||||
def test_on_llm_start_token_limit(handler_with_both_limits: Any) -> None:
|
||||
handler_with_both_limits.on_llm_start(serialized={}, prompts=["test"])
|
||||
assert handler_with_both_limits.token_ratelimit.get_remaining.call_count == 1
|
||||
|
||||
|
||||
def test_full_chain_with_both_limits(handler_with_both_limits: Any) -> None:
|
||||
handler_with_both_limits.on_chain_start(serialized={}, inputs={})
|
||||
handler_with_both_limits.on_chain_start(serialized={}, inputs={})
|
||||
|
||||
assert handler_with_both_limits.request_ratelimit.limit.call_count == 1
|
||||
assert handler_with_both_limits.token_ratelimit.limit.call_count == 0
|
||||
assert handler_with_both_limits.token_ratelimit.get_remaining.call_count == 0
|
||||
|
||||
handler_with_both_limits.on_llm_start(serialized={}, prompts=["test"])
|
||||
|
||||
assert handler_with_both_limits.request_ratelimit.limit.call_count == 1
|
||||
assert handler_with_both_limits.token_ratelimit.limit.call_count == 0
|
||||
assert handler_with_both_limits.token_ratelimit.get_remaining.call_count == 1
|
||||
|
||||
response = LLMResult(
|
||||
generations=[],
|
||||
llm_output={
|
||||
"token_usage": {
|
||||
"prompt_tokens": 2,
|
||||
"completion_tokens": 3,
|
||||
"total_tokens": 5,
|
||||
}
|
||||
},
|
||||
)
|
||||
handler_with_both_limits.on_llm_end(response)
|
||||
|
||||
assert handler_with_both_limits.request_ratelimit.limit.call_count == 1
|
||||
assert handler_with_both_limits.token_ratelimit.limit.call_count == 1
|
||||
assert handler_with_both_limits.token_ratelimit.get_remaining.call_count == 1
|
Loading…
Reference in New Issue