fix: added client pool support (#81)

* fix: added client pool support

* Added async across client pool
pull/82/head
Laurel Orr 1 year ago committed by GitHub
parent d375ef0c74
commit db963cf4a7

@ -1,4 +1,14 @@
0.1.1 - Unreleased
0.1.4 - Unreleased
---------------------
Added
^^^^^
* Connection pools to swap between clients
Fixed
^^^^^
* Determine cache and response by request type, not client name
0.1.1
---------------------
Added
^^^^^

@ -8,6 +8,7 @@ How to make prompt programming with Foundation Models a little easier.
- [Manifest](#manifest-components)
- [Local HuggingFace Models](#local-huggingface-models)
- [Embedding Models](#embedding-models)
- [Road Map](#road-map)
- [Development](#development)
- [Cite](#cite)
@ -56,7 +57,7 @@ Manifest is meant to be a very light weight package to help with prompt design a
* All models are behind APIs
* Supports caching of model inputs/outputs for iteration, reproducibility, and cost saving
* Unified API of generate, score, and embed
* Unified API to support generate, score, and embed
## Models
Manifest provides model clients for [OpenAI](https://openai.com/), [AI21](https://studio.ai21.com/), [Cohere](https://cohere.ai/), [Together](https://together.xyz/), and HuggingFace (see [below](#huggingface-models) for how to use locally hosted HuggingFace models). You can toggle between the models by changing `client_name` and `client_connection`. For example, if a HuggingFace model is loaded locally, run
@ -82,6 +83,33 @@ print(manifest.client.get_model_params())
print(manifest.client.get_model_inputs())
```
## Model Pools
Manifest supports querying multiple models with different schedulers. This is very much a work in progress effort, but Manifest will round robin select (or randomly select) the clients you want. You can use the same client multiple times with different connection strings (e.g. different API keys), or you can mix and match. The only requirement is that all clients are the same request type. I.e. you can't have a pool of generation models and embedding models.
To query between a local model and OpenAI,
```python
from manifest.connections.client_pool import ClientConnection
from manifest import Manifest
client_connection1 = ClientConnection(
client_name="huggingface",
client_connection="http://127.0.0.1:5000",
)
client_connection2 = ClientConnection(client_name="openai", engine="text-ada-001")
manifest = Manifest(
client_pool=[client_connection1, client_connection2],
cache_name="sqlite",
client_connection=sqlite_cache,
)
clmanifestient.run(...)
```
The speed benefit also comes in with async batched runs. When calling `arun_batch` with a list of prompts, Manifest supports a `chunk_size` param. This will break the prompts into `chunk_size` chunks to send across all client in the pool asynchronously. By default `chunk_size` is `-1` which means only one client will get a chunk of prompts. You must set `chunk_size > 1` to distribute across the pool. There is a further `batch_size` param which control the individual client `batch_size` to send to the model.
```
responses = asyncio.run(manifest.arun_batch(prompts, max_tokens=30, chunk_size=20))
```
## Global Cache
We support having queries and results stored in a global cache that can be shared across users. We treat inputs and outputs as key value pairs and support SQLite or Redis backends. To start with global caching using SQLite, run
@ -205,6 +233,19 @@ python3 -m manifest.api.app \
--device 0
```
# Road Map
Here's what's coming up next
- [ ] Clients
- [ ] HuggingFace Hub
- [ ] Azure OpenAI
- [ ] Anthropic
- [ ] Data Types
- [ ] Diffusion Models
- [ ] Orchestration
- [ ] Connection pools
- [ ] Local Inference
- [ ] FlexGen
# Development
Before submitting a PR, run
```bash

@ -0,0 +1,208 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"OPENAI_KEY1 = \"sk-XXX\"\n",
"OPENAI_KEY2 = \"sk-XX\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use OpenAI\n",
"\n",
"Set you `OPENAI_API_KEY` environment variable."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from manifest import Manifest\n",
"from manifest.connections.client_pool import ClientConnection\n",
"\n",
"openai_ada = ClientConnection(\n",
" client_name=\"openai\",\n",
" client_connection=OPENAI_KEY1,\n",
" engine=\"text-ada-001\"\n",
")\n",
"\n",
"openai_curie = ClientConnection(\n",
" client_name=\"openai\",\n",
" client_connection=OPENAI_KEY2,\n",
" engine=\"text-curie-001\"\n",
")\n",
"\n",
"manifest = Manifest(client_pool=[openai_ada, openai_curie], client_pool_schedule=\"round_robin\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n",
"I am a model.\n",
"1\n",
"I am a MacBook Pro with a retina\n"
]
}
],
"source": [
"res = manifest.run(\"What model are you?\", temperature=0.0)\n",
"print(manifest.client_pool.current_client_id)\n",
"print(res)\n",
"res = manifest.run(\"What model are you?\", temperature=0.0)\n",
"print(manifest.client_pool.current_client_id)\n",
"print(res)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## With Async"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import nest_asyncio\n",
"# This is required for asyncio.run(...) to work in Jupyter notebooks.\n",
"nest_asyncio.apply()"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"from manifest import Manifest\n",
"from manifest.connections.client_pool import ClientConnection\n",
"\n",
"openai_ada = ClientConnection(\n",
" client_name=\"openai\",\n",
" client_connection=OPENAI_KEY1,\n",
" engine=\"text-ada-001\"\n",
")\n",
"\n",
"openai_babbage = ClientConnection(\n",
" client_name=\"openai\",\n",
" client_connection=OPENAI_KEY2,\n",
" engine=\"text-babbage-001\"\n",
")\n",
"\n",
"openai_curie = ClientConnection(\n",
" client_name=\"openai\",\n",
" client_connection=OPENAI_KEY2,\n",
" engine=\"text-curie-001\"\n",
")\n",
"\n",
"manifest = Manifest(client_pool=[openai_ada, openai_babbage, openai_curie], client_pool_schedule=\"round_robin\")\n",
"manifest_single_client = Manifest(client_pool=[openai_babbage])"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"For loop: 229.93\n",
"Running with async single client\n",
"Running 1 tasks across all clients.\n",
"Async loop: 1.39\n",
"Running with async two clients but not chunking\n",
"Running 1 tasks across all clients.\n",
"Async loop: 1.65\n",
"Running with async two clients and chunk size\n",
"Running 20 tasks across all clients.\n",
"Async loop: 0.64\n"
]
}
],
"source": [
"import time\n",
"import asyncio\n",
"\n",
"prompts = [f\"Tell me something interesting about {i}\" for i in range(400)]\n",
"st = time.time()\n",
"for pmt in prompts:\n",
" _ = manifest_single_client.run(pmt, max_tokens=30)\n",
"print(f\"For loop: {time.time() - st :.2f}\")\n",
"\n",
"print(\"Running with async single client\")\n",
"st = time.time()\n",
"_ = asyncio.run(manifest_single_client.arun_batch(prompts, max_tokens=30, chunk_size=-1))\n",
"print(f\"Async loop: {time.time() - st :.2f}\")\n",
"\n",
"print(\"Running with async two clients but not chunking\")\n",
"st = time.time()\n",
"_ = asyncio.run(manifest.arun_batch(prompts, max_tokens=30, chunk_size=-1))\n",
"print(f\"Async loop: {time.time() - st :.2f}\")\n",
"\n",
"print(\"Running with async two clients and chunk size\")\n",
"st = time.time()\n",
"_ = asyncio.run(manifest.arun_batch(prompts, max_tokens=30, chunk_size=20))\n",
"print(f\"Async loop: {time.time() - st :.2f}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "manifest",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -47,7 +47,7 @@
" cache_name=\"sqlite\",\n",
" cache_connection=\"my_sqlite_manifest.sqlite\"\n",
")\n",
"print(manifest.client.get_model_params())"
"print(manifest.client_pool.get_client().get_model_params())"
]
},
{
@ -86,7 +86,7 @@
" cache_name=\"sqlite\",\n",
" cache_connection=\"my_sqlite_manifest.sqlite\"\n",
")\n",
"print(manifest_diff.client.get_model_params())"
"print(manifest_diff.client_pool.get_client().get_model_params())"
]
},
{

@ -37,7 +37,7 @@
"from manifest import Manifest\n",
"\n",
"manifest = Manifest(client_name=\"openaiembedding\")\n",
"print(manifest.client.get_model_params())"
"print(manifest.client_pool.get_client().get_model_params())"
]
},
{
@ -100,7 +100,7 @@
" cache_name=\"sqlite\",\n",
" cache_connection=\"my_sqlite_manifest.sqlite\"\n",
")\n",
"print(manifest.client.get_model_params())"
"print(manifest.client_pool.get_client().get_model_params())"
]
},
{

@ -30,8 +30,7 @@ class ModelResponse:
"""Return dictionary representation of response."""
key = (
"text"
if self.response_type
not in {"prompt_logit_score", "image_generation", "embedding_generation"}
if self.response_type not in {"image_generation", "embedding_generation"}
else "array"
)
return {

@ -1,17 +1,13 @@
"""Cache for queries and responses."""
from abc import ABC, abstractmethod
from typing import Any, Dict, Union
from typing import Any, Dict, Type, Union
from manifest.caches.serializers import ArraySerializer, NumpyByteSerializer, Serializer
from manifest.request import DiffusionRequest, EmbeddingRequest, LMRequest, Request
from manifest.response import RESPONSE_CONSTRUCTORS, Response
# Non-text return type caches
ARRAY_CACHE_TYPES = {
"diffuser",
"tomadiffuser",
"openaiembedding",
"huggingfaceembedding",
}
ARRAY_CACHE_TYPES = {EmbeddingRequest, DiffusionRequest}
class Cache(ABC):
@ -20,7 +16,7 @@ class Cache(ABC):
def __init__(
self,
connection_str: str,
client_name: str = "None",
request_type: Type[Request] = LMRequest,
cache_args: Dict[str, Any] = {},
):
"""
@ -28,7 +24,7 @@ class Cache(ABC):
Args:
connection_str: connection string.
client_name: name of client.
request_type: request type.
cache_args: arguments for cache.
cache_args are any arguments needed to initialize the cache.
@ -41,12 +37,12 @@ class Cache(ABC):
the entire byte string. `byte_string` is default.
Args:
connection_str: connection string for client.
connection_str: connection string for cache.
cache_args: cache arguments.
"""
self.client_name = client_name
self.request_type = request_type
self.connect(connection_str, cache_args)
if self.client_name in ARRAY_CACHE_TYPES:
if self.request_type in ARRAY_CACHE_TYPES:
array_serializer = cache_args.pop("array_serializer", "byte_string")
if array_serializer not in ["local_file", "byte_string"]:
raise ValueError(
@ -65,13 +61,13 @@ class Cache(ABC):
@abstractmethod
def close(self) -> None:
"""Close the client."""
"""Close the cache."""
raise NotImplementedError()
@abstractmethod
def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
"""
Connect to client.
Connect to cache.
Args:
connection_str: connection string.
@ -129,7 +125,7 @@ class Cache(ABC):
response,
cached,
request,
**RESPONSE_CONSTRUCTORS.get(self.client_name, {}),
**RESPONSE_CONSTRUCTORS.get(self.request_type, {}),
)
return None

@ -128,8 +128,12 @@ class Client(ABC):
request.
"""
params = {"prompt": prompt}
# Adds default values from self.PARAMS if not in request_args
for key in self.PARAMS:
params[key] = request_args.pop(key, getattr(self, key))
# Allows for overriding DEFAULT_REQUEST_KEYS even if they are not
# in self.PARAMS. Note that DEFAULT_REQUEST_KEYS match the default
# values in Request.
for key in DEFAULT_REQUEST_KEYS:
if key not in params and key in request_args:
params[key] = request_args.pop(key)
@ -142,7 +146,10 @@ class Client(ABC):
We drop these before sending to the model.
"""
params_to_add = DEFAULT_REQUEST_KEYS.copy()
# This will override DEFAULT_REQUEST_KEYS with those in PARAMS
params_to_add.update(self.PARAMS)
# to_dict will handle parameter renaming but not any
# default value handling - that is done in get_request()
request_params = request.to_dict(params_to_add)
return request_params
@ -298,7 +305,7 @@ class Client(ABC):
response_dict,
cached=False,
request_params=request_params,
**RESPONSE_CONSTRUCTORS.get(self.NAME, {}), # type: ignore
**RESPONSE_CONSTRUCTORS.get(self.REQUEST_CLS, {}), # type: ignore
)
async def arun_batch_request(self, request: Request) -> Response:
@ -352,7 +359,7 @@ class Client(ABC):
final_response_dict,
cached=False,
request_params=request_params,
**RESPONSE_CONSTRUCTORS.get(self.NAME, {}), # type: ignore
**RESPONSE_CONSTRUCTORS.get(self.REQUEST_CLS, {}), # type: ignore
)
def get_score_prompt_request(

@ -40,6 +40,7 @@ class OpenAIClient(Client):
"stop_sequences": ("stop", None), # OpenAI doesn't like empty lists
"presence_penalty": ("presence_penalty", 0.0),
"frequency_penalty": ("frequency_penalty", 0.0),
"batch_size": ("batch_size", 20),
}
REQUEST_CLS: Type[Request] = LMRequest
NAME = "openai"

@ -0,0 +1 @@
"""Connection init."""

@ -0,0 +1,173 @@
"""Client connection."""
import logging
import time
from typing import Any, Dict, List, Optional, Type
from pydantic import BaseModel, Extra
from manifest.clients.ai21 import AI21Client
from manifest.clients.client import Client
from manifest.clients.cohere import CohereClient
from manifest.clients.dummy import DummyClient
from manifest.clients.huggingface import HuggingFaceClient
from manifest.clients.huggingface_embedding import HuggingFaceEmbeddingClient
from manifest.clients.openai import OpenAIClient
from manifest.clients.openai_chat import OpenAIChatClient
from manifest.clients.openai_embedding import OpenAIEmbeddingClient
from manifest.clients.toma import TOMAClient
from manifest.connections.scheduler import RandomScheduler, RoundRobinScheduler
logging.getLogger("openai").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
CLIENT_CONSTRUCTORS = {
OpenAIClient.NAME: OpenAIClient,
OpenAIChatClient.NAME: OpenAIChatClient,
OpenAIEmbeddingClient.NAME: OpenAIEmbeddingClient,
CohereClient.NAME: CohereClient,
AI21Client.NAME: AI21Client,
HuggingFaceClient.NAME: HuggingFaceClient,
HuggingFaceEmbeddingClient.NAME: HuggingFaceEmbeddingClient,
DummyClient.NAME: DummyClient,
TOMAClient.NAME: TOMAClient,
}
CLIENT_REQUEST_TYPES: Dict[str, Type] = {
k: v.REQUEST_CLS for k, v in CLIENT_CONSTRUCTORS.items()
}
# Diffusion
DIFFUSION_CLIENTS = ["diffuser", "tomadiffuser"]
try:
from manifest.clients.diffuser import DiffuserClient
from manifest.clients.toma_diffuser import TOMADiffuserClient
CLIENT_CONSTRUCTORS[DiffuserClient.NAME] = DiffuserClient
CLIENT_CONSTRUCTORS[TOMADiffuserClient.NAME] = TOMADiffuserClient
except Exception:
logger.info("Diffusion not supported. Skipping import.")
pass
SCHEDULER_CONSTRUCTORS = {
RandomScheduler.NAME: RandomScheduler,
RoundRobinScheduler.NAME: RoundRobinScheduler,
}
class Timing(BaseModel):
"""Timing class."""
start: float = -1.0
end: float = -1.0
class ClientConnection(BaseModel):
"""Client Connection class."""
client_name: str
# Use environment variables (depending on client)
client_connection: Optional[str] = None
# Use default engine
engine: Optional[str] = None
# Prevent extra args
class Config:
"""Config class.
Allows to override pydantic behavior.
"""
extra = Extra.forbid
class ClientConnectionPool:
"""Client connection pool."""
def __init__(
self,
client_pool: List[ClientConnection],
client_pool_scheduler: str = "round_robin",
client_args: Dict[str, Any] = {},
):
"""Init."""
# Verify the clients are allowed and supported
for client in client_pool:
if client.client_name not in CLIENT_CONSTRUCTORS:
if client.client_name in DIFFUSION_CLIENTS:
raise ImportError(
f"Diffusion client {client.client_name} requires "
"the proper install. Make sure to run "
"`pip install manifest-ml[diffusers]` "
"or install Pillow."
)
else:
raise ValueError(
f"Unknown client name: {client.client_name}. "
f"Choices are {list(CLIENT_CONSTRUCTORS.keys())}"
)
# Verify that the serialization of all clients is the same
request_types = set(
[CLIENT_REQUEST_TYPES[client.client_name] for client in client_pool]
)
if len(request_types) > 1:
raise ValueError(
"All clients in the client pool must use the same request type. "
f"You have {sorted(list(map(str, request_types)))}"
)
# Verify scheduler
if client_pool_scheduler not in SCHEDULER_CONSTRUCTORS:
raise ValueError(f"Unknown scheduler: {client_pool_scheduler}.")
self.request_type = request_types.pop()
# Initialize the clients
# We must keep track of the used args so we know
# if a user passed in an arg that was never used
used_args = set()
self.client_pool = []
for client in client_pool:
to_pass_kwargs = client_args.copy()
# Override the engine param for each
to_pass_kwargs.pop("engine", None)
if client.engine:
to_pass_kwargs["engine"] = client.engine
self.client_pool.append(
CLIENT_CONSTRUCTORS[client.client_name]( # type: ignore
client.client_connection, client_args=to_pass_kwargs
)
)
# Udpate used args
for k in client_args:
if k not in to_pass_kwargs:
used_args.add(k)
# Removed used args
for k in used_args:
client_args.pop(k)
# Get the scheduler
self.scheduler = SCHEDULER_CONSTRUCTORS[client_pool_scheduler](
num_clients=len(self.client_pool)
)
self.current_client_id = 0
# Record timing metrics for each client for load balancing
# TODO: Implement this in the future
self.client_pool_metrics = [Timing() for _ in self.client_pool]
def close(self) -> None:
"""Close."""
for client in self.client_pool:
client.close()
def get_client(self) -> Client:
"""Get client."""
client_int = self.scheduler.get_client()
self.current_client_id = client_int
return self.client_pool[client_int]
def start_timer(self) -> None:
"""Start timer."""
self.client_pool_metrics[self.current_client_id].start = time.time()
def end_timer(self) -> None:
"""End timer."""
self.client_pool_metrics[self.current_client_id].end = time.time()

@ -0,0 +1,52 @@
"""Request client schedulers.
Supports random selection and round robin selection.
"""
import numpy as np
class Scheduler:
"""Scheduler base class."""
NAME: str = "scheduler"
def __init__(self, num_clients: int):
"""Initialize scheduler."""
self.num_clients = num_clients
def get_client(self) -> int:
"""Get client by id."""
raise NotImplementedError
class RandomScheduler(Scheduler):
"""Random scheduler."""
NAME: str = "random"
def __init__(self, num_clients: int):
"""Initialize scheduler."""
super().__init__(num_clients)
# Set seed
np.random.seed(0)
def get_client(self) -> int:
"""Get client by id."""
return np.random.randint(self.num_clients)
class RoundRobinScheduler(Scheduler):
"""Round robin scheduler."""
NAME: str = "round_robin"
def __init__(self, num_clients: int):
"""Initialize scheduler."""
super().__init__(num_clients)
self.current_client = 0
def get_client(self) -> int:
"""Get client by id."""
client = self.current_client
self.current_client = (self.current_client + 1) % self.num_clients
return client

@ -1,4 +1,5 @@
"""Manifest class."""
import asyncio
import copy
import logging
from typing import Any, Dict, List, Optional, Tuple, Union, cast
@ -9,45 +10,19 @@ from manifest.caches.noop import NoopCache
from manifest.caches.postgres import PostgresCache
from manifest.caches.redis import RedisCache
from manifest.caches.sqlite import SQLiteCache
from manifest.clients.ai21 import AI21Client
from manifest.clients.cohere import CohereClient
from manifest.clients.dummy import DummyClient
from manifest.clients.client import Client
from manifest.clients.huggingface import HuggingFaceClient
from manifest.clients.huggingface_embedding import HuggingFaceEmbeddingClient
from manifest.clients.openai import OpenAIClient
from manifest.clients.openai_chat import OpenAIChatClient
from manifest.clients.openai_embedding import OpenAIEmbeddingClient
from manifest.clients.toma import TOMAClient
from manifest.connections.client_pool import (
CLIENT_CONSTRUCTORS,
ClientConnection,
ClientConnectionPool,
)
from manifest.request import Request
from manifest.response import Response
logging.getLogger("openai").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
CLIENT_CONSTRUCTORS = {
OpenAIClient.NAME: OpenAIClient,
OpenAIChatClient.NAME: OpenAIChatClient,
OpenAIEmbeddingClient.NAME: OpenAIEmbeddingClient,
CohereClient.NAME: CohereClient,
AI21Client.NAME: AI21Client,
HuggingFaceClient.NAME: HuggingFaceClient,
HuggingFaceEmbeddingClient.NAME: HuggingFaceEmbeddingClient,
DummyClient.NAME: DummyClient,
TOMAClient.NAME: TOMAClient,
}
# Diffusion
DIFFUSION_CLIENTS = ["diffuser", "tomadiffuser"]
try:
from manifest.clients.diffuser import DiffuserClient
from manifest.clients.toma_diffuser import TOMADiffuserClient
CLIENT_CONSTRUCTORS[DiffuserClient.NAME] = DiffuserClient
CLIENT_CONSTRUCTORS[TOMADiffuserClient.NAME] = TOMADiffuserClient
except Exception:
logger.info("Diffusion not supported. Skipping import.")
pass
CACHE_CONSTRUCTORS = {
"redis": RedisCache,
@ -62,8 +37,10 @@ class Manifest:
def __init__(
self,
client_name: str = "openai",
client_name: Optional[str] = None,
client_connection: Optional[str] = None,
client_pool: Optional[List[ClientConnection]] = None,
client_pool_schedule: str = "round_robin",
cache_name: str = "noop",
cache_connection: Optional[str] = None,
stop_token: str = "",
@ -75,6 +52,8 @@ class Manifest:
Args:
client_name: name of client.
client_connection: connection string for client.
client_pool: list of client connections for multi-client.
client_pool_schedule: schedule for client pool.
cache_name: name of cache.
cache_connection: connection string for cache.
stop_token: stop token prompt generation.
@ -82,30 +61,33 @@ class Manifest:
Remaining kwargs sent to client and cache.
"""
if client_name not in CLIENT_CONSTRUCTORS:
if client_name in DIFFUSION_CLIENTS:
raise ImportError(
f"Diffusion client {client_name} requires the proper install. "
"Make sure to run `pip install manifest-ml[diffusers]` "
"or install Pillow."
)
else:
raise ValueError(
f"Unknown client name: {client_name}. "
f"Choices are {list(CLIENT_CONSTRUCTORS.keys())}"
if not client_name and not client_pool:
raise ValueError(
"Must specify client_name or client_pool. "
f"Choices are {list(CLIENT_CONSTRUCTORS.keys())}"
)
if client_name and client_pool:
raise ValueError("Cannot specify both client_name and client_pool")
if client_name:
client_pool = [
ClientConnection(
client_name=client_name,
client_connection=client_connection,
# Remove engine from kwargs
engine=kwargs.pop("engine", None),
)
]
self.client_pool = ClientConnectionPool(
client_pool, client_pool_schedule, client_args=kwargs
)
if cache_name not in CACHE_CONSTRUCTORS:
raise ValueError(
f"Unknown cache name: {cache_name}. "
f"Choices are {list(CACHE_CONSTRUCTORS.keys())}"
)
self.client_name = client_name
# Must pass kwargs as dict for client "pop" methods removed used arguments
self.cache = CACHE_CONSTRUCTORS[cache_name]( # type: ignore
cache_connection, self.client_name, cache_args=kwargs
)
self.client = CLIENT_CONSTRUCTORS[self.client_name]( # type: ignore
client_connection, client_args=kwargs
cache_connection, self.client_pool.request_type, cache_args=kwargs
)
if len(kwargs) > 0:
raise ValueError(f"{list(kwargs.items())} arguments are not recognized.")
@ -114,45 +96,9 @@ class Manifest:
def close(self) -> None:
"""Close the client and cache."""
self.client.close()
self.client_pool.close()
self.cache.close()
def change_client(
self,
client_name: Optional[str] = None,
client_connection: Optional[str] = None,
stop_token: Optional[str] = None,
**kwargs: Any,
) -> None:
"""
Change manifest client.
Args:
client_name: name of client.
client_connection: connection string for client.
stop_token: stop token prompt generation.
Can be overridden in run
Remaining kwargs sent to client.
"""
if client_name:
if client_name not in CLIENT_CONSTRUCTORS:
raise ValueError(
f"Unknown client name: {client_name}. "
f"Choices are {list(CLIENT_CONSTRUCTORS.keys())}"
)
self.client_name = client_name
self.client = CLIENT_CONSTRUCTORS[client_name]( # type: ignore
client_connection, client_args=kwargs
)
if len(kwargs) > 0:
raise ValueError(
f"{list(kwargs.items())} arguments are not recognized."
)
if stop_token is not None:
self.stop_token = stop_token
def _validate_kwargs(self, kwargs: Dict, request_params: Request) -> None:
"""Validate kwargs.
@ -180,6 +126,7 @@ class Manifest:
def _split_cached_requests(
self,
request: Request,
client: Client,
overwrite_cache: bool,
) -> Tuple[Dict[int, Response], Request]:
"""Split a request into cached responses and Requests to run.
@ -201,16 +148,14 @@ class Manifest:
single_request = copy.deepcopy(request)
single_request.prompt = prompt_str
possible_response = self.cache.get(
self.client.get_cache_key(single_request)
client.get_cache_key(single_request)
)
if possible_response:
cached_idx_to_response[idx] = possible_response
else:
new_request.prompt.append(prompt_str)
else:
possible_response = self.cache.get(
self.client.get_cache_key(new_request)
)
possible_response = self.cache.get(client.get_cache_key(new_request))
if possible_response:
cached_idx_to_response[0] = possible_response
new_request.prompt = None
@ -219,6 +164,7 @@ class Manifest:
def _stitch_responses_and_cache(
self,
request: Request,
client: Client,
response: Union[Response, None],
cached_idx_to_response: Dict[int, Response],
) -> Response:
@ -283,7 +229,7 @@ class Manifest:
# set cache
new_request = copy.deepcopy(request)
new_request.prompt = prompt
cache_key = self.client.get_cache_key(new_request)
cache_key = client.get_cache_key(new_request)
new_response_key = copy.deepcopy(response.get_json_response())
new_response_key[response_gen_key] = current_choices
if response_usage_key:
@ -303,7 +249,7 @@ class Manifest:
response_obj = Response(
new_response,
cached=len(cached_idx_to_response) > 0,
request_params=self.client.get_cache_key(new_request),
request_params=client.get_cache_key(new_request),
generation_key=response_gen_key,
logits_key=response_logits_key,
item_key=response_item_key,
@ -334,27 +280,32 @@ class Manifest:
response from prompt.
"""
is_batch = isinstance(prompt, list)
# Get the client to run
client = self.client_pool.get_client()
stop_token = stop_token if stop_token is not None else self.stop_token
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = self.client.get_request(prompt, kwargs)
request_params = client.get_request(prompt, kwargs)
# Avoid nested list of results - enforce n = 1 for batch
if is_batch and request_params.n > 1:
raise ValueError("Batch mode does not support n > 1.")
self._validate_kwargs(kwargs, request_params)
cached_idx_to_response, request_params = self._split_cached_requests(
request_params, overwrite_cache
request_params, client, overwrite_cache
)
# If not None value or empty list - run new request
if request_params.prompt:
response = self.client.run_request(request_params)
# Start timing metrics
self.client_pool.start_timer()
response = client.run_request(request_params)
self.client_pool.end_timer()
else:
# Nothing to run
response = None
final_response = self._stitch_responses_and_cache(
request=request_params,
client=client,
response=response,
cached_idx_to_response=cached_idx_to_response,
)
@ -371,54 +322,119 @@ class Manifest:
overwrite_cache: bool = False,
stop_token: Optional[str] = None,
return_response: bool = False,
chunk_size: int = -1,
**kwargs: Any,
) -> Union[List[str], List[np.ndarray], Response]:
"""
Run a batch of prompts with async.
If the client pool is a single client, all prompts will be sent
to one client and batch_size (which is passed it as kwargs) will
determine how the prompts are split.
If the client pool is a pool of clients, the prompts will be split
into chunks and sent to the clients. Each client will split the
chunk into batch_size prompts to send to the model.
Args:
prompts: prompts to run.
overwrite_cache: whether to overwrite cache.
stop_token: stop token for prompt generation.
Default is self.stop_token.
"" for no stop token.
Default is self.stop_token.
"" for no stop token.
return_response: whether to return Response object.
chunk_size: number of prompts to send to a client in chunks.
For each chunk, the client will split the chunk into
batch_sized prompts to send to the model.
For a single manifest client, there is no impact to
setting chunk_size. For a client pool, chunk_size
can be used to distribute the load across the clients.
Returns:
response from prompt.
"""
# Split the prompts into chunks
prompt_chunks: List[Tuple[Client, List[str]]] = []
if chunk_size > 0:
for i in range(0, len(prompts), chunk_size):
prompt_chunks.append(
(self.client_pool.get_client(), prompts[i : i + chunk_size])
)
else:
prompt_chunks = [(self.client_pool.get_client(), prompts)]
# Run the chunks
tasks = []
for client, chunk in prompt_chunks:
tasks.append(
asyncio.create_task(
self._arun_batch_client(
prompts=chunk,
client=client,
overwrite_cache=overwrite_cache,
**kwargs,
)
)
)
print(f"Running {len(tasks)} tasks across all clients.")
logger.info(f"Running {len(tasks)} tasks across all clients.")
responses = await asyncio.gather(*tasks)
final_response = Response.union_all(responses)
stop_token = stop_token if stop_token is not None else self.stop_token
# Extract text results
if return_response:
return final_response
else:
return cast(
Union[List[str], List[np.ndarray]],
final_response.get_response(stop_token, True),
)
async def _arun_batch_client(
self,
prompts: List[str],
client: Client,
overwrite_cache: bool = False,
**kwargs: Any,
) -> Response:
"""
Run a batch of prompts with async for single client.
Args:
prompts: prompts to run.
client: client to run.
overwrite_cache: whether to overwrite cache.
Returns:
response from prompt.
"""
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = self.client.get_request(prompts, kwargs)
request_params = client.get_request(prompts, kwargs)
# Avoid nested list of results - enforce n = 1 for batch
if request_params.n > 1:
raise ValueError("Batch mode does not support n > 1.")
self._validate_kwargs(kwargs, request_params)
cached_idx_to_response, request_params = self._split_cached_requests(
request_params, overwrite_cache
request_params, client, overwrite_cache
)
# If not None value or empty list - run new request
if request_params.prompt:
response = await self.client.arun_batch_request(request_params)
self.client_pool.start_timer()
response = await client.arun_batch_request(request_params)
self.client_pool.end_timer()
else:
# Nothing to run
response = None
final_response = self._stitch_responses_and_cache(
request=request_params,
client=client,
response=response,
cached_idx_to_response=cached_idx_to_response,
)
# Extract text results
if return_response:
return final_response
else:
return cast(
Union[List[str], List[np.ndarray]],
final_response.get_response(stop_token, True),
)
return final_response
def score_prompt(
self,
@ -438,8 +454,9 @@ class Manifest:
Returns:
response from prompt.
"""
client = self.client_pool.get_client()
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = self.client.get_request(prompt, kwargs)
request_params = client.get_request(prompt, kwargs)
request_params.request_type = "score_prompt"
if request_params.n > 1:
@ -447,14 +464,14 @@ class Manifest:
self._validate_kwargs(kwargs, request_params)
cached_idx_to_response, request_params = self._split_cached_requests(
request_params, overwrite_cache
request_params, client, overwrite_cache
)
# If not None value or empty list - run new request
if request_params.prompt:
try:
response = cast(
HuggingFaceClient, self.client
).get_score_prompt_request(request_params)
response = cast(HuggingFaceClient, client).get_score_prompt_request(
request_params
)
except AttributeError:
raise ValueError("`score_prompt` only supported for HF models.")
else:
@ -463,6 +480,7 @@ class Manifest:
final_response = self._stitch_responses_and_cache(
request=request_params,
client=client,
response=response,
cached_idx_to_response=cached_idx_to_response,
)

@ -4,9 +4,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel
NOT_CACHE_KEYS = {"client_timeout", "batch_size"}
# The below should match those in Request.
DEFAULT_REQUEST_KEYS = {
"client_timeout": ("client_timeout", 60), # seconds
"batch_size": ("batch_size", 1),
"batch_size": ("batch_size", 8),
"run_id": ("run_id", None),
"request_type": ("request_type", None),
}
@ -25,7 +26,7 @@ class Request(BaseModel):
n: int = 1
# Timeout
client_timeout: int = 120
client_timeout: int = 60
# Run id used to repeat run with same parameters
run_id: Optional[str] = None
@ -33,7 +34,7 @@ class Request(BaseModel):
# Batch size for async batch run
batch_size: int = 8
# Request type None is for completion. Used to scoring prompt
# Request type None is for completion. Used for scoring prompt
request_type: str = None
def to_dict(
@ -42,6 +43,9 @@ class Request(BaseModel):
"""
Convert request to a dictionary.
Handles parameter renaming but does not fill in default values.
It will drop any None values.
Add prompt ensures the prompt is always in the output dictionary.
"""
if allowable_keys:

@ -1,23 +1,18 @@
"""Client response."""
import copy
import json
from typing import Any, Dict, List, Union
import numpy as np
from manifest.request import DiffusionRequest, EmbeddingRequest
RESPONSE_CONSTRUCTORS = {
"diffuser": {
"logits_key": "token_logprobs",
"item_key": "array",
},
"tomadiffuser": {
"logits_key": "token_logprobs",
"item_key": "array",
},
"openaiembedding": {
EmbeddingRequest: {
"logits_key": "token_logprobs",
"item_key": "array",
},
"huggingfaceembedding": {
DiffusionRequest: {
"logits_key": "token_logprobs",
"item_key": "array",
},
@ -150,6 +145,72 @@ class Response:
else:
return processed_results
@classmethod
def union_all(cls, responses: List["Response"]) -> "Response":
"""Union a list of response."""
if not responses:
raise ValueError("Response list is empty.")
if len(responses) == 1:
return responses[0]
first_response = responses[0]
generation_key = first_response.generation_key
logits_key = first_response.logits_key
item_key = first_response.item_key
# Usage key may be None, so get first not-None value
possible_usage_keys = [r.usage_key for r in responses if r.usage_key]
if possible_usage_keys:
usage_key = possible_usage_keys[0]
else:
usage_key = None
request = first_response._request_params
# Make sure all responses have the same keys
if not all(
[
(r.generation_key == generation_key)
and (r.logits_key == logits_key)
and (r.item_key == item_key)
# Usage key can be empty
and (not r.usage_key or not usage_key or r.usage_key == usage_key)
for r in responses
]
):
raise ValueError("All responses must have the same keys.")
# Get all the prompts and model choices
all_prompts = []
all_choices = []
all_usages = []
for res in responses:
json_response = res.get_json_response()
res_prompt = res.get_request()["prompt"]
if isinstance(res_prompt, str):
res_prompt = [res_prompt]
all_prompts.extend(res_prompt)
all_choices.extend(json_response[generation_key])
if usage_key and usage_key in json_response:
all_usages.extend(json_response[usage_key])
else:
# Add empty usage
all_usages.extend([{}] * len(res_prompt))
new_request = copy.deepcopy(request)
# TODO: add both models back in request. This should be a lot
# easier after I pydantic the response and request more formally
new_request["prompt"] = all_prompts
new_response = {generation_key: all_choices}
if usage_key:
new_response[usage_key] = all_usages
response_obj = cls(
new_response,
cached=any(res.is_cached() for res in responses),
request_params=new_request,
generation_key=generation_key,
logits_key=logits_key,
item_key=item_key,
usage_key=usage_key,
)
return response_obj
def serialize(self) -> str:
"""
Serialize response to string.

@ -1,5 +1,5 @@
"""Cache test."""
from typing import Dict, cast
from typing import Dict, Type, cast
import numpy as np
import pytest
@ -11,16 +11,17 @@ from manifest.caches.noop import NoopCache
from manifest.caches.postgres import PostgresCache
from manifest.caches.redis import RedisCache
from manifest.caches.sqlite import SQLiteCache
from manifest.request import DiffusionRequest, LMRequest, Request
def _get_postgres_cache(
client_name: str = "", cache_args: Dict = {}
request_type: Type[Request] = LMRequest, cache_args: Dict = {}
) -> Cache: # type: ignore
"""Get postgres cache."""
cache_args.update({"cache_user": "", "cache_password": "", "cache_db": ""})
return PostgresCache(
"postgres",
client_name=client_name,
request_type=request_type,
cache_args=cache_args,
)
@ -105,11 +106,11 @@ def test_get(
compute_arr_response = {"choices": [{"array": arr}]}
if cache_type == "sqlite":
cache = SQLiteCache(sqlite_cache, client_name="diffuser")
cache = SQLiteCache(sqlite_cache, request_type=DiffusionRequest)
elif cache_type == "redis":
cache = RedisCache(redis_cache, client_name="diffuser")
cache = RedisCache(redis_cache, request_type=DiffusionRequest)
elif cache_type == "postgres":
cache = _get_postgres_cache(client_name="diffuser")
cache = _get_postgres_cache(request_type=DiffusionRequest)
response = cache.get(test_request)
assert response is None
@ -128,18 +129,19 @@ def test_get(
if cache_type == "sqlite":
cache = SQLiteCache(
sqlite_cache,
client_name="diffuser",
request_type=DiffusionRequest,
cache_args={"array_serializer": "byte_string"},
)
elif cache_type == "redis":
cache = RedisCache(
redis_cache,
client_name="diffuser",
request_type=DiffusionRequest,
cache_args={"array_serializer": "byte_string"},
)
elif cache_type == "postgres":
cache = _get_postgres_cache(
client_name="diffuser", cache_args={"array_serializer": "byte_string"}
request_type=DiffusionRequest,
cache_args={"array_serializer": "byte_string"},
)
response = cache.get(test_request)
@ -186,11 +188,11 @@ def test_get_batch_prompt(
compute_arr_response = {"choices": [{"array": arr}, {"array": arr2}]}
if cache_type == "sqlite":
cache = SQLiteCache(sqlite_cache, client_name="diffuser")
cache = SQLiteCache(sqlite_cache, request_type=DiffusionRequest)
elif cache_type == "redis":
cache = RedisCache(redis_cache, client_name="diffuser")
cache = RedisCache(redis_cache, request_type=DiffusionRequest)
elif cache_type == "postgres":
cache = _get_postgres_cache(client_name="diffuser")
cache = _get_postgres_cache(request_type=DiffusionRequest)
response = cache.get(test_request)
assert response is None
@ -211,18 +213,19 @@ def test_get_batch_prompt(
if cache_type == "sqlite":
cache = SQLiteCache(
sqlite_cache,
client_name="diffuser",
request_type=DiffusionRequest,
cache_args={"array_serializer": "byte_string"},
)
elif cache_type == "redis":
cache = RedisCache(
redis_cache,
client_name="diffuser",
request_type=DiffusionRequest,
cache_args={"array_serializer": "byte_string"},
)
elif cache_type == "postgres":
cache = _get_postgres_cache(
client_name="diffuser", cache_args={"array_serializer": "byte_string"}
request_type=DiffusionRequest,
cache_args={"array_serializer": "byte_string"},
)
response = cache.get(test_request)

@ -0,0 +1,63 @@
"""Test client pool."""
import time
import pytest
from manifest.connections.client_pool import ClientConnection, ClientConnectionPool
from manifest.request import LMRequest
def test_init() -> None:
"""Test initialization."""
client_connection1 = ClientConnection(
client_name="openai", client_connection="XXX", engine="text-davinci-002"
)
client_connection2 = ClientConnection(
client_name="openai", client_connection="XXX", engine="text-ada-001"
)
client_connection3 = ClientConnection(
client_name="openaiembedding", client_connection="XXX"
)
with pytest.raises(ValueError) as exc_info:
ClientConnectionPool(
[client_connection1, client_connection2], client_pool_scheduler="bad"
)
assert str(exc_info.value) == "Unknown scheduler: bad."
with pytest.raises(ValueError) as exc_info:
ClientConnectionPool([client_connection1, client_connection3])
assert (
str(exc_info.value)
== "All clients in the client pool must use the same request type. You have [\"<class 'manifest.request.EmbeddingRequest'>\", \"<class 'manifest.request.LMRequest'>\"]" # noqa: E501"
)
pool = ClientConnectionPool([client_connection1, client_connection2])
assert pool.request_type == LMRequest
assert len(pool.client_pool) == 2
assert len(pool.client_pool_metrics) == 2
assert pool.client_pool[0].engine == "text-davinci-002" # type: ignore
assert pool.client_pool[1].engine == "text-ada-001" # type: ignore
def test_timing() -> None:
"""Test timing client."""
client_connection1 = ClientConnection(client_name="dummy")
client_connection2 = ClientConnection(client_name="dummy")
connection_pool = ClientConnectionPool([client_connection1, client_connection2])
connection_pool.get_client()
assert connection_pool.current_client_id == 0
connection_pool.start_timer()
time.sleep(2)
connection_pool.end_timer()
connection_pool.get_client()
assert connection_pool.current_client_id == 1
connection_pool.start_timer()
time.sleep(1)
connection_pool.end_timer()
timing = connection_pool.client_pool_metrics
assert timing[0].end - timing[0].start > 1.9
assert timing[1].end - timing[1].start > 0.9

@ -13,6 +13,7 @@ from manifest import Manifest, Response
from manifest.caches.noop import NoopCache
from manifest.caches.sqlite import SQLiteCache
from manifest.clients.dummy import DummyClient
from manifest.connections.client_pool import ClientConnection
URL = "http://localhost:6000"
try:
@ -41,10 +42,11 @@ def test_init(sqlite_cache: str) -> None:
cache_name="sqlite",
cache_connection=sqlite_cache,
)
assert manifest.client_name == "dummy"
assert isinstance(manifest.client, DummyClient)
assert len(manifest.client_pool.client_pool) == 1
client = manifest.client_pool.get_client()
assert isinstance(client, DummyClient)
assert isinstance(manifest.cache, SQLiteCache)
assert manifest.client.n == 1 # type: ignore
assert client.n == 1 # type: ignore
assert manifest.stop_token == ""
manifest = Manifest(
@ -53,34 +55,11 @@ def test_init(sqlite_cache: str) -> None:
n=3,
stop_token="\n",
)
assert manifest.client_name == "dummy"
assert isinstance(manifest.client, DummyClient)
assert len(manifest.client_pool.client_pool) == 1
client = manifest.client_pool.get_client()
assert isinstance(client, DummyClient)
assert isinstance(manifest.cache, NoopCache)
assert manifest.client.n == 3 # type: ignore
assert manifest.stop_token == "\n"
@pytest.mark.usefixtures("sqlite_cache")
def test_change_manifest(sqlite_cache: str) -> None:
"""Test manifest change."""
manifest = Manifest(
client_name="dummy",
cache_name="sqlite",
cache_connection=sqlite_cache,
)
manifest.change_client()
assert manifest.client_name == "dummy"
assert isinstance(manifest.client, DummyClient)
assert isinstance(manifest.cache, SQLiteCache)
assert manifest.client.n == 1 # type: ignore
assert manifest.stop_token == ""
manifest.change_client(stop_token="\n")
assert manifest.client_name == "dummy"
assert isinstance(manifest.client, DummyClient)
assert isinstance(manifest.cache, SQLiteCache)
assert manifest.client.n == 1 # type: ignore
assert client.n == 3 # type: ignore
assert manifest.stop_token == "\n"
@ -102,7 +81,7 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
assert str(exc_info.value) == "[('bad_input', 5)] arguments are not recognized."
# Allow params in the request object but not in the client to go through
assert "top_k" not in manifest.client.PARAMS
assert "top_k" not in manifest.client_pool.get_client().PARAMS
result = manifest.run(prompt, return_response=return_response, top_k=5)
assert result is not None
@ -891,6 +870,133 @@ def test_openaiembedding(sqlite_cache: str) -> None:
assert response.is_cached() is True
@pytest.mark.skipif(not OPENAI_ALIVE, reason="No openai key set")
@pytest.mark.usefixtures("sqlite_cache")
def test_openai_pool(sqlite_cache: str) -> None:
"""Test openai and openaichat client."""
client_connection1 = ClientConnection(
client_name="openaichat",
)
client_connection2 = ClientConnection(client_name="openai", engine="text-ada-001")
client = Manifest(
client_pool=[client_connection1, client_connection2],
cache_name="sqlite",
client_connection=sqlite_cache,
)
res = client.run("Why are there apples?")
assert isinstance(res, str) and len(res) > 0
res2 = client.run("Why are there apples?")
assert isinstance(res2, str) and len(res2) > 0
# Different models
assert res != res2
assert cast(
Response, client.run("Why are there apples?", return_response=True)
).is_cached()
res_list = asyncio.run(
client.arun_batch(["Why are there pears?", "Why are there oranges?"])
)
assert isinstance(res_list, list) and len(res_list) == 2
res_list2 = asyncio.run(
client.arun_batch(["Why are there pears?", "Why are there oranges?"])
)
assert isinstance(res_list2, list) and len(res_list2) == 2
# Different models
assert res_list != res_list2
assert cast(
Response,
asyncio.run(
client.arun_batch(
["Why are there pears?", "Why are there oranges?"], return_response=True
)
),
).is_cached()
# Test chunk size of 1
res_list = asyncio.run(
client.arun_batch(
["Why are there pineapples?", "Why are there pinecones?"], chunk_size=1
)
)
assert isinstance(res_list, list) and len(res_list) == 2
res_list2 = asyncio.run(
client.arun_batch(
["Why are there pineapples?", "Why are there pinecones?"], chunk_size=1
)
)
# Because we split across both models exactly in first run,
# we will get the same result
assert res_list == res_list2
@pytest.mark.skipif(
not OPENAI_ALIVE or not MODEL_ALIVE, reason="No openai or local model set"
)
@pytest.mark.usefixtures("sqlite_cache")
def test_mixed_pool(sqlite_cache: str) -> None:
"""Test openai and openaichat client."""
client_connection1 = ClientConnection(
client_name="huggingface",
client_connection=URL,
)
client_connection2 = ClientConnection(client_name="openai", engine="text-ada-001")
client = Manifest(
client_pool=[client_connection1, client_connection2],
cache_name="sqlite",
client_connection=sqlite_cache,
)
res = client.run("Why are there apples?")
assert isinstance(res, str) and len(res) > 0
res2 = client.run("Why are there apples?")
assert isinstance(res2, str) and len(res2) > 0
# Different models
assert res != res2
assert cast(
Response, client.run("Why are there apples?", return_response=True)
).is_cached()
res_list = asyncio.run(
client.arun_batch(["Why are there pears?", "Why are there oranges?"])
)
assert isinstance(res_list, list) and len(res_list) == 2
res_list2 = asyncio.run(
client.arun_batch(["Why are there pears?", "Why are there oranges?"])
)
assert isinstance(res_list2, list) and len(res_list2) == 2
# Different models
assert res_list != res_list2
assert cast(
Response,
asyncio.run(
client.arun_batch(
["Why are there pears?", "Why are there oranges?"], return_response=True
)
),
).is_cached()
# Test chunk size of 1
res_list = asyncio.run(
client.arun_batch(
["Why are there pineapples?", "Why are there pinecones?"], chunk_size=1
)
)
assert isinstance(res_list, list) and len(res_list) == 2
res_list2 = asyncio.run(
client.arun_batch(
["Why are there pineapples?", "Why are there pinecones?"], chunk_size=1
)
)
# Because we split across both models exactly in first run,
# we will get the same result
assert res_list == res_list2
def test_retry_handling() -> None:
"""Test retry handling."""
# We'll mock the response so we won't need a real connection

@ -1,8 +1,11 @@
"""Response test."""
from typing import Any, Dict
import numpy as np
import pytest
from manifest import Response
from manifest.request import LMRequest
def test_init() -> None:
@ -141,3 +144,87 @@ def test_get_results() -> None:
)
assert response.get_response() == [float_arr, float_arr]
assert response.get_response(stop_token="m") == [float_arr, float_arr]
def test_union_all() -> None:
"""Test union all."""
request_paramsa = LMRequest(prompt=["apple", "orange", "pear"]).to_dict()
request_paramsa["model"] = "modelA"
response_paramsa = {
"choices": [
{"text": "hello", "token_logprobs": [1]},
{"text": "hello 2", "token_logprobs": [1]},
{"text": "hello 3", "token_logprobs": [1]},
]
}
responsea = Response(response_paramsa, False, request_paramsa)
request_paramsb = LMRequest(prompt=["banana", "pineapple", "mango"]).to_dict()
request_paramsb["model"] = "modelB"
response_paramsb = {
"choices": [
{"text": "bye", "token_logprobs": [2]},
{"text": "bye 2", "token_logprobs": [2]},
{"text": "bye 3", "token_logprobs": [2]},
]
}
responseb = Response(response_paramsb, False, request_paramsb)
final_response = Response.union_all([responsea, responseb])
assert final_response.get_json_response() == {
"choices": [
{"text": "hello", "token_logprobs": [1]},
{"text": "hello 2", "token_logprobs": [1]},
{"text": "hello 3", "token_logprobs": [1]},
{"text": "bye", "token_logprobs": [2]},
{"text": "bye 2", "token_logprobs": [2]},
{"text": "bye 3", "token_logprobs": [2]},
]
}
final_request = LMRequest(
prompt=["apple", "orange", "pear", "banana", "pineapple", "mango"]
).to_dict()
final_request["model"] = "modelA"
assert final_response.get_request() == final_request
assert not final_response.is_cached()
# Modify A to have usage and cached
response_paramsa_2: Dict[str, Any] = {
"choices": [
{"text": "hello", "token_logprobs": [1]},
{"text": "hello 2", "token_logprobs": [1]},
{"text": "hello 3", "token_logprobs": [1]},
],
"usage": [
{"completion_tokens": 10},
{"completion_tokens": 10},
{"completion_tokens": 10},
],
}
responsea = Response(response_paramsa_2, True, request_paramsa)
final_response = Response.union_all([responsea, responseb])
assert final_response.get_json_response() == {
"choices": [
{"text": "hello", "token_logprobs": [1]},
{"text": "hello 2", "token_logprobs": [1]},
{"text": "hello 3", "token_logprobs": [1]},
{"text": "bye", "token_logprobs": [2]},
{"text": "bye 2", "token_logprobs": [2]},
{"text": "bye 3", "token_logprobs": [2]},
],
"usage": [
{"completion_tokens": 10},
{"completion_tokens": 10},
{"completion_tokens": 10},
{},
{},
{},
],
}
final_request = LMRequest(
prompt=["apple", "orange", "pear", "banana", "pineapple", "mango"]
).to_dict()
final_request["model"] = "modelA"
assert final_response.get_request() == final_request
assert final_response.is_cached()

@ -0,0 +1,25 @@
"""Test scheduler."""
from manifest.connections.scheduler import RandomScheduler, RoundRobinScheduler
def test_random_scheduler() -> None:
"""Test random scheduler."""
scheduler = RandomScheduler(num_clients=2)
# Try 20 clients and make sure 0 and 1 are both
# returned
client_ids = set()
for _ in range(20):
client_id = scheduler.get_client()
assert client_id in [0, 1]
client_ids.add(client_id)
assert len(client_ids) == 2
def test_round_robin_scheduler() -> None:
"""Test round robin scheduler."""
scheduler = RoundRobinScheduler(num_clients=2)
assert scheduler.get_client() == 0
assert scheduler.get_client() == 1
assert scheduler.get_client() == 0
assert scheduler.get_client() == 1
Loading…
Cancel
Save