mirror of https://github.com/xtekky/gpt4free
fix error to make it only works for Ollama to ignore ":"
parent
6fb0dedfd5
commit
7ca2eda9c2
@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from ..helper import filter_none
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, FinishReason
|
||||
from ...typing import Union, Optional, AsyncResult, Messages
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...errors import MissingAuthError, ResponseError
|
||||
|
||||
class Ollama(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
label = "Ollama"
|
||||
url = "http://localhost:11434"
|
||||
working = True
|
||||
needs_auth = True
|
||||
supports_message_history = True
|
||||
supports_system_message = True
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
proxy: str = None,
|
||||
timeout: int = 120,
|
||||
api_key: str = 'Ollama',
|
||||
api_base: str = "http://localhost:11434/v1/",
|
||||
temperature: float = None,
|
||||
max_tokens: int = None,
|
||||
top_p: float = None,
|
||||
stop: Union[str, list[str]] = None,
|
||||
stream: bool = False,
|
||||
headers: dict = None,
|
||||
extra_data: dict = {},
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
if not model:
|
||||
model='phi3'
|
||||
if cls.needs_auth and api_key is None:
|
||||
raise MissingAuthError('Add a "api_key"')
|
||||
async with StreamSession(
|
||||
proxies={"all": proxy},
|
||||
headers=cls.get_headers(stream, api_key, headers),
|
||||
timeout=timeout
|
||||
) as session:
|
||||
data = filter_none(
|
||||
messages=messages,
|
||||
model=cls.get_model(model),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
**extra_data
|
||||
)
|
||||
|
||||
async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response:
|
||||
await raise_for_status(response)
|
||||
if not stream:
|
||||
data = await response.json()
|
||||
cls.raise_error(data)
|
||||
choice = data["choices"][0]
|
||||
if "content" in choice["message"]:
|
||||
yield choice["message"]["content"].strip()
|
||||
finish = cls.read_finish_reason(choice)
|
||||
if finish is not None:
|
||||
yield finish
|
||||
else:
|
||||
first = True
|
||||
async for line in response.iter_lines():
|
||||
if line.startswith(b"data: "):
|
||||
chunk = line[6:]
|
||||
if chunk == b"[DONE]":
|
||||
break
|
||||
data = json.loads(chunk)
|
||||
cls.raise_error(data)
|
||||
choice = data["choices"][0]
|
||||
if "content" in choice["delta"] and choice["delta"]["content"]:
|
||||
delta = choice["delta"]["content"]
|
||||
if first:
|
||||
delta = delta.lstrip()
|
||||
if delta:
|
||||
first = False
|
||||
yield delta
|
||||
finish = cls.read_finish_reason(choice)
|
||||
if finish is not None:
|
||||
yield finish
|
||||
|
||||
@staticmethod
|
||||
def read_finish_reason(choice: dict) -> Optional[FinishReason]:
|
||||
if "finish_reason" in choice and choice["finish_reason"] is not None:
|
||||
return FinishReason(choice["finish_reason"])
|
||||
|
||||
@staticmethod
|
||||
def raise_error(data: dict):
|
||||
if "error_message" in data:
|
||||
raise ResponseError(data["error_message"])
|
||||
elif "error" in data:
|
||||
raise ResponseError(f'Error {data["error"]["code"]}: {data["error"]["message"]}')
|
||||
|
||||
@classmethod
|
||||
def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
|
||||
return {
|
||||
"Accept": "text/event-stream" if stream else "application/json",
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{"Authorization": f"Bearer {api_key}"}
|
||||
if cls.needs_auth and api_key is not None
|
||||
else {}
|
||||
),
|
||||
**({} if headers is None else headers)
|
||||
}
|
Loading…
Reference in New Issue