|
|
|
@ -27,7 +27,10 @@ class Pi(BaseProvider):
|
|
|
|
|
session = get_session_from_browser(url=cls.url, proxy=proxy, timeout=timeout)
|
|
|
|
|
if not conversation_id:
|
|
|
|
|
conversation_id = cls.start_conversation(session)
|
|
|
|
|
answer = cls.ask(session, messages, conversation_id)
|
|
|
|
|
prompt = format_prompt(messages)
|
|
|
|
|
else:
|
|
|
|
|
prompt = messages[-1]["content"]
|
|
|
|
|
answer = cls.ask(session, prompt, conversation_id)
|
|
|
|
|
for line in answer:
|
|
|
|
|
if "text" in line:
|
|
|
|
|
yield line["text"]
|
|
|
|
@ -51,9 +54,9 @@ class Pi(BaseProvider):
|
|
|
|
|
raise RuntimeError('Error: Cloudflare detected')
|
|
|
|
|
return response.json()
|
|
|
|
|
|
|
|
|
|
def ask(session: Session, messages: Messages, conversation_id: str):
|
|
|
|
|
def ask(session: Session, prompt: str, conversation_id: str):
|
|
|
|
|
json_data = {
|
|
|
|
|
'text': format_prompt(messages),
|
|
|
|
|
'text': prompt,
|
|
|
|
|
'conversation': conversation_id,
|
|
|
|
|
'mode': 'BASE',
|
|
|
|
|
}
|
|
|
|
|