|
|
|
@ -1,4 +1,5 @@
|
|
|
|
|
import os
|
|
|
|
|
import sys
|
|
|
|
|
import openai
|
|
|
|
|
from tenacity import (
|
|
|
|
|
retry,
|
|
|
|
@ -6,7 +7,12 @@ from tenacity import (
|
|
|
|
|
wait_random_exponential, # type: ignore
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from typing import Optional, List, Union, Literal
|
|
|
|
|
from typing import Optional, List
|
|
|
|
|
if sys.version_info >= (3, 7):
|
|
|
|
|
from typing import Literal
|
|
|
|
|
else:
|
|
|
|
|
from typing_extensions import Literal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Model = Literal["gpt-4", "gpt-3.5-turbo", "text-davinci-003"]
|
|
|
|
|
|
|
|
|
@ -41,4 +47,4 @@ def get_chat(prompt: str, model: Model, max_tokens: int = 256, stop_strs: Option
|
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
|
stop=stop_strs,
|
|
|
|
|
)
|
|
|
|
|
return response.choices[0].message.content
|
|
|
|
|
return response.choices[0].message.content
|
|
|
|
|