|
|
|
@ -1,5 +1,6 @@
|
|
|
|
|
"""Manifest test."""
|
|
|
|
|
import asyncio
|
|
|
|
|
import os
|
|
|
|
|
from typing import cast
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
@ -17,6 +18,8 @@ try:
|
|
|
|
|
except Exception:
|
|
|
|
|
MODEL_ALIVE = False
|
|
|
|
|
|
|
|
|
|
OPENAI_ALIVE = os.environ.get("OPENAI_API_KEY") is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.usefixtures("sqlite_cache")
|
|
|
|
|
def test_init(sqlite_cache: str) -> None:
|
|
|
|
@ -104,7 +107,11 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
|
|
|
|
|
result = manifest.run(prompt, return_response=return_response)
|
|
|
|
|
if return_response:
|
|
|
|
|
assert isinstance(result, Response)
|
|
|
|
|
res = cast(Response, result).get_response(manifest.stop_token)
|
|
|
|
|
result = cast(Response, result)
|
|
|
|
|
assert len(result.get_json_response()["usage"]) == len(
|
|
|
|
|
result.get_json_response()["choices"]
|
|
|
|
|
)
|
|
|
|
|
res = result.get_response(manifest.stop_token)
|
|
|
|
|
else:
|
|
|
|
|
res = cast(str, result)
|
|
|
|
|
assert (
|
|
|
|
@ -126,7 +133,11 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
|
|
|
|
|
result = manifest.run(prompt, run_id="34", return_response=return_response)
|
|
|
|
|
if return_response:
|
|
|
|
|
assert isinstance(result, Response)
|
|
|
|
|
res = cast(Response, result).get_response(manifest.stop_token)
|
|
|
|
|
result = cast(Response, result)
|
|
|
|
|
assert len(result.get_json_response()["usage"]) == len(
|
|
|
|
|
result.get_json_response()["choices"]
|
|
|
|
|
)
|
|
|
|
|
res = result.get_response(manifest.stop_token)
|
|
|
|
|
else:
|
|
|
|
|
res = cast(str, result)
|
|
|
|
|
assert (
|
|
|
|
@ -149,7 +160,11 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
|
|
|
|
|
result = manifest.run(prompt, return_response=return_response)
|
|
|
|
|
if return_response:
|
|
|
|
|
assert isinstance(result, Response)
|
|
|
|
|
res = cast(Response, result).get_response(manifest.stop_token)
|
|
|
|
|
result = cast(Response, result)
|
|
|
|
|
assert len(result.get_json_response()["usage"]) == len(
|
|
|
|
|
result.get_json_response()["choices"]
|
|
|
|
|
)
|
|
|
|
|
res = result.get_response(manifest.stop_token)
|
|
|
|
|
else:
|
|
|
|
|
res = cast(str, result)
|
|
|
|
|
assert (
|
|
|
|
@ -171,7 +186,11 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
|
|
|
|
|
result = manifest.run(prompt, stop_token="ll", return_response=return_response)
|
|
|
|
|
if return_response:
|
|
|
|
|
assert isinstance(result, Response)
|
|
|
|
|
res = cast(Response, result).get_response(stop_token="ll")
|
|
|
|
|
result = cast(Response, result)
|
|
|
|
|
assert len(result.get_json_response()["usage"]) == len(
|
|
|
|
|
result.get_json_response()["choices"]
|
|
|
|
|
)
|
|
|
|
|
res = result.get_response(stop_token="ll")
|
|
|
|
|
else:
|
|
|
|
|
res = cast(str, result)
|
|
|
|
|
assert (
|
|
|
|
@ -209,9 +228,12 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
|
|
|
|
|
else:
|
|
|
|
|
result = manifest.run(prompt, return_response=return_response)
|
|
|
|
|
if return_response:
|
|
|
|
|
res = cast(Response, result).get_response(
|
|
|
|
|
manifest.stop_token, is_batch=True
|
|
|
|
|
assert isinstance(result, Response)
|
|
|
|
|
result = cast(Response, result)
|
|
|
|
|
assert len(result.get_json_response()["usage"]) == len(
|
|
|
|
|
result.get_json_response()["choices"]
|
|
|
|
|
)
|
|
|
|
|
res = result.get_response(manifest.stop_token, is_batch=True)
|
|
|
|
|
else:
|
|
|
|
|
res = cast(str, result)
|
|
|
|
|
assert res == ["hello"]
|
|
|
|
@ -229,9 +251,12 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
|
|
|
|
|
prompt = ["Hello is a prompt", "Hello is a prompt"]
|
|
|
|
|
result = manifest.run(prompt, return_response=return_response)
|
|
|
|
|
if return_response:
|
|
|
|
|
res = cast(Response, result).get_response(
|
|
|
|
|
manifest.stop_token, is_batch=True
|
|
|
|
|
assert isinstance(result, Response)
|
|
|
|
|
result = cast(Response, result)
|
|
|
|
|
assert len(result.get_json_response()["usage"]) == len(
|
|
|
|
|
result.get_json_response()["choices"]
|
|
|
|
|
)
|
|
|
|
|
res = result.get_response(manifest.stop_token, is_batch=True)
|
|
|
|
|
else:
|
|
|
|
|
res = cast(str, result)
|
|
|
|
|
assert res == ["hello", "hello"]
|
|
|
|
@ -263,11 +288,14 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
|
|
|
|
|
prompt = ["This is a prompt", "New prompt"]
|
|
|
|
|
result = manifest.run(prompt, return_response=return_response)
|
|
|
|
|
if return_response:
|
|
|
|
|
res = cast(Response, result).get_response(
|
|
|
|
|
manifest.stop_token, is_batch=True
|
|
|
|
|
assert isinstance(result, Response)
|
|
|
|
|
result = cast(Response, result)
|
|
|
|
|
assert len(result.get_json_response()["usage"]) == len(
|
|
|
|
|
result.get_json_response()["choices"]
|
|
|
|
|
)
|
|
|
|
|
res = result.get_response(manifest.stop_token, is_batch=True)
|
|
|
|
|
# Cached because one item is in cache
|
|
|
|
|
assert cast(Response, result).is_cached()
|
|
|
|
|
assert result.is_cached()
|
|
|
|
|
else:
|
|
|
|
|
res = cast(str, result)
|
|
|
|
|
assert res == ["hello", "hello"]
|
|
|
|
@ -275,7 +303,12 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
|
|
|
|
|
prompt = ["Hello is a prompt", "Hello is a prompt"]
|
|
|
|
|
result = manifest.run(prompt, stop_token="ll", return_response=return_response)
|
|
|
|
|
if return_response:
|
|
|
|
|
res = cast(Response, result).get_response(stop_token="ll", is_batch=True)
|
|
|
|
|
assert isinstance(result, Response)
|
|
|
|
|
result = cast(Response, result)
|
|
|
|
|
assert len(result.get_json_response()["usage"]) == len(
|
|
|
|
|
result.get_json_response()["choices"]
|
|
|
|
|
)
|
|
|
|
|
res = result.get_response(stop_token="ll", is_batch=True)
|
|
|
|
|
else:
|
|
|
|
|
res = cast(str, result)
|
|
|
|
|
assert res == ["he", "he"]
|
|
|
|
@ -290,9 +323,14 @@ def test_abatch_run(sqlite_cache: str) -> None:
|
|
|
|
|
cache_connection=sqlite_cache,
|
|
|
|
|
)
|
|
|
|
|
prompt = ["This is a prompt"]
|
|
|
|
|
result = asyncio.run(manifest.arun_batch(prompt, return_response=True))
|
|
|
|
|
result = cast(
|
|
|
|
|
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True))
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
res = cast(Response, result).get_response(manifest.stop_token, is_batch=True)
|
|
|
|
|
assert len(result.get_json_response()["usage"]) == len(
|
|
|
|
|
result.get_json_response()["choices"]
|
|
|
|
|
)
|
|
|
|
|
res = result.get_response(manifest.stop_token, is_batch=True)
|
|
|
|
|
assert res == ["hello"]
|
|
|
|
|
assert (
|
|
|
|
|
manifest.cache.get(
|
|
|
|
@ -306,8 +344,14 @@ def test_abatch_run(sqlite_cache: str) -> None:
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
prompt = ["Hello is a prompt", "Hello is a prompt"]
|
|
|
|
|
result = asyncio.run(manifest.arun_batch(prompt, return_response=True))
|
|
|
|
|
res = cast(Response, result).get_response(manifest.stop_token, is_batch=True)
|
|
|
|
|
result = cast(
|
|
|
|
|
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True))
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert len(result.get_json_response()["usage"]) == len(
|
|
|
|
|
result.get_json_response()["choices"]
|
|
|
|
|
)
|
|
|
|
|
res = result.get_response(manifest.stop_token, is_batch=True)
|
|
|
|
|
assert res == ["hello", "hello"]
|
|
|
|
|
assert (
|
|
|
|
|
manifest.cache.get(
|
|
|
|
@ -320,9 +364,15 @@ def test_abatch_run(sqlite_cache: str) -> None:
|
|
|
|
|
is not None
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
result = asyncio.run(manifest.arun_batch(prompt, return_response=True))
|
|
|
|
|
res = cast(Response, result).get_response(manifest.stop_token, is_batch=True)
|
|
|
|
|
assert cast(Response, result).is_cached()
|
|
|
|
|
result = cast(
|
|
|
|
|
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True))
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert len(result.get_json_response()["usage"]) == len(
|
|
|
|
|
result.get_json_response()["choices"]
|
|
|
|
|
)
|
|
|
|
|
res = result.get_response(manifest.stop_token, is_batch=True)
|
|
|
|
|
assert result.is_cached()
|
|
|
|
|
|
|
|
|
|
assert (
|
|
|
|
|
manifest.cache.get(
|
|
|
|
@ -335,15 +385,27 @@ def test_abatch_run(sqlite_cache: str) -> None:
|
|
|
|
|
is None
|
|
|
|
|
)
|
|
|
|
|
prompt = ["This is a prompt", "New prompt"]
|
|
|
|
|
result = asyncio.run(manifest.arun_batch(prompt, return_response=True))
|
|
|
|
|
res = cast(Response, result).get_response(manifest.stop_token, is_batch=True)
|
|
|
|
|
result = cast(
|
|
|
|
|
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True))
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert len(result.get_json_response()["usage"]) == len(
|
|
|
|
|
result.get_json_response()["choices"]
|
|
|
|
|
)
|
|
|
|
|
res = result.get_response(manifest.stop_token, is_batch=True)
|
|
|
|
|
# Cached because one item is in cache
|
|
|
|
|
assert cast(Response, result).is_cached()
|
|
|
|
|
assert result.is_cached()
|
|
|
|
|
assert res == ["hello", "hello"]
|
|
|
|
|
|
|
|
|
|
prompt = ["Hello is a prompt", "Hello is a prompt"]
|
|
|
|
|
result = asyncio.run(manifest.arun_batch(prompt, return_response=True))
|
|
|
|
|
res = cast(Response, result).get_response(stop_token="ll", is_batch=True)
|
|
|
|
|
result = cast(
|
|
|
|
|
Response, asyncio.run(manifest.arun_batch(prompt, return_response=True))
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert len(result.get_json_response()["usage"]) == len(
|
|
|
|
|
result.get_json_response()["choices"]
|
|
|
|
|
)
|
|
|
|
|
res = result.get_response(stop_token="ll", is_batch=True)
|
|
|
|
|
assert res == ["he", "he"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -484,3 +546,138 @@ def test_local_huggingface(sqlite_cache: str) -> None:
|
|
|
|
|
assert len(scores["response"]["choices"][0]["token_logprobs"]) == len(
|
|
|
|
|
scores["response"]["choices"][0]["tokens"]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(not OPENAI_ALIVE, reason="No openai key set")
|
|
|
|
|
@pytest.mark.usefixtures("sqlite_cache")
|
|
|
|
|
def test_openai(sqlite_cache: str) -> None:
|
|
|
|
|
"""Test openai client."""
|
|
|
|
|
client = Manifest(
|
|
|
|
|
client_name="openai",
|
|
|
|
|
engine="text-ada-001",
|
|
|
|
|
cache_name="sqlite",
|
|
|
|
|
cache_connection=sqlite_cache,
|
|
|
|
|
temperature=0.0,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
res = client.run("Why are there apples?")
|
|
|
|
|
assert isinstance(res, str) and len(res) > 0
|
|
|
|
|
|
|
|
|
|
response = cast(Response, client.run("Why are there apples?", return_response=True))
|
|
|
|
|
assert isinstance(response.get_response(), str) and len(response.get_response()) > 0
|
|
|
|
|
assert response.is_cached() is True
|
|
|
|
|
assert "usage" in response.get_json_response()
|
|
|
|
|
assert response.get_json_response()["usage"][0]["total_tokens"] == 15
|
|
|
|
|
|
|
|
|
|
response = cast(Response, client.run("Why are there apples?", return_response=True))
|
|
|
|
|
assert response.is_cached() is True
|
|
|
|
|
|
|
|
|
|
res_list = client.run(["Why are there apples?", "Why are there bananas?"])
|
|
|
|
|
assert isinstance(res_list, list) and len(res_list) == 2
|
|
|
|
|
|
|
|
|
|
response = cast(
|
|
|
|
|
Response,
|
|
|
|
|
client.run(
|
|
|
|
|
["Why are there apples?", "Why are there mangos?"], return_response=True
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
assert (
|
|
|
|
|
isinstance(response.get_response(), list) and len(response.get_response()) == 2
|
|
|
|
|
)
|
|
|
|
|
assert (
|
|
|
|
|
"usage" in response.get_json_response()
|
|
|
|
|
and len(response.get_json_response()["usage"]) == 2
|
|
|
|
|
)
|
|
|
|
|
assert response.get_json_response()["usage"][0]["total_tokens"] == 15
|
|
|
|
|
assert response.get_json_response()["usage"][1]["total_tokens"] == 16
|
|
|
|
|
|
|
|
|
|
response = cast(
|
|
|
|
|
Response, client.run("Why are there bananas?", return_response=True)
|
|
|
|
|
)
|
|
|
|
|
assert response.is_cached() is True
|
|
|
|
|
|
|
|
|
|
res_list = asyncio.run(
|
|
|
|
|
client.arun_batch(["Why are there pears?", "Why are there oranges?"])
|
|
|
|
|
)
|
|
|
|
|
assert isinstance(res_list, list) and len(res_list) == 2
|
|
|
|
|
|
|
|
|
|
response = cast(
|
|
|
|
|
Response,
|
|
|
|
|
asyncio.run(
|
|
|
|
|
client.arun_batch(
|
|
|
|
|
["Why are there pinenuts?", "Why are there cocoa?"],
|
|
|
|
|
return_response=True,
|
|
|
|
|
)
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
assert (
|
|
|
|
|
isinstance(response.get_response(), list) and len(response.get_response()) == 2
|
|
|
|
|
)
|
|
|
|
|
assert (
|
|
|
|
|
"usage" in response.get_json_response()
|
|
|
|
|
and len(response.get_json_response()["usage"]) == 2
|
|
|
|
|
)
|
|
|
|
|
assert response.get_json_response()["usage"][0]["total_tokens"] == 17
|
|
|
|
|
assert response.get_json_response()["usage"][1]["total_tokens"] == 15
|
|
|
|
|
|
|
|
|
|
response = cast(
|
|
|
|
|
Response, client.run("Why are there oranges?", return_response=True)
|
|
|
|
|
)
|
|
|
|
|
assert response.is_cached() is True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(not OPENAI_ALIVE, reason="No openai key set")
|
|
|
|
|
@pytest.mark.usefixtures("sqlite_cache")
|
|
|
|
|
def test_openaichat(sqlite_cache: str) -> None:
|
|
|
|
|
"""Test openaichat client."""
|
|
|
|
|
client = Manifest(
|
|
|
|
|
client_name="openaichat",
|
|
|
|
|
cache_name="sqlite",
|
|
|
|
|
cache_connection=sqlite_cache,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
res = client.run("Why are there apples?")
|
|
|
|
|
assert isinstance(res, str) and len(res) > 0
|
|
|
|
|
|
|
|
|
|
response = cast(Response, client.run("Why are there apples?", return_response=True))
|
|
|
|
|
assert isinstance(response.get_response(), str) and len(response.get_response()) > 0
|
|
|
|
|
assert response.is_cached() is True
|
|
|
|
|
assert "usage" in response.get_json_response()
|
|
|
|
|
assert response.get_json_response()["usage"][0]["total_tokens"] == 22
|
|
|
|
|
|
|
|
|
|
response = cast(Response, client.run("Why are there apples?", return_response=True))
|
|
|
|
|
assert response.is_cached() is True
|
|
|
|
|
|
|
|
|
|
response = cast(
|
|
|
|
|
Response, client.run("Why are there oranges?", return_response=True)
|
|
|
|
|
)
|
|
|
|
|
assert response.is_cached() is False
|
|
|
|
|
|
|
|
|
|
res_list = asyncio.run(
|
|
|
|
|
client.arun_batch(["Why are there pears?", "Why are there oranges?"])
|
|
|
|
|
)
|
|
|
|
|
assert isinstance(res_list, list) and len(res_list) == 2
|
|
|
|
|
|
|
|
|
|
response = cast(
|
|
|
|
|
Response,
|
|
|
|
|
asyncio.run(
|
|
|
|
|
client.arun_batch(
|
|
|
|
|
["Why are there pinenuts?", "Why are there cocoa?"],
|
|
|
|
|
return_response=True,
|
|
|
|
|
)
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
assert (
|
|
|
|
|
isinstance(response.get_response(), list) and len(response.get_response()) == 2
|
|
|
|
|
)
|
|
|
|
|
assert (
|
|
|
|
|
"usage" in response.get_json_response()
|
|
|
|
|
and len(response.get_json_response()["usage"]) == 2
|
|
|
|
|
)
|
|
|
|
|
assert response.get_json_response()["usage"][0]["total_tokens"] == 24
|
|
|
|
|
assert response.get_json_response()["usage"][1]["total_tokens"] == 22
|
|
|
|
|
|
|
|
|
|
response = cast(
|
|
|
|
|
Response, client.run("Why are there oranges?", return_response=True)
|
|
|
|
|
)
|
|
|
|
|
assert response.is_cached() is True
|
|
|
|
|