@ -1,4 +1,5 @@
""" Manifest class. """
import asyncio
import copy
import logging
from typing import Any , Dict , List , Optional , Tuple , Union , cast
@ -9,45 +10,19 @@ from manifest.caches.noop import NoopCache
from manifest . caches . postgres import PostgresCache
from manifest . caches . redis import RedisCache
from manifest . caches . sqlite import SQLiteCache
from manifest . clients . ai21 import AI21Client
from manifest . clients . cohere import CohereClient
from manifest . clients . dummy import DummyClient
from manifest . clients . client import Client
from manifest . clients . huggingface import HuggingFaceClient
from manifest . c lients. huggingface_embedding import HuggingFaceEmbeddingClient
from manifest . clients . openai import OpenAIClient
from manifest . clients . openai_chat import OpenAIChatClient
from manifest . clients . openai_embedding import OpenAIEmbeddingClient
from manifest . clients . toma import TOMAClient
from manifest . connections . client_pool import (
CLIENT_CONSTRUCTORS ,
ClientConnection ,
ClientConnectionPool ,
)
from manifest . request import Request
from manifest . response import Response
logging . getLogger ( " openai " ) . setLevel ( logging . WARNING )
logger = logging . getLogger ( __name__ )
CLIENT_CONSTRUCTORS = {
OpenAIClient . NAME : OpenAIClient ,
OpenAIChatClient . NAME : OpenAIChatClient ,
OpenAIEmbeddingClient . NAME : OpenAIEmbeddingClient ,
CohereClient . NAME : CohereClient ,
AI21Client . NAME : AI21Client ,
HuggingFaceClient . NAME : HuggingFaceClient ,
HuggingFaceEmbeddingClient . NAME : HuggingFaceEmbeddingClient ,
DummyClient . NAME : DummyClient ,
TOMAClient . NAME : TOMAClient ,
}
# Diffusion
DIFFUSION_CLIENTS = [ " diffuser " , " tomadiffuser " ]
try :
from manifest . clients . diffuser import DiffuserClient
from manifest . clients . toma_diffuser import TOMADiffuserClient
CLIENT_CONSTRUCTORS [ DiffuserClient . NAME ] = DiffuserClient
CLIENT_CONSTRUCTORS [ TOMADiffuserClient . NAME ] = TOMADiffuserClient
except Exception :
logger . info ( " Diffusion not supported. Skipping import. " )
pass
CACHE_CONSTRUCTORS = {
" redis " : RedisCache ,
@ -62,8 +37,10 @@ class Manifest:
def __init__ (
self ,
client_name : str = " openai " ,
client_name : Optional [ str ] = None ,
client_connection : Optional [ str ] = None ,
client_pool : Optional [ List [ ClientConnection ] ] = None ,
client_pool_schedule : str = " round_robin " ,
cache_name : str = " noop " ,
cache_connection : Optional [ str ] = None ,
stop_token : str = " " ,
@ -75,6 +52,8 @@ class Manifest:
Args :
client_name : name of client .
client_connection : connection string for client .
client_pool : list of client connections for multi - client .
client_pool_schedule : schedule for client pool .
cache_name : name of cache .
cache_connection : connection string for cache .
stop_token : stop token prompt generation .
@ -82,30 +61,33 @@ class Manifest:
Remaining kwargs sent to client and cache .
"""
if client_name not in CLIENT_CONSTRUCTORS :
if client_name in DIFFUSION_CLIENTS :
raise ImportError (
f " Diffusion client { client_name } requires the proper install. "
" Make sure to run `pip install manifest-ml[diffusers]` "
" or install Pillow. "
)
else :
raise ValueError (
f " Unknown client name: { client_name } . "
f " Choices are { list ( CLIENT_CONSTRUCTORS . keys ( ) ) } "
if not client_name and not client_pool :
raise ValueError (
" Must specify client_name or client_pool. "
f " Choices are { list ( CLIENT_CONSTRUCTORS . keys ( ) ) } "
)
if client_name and client_pool :
raise ValueError ( " Cannot specify both client_name and client_pool " )
if client_name :
client_pool = [
ClientConnection (
client_name = client_name ,
client_connection = client_connection ,
# Remove engine from kwargs
engine = kwargs . pop ( " engine " , None ) ,
)
]
self . client_pool = ClientConnectionPool (
client_pool , client_pool_schedule , client_args = kwargs
)
if cache_name not in CACHE_CONSTRUCTORS :
raise ValueError (
f " Unknown cache name: { cache_name } . "
f " Choices are { list ( CACHE_CONSTRUCTORS . keys ( ) ) } "
)
self . client_name = client_name
# Must pass kwargs as dict for client "pop" methods removed used arguments
self . cache = CACHE_CONSTRUCTORS [ cache_name ] ( # type: ignore
cache_connection , self . client_name , cache_args = kwargs
)
self . client = CLIENT_CONSTRUCTORS [ self . client_name ] ( # type: ignore
client_connection , client_args = kwargs
cache_connection , self . client_pool . request_type , cache_args = kwargs
)
if len ( kwargs ) > 0 :
raise ValueError ( f " { list ( kwargs . items ( ) ) } arguments are not recognized. " )
@ -114,45 +96,9 @@ class Manifest:
def close ( self ) - > None :
""" Close the client and cache. """
self . client . close ( )
self . client _pool . close ( )
self . cache . close ( )
def change_client (
self ,
client_name : Optional [ str ] = None ,
client_connection : Optional [ str ] = None ,
stop_token : Optional [ str ] = None ,
* * kwargs : Any ,
) - > None :
"""
Change manifest client .
Args :
client_name : name of client .
client_connection : connection string for client .
stop_token : stop token prompt generation .
Can be overridden in run
Remaining kwargs sent to client .
"""
if client_name :
if client_name not in CLIENT_CONSTRUCTORS :
raise ValueError (
f " Unknown client name: { client_name } . "
f " Choices are { list ( CLIENT_CONSTRUCTORS . keys ( ) ) } "
)
self . client_name = client_name
self . client = CLIENT_CONSTRUCTORS [ client_name ] ( # type: ignore
client_connection , client_args = kwargs
)
if len ( kwargs ) > 0 :
raise ValueError (
f " { list ( kwargs . items ( ) ) } arguments are not recognized. "
)
if stop_token is not None :
self . stop_token = stop_token
def _validate_kwargs ( self , kwargs : Dict , request_params : Request ) - > None :
""" Validate kwargs.
@ -180,6 +126,7 @@ class Manifest:
def _split_cached_requests (
self ,
request : Request ,
client : Client ,
overwrite_cache : bool ,
) - > Tuple [ Dict [ int , Response ] , Request ] :
""" Split a request into cached responses and Requests to run.
@ -201,16 +148,14 @@ class Manifest:
single_request = copy . deepcopy ( request )
single_request . prompt = prompt_str
possible_response = self . cache . get (
self . client . get_cache_key ( single_request )
client . get_cache_key ( single_request )
)
if possible_response :
cached_idx_to_response [ idx ] = possible_response
else :
new_request . prompt . append ( prompt_str )
else :
possible_response = self . cache . get (
self . client . get_cache_key ( new_request )
)
possible_response = self . cache . get ( client . get_cache_key ( new_request ) )
if possible_response :
cached_idx_to_response [ 0 ] = possible_response
new_request . prompt = None
@ -219,6 +164,7 @@ class Manifest:
def _stitch_responses_and_cache (
self ,
request : Request ,
client : Client ,
response : Union [ Response , None ] ,
cached_idx_to_response : Dict [ int , Response ] ,
) - > Response :
@ -283,7 +229,7 @@ class Manifest:
# set cache
new_request = copy . deepcopy ( request )
new_request . prompt = prompt
cache_key = self . client . get_cache_key ( new_request )
cache_key = client . get_cache_key ( new_request )
new_response_key = copy . deepcopy ( response . get_json_response ( ) )
new_response_key [ response_gen_key ] = current_choices
if response_usage_key :
@ -303,7 +249,7 @@ class Manifest:
response_obj = Response (
new_response ,
cached = len ( cached_idx_to_response ) > 0 ,
request_params = self . client . get_cache_key ( new_request ) ,
request_params = client . get_cache_key ( new_request ) ,
generation_key = response_gen_key ,
logits_key = response_logits_key ,
item_key = response_item_key ,
@ -334,27 +280,32 @@ class Manifest:
response from prompt .
"""
is_batch = isinstance ( prompt , list )
# Get the client to run
client = self . client_pool . get_client ( )
stop_token = stop_token if stop_token is not None else self . stop_token
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = self . client . get_request ( prompt , kwargs )
request_params = client . get_request ( prompt , kwargs )
# Avoid nested list of results - enforce n = 1 for batch
if is_batch and request_params . n > 1 :
raise ValueError ( " Batch mode does not support n > 1. " )
self . _validate_kwargs ( kwargs , request_params )
cached_idx_to_response , request_params = self . _split_cached_requests (
request_params , overwrite_cache
request_params , client, overwrite_cache
)
# If not None value or empty list - run new request
if request_params . prompt :
response = self . client . run_request ( request_params )
# Start timing metrics
self . client_pool . start_timer ( )
response = client . run_request ( request_params )
self . client_pool . end_timer ( )
else :
# Nothing to run
response = None
final_response = self . _stitch_responses_and_cache (
request = request_params ,
client = client ,
response = response ,
cached_idx_to_response = cached_idx_to_response ,
)
@ -371,54 +322,119 @@ class Manifest:
overwrite_cache : bool = False ,
stop_token : Optional [ str ] = None ,
return_response : bool = False ,
chunk_size : int = - 1 ,
* * kwargs : Any ,
) - > Union [ List [ str ] , List [ np . ndarray ] , Response ] :
"""
Run a batch of prompts with async .
If the client pool is a single client , all prompts will be sent
to one client and batch_size ( which is passed it as kwargs ) will
determine how the prompts are split .
If the client pool is a pool of clients , the prompts will be split
into chunks and sent to the clients . Each client will split the
chunk into batch_size prompts to send to the model .
Args :
prompts : prompts to run .
overwrite_cache : whether to overwrite cache .
stop_token : stop token for prompt generation .
Default is self . stop_token .
" " for no stop token .
Default is self . stop_token .
" " for no stop token .
return_response : whether to return Response object .
chunk_size : number of prompts to send to a client in chunks .
For each chunk , the client will split the chunk into
batch_sized prompts to send to the model .
For a single manifest client , there is no impact to
setting chunk_size . For a client pool , chunk_size
can be used to distribute the load across the clients .
Returns :
response from prompt .
"""
# Split the prompts into chunks
prompt_chunks : List [ Tuple [ Client , List [ str ] ] ] = [ ]
if chunk_size > 0 :
for i in range ( 0 , len ( prompts ) , chunk_size ) :
prompt_chunks . append (
( self . client_pool . get_client ( ) , prompts [ i : i + chunk_size ] )
)
else :
prompt_chunks = [ ( self . client_pool . get_client ( ) , prompts ) ]
# Run the chunks
tasks = [ ]
for client , chunk in prompt_chunks :
tasks . append (
asyncio . create_task (
self . _arun_batch_client (
prompts = chunk ,
client = client ,
overwrite_cache = overwrite_cache ,
* * kwargs ,
)
)
)
print ( f " Running { len ( tasks ) } tasks across all clients. " )
logger . info ( f " Running { len ( tasks ) } tasks across all clients. " )
responses = await asyncio . gather ( * tasks )
final_response = Response . union_all ( responses )
stop_token = stop_token if stop_token is not None else self . stop_token
# Extract text results
if return_response :
return final_response
else :
return cast (
Union [ List [ str ] , List [ np . ndarray ] ] ,
final_response . get_response ( stop_token , True ) ,
)
async def _arun_batch_client (
self ,
prompts : List [ str ] ,
client : Client ,
overwrite_cache : bool = False ,
* * kwargs : Any ,
) - > Response :
"""
Run a batch of prompts with async for single client .
Args :
prompts : prompts to run .
client : client to run .
overwrite_cache : whether to overwrite cache .
Returns :
response from prompt .
"""
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = self . client . get_request ( prompts , kwargs )
request_params = client . get_request ( prompts , kwargs )
# Avoid nested list of results - enforce n = 1 for batch
if request_params . n > 1 :
raise ValueError ( " Batch mode does not support n > 1. " )
self . _validate_kwargs ( kwargs , request_params )
cached_idx_to_response , request_params = self . _split_cached_requests (
request_params , overwrite_cache
request_params , client, overwrite_cache
)
# If not None value or empty list - run new request
if request_params . prompt :
response = await self . client . arun_batch_request ( request_params )
self . client_pool . start_timer ( )
response = await client . arun_batch_request ( request_params )
self . client_pool . end_timer ( )
else :
# Nothing to run
response = None
final_response = self . _stitch_responses_and_cache (
request = request_params ,
client = client ,
response = response ,
cached_idx_to_response = cached_idx_to_response ,
)
# Extract text results
if return_response :
return final_response
else :
return cast (
Union [ List [ str ] , List [ np . ndarray ] ] ,
final_response . get_response ( stop_token , True ) ,
)
return final_response
def score_prompt (
self ,
@ -438,8 +454,9 @@ class Manifest:
Returns :
response from prompt .
"""
client = self . client_pool . get_client ( )
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = self . client . get_request ( prompt , kwargs )
request_params = client . get_request ( prompt , kwargs )
request_params . request_type = " score_prompt "
if request_params . n > 1 :
@ -447,14 +464,14 @@ class Manifest:
self . _validate_kwargs ( kwargs , request_params )
cached_idx_to_response , request_params = self . _split_cached_requests (
request_params , overwrite_cache
request_params , client, overwrite_cache
)
# If not None value or empty list - run new request
if request_params . prompt :
try :
response = cast (
HuggingFaceClient, self . client
) . get_score_prompt_request ( request_params )
response = cast ( HuggingFaceClient , client ) . get_score_prompt_request (
request_params
)
except AttributeError :
raise ValueError ( " `score_prompt` only supported for HF models. " )
else :
@ -463,6 +480,7 @@ class Manifest:
final_response = self . _stitch_responses_and_cache (
request = request_params ,
client = client ,
response = response ,
cached_idx_to_response = cached_idx_to_response ,
)