feat: added run_chat for chat models (#87)

pull/88/head
Laurel Orr 1 year ago committed by GitHub
parent c0b4644a1c
commit afe0fc5a1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,5 +1,8 @@
0.1.5 - Unreleased
---------------------
Added
^^^^^
* Added chat input for chat models.
0.1.4 - 2022-04-24
---------------------

@ -7,6 +7,7 @@ How to make prompt programming with Foundation Models a little easier.
- [Getting Started](#getting-started)
- [Manifest](#manifest-components)
- [Local HuggingFace Models](#local-huggingface-models)
- [Chat Models](#chat-models)
- [Embedding Models](#embedding-models)
- [Road Map](#road-map)
- [Development](#development)
@ -79,8 +80,8 @@ You can also just set `export COHERE_API_KEY=<COHERE_API_KEY>` and not use `clie
You can see the model details and possible model inputs to `run()` via
```python
print(manifest.client.get_model_params())
print(manifest.client.get_model_inputs())
print(manifest.client_pool.get_client().get_model_params())
print(manifest.client_pool.get_client().get_model_inputs())
```
## Global Cache
@ -216,6 +217,18 @@ python3 -m manifest.api.app \
--percent_max_gpu_mem_reduction 0.85
```
# Chat Models
Manifest has specific support for executing against chat models in the more standard "system" / "user" dialogue. To pass in a dialogue history to Manifest, you must use the `run_chat` command with an associated chat model such as `openaichat`.
```python
manifest = Manifest(client_name="openaichat")
dialogue = [
{"role": "system", "content": "You are a helpful assistant who also responds in rhymes"},
{"role": "user", "content": "What is the date?"},
]
res = manifest.run_chat(dialogue, max_tokens=100)
```
# Embedding Models
Manifest also supports getting embeddings from models and available APIs. We do this all through changing the `client_name` argument. You still use `run` and `abatch_run`.

@ -0,0 +1,102 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"OPENAI_KEY = \"sk-xxx\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use ChatOpenAI\n",
"\n",
"Set you `OPENAI_API_KEY` environment variable."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from manifest import Manifest\n",
"from manifest.connections.client_pool import ClientConnection\n",
"\n",
"openai_chat = ClientConnection(\n",
" client_name=\"openaichat\",\n",
" client_connection=OPENAI_KEY,\n",
" engine=\"gpt-3.5-turbo\"\n",
")\n",
"\n",
"manifest = Manifest(client_pool=[openai_chat])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The 2020 World Series was played at Globe\n"
]
}
],
"source": [
"# Simple question\n",
"chat_dict = [\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"Who won the world series in 2020?\"},\n",
" {\"role\": \"assistant\", \"content\": \"The Los Angeles Dodgers won the World Series in 2020.\"},\n",
" {\"role\": \"user\", \"content\": \"Where was it played?\"}\n",
"]\n",
"print(manifest.run_chat(chat_dict))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "manifest",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -13,6 +13,7 @@ from tenacity import RetryCallState, retry, stop_after_attempt, wait_random_expo
from manifest.request import (
DEFAULT_REQUEST_KEYS,
NOT_CACHE_KEYS,
LMChatRequest,
LMScoreRequest,
Request,
)
@ -308,7 +309,7 @@ class Client(ABC):
def run_request(self, request: Request) -> Response:
"""
Get request string function.
Run request.
Args:
request: request.
@ -342,7 +343,7 @@ class Client(ABC):
async def arun_batch_request(self, request: Request) -> Response:
"""
Get async request string function.
Run async request.
Args:
request: request.
@ -396,7 +397,39 @@ class Client(ABC):
**RESPONSE_CONSTRUCTORS[self.REQUEST_CLS], # type: ignore
)
def get_score_prompt_request(
def run_chat_request(
self,
request: LMChatRequest,
) -> Response:
"""
Get the response from chat model.
Args:
request: request.
Returns:
response.
"""
request_params = self.get_request_params(request)
# Take the default keys we need and drop the rest as they
# are not part of the model request.
retry_timeout = request_params.pop("client_timeout")
for key in DEFAULT_REQUEST_KEYS:
request_params.pop(key, None)
response_dict = self._run_completion(request_params, retry_timeout)
usages = None
if "usage" in response_dict:
usages = [Usage(**usage) for usage in response_dict["usage"]]
return Response(
response=self.get_model_choices(response_dict),
cached=False,
request=request,
usages=Usages(usages=usages) if usages else None,
**RESPONSE_CONSTRUCTORS[LMChatRequest], # type: ignore
)
def run_score_prompt_request(
self,
request: LMScoreRequest,
) -> Response:
@ -407,8 +440,7 @@ class Client(ABC):
request: request.
Returns:
request function that takes no input.
request parameters as dict.
response.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not support prompt scoring request."

@ -3,7 +3,7 @@ import logging
from typing import Any, Dict, Optional
from manifest.clients.client import Client
from manifest.request import LMRequest, LMScoreRequest, Request
from manifest.request import LMChatRequest, LMRequest, LMScoreRequest, Request
from manifest.response import LMModelChoice, ModelChoices, Response, Usage, Usages
logger = logging.getLogger(__name__)
@ -123,7 +123,53 @@ class DummyClient(Client):
"""
return self.run_request(request)
def get_score_prompt_request(
def run_chat_request(
self,
request: LMChatRequest,
) -> Response:
"""
Get the response from chat model.
Args:
request: request.
Returns:
response.
"""
num_results = 1
response_dict = {
"choices": [
{
"text": request.prompt[0]["content"],
}
for i in range(num_results)
]
}
return Response(
response=ModelChoices(
choices=[
LMModelChoice(**choice) # type: ignore
for choice in response_dict["choices"]
]
),
cached=False,
request=request,
usages=Usages(
usages=[
Usage(
**{
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
}
)
]
),
response_type="text",
request_type=LMChatRequest,
)
def run_score_prompt_request(
self,
request: LMScoreRequest,
) -> Response:

@ -80,7 +80,7 @@ class HuggingFaceClient(Client):
res["client_name"] = self.NAME
return res
def get_score_prompt_request(
def run_score_prompt_request(
self,
request: LMScoreRequest,
) -> Response:

@ -92,10 +92,23 @@ class OpenAIChatClient(OpenAIClient):
request_params = copy.deepcopy(request_params)
prompt = request_params.pop("prompt")
if isinstance(prompt, str):
prompt_list = [prompt]
else:
messages = [{"role": "user", "content": prompt}]
elif isinstance(prompt, list) and isinstance(prompt[0], str):
prompt_list = prompt
messages = [{"role": "user", "content": prompt} for prompt in prompt_list]
messages = [{"role": "user", "content": prompt} for prompt in prompt_list]
elif isinstance(prompt, list) and isinstance(prompt[0], dict):
for pmt_dict in prompt:
if "role" not in pmt_dict or "content" not in pmt_dict:
raise ValueError(
"Prompt must be list of dicts with 'role' and 'content' "
f"keys. Got {prompt}."
)
messages = prompt
else:
raise ValueError(
"Prompt must be string, list of strings, or list of dicts."
f"Got {prompt}"
)
request_params["messages"] = messages
return request_params

@ -17,7 +17,7 @@ from manifest.connections.client_pool import (
ClientConnection,
ClientConnectionPool,
)
from manifest.request import LMScoreRequest, Request
from manifest.request import LMChatRequest, LMScoreRequest, Request
from manifest.response import ModelChoices, Response, Usage, Usages
logging.getLogger("openai").setLevel(logging.WARNING)
@ -142,7 +142,9 @@ class Manifest:
cached_idx_to_response: Dict[int, Response] = {}
new_request = copy.deepcopy(request)
if not overwrite_cache:
if isinstance(new_request.prompt, list):
if isinstance(new_request.prompt, list) and not isinstance(
request, LMChatRequest
):
new_request.prompt = []
for idx, prompt_str in enumerate(request.prompt):
single_request = copy.deepcopy(request)
@ -154,11 +156,21 @@ class Manifest:
cached_idx_to_response[idx] = possible_response
else:
new_request.prompt.append(prompt_str)
else:
# Chat or single string requests are not broken down into
# subprompts for caching.
elif (isinstance(new_request.prompt, str)) or (
isinstance(new_request.prompt, list)
and isinstance(request, LMChatRequest)
):
possible_response = self.cache.get(client.get_cache_key(new_request))
if possible_response:
cached_idx_to_response[0] = possible_response
new_request.prompt = None
else:
raise ValueError(
f"Invalid prompt type: {type(new_request.prompt)}"
f" with request type: {type(request)}"
)
return cached_idx_to_response, new_request
def _stitch_responses_and_cache(
@ -173,16 +185,28 @@ class Manifest:
# cached entries.
all_model_choices = []
all_usages = []
all_input_prompts = []
all_input_prompts: List[Union[str, List[str], List[Dict]]] = []
response_idx = 0
number_prompts = len(cached_idx_to_response)
single_output = False
single_completion_output = False
if response:
if isinstance(response.get_request_obj().prompt, str):
single_output = True
single_completion_output = True
number_prompts += 1
else:
elif isinstance(response.get_request_obj().prompt, list) and not isinstance(
request, LMChatRequest
):
number_prompts += len(response.get_request_obj().prompt)
elif isinstance(response.get_request_obj().prompt, list) and isinstance(
request, LMChatRequest
):
assert len(cached_idx_to_response) <= 1
number_prompts += 1
else:
raise ValueError(
f"Invalid prompt type: {type(response.get_request_obj().prompt)}"
f" with request type: {type(request)}"
)
response_type = None
request_type: Type[Request] = None
for idx in range(number_prompts):
@ -210,10 +234,24 @@ class Manifest:
]
all_model_choices.extend(current_choices)
if isinstance(response.get_request_obj().prompt, list):
prompt = response.get_request_obj().prompt[response_idx]
if isinstance(
response.get_request_obj().prompt, list
) and not isinstance(request, LMChatRequest):
prompt: Union[
str, List[str], List[Dict]
] = response.get_request_obj().prompt[response_idx]
# Chat request
elif isinstance(response.get_request_obj().prompt, list) and isinstance(
request, LMChatRequest
):
# We will only have response_idx == 0 here as we can only
# support single chat requests.
assert request.n == 1
assert number_prompts <= 1
prompt = response.get_request_obj().prompt
else:
prompt = str(response.get_request_obj().prompt)
usages: Optional[List[Usage]] = None
if response.get_usage_obj().usages:
usages = response.get_usage_obj().usages[
@ -223,7 +261,7 @@ class Manifest:
all_input_prompts.append(prompt)
# set cache
new_request = copy.deepcopy(request)
new_request.prompt = prompt
new_request.prompt = prompt # type: ignore
cache_key = client.get_cache_key(new_request)
new_response = copy.deepcopy(response)
new_response._response.choices = current_choices
@ -234,7 +272,7 @@ class Manifest:
new_request = copy.deepcopy(request)
new_request.prompt = (
all_input_prompts # type: ignore
if len(all_input_prompts) > 1 or not single_output
if len(all_input_prompts) > 1 or not single_completion_output
else all_input_prompts[0]
)
response_obj = Response(
@ -426,6 +464,67 @@ class Manifest:
)
return final_response
def run_chat(
self,
prompt: List[Dict[str, str]],
overwrite_cache: bool = False,
return_response: bool = False,
**kwargs: Any,
) -> Union[str, Response]:
"""
Run the prompt.
Args:
prompt: prompt dictionary to run.
overwrite_cache: whether to overwrite cache.
stop_token: stop token for prompt generation.
Default is self.stop_token.
"" for no stop token.
return_response: whether to return Response object.
Returns:
response from prompt.
"""
is_batch = False
# Get the client to run
client = self.client_pool.get_client()
# Get a request for an empty prompt to handle all kwargs
request_params = client.get_request("", kwargs)
# Add prompt and cast as chat request
request_params_dict = request_params.to_dict()
request_params_dict["prompt"] = prompt
request_params_as_chat = LMChatRequest(**request_params_dict)
# Avoid nested list of results - enforce n = 1 for batch
if request_params_as_chat.n > 1:
raise ValueError("Chat mode does not support n > 1.")
self._validate_kwargs(kwargs, request_params_as_chat)
cached_idx_to_response, request_params_as_chat = self._split_cached_requests( # type: ignore # noqa: E501
request_params_as_chat, client, overwrite_cache
)
# If not None value or empty list - run new request
if request_params_as_chat.prompt:
# Start timing metrics
self.client_pool.start_timer()
response = client.run_chat_request(request_params_as_chat)
self.client_pool.end_timer()
else:
# Nothing to run
response = None
final_response = self._stitch_responses_and_cache(
request=request_params_as_chat,
client=client,
response=response,
cached_idx_to_response=cached_idx_to_response,
)
# Extract text results
if return_response:
return final_response
else:
return cast(str, final_response.get_response("", is_batch))
def score_prompt(
self,
prompt: Union[str, List[str]],
@ -459,7 +558,7 @@ class Manifest:
# If not None value or empty list - run new request
if request_params_as_score.prompt:
try:
response = cast(HuggingFaceClient, client).get_score_prompt_request(
response = cast(HuggingFaceClient, client).run_score_prompt_request(
request_params_as_score
)
except AttributeError:

@ -101,6 +101,12 @@ class LMRequest(Request):
frequency_penalty: float = 0
class LMChatRequest(LMRequest):
"""Language Model Chat Request object."""
prompt: List[Dict[str, str]] = {} # type: ignore
class LMScoreRequest(LMRequest):
"""Language Model Score Request object."""

@ -10,6 +10,7 @@ from manifest.request import (
ENGINE_SEP,
DiffusionRequest,
EmbeddingRequest,
LMChatRequest,
LMRequest,
LMScoreRequest,
Request,
@ -17,6 +18,7 @@ from manifest.request import (
RESPONSE_CONSTRUCTORS: Dict[Type[Request], Dict[str, Union[str, Type[Request]]]] = {
LMRequest: {"response_type": "text", "request_type": LMRequest},
LMChatRequest: {"response_type": "text", "request_type": LMChatRequest},
LMScoreRequest: {"response_type": "text", "request_type": LMScoreRequest},
EmbeddingRequest: {"response_type": "array", "request_type": EmbeddingRequest},
DiffusionRequest: {"response_type": "array", "request_type": DiffusionRequest},
@ -291,6 +293,8 @@ class Response:
response_type = response_dict["response_type"]
if response_dict["request_type"] == "LMRequest":
request_type: Type[Request] = LMRequest
elif response_dict["request_type"] == "LMChatRequest":
request_type = LMChatRequest
elif response_dict["request_type"] == "LMScoreRequest":
request_type = LMScoreRequest
elif response_dict["request_type"] == "EmbeddingRequest":

@ -391,6 +391,58 @@ def test_abatch_run(sqlite_cache: str) -> None:
assert res == ["he", "he"]
@pytest.mark.usefixtures("sqlite_cache")
def test_run_chat(sqlite_cache: str) -> None:
"""Test manifest run."""
manifest = Manifest(
client_name="dummy",
cache_name="sqlite",
cache_connection=sqlite_cache,
)
prompt = [
{"role": "system", "content": "Hello."},
]
result = manifest.run_chat(prompt, return_response=False)
assert result == "Hello."
assert (
manifest.cache.get(
{
"prompt": [{"content": "Hello.", "role": "system"}],
"engine": "dummy",
"num_results": 1,
"request_cls": "LMChatRequest",
},
)
is not None
)
prompt = [
{"role": "system", "content": "Hello."},
{"role": "user", "content": "Goodbye?"},
]
result = manifest.run_chat(prompt, return_response=True)
assert isinstance(result, Response)
result = cast(Response, result)
assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices)
res = result.get_response()
assert res == "Hello."
assert (
manifest.cache.get(
{
"prompt": [
{"role": "system", "content": "Hello."},
{"role": "user", "content": "Goodbye?"},
],
"engine": "dummy",
"num_results": 1,
"request_cls": "LMChatRequest",
},
)
is not None
)
@pytest.mark.usefixtures("sqlite_cache")
def test_score_run(sqlite_cache: str) -> None:
"""Test manifest run."""
@ -788,6 +840,32 @@ def test_openaichat(sqlite_cache: str) -> None:
)
assert response.is_cached() is True
chat_dict = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who won the world series in 2020?"},
{
"role": "assistant",
"content": "The Los Angeles Dodgers won the World Series in 2020.",
},
{"role": "user", "content": "Where was it played?"},
]
res = client.run_chat(chat_dict)
assert isinstance(res, str) and len(res) > 0
response = cast(Response, client.run_chat(chat_dict, return_response=True))
assert response.is_cached() is True
assert response.get_usage_obj().usages[0].total_tokens == 67
chat_dict = [
{"role": "system", "content": "You are a helpful assistanttttt."},
{"role": "user", "content": "Who won the world series in 2020?"},
{
"role": "assistant",
"content": "The Los Angeles Dodgers won the World Series in 2020.",
},
{"role": "user", "content": "Where was it played?"},
]
response = cast(Response, client.run_chat(chat_dict, return_response=True))
assert response.is_cached() is False
@pytest.mark.skipif(not OPENAI_ALIVE, reason="No openai key set")
@pytest.mark.usefixtures("sqlite_cache")

Loading…
Cancel
Save