feat: openai-compatible platforms share the same client (#469)

pull/470/head
sigoden 3 weeks ago committed by GitHub
parent 8a65337d59
commit 8dba46becf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -17,10 +17,10 @@ test-init-config() {
cargo run -- "$@"
}
# @cmd Test running without the config file
# @env AICHAT_CLIENT_TYPE!
# @cmd Test running with AICHAT_PLATFORM environment varialbe
# @env AICHAT_PLATFORM!
# @arg args~
test-without-config() {
test-platform-env() {
cargo run -- "$@"
}
@ -50,44 +50,85 @@ test-server() {
"$@"
}
OPEIA_COMPATIBLE_CLIENTS=( \
openai,gpt-3.5-turbo,https://api.openai.com/v1 \
anyscale,meta-llama/Meta-Llama-3-8B-Instruct,https://api.endpoints.anyscale.com/v1 \
deepinfra,meta-llama/Meta-Llama-3-8B-Instruct,https://api.deepinfra.com/v1/openai \
fireworks,accounts/fireworks/models/llama-v3-8b-instruct,https://api.fireworks.ai/inference/v1 \
groq,llama3-8b-8192,https://api.groq.com/openai/v1 \
mistral,mistral-small-latest,https://api.mistral.ai/v1 \
moonshot,moonshot-v1-8k,https://api.moonshot.cn/v1 \
openrouter,meta-llama/llama-3-8b-instruct,https://openrouter.ai/api/v1 \
octoai,meta-llama-3-8b-instruct,https://text.octoai.run/v1 \
perplexity,llama-3-8b-instruct,https://api.perplexity.ai \
together,meta-llama/Llama-3-8b-chat-hf,https://api.together.xyz/v1 \
)
# @cmd Chat with openai-comptabile api
# @option --api-base! $$
# @option --api-key! $$
# @option -m --model! $$
# @flag -S --no-stream
# @arg platform![`_choice_platform`]
# @arg text~
chat-llm() {
curl_args="$CURL_ARGS"
_openai_chat "$@"
chat() {
for client_config in "${OPEIA_COMPATIBLE_CLIENTS[@]}"; do
if [[ "$argc_platform" == "${client_config%%,*}" ]]; then
api_base="${client_config##*,}"
break
fi
done
if [[ -n "$api_base" ]]; then
env_prefix="$(echo "$argc_platform" | tr '[:lower:]' '[:upper:]')"
api_key_env="${env_prefix}_API_KEY"
api_key="${!api_key_env}"
if [[ -z "$model" ]]; then
model="$(echo "$client_config" | cut -d, -f2)"
fi
if [[ -z "$model" ]]; then
model_env="${env_prefix}_MODEL"
model="${!model_env}"
fi
argc chat-openai-comptabile \
--api-base "$api_base" \
--api-key "$api_key" \
--model "$model" \
"${argc_text[@]}"
else
argc chat-$argc_platform "${argc_text[@]}"
fi
}
# @cmd List models by openai-comptabile api
# @option --api-base! $$
# @option --api-key! $$
models-llm() {
curl_args="$CURL_ARGS"
_openai_models
# @arg platform![`_choice_platform`]
models() {
for client_config in "${OPEIA_COMPATIBLE_CLIENTS[@]}"; do
if [[ "$argc_platform" == "${client_config%%,*}" ]]; then
api_base="${client_config##*,}"
break
fi
done
if [[ -n "$api_base" ]]; then
env_prefix="$(echo "$argc_platform" | tr '[:lower:]' '[:upper:]')"
api_key_env="${env_prefix}_API_KEY"
api_key="${!api_key_env}"
_openai_models
else
argc models-$argc_platform
fi
}
# @cmd Chat with openai api
# @env OPENAI_API_KEY!
# @option -m --model=gpt-3.5-turbo $OPENAI_MODEL
# @cmd Chat with openai-comptabile api
# @option --api-base! $$
# @option --api-key! $$
# @option -m --model! $$
# @flag -S --no-stream
# @arg text~
chat-openai() {
api_base=https://api.openai.com/v1
api_key=$OPENAI_API_KEY
curl_args="-i $OPENAI_CURL_ARGS"
chat-openai-comptabile() {
_openai_chat "$@"
}
# @cmd List openai models
# @env OPENAI_API_KEY!
models-openai() {
api_base=https://api.openai.com/v1
api_key=$OPENAI_API_KEY
curl_args="$OPENAI_CURL_ARGS"
# @cmd List models by openai-comptabile api
# @option --api-base! $$
# @option --api-key! $$
models-openai-comptabile() {
_openai_models
}
@ -101,7 +142,7 @@ chat-gemini() {
if [[ -n "$argc_no_stream" ]]; then
method="generateContent"
fi
_wrapper curl -i $GEMINI_CURL_ARGS "https://generativelanguage.googleapis.com/v1beta/models/${argc_model}:${method}?key=${GEMINI_API_KEY}" \
_wrapper curl -i "https://generativelanguage.googleapis.com/v1beta/models/${argc_model}:${method}?key=${GEMINI_API_KEY}" \
-i -X POST \
-H 'Content-Type: application/json' \
-d '{
@ -113,7 +154,7 @@ chat-gemini() {
# @cmd List gemini models
# @env GEMINI_API_KEY!
models-gemini() {
_wrapper curl $GEMINI_CURL_ARGS "https://generativelanguage.googleapis.com/v1beta/models?key=${GEMINI_API_KEY}" \
_wrapper curl "https://generativelanguage.googleapis.com/v1beta/models?key=${GEMINI_API_KEY}" \
-H 'Content-Type: application/json' \
}
@ -124,7 +165,7 @@ models-gemini() {
# @flag -S --no-stream
# @arg text~
chat-claude() {
_wrapper curl -i $CLAUDE_CURL_ARGS https://api.anthropic.com/v1/messages \
_wrapper curl -i https://api.anthropic.com/v1/messages \
-X POST \
-H 'content-type: application/json' \
-H 'anthropic-version: 2023-06-01' \
@ -138,34 +179,13 @@ chat-claude() {
'
}
# @cmd Chat with mistral api
# @env MISTRAL_API_KEY!
# @option -m --model=mistral-small-latest $MISTRAL_MODEL
# @flag -S --no-stream
# @arg text~
chat-mistral() {
api_base=https://api.mistral.ai/v1
api_key=$MISTRAL_API_KEY
curl_args="$MISTRAL_CURL_ARGS"
_openai_chat "$@"
}
# @cmd List mistral models
# @env MISTRAL_API_KEY!
models-mistral() {
api_base=https://api.mistral.ai/v1
api_key=$MISTRAL_API_KEY
curl_args="$MISTRAL_CURL_ARGS"
_openai_models
}
# @cmd Chat with cohere api
# @env COHERE_API_KEY!
# @option -m --model=command-r $COHERE_MODEL
# @flag -S --no-stream
# @arg text~
chat-cohere() {
_wrapper curl -i $COHERE_CURL_ARGS https://api.cohere.ai/v1/chat \
_wrapper curl -i https://api.cohere.ai/v1/chat \
-X POST \
-H 'Content-Type: application/json' \
-H "Authorization: Bearer $COHERE_API_KEY" \
@ -180,50 +200,17 @@ chat-cohere() {
# @cmd List cohere models
# @env COHERE_API_KEY!
models-cohere() {
_wrapper curl $COHERE_CURL_ARGS https://api.cohere.ai/v1/models \
_wrapper curl https://api.cohere.ai/v1/models \
-H "Authorization: Bearer $COHERE_API_KEY" \
}
# @cmd Chat with perplexity api
# @env PERPLEXITY_API_KEY!
# @option -m --model=sonar-small-chat $PERPLEXITY_MODEL
# @flag -S --no-stream
# @arg text~
chat-perplexity() {
api_base=https://api.perplexity.ai
api_key=$PERPLEXITY_API_KEY
curl_args="$PERPLEXITY_CURL_ARGS"
_openai_chat "$@"
}
# @cmd Chat with groq api
# @env GROQ_API_KEY!
# @option -m --model=llama3-70b-8192 $GROQ_MODEL
# @flag -S --no-stream
# @arg text~
chat-groq() {
api_base=https://api.groq.com/openai/v1
api_key=$GROQ_API_KEY
curl_args="$GROQ_CURL_ARGS"
_openai_chat "$@"
}
# @cmd List groq models
# @env GROQ_API_KEY!
models-groq() {
api_base=https://api.groq.com/openai/v1
api_key=$GROQ_API_KEY
curl_args="$GROQ_CURL_ARGS"
_openai_models
}
# @cmd Chat with ollama api
# @option -m --model=codegemma $OLLAMA_MODEL
# @flag -S --no-stream
# @arg text~
chat-ollama() {
_wrapper curl -i $OLLAMA_CURL_ARGS http://localhost:11434/api/chat \
_wrapper curl -i http://localhost:11434/api/chat \
-X POST \
-H 'Content-Type: application/json' \
-d '{
@ -247,7 +234,7 @@ chat-vertexai-gemini() {
func="generateContent"
fi
url=https://$VERTEXAI_LOCATION-aiplatform.googleapis.com/v1/projects/$VERTEXAI_PROJECT_ID/locations/$VERTEXAI_LOCATION/publishers/google/models/$argc_model:$func
_wrapper curl -i $VERTEXAI_CURL_ARGS $url \
_wrapper curl -i $url \
-X POST \
-H "Authorization: Bearer $api_key" \
-H 'Content-Type: application/json' \
@ -267,7 +254,7 @@ chat-vertexai-gemini() {
chat-vertexai-claude() {
api_key="$(gcloud auth print-access-token)"
url=https://$VERTEXAI_LOCATION-aiplatform.googleapis.com/v1/projects/$VERTEXAI_PROJECT_ID/locations/$VERTEXAI_LOCATION/publishers/anthropic/models/$argc_model:streamRawPredict
_wrapper curl -i $VERTEXAI_CURL_ARGS $url \
_wrapper curl -i $url \
-X POST \
-H "Authorization: Bearer $api_key" \
-H 'Content-Type: application/json' \
@ -316,7 +303,7 @@ chat-bedrock() {
# @arg text~
chat-cloudflare() {
url="https://api.cloudflare.com/client/v4/accounts/$CLOUDFLARE_ACCOUNT_ID/ai/run/$argc_model"
_wrapper curl -i $CLOUDFLARE_CURL_ARGS "$url" \
_wrapper curl -i "$url" \
-X POST \
-H "Authorization: Bearer $CLOUDFLARE_API_KEY" \
-d '{
@ -332,7 +319,7 @@ chat-cloudflare() {
# @arg text~
chat-replicate() {
url="https://api.replicate.com/v1/models/$argc_model/predictions"
res="$(_wrapper curl -s $DEEPINFRA_CURL_ARGS "$url" \
res="$(_wrapper curl -s "$url" \
-X POST \
-H "Authorization: Bearer $REPLICATE_API_KEY" \
-H "Content-Type: application/json" \
@ -346,7 +333,7 @@ chat-replicate() {
if [[ -n "$argc_no_stream" ]]; then
prediction_url="$(echo "$res" | jq -r '.urls.get')"
while true; do
output="$(_wrapper curl $DEEPINFRA_CURL_ARGS -s -H "Authorization: Bearer $REPLICATE_API_KEY" "$prediction_url")"
output="$(_wrapper curl -s -H "Authorization: Bearer $REPLICATE_API_KEY" "$prediction_url")"
prediction_status=$(printf "%s" "$output" | jq -r .status)
if [ "$prediction_status"=="succeeded" ]; then
echo "$output"
@ -359,7 +346,7 @@ chat-replicate() {
done
else
stream_url="$(echo "$res" | jq -r '.urls.stream')"
_wrapper curl -i $DEEPINFRA_CURL_ARGS --no-buffer "$stream_url" \
_wrapper curl -i --no-buffer "$stream_url" \
-H "Accept: text/event-stream" \
fi
@ -376,7 +363,7 @@ chat-ernie() {
auth_url="https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=$ERNIE_API_KEY&client_secret=$ERNIE_SECRET_KEY"
ACCESS_TOKEN="$(curl -fsSL "$auth_url" | jq -r '.access_token')"
url="https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/$argc_model?access_token=$ACCESS_TOKEN"
_wrapper curl -i $ERNIE_CURL_ARGS "$url" \
_wrapper curl -i "$url" \
-X POST \
-d '{
"messages": '"$(_build_msg $*)"',
@ -398,7 +385,7 @@ chat-qianwen() {
parameters_args='{}'
fi
url=https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation
_wrapper curl -i $QIANWEN_CURL_ARGS "$url" \
_wrapper curl -i "$url" \
-X POST \
-H "Authorization: Bearer $QIANWEN_API_KEY" \
-H 'Content-Type: application/json' $stream_args \
@ -411,31 +398,6 @@ chat-qianwen() {
}'
}
# @cmd Chat with moonshot api
# @env MOONSHOT_API_KEY!
# @option -m --model=moonshot-v1-8k @MOONSHOT_MODEL
# @flag -S --no-stream
# @arg text~
chat-moonshot() {
api_base=https://api.moonshot.cn/v1
api_key=$MOONSHOT_API_KEY
curl_args="$MOONSHOT_CURL_ARGS"
_openai_chat "$@"
}
# @cmd List moonshot models
# @env MOONSHOT_API_KEY!
models-moonshot() {
api_base=https://api.moonshot.cn/v1
api_key=$MOONSHOT_API_KEY
curl_args="$MOONSHOT_CURL_ARGS"
_openai_models
}
_choice_model() {
aichat --list-models
}
_argc_before() {
stream="true"
if [[ -n "$argc_no_stream" ]]; then
@ -466,8 +428,23 @@ _openai_models() {
}
_choice_model() {
aichat --list-models
}
_choice_platform() {
_choice_client
_choice_openai_compatible_platform
}
_choice_client() {
printf "%s\n" openai gemini claude mistral cohere ollama vertexai bedrock ernie qianwen moonshot
printf "%s\n" openai gemini claude cohere ollama azure-openai vertexai bedrock cloudflare replicate ernie qianwen moonshot
}
_choice_openai_compatible_platform() {
for v in "${OPEIA_COMPATIBLE_CLIENTS[@]}"; do
echo "${v%%,*}"
done
}
_build_msg() {

@ -47,6 +47,16 @@ clients:
api_base: https://api.openai.com/v1 # ENV: {client_name}_API_BASE
organization_id: org-xxx # Optional
# For any platform compatible with OpenAI's API
- type: openai-compatible
name: localai
api_base: http://localhost:8080/v1 # ENV: {client_name}_API_BASE
api_key: xxx # ENV: {client_name}_API_KEY
chat_endpoint: /chat/completions # Optional
models:
- name: llama3
max_input_tokens: 8192
# See https://ai.google.dev/docs
- type: gemini
api_key: xxx # ENV: {client_name}_API_KEY
@ -58,7 +68,8 @@ clients:
api_key: sk-ant-xxx # ENV: {client_name}_API_KEY
# See https://docs.mistral.ai/
- type: mistral
- type: openai-compatible
name: mistral
api_key: xxx # ENV: {client_name}_API_KEY
# See https://docs.cohere.com/docs/the-cohere-platform
@ -129,20 +140,9 @@ clients:
- type: moonshot
api_key: sk-xxx # ENV: {client_name}_API_KEY
# For any platform compatible with OpenAI's API
- type: openai-compatible
name: localai
api_base: http://localhost:8080/v1 # ENV: {client_name}_API_BASE
api_key: sk-xxx # ENV: {client_name}_API_KEY
chat_endpoint: /chat/completions # Optional
models: # Required
- name: llama3
max_input_tokens: 8192
# See https://docs.endpoints.anyscale.com/
- type: openai-compatible
name: anyscale
api_base: https://api.endpoints.anyscale.com/v1
api_key: xxx
models:
# https://docs.endpoints.anyscale.com/text-generation/query-a-model#select-a-model
@ -154,7 +154,6 @@ clients:
# See https://deepinfra.com/docs
- type: openai-compatible
name: deepinfra
api_base: https://api.deepinfra.com/v1/openai
api_key: xxx
models:
# https://deepinfra.com/models
@ -166,7 +165,6 @@ clients:
# See https://readme.fireworks.ai/docs/quickstart
- type: openai-compatible
name: fireworks
api_base: https://api.fireworks.ai/inference/v1
api_key: xxx
models:
# https://fireworks.ai/models
@ -175,12 +173,21 @@ clients:
input_price: 0.9
output_price: 0.9
# See https://openrouter.ai/docs#quick-start
- type: openai-compatible
name: openrouter
api_key: xxx # ENV: {client_name}_API_KEY
models:
# https://openrouter.ai/docs#models
- name: meta-llama/llama-3-70b-instruct
max_input_tokens: 8192
input_price: 0.81
output_price: 0.81
# See https://octo.ai/docs/getting-started/quickstart
- type: openai-compatible
name: octoai
api_base: https://text.octoai.run/v1
api_key: xxx
api_key: xxx # ENV: {client_name}_API_KEY
models:
# https://octo.ai/docs/getting-started/inference-models
- name: meta-llama-3-70b-instruct
@ -191,8 +198,7 @@ clients:
# See https://docs.together.ai/docs/quickstart
- type: openai-compatible
name: together
api_base: https://api.together.xyz/v1
api_key: xxx
api_key: xxx # ENV: {client_name}_API_KEY
models:
# https://docs.together.ai/docs/inference-models
- name: meta-llama/Llama-3-70b-chat-hf

@ -2,7 +2,7 @@
# - This model list is scheduled to be updated with each new aichat release. Please do not submit PR to add new models.
# - This model list does not include models officially marked as legacy or beta.
- type: openai
- platform: openai
# docs:
# - https://platform.openai.com/docs/models
# - https://openai.com/pricing
@ -53,7 +53,7 @@
input_price: 60
output_price: 120
- type: gemini
- platform: gemini
# docs:
# - https://ai.google.dev/models/gemini
# - https://ai.google.dev/pricing
@ -79,7 +79,7 @@
output_price: 21
supports_vision: true
- type: claude
- platform: claude
# docs:
# - https://docs.anthropic.com/claude/docs/models-overview
# - https://docs.anthropic.com/claude/reference/messages-streaming
@ -105,7 +105,7 @@
output_price: 1.25
supports_vision: true
- type: mistral
- platform: mistral
# docs:
# - https://docs.mistral.ai/getting-started/models/
# - https://mistral.ai/technology/#pricing
@ -138,7 +138,7 @@
input_price: 8
output_price: 24
- type: cohere
- platform: cohere
# docs:
# - https://docs.cohere.com/docs/command-r
# - https://cohere.com/pricing
@ -157,7 +157,7 @@
input_price: 3
output_price: 15
- type: perplexity
- platform: perplexity
# docs:
# - https://docs.perplexity.ai/docs/model-cards
# - https://docs.perplexity.ai/docs/pricing
@ -209,7 +209,7 @@
input_price: 1
output_price: 1
- type: groq
- platform: groq
# docs:
# - https://console.groq.com/docs/models
# - https://wow.groq.com
@ -239,7 +239,7 @@
input_price: 0.10
output_price: 0.10
- type: vertexai
- platform: vertexai
# docs:
# - https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models
# - https://cloud.google.com/vertex-ai/generative-ai/pricing
@ -284,7 +284,7 @@
output_price: 1.25
supports_vision: true
- type: bedrock
- platform: bedrock
# docs:
# - https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns
# - https://aws.amazon.com/bedrock/pricing/
@ -346,7 +346,7 @@
input_price: 8
output_price: 2.4
- type: cloudflare
- platform: cloudflare
# docs:
# - https://developers.cloudflare.com/workers-ai/models/
# - https://developers.cloudflare.com/workers-ai/platform/pricing/
@ -367,7 +367,7 @@
input_price: 0.11
output_price: 0.19
- type: replicate
- platform: replicate
# docs:
# - https://replicate.com/docs
# - https://replicate.com/pricing
@ -395,7 +395,7 @@
input_price: 0.3
output_price: 1
- type: ernie
- platform: ernie
# docs:
# - https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
# - https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
@ -428,7 +428,7 @@
input_price: 0.14
output_price: 0.14
- type: qianwen
- platform: qianwen
# docs:
# - https://help.aliyun.com/zh/dashscope/developer-reference/tongyiqianwen-large-language-models/
# - https://help.aliyun.com/zh/dashscope/developer-reference/qwen-vl-plus/
@ -462,7 +462,7 @@
output_price: 2.8
supports_vision: true
- type: moonshot
- platform: moonshot
# docs:
# - https://platform.moonshot.cn/docs/intro
# - https://platform.moonshot.cn/docs/pricing

@ -1,5 +1,7 @@
use super::openai::openai_build_body;
use super::{AzureOpenAIClient, ExtraConfig, Model, ModelConfig, PromptKind, PromptType, SendData};
use super::{
AzureOpenAIClient, ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData,
};
use anyhow::Result;
use reqwest::{Client as ReqwestClient, RequestBuilder};
@ -18,7 +20,7 @@ impl AzureOpenAIClient {
config_get_fn!(api_base, get_api_base);
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 4] = [
pub const PROMPTS: [PromptAction<'static>; 4] = [
("api_base", "API Base:", true, PromptKind::String),
("api_key", "API Key:", true, PromptKind::String),
("models[].name", "Model Name:", true, PromptKind::String),

@ -1,8 +1,8 @@
use super::claude::{claude_build_body, claude_extract_completion};
use super::{
catch_error, generate_prompt, BedrockClient, Client, CompletionDetails, ExtraConfig, Model,
ModelConfig, PromptFormat, PromptKind, PromptType, SendData, SseHandler, LLAMA2_PROMPT_FORMAT,
LLAMA3_PROMPT_FORMAT,
ModelConfig, PromptAction, PromptFormat, PromptKind, SendData, SseHandler,
LLAMA2_PROMPT_FORMAT, LLAMA3_PROMPT_FORMAT,
};
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256};
@ -65,7 +65,7 @@ impl BedrockClient {
config_get_fn!(secret_access_key, get_secret_access_key);
config_get_fn!(region, get_region);
pub const PROMPTS: [PromptType<'static>; 3] = [
pub const PROMPTS: [PromptAction<'static>; 3] = [
(
"access_key_id",
"AWS Access Key ID",

@ -1,6 +1,6 @@
use super::{
catch_error, extract_system_message, sse_stream, ClaudeClient, CompletionDetails, ExtraConfig,
ImageUrl, MessageContent, MessageContentPart, Model, ModelConfig, PromptKind, PromptType,
ImageUrl, MessageContent, MessageContentPart, Model, ModelConfig, PromptAction, PromptKind,
SendData, SsMmessage, SseHandler,
};
@ -23,7 +23,7 @@ pub struct ClaudeConfig {
impl ClaudeClient {
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 1] =
pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {

@ -1,6 +1,6 @@
use super::{
catch_error, sse_stream, CloudflareClient, CompletionDetails, ExtraConfig, Model, ModelConfig,
PromptKind, PromptType, SendData, SsMmessage, SseHandler,
PromptAction, PromptKind, SendData, SsMmessage, SseHandler,
};
use anyhow::{anyhow, Result};
@ -24,7 +24,7 @@ impl CloudflareClient {
config_get_fn!(account_id, get_account_id);
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 2] = [
pub const PROMPTS: [PromptAction<'static>; 2] = [
("account_id", "Account ID:", true, PromptKind::String),
("api_key", "API Key:", true, PromptKind::String),
];

@ -1,6 +1,6 @@
use super::{
catch_error, extract_system_message, json_stream, message::*, CohereClient, CompletionDetails,
ExtraConfig, Model, ModelConfig, PromptKind, PromptType, SendData, SseHandler,
ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData, SseHandler,
};
use anyhow::{anyhow, bail, Result};
@ -22,7 +22,7 @@ pub struct CohereConfig {
impl CohereClient {
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 1] =
pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {

@ -1,4 +1,4 @@
use super::{openai::OpenAIConfig, ClientConfig, ClientModel, Message, Model, SseHandler};
use super::{openai::OpenAIConfig, BuiltinModels, ClientConfig, Message, Model, SseHandler};
use crate::{
config::{GlobalConfig, Input},
@ -20,7 +20,8 @@ use tokio::{sync::mpsc::unbounded_channel, time::sleep};
const MODELS_YAML: &str = include_str!("../../models.yaml");
lazy_static! {
pub static ref CLIENT_MODELS: Vec<ClientModel> = serde_yaml::from_str(MODELS_YAML).unwrap();
pub static ref ALL_CLIENT_MODELS: Vec<BuiltinModels> =
serde_yaml::from_str(MODELS_YAML).unwrap();
}
#[macro_export]
@ -90,13 +91,10 @@ macro_rules! register_client {
pub fn list_models(local_config: &$config) -> Vec<Model> {
let client_name = Self::name(local_config);
if local_config.models.is_empty() {
for model in $crate::client::CLIENT_MODELS.iter() {
match model {
$crate::client::ClientModel::$config { models } => {
return Model::from_config(client_name, models);
}
_ => {}
}
if let Some(client_models) = $crate::client::ALL_CLIENT_MODELS.iter().find(|v| {
v.platform == $name || ($name == "openai-compatible" && local_config.name.as_deref() == Some(&v.platform))
}) {
return Model::from_config(client_name, &client_models.models);
}
vec![]
} else {
@ -135,7 +133,7 @@ macro_rules! register_client {
pub fn list_client_types() -> Vec<&'static str> {
let mut client_types: Vec<_> = vec![$($client::NAME,)+];
client_types.extend($crate::client::KNOWN_OPENAI_COMPATIBLE_PLATFORMS.iter().map(|(name, _)| *name));
client_types.extend($crate::client::OPENAI_COMPATIBLE_PLATFORMS.iter().map(|(name, _)| *name));
client_types
}
@ -170,69 +168,6 @@ macro_rules! register_client {
};
}
#[macro_export]
macro_rules! openai_compatible_client {
(
$config:ident,
$client:ident,
$api_base:literal,
) => {
use $crate::client::openai::openai_build_body;
use $crate::client::{$client, ExtraConfig, Model, ModelConfig, PromptType, SendData};
use $crate::utils::PromptKind;
use anyhow::Result;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
const API_BASE: &str = $api_base;
#[derive(Debug, Clone, Deserialize)]
pub struct $config {
pub name: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
}
impl_client_trait!(
$client,
$crate::client::openai::openai_send_message,
$crate::client::openai::openai_send_message_streaming
);
impl $client {
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
fn request_builder(
&self,
client: &ReqwestClient,
data: SendData,
) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();
let body = openai_build_body(data, &self.model);
let url = format!("{API_BASE}/chat/completions");
debug!("Request: {url} {body}");
let mut builder = client.post(url).json(&body);
if let Some(api_key) = api_key {
builder = builder.bearer_auth(api_key);
}
Ok(builder)
}
}
};
}
#[macro_export]
macro_rules! client_common_fns {
() => {
@ -437,36 +372,45 @@ pub struct CompletionDetails {
pub output_tokens: Option<u64>,
}
pub type PromptType<'a> = (&'a str, &'a str, bool, PromptKind);
pub type PromptAction<'a> = (&'a str, &'a str, bool, PromptKind);
pub fn create_config(list: &[PromptType], client: &str) -> Result<(String, Value)> {
pub fn create_config(prompts: &[PromptAction], client: &str) -> Result<(String, Value)> {
let mut config = json!({
"type": client,
});
let mut model = client.to_string();
set_client_config_values(list, &mut model, &mut config)?;
set_client_config_values(prompts, &mut model, &mut config)?;
let clients = json!(vec![config]);
Ok((model, clients))
}
pub fn create_openai_compatible_client_config(client: &str) -> Result<Option<(String, Value)>> {
match super::KNOWN_OPENAI_COMPATIBLE_PLATFORMS
match super::OPENAI_COMPATIBLE_PLATFORMS
.iter()
.find(|(name, _)| client == *name)
{
None => Ok(None),
Some((name, api_base)) => {
Some((name, _)) => {
let mut config = json!({
"type": "openai-compatible",
"name": name,
"api_base": api_base,
});
let prompts = if ALL_CLIENT_MODELS.iter().any(|v| &v.platform == name) {
vec![("api_key", "API Key:", false, PromptKind::String)]
} else {
vec![
("api_key", "API Key:", false, PromptKind::String),
("models[].name", "Model Name:", true, PromptKind::String),
(
"models[].max_input_tokens",
"Max Input Tokens:",
false,
PromptKind::Integer,
),
]
};
let mut model = client.to_string();
set_client_config_values(
&super::KNOWN_OPENAI_COMPATIBLE_PROMPTS,
&mut model,
&mut config,
)?;
set_client_config_values(&prompts, &mut model, &mut config)?;
let clients = json!(vec![config]);
Ok(Some((model, clients)))
}
@ -683,7 +627,7 @@ where
}
fn set_client_config_values(
list: &[PromptType],
list: &[PromptAction],
model: &mut String,
client_config: &mut Value,
) -> Result<()> {

@ -1,6 +1,6 @@
use super::{
maybe_catch_error, patch_system_message, sse_stream, Client, CompletionDetails, ErnieClient,
ExtraConfig, Model, ModelConfig, PromptKind, PromptType, SendData, SsMmessage, SseHandler,
ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData, SsMmessage, SseHandler,
};
use anyhow::{anyhow, Context, Result};
@ -27,7 +27,7 @@ pub struct ErnieConfig {
}
impl ErnieClient {
pub const PROMPTS: [PromptType<'static>; 2] = [
pub const PROMPTS: [PromptAction<'static>; 2] = [
("api_key", "API Key:", true, PromptKind::String),
("secret_key", "Secret Key:", true, PromptKind::String),
];

@ -1,5 +1,5 @@
use super::vertexai::gemini_build_body;
use super::{ExtraConfig, GeminiClient, Model, ModelConfig, PromptKind, PromptType, SendData};
use super::{ExtraConfig, GeminiClient, Model, ModelConfig, PromptAction, PromptKind, SendData};
use anyhow::Result;
use reqwest::{Client as ReqwestClient, RequestBuilder};
@ -20,7 +20,7 @@ pub struct GeminiConfig {
impl GeminiClient {
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 1] =
pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {

@ -1 +0,0 @@
openai_compatible_client!(GroqConfig, GroqClient, "https://api.groq.com/openai/v1",);

@ -1 +0,0 @@
openai_compatible_client!(MistralConfig, MistralClient, "https://api.mistral.ai/v1",);

@ -14,12 +14,15 @@ pub use sse_handler::*;
register_client!(
(openai, "openai", OpenAIConfig, OpenAIClient),
(
openai_compatible,
"openai-compatible",
OpenAICompatibleConfig,
OpenAICompatibleClient
),
(gemini, "gemini", GeminiConfig, GeminiClient),
(claude, "claude", ClaudeConfig, ClaudeClient),
(mistral, "mistral", MistralConfig, MistralClient),
(cohere, "cohere", CohereConfig, CohereClient),
(perplexity, "perplexity", PerplexityConfig, PerplexityClient),
(groq, "groq", GroqConfig, GroqClient),
(ollama, "ollama", OllamaConfig, OllamaClient),
(
azure_openai,
@ -33,30 +36,17 @@ register_client!(
(replicate, "replicate", ReplicateConfig, ReplicateClient),
(ernie, "ernie", ErnieConfig, ErnieClient),
(qianwen, "qianwen", QianwenConfig, QianwenClient),
(moonshot, "moonshot", MoonshotConfig, MoonshotClient),
(
openai_compatible,
"openai-compatible",
OpenAICompatibleConfig,
OpenAICompatibleClient
),
);
pub const KNOWN_OPENAI_COMPATIBLE_PLATFORMS: [(&str, &str); 5] = [
pub const OPENAI_COMPATIBLE_PLATFORMS: [(&str, &str); 10] = [
("anyscale", "https://api.endpoints.anyscale.com/v1"),
("deepinfra", "https://api.deepinfra.com/v1/openai"),
("fireworks", "https://api.fireworks.ai/inference/v1"),
("groq", "https://api.groq.com/openai/v1"),
("mistral", "https://api.mistral.ai/v1"),
("moonshot", "https://api.moonshot.cn/v1"),
("openrouter", "https://openrouter.ai/api/v1"),
("octoai", "https://text.octoai.run/v1"),
("perplexity", "https://api.perplexity.ai"),
("together", "https://api.together.xyz/v1"),
];
pub const KNOWN_OPENAI_COMPATIBLE_PROMPTS: [PromptType<'static>; 3] = [
("api_key", "API Key:", false, PromptKind::String),
("models[].name", "Model Name:", true, PromptKind::String),
(
"models[].max_input_tokens",
"Max Input Tokens:",
false,
PromptKind::Integer,
),
];

@ -242,6 +242,12 @@ pub struct ModelConfig {
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct BuiltinModels {
pub platform: String,
pub models: Vec<ModelConfig>,
}
bitflags::bitflags! {
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ModelCapabilities: u32 {

@ -1 +0,0 @@
openai_compatible_client!(MoonshotConfig, MoonshotClient, "https://api.moonshot.cn/v1",);

@ -1,6 +1,6 @@
use super::{
catch_error, message::*, CompletionDetails, ExtraConfig, Model, ModelConfig, OllamaClient,
PromptKind, PromptType, SendData, SseHandler,
PromptAction, PromptKind, SendData, SseHandler,
};
use anyhow::{anyhow, bail, Result};
@ -23,7 +23,7 @@ impl OllamaClient {
config_get_fn!(api_base, get_api_base);
config_get_fn!(api_auth, get_api_auth);
pub const PROMPTS: [PromptType<'static>; 4] = [
pub const PROMPTS: [PromptAction<'static>; 4] = [
("api_base", "API Base:", true, PromptKind::String),
("api_auth", "API Auth:", false, PromptKind::String),
("models[].name", "Model Name:", true, PromptKind::String),

@ -1,6 +1,6 @@
use super::{
catch_error, sse_stream, CompletionDetails, ExtraConfig, Model, ModelConfig, OpenAIClient,
PromptKind, PromptType, SendData, SsMmessage, SseHandler,
PromptAction, PromptKind, SendData, SsMmessage, SseHandler,
};
use anyhow::{anyhow, Result};
@ -25,7 +25,7 @@ impl OpenAIClient {
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);
pub const PROMPTS: [PromptType<'static>; 1] =
pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {

@ -1,6 +1,8 @@
use crate::client::OPENAI_COMPATIBLE_PLATFORMS;
use super::openai::openai_build_body;
use super::{
ExtraConfig, Model, ModelConfig, OpenAICompatibleClient, PromptKind, PromptType, SendData,
ExtraConfig, Model, ModelConfig, OpenAICompatibleClient, PromptAction, PromptKind, SendData,
};
use anyhow::Result;
@ -13,6 +15,7 @@ pub struct OpenAICompatibleConfig {
pub api_base: Option<String>,
pub api_key: Option<String>,
pub chat_endpoint: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
}
@ -21,7 +24,7 @@ impl OpenAICompatibleClient {
config_get_fn!(api_base, get_api_base);
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 5] = [
pub const PROMPTS: [PromptAction<'static>; 5] = [
("name", "Platform Name:", true, PromptKind::String),
("api_base", "API Base:", true, PromptKind::String),
("api_key", "API Key:", false, PromptKind::String),
@ -35,7 +38,23 @@ impl OpenAICompatibleClient {
];
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_base = self.get_api_base()?;
let api_base = match self.get_api_base() {
Ok(v) => v,
Err(err) => {
match OPENAI_COMPATIBLE_PLATFORMS
.into_iter()
.find_map(|(name, api_base)| {
if name == self.model.client_name {
Some(api_base.to_string())
} else {
None
}
}) {
Some(v) => v,
None => return Err(err),
}
}
};
let api_key = self.get_api_key().ok();
let mut body = openai_build_body(data, &self.model);

@ -1,5 +0,0 @@
openai_compatible_client!(
PerplexityConfig,
PerplexityClient,
"https://api.perplexity.ai",
);

@ -1,6 +1,6 @@
use super::{
maybe_catch_error, message::*, sse_stream, Client, CompletionDetails, ExtraConfig, Model,
ModelConfig, PromptKind, PromptType, QianwenClient, SendData, SsMmessage, SseHandler,
ModelConfig, PromptAction, PromptKind, QianwenClient, SendData, SsMmessage, SseHandler,
};
use crate::utils::{base64_decode, sha256};
@ -33,7 +33,7 @@ pub struct QianwenConfig {
impl QianwenClient {
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 1] =
pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {

@ -2,8 +2,8 @@ use std::time::Duration;
use super::{
catch_error, generate_prompt, smart_prompt_format, sse_stream, Client, CompletionDetails,
ExtraConfig, Model, ModelConfig, PromptKind, PromptType, ReplicateClient, SendData, SsMmessage,
SseHandler,
ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, ReplicateClient, SendData,
SsMmessage, SseHandler,
};
use anyhow::{anyhow, Result};
@ -26,7 +26,7 @@ pub struct ReplicateConfig {
impl ReplicateClient {
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 1] =
pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
fn request_builder(

@ -1,7 +1,8 @@
use super::claude::{claude_build_body, claude_send_message, claude_send_message_streaming};
use super::{
catch_error, json_stream, message::*, patch_system_message, Client, CompletionDetails,
ExtraConfig, Model, ModelConfig, PromptKind, PromptType, SendData, SseHandler, VertexAIClient,
ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData, SseHandler,
VertexAIClient,
};
use anyhow::{anyhow, bail, Context, Result};
@ -30,7 +31,7 @@ impl VertexAIClient {
config_get_fn!(project_id, get_project_id);
config_get_fn!(location, get_location);
pub const PROMPTS: [PromptType<'static>; 2] = [
pub const PROMPTS: [PromptAction<'static>; 2] = [
("project_id", "Project ID", true, PromptKind::String),
("location", "Location", true, PromptKind::String),
];

@ -9,6 +9,7 @@ use self::session::{Session, TEMP_SESSION_NAME};
use crate::client::{
create_client_config, list_client_types, list_models, ClientConfig, Message, Model, SendData,
OPENAI_COMPATIBLE_PLATFORMS,
};
use crate::render::{MarkdownRender, RenderOptions};
use crate::utils::{
@ -21,6 +22,7 @@ use inquire::{Confirm, Select, Text};
use is_terminal::IsTerminal;
use parking_lot::RwLock;
use serde::Deserialize;
use serde_json::json;
use std::collections::{HashMap, HashSet};
use std::{
env,
@ -126,12 +128,12 @@ impl Config {
pub fn init(working_mode: WorkingMode) -> Result<Self> {
let config_path = Self::config_file()?;
let client_type = env::var(get_env_name("client_type")).ok();
if working_mode != WorkingMode::Command && client_type.is_none() && !config_path.exists() {
let platform = env::var(get_env_name("platform")).ok();
if working_mode != WorkingMode::Command && platform.is_none() && !config_path.exists() {
create_config_file(&config_path)?;
}
let mut config = if client_type.is_some() {
Self::load_config_env(&client_type.unwrap())?
let mut config = if platform.is_some() {
Self::load_config_env(&platform.unwrap())?
} else {
Self::load_config_file(&config_path)?
};
@ -926,37 +928,45 @@ impl Config {
fn load_config_file(config_path: &Path) -> Result<Self> {
let ctx = || format!("Failed to load config at {}", config_path.display());
let content = read_to_string(config_path).with_context(ctx)?;
let config = Self::load_config(&content).with_context(ctx)?;
let config: Self = serde_yaml::from_str(&content).map_err(|err| {
let err_msg = err.to_string();
let err_msg = if err_msg.starts_with(&format!("{}: ", CLIENTS_FIELD)) {
// location is incorrect, get rid of it
err_msg
.split_once(" at line")
.map(|(v, _)| {
format!("{v} (Sorry for being unable to provide an exact location)")
})
.unwrap_or_else(|| "clients: invalid value".into())
} else {
err_msg
};
anyhow!("{err_msg}")
})?;
Ok(config)
}
fn load_config_env(client_type: &str) -> Result<Self> {
fn load_config_env(platform: &str) -> Result<Self> {
let model_id = match env::var(get_env_name("model_name")) {
Ok(model_name) => format!("{client_type}:{model_name}"),
Err(_) => client_type.to_string(),
Ok(model_name) => format!("{platform}:{model_name}"),
Err(_) => platform.to_string(),
};
let content = format!(
r#"
model: {model_id}
save: false
clients:
- type: {client_type}
"#
);
let config = Self::load_config(&content).with_context(|| "Failed to load config")?;
Ok(config)
}
fn load_config(content: &str) -> Result<Self> {
let config: Self = serde_yaml::from_str(content).map_err(|err| {
let err_msg = err.to_string();
if err_msg.starts_with(&format!("{}: ", CLIENTS_FIELD)) {
anyhow!("clients: invalid value")
} else {
anyhow!("{err_msg}")
}
})?;
let is_openai_compatible = OPENAI_COMPATIBLE_PLATFORMS
.into_iter()
.any(|(name, _)| platform == name);
let client = if is_openai_compatible {
json!({ "type": "openai-compatible", "name": platform })
} else {
json!({ "type": platform })
};
let config = json!({
"model": model_id,
"save": false,
"clients": vec![client],
});
let config =
serde_json::from_value(config).with_context(|| "Failed to load config from env")?;
Ok(config)
}

Loading…
Cancel
Save