feat: openai-compatible platforms share the same client (#469)

pull/470/head
sigoden 1 month ago committed by GitHub
parent 8a65337d59
commit 8dba46becf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -17,10 +17,10 @@ test-init-config() {
cargo run -- "$@" cargo run -- "$@"
} }
# @cmd Test running without the config file # @cmd Test running with AICHAT_PLATFORM environment varialbe
# @env AICHAT_CLIENT_TYPE! # @env AICHAT_PLATFORM!
# @arg args~ # @arg args~
test-without-config() { test-platform-env() {
cargo run -- "$@" cargo run -- "$@"
} }
@ -50,44 +50,85 @@ test-server() {
"$@" "$@"
} }
OPEIA_COMPATIBLE_CLIENTS=( \
openai,gpt-3.5-turbo,https://api.openai.com/v1 \
anyscale,meta-llama/Meta-Llama-3-8B-Instruct,https://api.endpoints.anyscale.com/v1 \
deepinfra,meta-llama/Meta-Llama-3-8B-Instruct,https://api.deepinfra.com/v1/openai \
fireworks,accounts/fireworks/models/llama-v3-8b-instruct,https://api.fireworks.ai/inference/v1 \
groq,llama3-8b-8192,https://api.groq.com/openai/v1 \
mistral,mistral-small-latest,https://api.mistral.ai/v1 \
moonshot,moonshot-v1-8k,https://api.moonshot.cn/v1 \
openrouter,meta-llama/llama-3-8b-instruct,https://openrouter.ai/api/v1 \
octoai,meta-llama-3-8b-instruct,https://text.octoai.run/v1 \
perplexity,llama-3-8b-instruct,https://api.perplexity.ai \
together,meta-llama/Llama-3-8b-chat-hf,https://api.together.xyz/v1 \
)
# @cmd Chat with openai-comptabile api # @cmd Chat with openai-comptabile api
# @option --api-base! $$
# @option --api-key! $$
# @option -m --model! $$
# @flag -S --no-stream # @flag -S --no-stream
# @arg platform![`_choice_platform`]
# @arg text~ # @arg text~
chat-llm() { chat() {
curl_args="$CURL_ARGS" for client_config in "${OPEIA_COMPATIBLE_CLIENTS[@]}"; do
_openai_chat "$@" if [[ "$argc_platform" == "${client_config%%,*}" ]]; then
api_base="${client_config##*,}"
break
fi
done
if [[ -n "$api_base" ]]; then
env_prefix="$(echo "$argc_platform" | tr '[:lower:]' '[:upper:]')"
api_key_env="${env_prefix}_API_KEY"
api_key="${!api_key_env}"
if [[ -z "$model" ]]; then
model="$(echo "$client_config" | cut -d, -f2)"
fi
if [[ -z "$model" ]]; then
model_env="${env_prefix}_MODEL"
model="${!model_env}"
fi
argc chat-openai-comptabile \
--api-base "$api_base" \
--api-key "$api_key" \
--model "$model" \
"${argc_text[@]}"
else
argc chat-$argc_platform "${argc_text[@]}"
fi
} }
# @cmd List models by openai-comptabile api # @cmd List models by openai-comptabile api
# @option --api-base! $$ # @arg platform![`_choice_platform`]
# @option --api-key! $$ models() {
models-llm() { for client_config in "${OPEIA_COMPATIBLE_CLIENTS[@]}"; do
curl_args="$CURL_ARGS" if [[ "$argc_platform" == "${client_config%%,*}" ]]; then
_openai_models api_base="${client_config##*,}"
break
fi
done
if [[ -n "$api_base" ]]; then
env_prefix="$(echo "$argc_platform" | tr '[:lower:]' '[:upper:]')"
api_key_env="${env_prefix}_API_KEY"
api_key="${!api_key_env}"
_openai_models
else
argc models-$argc_platform
fi
} }
# @cmd Chat with openai-comptabile api
# @cmd Chat with openai api # @option --api-base! $$
# @env OPENAI_API_KEY! # @option --api-key! $$
# @option -m --model=gpt-3.5-turbo $OPENAI_MODEL # @option -m --model! $$
# @flag -S --no-stream # @flag -S --no-stream
# @arg text~ # @arg text~
chat-openai() { chat-openai-comptabile() {
api_base=https://api.openai.com/v1
api_key=$OPENAI_API_KEY
curl_args="-i $OPENAI_CURL_ARGS"
_openai_chat "$@" _openai_chat "$@"
} }
# @cmd List openai models # @cmd List models by openai-comptabile api
# @env OPENAI_API_KEY! # @option --api-base! $$
models-openai() { # @option --api-key! $$
api_base=https://api.openai.com/v1 models-openai-comptabile() {
api_key=$OPENAI_API_KEY
curl_args="$OPENAI_CURL_ARGS"
_openai_models _openai_models
} }
@ -101,7 +142,7 @@ chat-gemini() {
if [[ -n "$argc_no_stream" ]]; then if [[ -n "$argc_no_stream" ]]; then
method="generateContent" method="generateContent"
fi fi
_wrapper curl -i $GEMINI_CURL_ARGS "https://generativelanguage.googleapis.com/v1beta/models/${argc_model}:${method}?key=${GEMINI_API_KEY}" \ _wrapper curl -i "https://generativelanguage.googleapis.com/v1beta/models/${argc_model}:${method}?key=${GEMINI_API_KEY}" \
-i -X POST \ -i -X POST \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
-d '{ -d '{
@ -113,7 +154,7 @@ chat-gemini() {
# @cmd List gemini models # @cmd List gemini models
# @env GEMINI_API_KEY! # @env GEMINI_API_KEY!
models-gemini() { models-gemini() {
_wrapper curl $GEMINI_CURL_ARGS "https://generativelanguage.googleapis.com/v1beta/models?key=${GEMINI_API_KEY}" \ _wrapper curl "https://generativelanguage.googleapis.com/v1beta/models?key=${GEMINI_API_KEY}" \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
} }
@ -124,7 +165,7 @@ models-gemini() {
# @flag -S --no-stream # @flag -S --no-stream
# @arg text~ # @arg text~
chat-claude() { chat-claude() {
_wrapper curl -i $CLAUDE_CURL_ARGS https://api.anthropic.com/v1/messages \ _wrapper curl -i https://api.anthropic.com/v1/messages \
-X POST \ -X POST \
-H 'content-type: application/json' \ -H 'content-type: application/json' \
-H 'anthropic-version: 2023-06-01' \ -H 'anthropic-version: 2023-06-01' \
@ -138,34 +179,13 @@ chat-claude() {
' '
} }
# @cmd Chat with mistral api
# @env MISTRAL_API_KEY!
# @option -m --model=mistral-small-latest $MISTRAL_MODEL
# @flag -S --no-stream
# @arg text~
chat-mistral() {
api_base=https://api.mistral.ai/v1
api_key=$MISTRAL_API_KEY
curl_args="$MISTRAL_CURL_ARGS"
_openai_chat "$@"
}
# @cmd List mistral models
# @env MISTRAL_API_KEY!
models-mistral() {
api_base=https://api.mistral.ai/v1
api_key=$MISTRAL_API_KEY
curl_args="$MISTRAL_CURL_ARGS"
_openai_models
}
# @cmd Chat with cohere api # @cmd Chat with cohere api
# @env COHERE_API_KEY! # @env COHERE_API_KEY!
# @option -m --model=command-r $COHERE_MODEL # @option -m --model=command-r $COHERE_MODEL
# @flag -S --no-stream # @flag -S --no-stream
# @arg text~ # @arg text~
chat-cohere() { chat-cohere() {
_wrapper curl -i $COHERE_CURL_ARGS https://api.cohere.ai/v1/chat \ _wrapper curl -i https://api.cohere.ai/v1/chat \
-X POST \ -X POST \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
-H "Authorization: Bearer $COHERE_API_KEY" \ -H "Authorization: Bearer $COHERE_API_KEY" \
@ -180,50 +200,17 @@ chat-cohere() {
# @cmd List cohere models # @cmd List cohere models
# @env COHERE_API_KEY! # @env COHERE_API_KEY!
models-cohere() { models-cohere() {
_wrapper curl $COHERE_CURL_ARGS https://api.cohere.ai/v1/models \ _wrapper curl https://api.cohere.ai/v1/models \
-H "Authorization: Bearer $COHERE_API_KEY" \ -H "Authorization: Bearer $COHERE_API_KEY" \
} }
# @cmd Chat with perplexity api
# @env PERPLEXITY_API_KEY!
# @option -m --model=sonar-small-chat $PERPLEXITY_MODEL
# @flag -S --no-stream
# @arg text~
chat-perplexity() {
api_base=https://api.perplexity.ai
api_key=$PERPLEXITY_API_KEY
curl_args="$PERPLEXITY_CURL_ARGS"
_openai_chat "$@"
}
# @cmd Chat with groq api
# @env GROQ_API_KEY!
# @option -m --model=llama3-70b-8192 $GROQ_MODEL
# @flag -S --no-stream
# @arg text~
chat-groq() {
api_base=https://api.groq.com/openai/v1
api_key=$GROQ_API_KEY
curl_args="$GROQ_CURL_ARGS"
_openai_chat "$@"
}
# @cmd List groq models
# @env GROQ_API_KEY!
models-groq() {
api_base=https://api.groq.com/openai/v1
api_key=$GROQ_API_KEY
curl_args="$GROQ_CURL_ARGS"
_openai_models
}
# @cmd Chat with ollama api # @cmd Chat with ollama api
# @option -m --model=codegemma $OLLAMA_MODEL # @option -m --model=codegemma $OLLAMA_MODEL
# @flag -S --no-stream # @flag -S --no-stream
# @arg text~ # @arg text~
chat-ollama() { chat-ollama() {
_wrapper curl -i $OLLAMA_CURL_ARGS http://localhost:11434/api/chat \ _wrapper curl -i http://localhost:11434/api/chat \
-X POST \ -X POST \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
-d '{ -d '{
@ -247,7 +234,7 @@ chat-vertexai-gemini() {
func="generateContent" func="generateContent"
fi fi
url=https://$VERTEXAI_LOCATION-aiplatform.googleapis.com/v1/projects/$VERTEXAI_PROJECT_ID/locations/$VERTEXAI_LOCATION/publishers/google/models/$argc_model:$func url=https://$VERTEXAI_LOCATION-aiplatform.googleapis.com/v1/projects/$VERTEXAI_PROJECT_ID/locations/$VERTEXAI_LOCATION/publishers/google/models/$argc_model:$func
_wrapper curl -i $VERTEXAI_CURL_ARGS $url \ _wrapper curl -i $url \
-X POST \ -X POST \
-H "Authorization: Bearer $api_key" \ -H "Authorization: Bearer $api_key" \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
@ -267,7 +254,7 @@ chat-vertexai-gemini() {
chat-vertexai-claude() { chat-vertexai-claude() {
api_key="$(gcloud auth print-access-token)" api_key="$(gcloud auth print-access-token)"
url=https://$VERTEXAI_LOCATION-aiplatform.googleapis.com/v1/projects/$VERTEXAI_PROJECT_ID/locations/$VERTEXAI_LOCATION/publishers/anthropic/models/$argc_model:streamRawPredict url=https://$VERTEXAI_LOCATION-aiplatform.googleapis.com/v1/projects/$VERTEXAI_PROJECT_ID/locations/$VERTEXAI_LOCATION/publishers/anthropic/models/$argc_model:streamRawPredict
_wrapper curl -i $VERTEXAI_CURL_ARGS $url \ _wrapper curl -i $url \
-X POST \ -X POST \
-H "Authorization: Bearer $api_key" \ -H "Authorization: Bearer $api_key" \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
@ -316,7 +303,7 @@ chat-bedrock() {
# @arg text~ # @arg text~
chat-cloudflare() { chat-cloudflare() {
url="https://api.cloudflare.com/client/v4/accounts/$CLOUDFLARE_ACCOUNT_ID/ai/run/$argc_model" url="https://api.cloudflare.com/client/v4/accounts/$CLOUDFLARE_ACCOUNT_ID/ai/run/$argc_model"
_wrapper curl -i $CLOUDFLARE_CURL_ARGS "$url" \ _wrapper curl -i "$url" \
-X POST \ -X POST \
-H "Authorization: Bearer $CLOUDFLARE_API_KEY" \ -H "Authorization: Bearer $CLOUDFLARE_API_KEY" \
-d '{ -d '{
@ -332,7 +319,7 @@ chat-cloudflare() {
# @arg text~ # @arg text~
chat-replicate() { chat-replicate() {
url="https://api.replicate.com/v1/models/$argc_model/predictions" url="https://api.replicate.com/v1/models/$argc_model/predictions"
res="$(_wrapper curl -s $DEEPINFRA_CURL_ARGS "$url" \ res="$(_wrapper curl -s "$url" \
-X POST \ -X POST \
-H "Authorization: Bearer $REPLICATE_API_KEY" \ -H "Authorization: Bearer $REPLICATE_API_KEY" \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
@ -346,7 +333,7 @@ chat-replicate() {
if [[ -n "$argc_no_stream" ]]; then if [[ -n "$argc_no_stream" ]]; then
prediction_url="$(echo "$res" | jq -r '.urls.get')" prediction_url="$(echo "$res" | jq -r '.urls.get')"
while true; do while true; do
output="$(_wrapper curl $DEEPINFRA_CURL_ARGS -s -H "Authorization: Bearer $REPLICATE_API_KEY" "$prediction_url")" output="$(_wrapper curl -s -H "Authorization: Bearer $REPLICATE_API_KEY" "$prediction_url")"
prediction_status=$(printf "%s" "$output" | jq -r .status) prediction_status=$(printf "%s" "$output" | jq -r .status)
if [ "$prediction_status"=="succeeded" ]; then if [ "$prediction_status"=="succeeded" ]; then
echo "$output" echo "$output"
@ -359,7 +346,7 @@ chat-replicate() {
done done
else else
stream_url="$(echo "$res" | jq -r '.urls.stream')" stream_url="$(echo "$res" | jq -r '.urls.stream')"
_wrapper curl -i $DEEPINFRA_CURL_ARGS --no-buffer "$stream_url" \ _wrapper curl -i --no-buffer "$stream_url" \
-H "Accept: text/event-stream" \ -H "Accept: text/event-stream" \
fi fi
@ -376,7 +363,7 @@ chat-ernie() {
auth_url="https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=$ERNIE_API_KEY&client_secret=$ERNIE_SECRET_KEY" auth_url="https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=$ERNIE_API_KEY&client_secret=$ERNIE_SECRET_KEY"
ACCESS_TOKEN="$(curl -fsSL "$auth_url" | jq -r '.access_token')" ACCESS_TOKEN="$(curl -fsSL "$auth_url" | jq -r '.access_token')"
url="https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/$argc_model?access_token=$ACCESS_TOKEN" url="https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/$argc_model?access_token=$ACCESS_TOKEN"
_wrapper curl -i $ERNIE_CURL_ARGS "$url" \ _wrapper curl -i "$url" \
-X POST \ -X POST \
-d '{ -d '{
"messages": '"$(_build_msg $*)"', "messages": '"$(_build_msg $*)"',
@ -398,7 +385,7 @@ chat-qianwen() {
parameters_args='{}' parameters_args='{}'
fi fi
url=https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation url=https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation
_wrapper curl -i $QIANWEN_CURL_ARGS "$url" \ _wrapper curl -i "$url" \
-X POST \ -X POST \
-H "Authorization: Bearer $QIANWEN_API_KEY" \ -H "Authorization: Bearer $QIANWEN_API_KEY" \
-H 'Content-Type: application/json' $stream_args \ -H 'Content-Type: application/json' $stream_args \
@ -411,31 +398,6 @@ chat-qianwen() {
}' }'
} }
# @cmd Chat with moonshot api
# @env MOONSHOT_API_KEY!
# @option -m --model=moonshot-v1-8k @MOONSHOT_MODEL
# @flag -S --no-stream
# @arg text~
chat-moonshot() {
api_base=https://api.moonshot.cn/v1
api_key=$MOONSHOT_API_KEY
curl_args="$MOONSHOT_CURL_ARGS"
_openai_chat "$@"
}
# @cmd List moonshot models
# @env MOONSHOT_API_KEY!
models-moonshot() {
api_base=https://api.moonshot.cn/v1
api_key=$MOONSHOT_API_KEY
curl_args="$MOONSHOT_CURL_ARGS"
_openai_models
}
_choice_model() {
aichat --list-models
}
_argc_before() { _argc_before() {
stream="true" stream="true"
if [[ -n "$argc_no_stream" ]]; then if [[ -n "$argc_no_stream" ]]; then
@ -466,8 +428,23 @@ _openai_models() {
} }
_choice_model() {
aichat --list-models
}
_choice_platform() {
_choice_client
_choice_openai_compatible_platform
}
_choice_client() { _choice_client() {
printf "%s\n" openai gemini claude mistral cohere ollama vertexai bedrock ernie qianwen moonshot printf "%s\n" openai gemini claude cohere ollama azure-openai vertexai bedrock cloudflare replicate ernie qianwen moonshot
}
_choice_openai_compatible_platform() {
for v in "${OPEIA_COMPATIBLE_CLIENTS[@]}"; do
echo "${v%%,*}"
done
} }
_build_msg() { _build_msg() {

@ -47,6 +47,16 @@ clients:
api_base: https://api.openai.com/v1 # ENV: {client_name}_API_BASE api_base: https://api.openai.com/v1 # ENV: {client_name}_API_BASE
organization_id: org-xxx # Optional organization_id: org-xxx # Optional
# For any platform compatible with OpenAI's API
- type: openai-compatible
name: localai
api_base: http://localhost:8080/v1 # ENV: {client_name}_API_BASE
api_key: xxx # ENV: {client_name}_API_KEY
chat_endpoint: /chat/completions # Optional
models:
- name: llama3
max_input_tokens: 8192
# See https://ai.google.dev/docs # See https://ai.google.dev/docs
- type: gemini - type: gemini
api_key: xxx # ENV: {client_name}_API_KEY api_key: xxx # ENV: {client_name}_API_KEY
@ -58,7 +68,8 @@ clients:
api_key: sk-ant-xxx # ENV: {client_name}_API_KEY api_key: sk-ant-xxx # ENV: {client_name}_API_KEY
# See https://docs.mistral.ai/ # See https://docs.mistral.ai/
- type: mistral - type: openai-compatible
name: mistral
api_key: xxx # ENV: {client_name}_API_KEY api_key: xxx # ENV: {client_name}_API_KEY
# See https://docs.cohere.com/docs/the-cohere-platform # See https://docs.cohere.com/docs/the-cohere-platform
@ -129,20 +140,9 @@ clients:
- type: moonshot - type: moonshot
api_key: sk-xxx # ENV: {client_name}_API_KEY api_key: sk-xxx # ENV: {client_name}_API_KEY
# For any platform compatible with OpenAI's API
- type: openai-compatible
name: localai
api_base: http://localhost:8080/v1 # ENV: {client_name}_API_BASE
api_key: sk-xxx # ENV: {client_name}_API_KEY
chat_endpoint: /chat/completions # Optional
models: # Required
- name: llama3
max_input_tokens: 8192
# See https://docs.endpoints.anyscale.com/ # See https://docs.endpoints.anyscale.com/
- type: openai-compatible - type: openai-compatible
name: anyscale name: anyscale
api_base: https://api.endpoints.anyscale.com/v1
api_key: xxx api_key: xxx
models: models:
# https://docs.endpoints.anyscale.com/text-generation/query-a-model#select-a-model # https://docs.endpoints.anyscale.com/text-generation/query-a-model#select-a-model
@ -154,7 +154,6 @@ clients:
# See https://deepinfra.com/docs # See https://deepinfra.com/docs
- type: openai-compatible - type: openai-compatible
name: deepinfra name: deepinfra
api_base: https://api.deepinfra.com/v1/openai
api_key: xxx api_key: xxx
models: models:
# https://deepinfra.com/models # https://deepinfra.com/models
@ -166,7 +165,6 @@ clients:
# See https://readme.fireworks.ai/docs/quickstart # See https://readme.fireworks.ai/docs/quickstart
- type: openai-compatible - type: openai-compatible
name: fireworks name: fireworks
api_base: https://api.fireworks.ai/inference/v1
api_key: xxx api_key: xxx
models: models:
# https://fireworks.ai/models # https://fireworks.ai/models
@ -175,12 +173,21 @@ clients:
input_price: 0.9 input_price: 0.9
output_price: 0.9 output_price: 0.9
# See https://openrouter.ai/docs#quick-start
- type: openai-compatible
name: openrouter
api_key: xxx # ENV: {client_name}_API_KEY
models:
# https://openrouter.ai/docs#models
- name: meta-llama/llama-3-70b-instruct
max_input_tokens: 8192
input_price: 0.81
output_price: 0.81
# See https://octo.ai/docs/getting-started/quickstart # See https://octo.ai/docs/getting-started/quickstart
- type: openai-compatible - type: openai-compatible
name: octoai name: octoai
api_base: https://text.octoai.run/v1 api_key: xxx # ENV: {client_name}_API_KEY
api_key: xxx
models: models:
# https://octo.ai/docs/getting-started/inference-models # https://octo.ai/docs/getting-started/inference-models
- name: meta-llama-3-70b-instruct - name: meta-llama-3-70b-instruct
@ -191,8 +198,7 @@ clients:
# See https://docs.together.ai/docs/quickstart # See https://docs.together.ai/docs/quickstart
- type: openai-compatible - type: openai-compatible
name: together name: together
api_base: https://api.together.xyz/v1 api_key: xxx # ENV: {client_name}_API_KEY
api_key: xxx
models: models:
# https://docs.together.ai/docs/inference-models # https://docs.together.ai/docs/inference-models
- name: meta-llama/Llama-3-70b-chat-hf - name: meta-llama/Llama-3-70b-chat-hf

@ -2,7 +2,7 @@
# - This model list is scheduled to be updated with each new aichat release. Please do not submit PR to add new models. # - This model list is scheduled to be updated with each new aichat release. Please do not submit PR to add new models.
# - This model list does not include models officially marked as legacy or beta. # - This model list does not include models officially marked as legacy or beta.
- type: openai - platform: openai
# docs: # docs:
# - https://platform.openai.com/docs/models # - https://platform.openai.com/docs/models
# - https://openai.com/pricing # - https://openai.com/pricing
@ -53,7 +53,7 @@
input_price: 60 input_price: 60
output_price: 120 output_price: 120
- type: gemini - platform: gemini
# docs: # docs:
# - https://ai.google.dev/models/gemini # - https://ai.google.dev/models/gemini
# - https://ai.google.dev/pricing # - https://ai.google.dev/pricing
@ -79,7 +79,7 @@
output_price: 21 output_price: 21
supports_vision: true supports_vision: true
- type: claude - platform: claude
# docs: # docs:
# - https://docs.anthropic.com/claude/docs/models-overview # - https://docs.anthropic.com/claude/docs/models-overview
# - https://docs.anthropic.com/claude/reference/messages-streaming # - https://docs.anthropic.com/claude/reference/messages-streaming
@ -105,7 +105,7 @@
output_price: 1.25 output_price: 1.25
supports_vision: true supports_vision: true
- type: mistral - platform: mistral
# docs: # docs:
# - https://docs.mistral.ai/getting-started/models/ # - https://docs.mistral.ai/getting-started/models/
# - https://mistral.ai/technology/#pricing # - https://mistral.ai/technology/#pricing
@ -138,7 +138,7 @@
input_price: 8 input_price: 8
output_price: 24 output_price: 24
- type: cohere - platform: cohere
# docs: # docs:
# - https://docs.cohere.com/docs/command-r # - https://docs.cohere.com/docs/command-r
# - https://cohere.com/pricing # - https://cohere.com/pricing
@ -157,7 +157,7 @@
input_price: 3 input_price: 3
output_price: 15 output_price: 15
- type: perplexity - platform: perplexity
# docs: # docs:
# - https://docs.perplexity.ai/docs/model-cards # - https://docs.perplexity.ai/docs/model-cards
# - https://docs.perplexity.ai/docs/pricing # - https://docs.perplexity.ai/docs/pricing
@ -209,7 +209,7 @@
input_price: 1 input_price: 1
output_price: 1 output_price: 1
- type: groq - platform: groq
# docs: # docs:
# - https://console.groq.com/docs/models # - https://console.groq.com/docs/models
# - https://wow.groq.com # - https://wow.groq.com
@ -239,7 +239,7 @@
input_price: 0.10 input_price: 0.10
output_price: 0.10 output_price: 0.10
- type: vertexai - platform: vertexai
# docs: # docs:
# - https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models # - https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models
# - https://cloud.google.com/vertex-ai/generative-ai/pricing # - https://cloud.google.com/vertex-ai/generative-ai/pricing
@ -284,7 +284,7 @@
output_price: 1.25 output_price: 1.25
supports_vision: true supports_vision: true
- type: bedrock - platform: bedrock
# docs: # docs:
# - https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns # - https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns
# - https://aws.amazon.com/bedrock/pricing/ # - https://aws.amazon.com/bedrock/pricing/
@ -346,7 +346,7 @@
input_price: 8 input_price: 8
output_price: 2.4 output_price: 2.4
- type: cloudflare - platform: cloudflare
# docs: # docs:
# - https://developers.cloudflare.com/workers-ai/models/ # - https://developers.cloudflare.com/workers-ai/models/
# - https://developers.cloudflare.com/workers-ai/platform/pricing/ # - https://developers.cloudflare.com/workers-ai/platform/pricing/
@ -367,7 +367,7 @@
input_price: 0.11 input_price: 0.11
output_price: 0.19 output_price: 0.19
- type: replicate - platform: replicate
# docs: # docs:
# - https://replicate.com/docs # - https://replicate.com/docs
# - https://replicate.com/pricing # - https://replicate.com/pricing
@ -395,7 +395,7 @@
input_price: 0.3 input_price: 0.3
output_price: 1 output_price: 1
- type: ernie - platform: ernie
# docs: # docs:
# - https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu # - https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
# - https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 # - https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
@ -428,7 +428,7 @@
input_price: 0.14 input_price: 0.14
output_price: 0.14 output_price: 0.14
- type: qianwen - platform: qianwen
# docs: # docs:
# - https://help.aliyun.com/zh/dashscope/developer-reference/tongyiqianwen-large-language-models/ # - https://help.aliyun.com/zh/dashscope/developer-reference/tongyiqianwen-large-language-models/
# - https://help.aliyun.com/zh/dashscope/developer-reference/qwen-vl-plus/ # - https://help.aliyun.com/zh/dashscope/developer-reference/qwen-vl-plus/
@ -462,7 +462,7 @@
output_price: 2.8 output_price: 2.8
supports_vision: true supports_vision: true
- type: moonshot - platform: moonshot
# docs: # docs:
# - https://platform.moonshot.cn/docs/intro # - https://platform.moonshot.cn/docs/intro
# - https://platform.moonshot.cn/docs/pricing # - https://platform.moonshot.cn/docs/pricing

@ -1,5 +1,7 @@
use super::openai::openai_build_body; use super::openai::openai_build_body;
use super::{AzureOpenAIClient, ExtraConfig, Model, ModelConfig, PromptKind, PromptType, SendData}; use super::{
AzureOpenAIClient, ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData,
};
use anyhow::Result; use anyhow::Result;
use reqwest::{Client as ReqwestClient, RequestBuilder}; use reqwest::{Client as ReqwestClient, RequestBuilder};
@ -18,7 +20,7 @@ impl AzureOpenAIClient {
config_get_fn!(api_base, get_api_base); config_get_fn!(api_base, get_api_base);
config_get_fn!(api_key, get_api_key); config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 4] = [ pub const PROMPTS: [PromptAction<'static>; 4] = [
("api_base", "API Base:", true, PromptKind::String), ("api_base", "API Base:", true, PromptKind::String),
("api_key", "API Key:", true, PromptKind::String), ("api_key", "API Key:", true, PromptKind::String),
("models[].name", "Model Name:", true, PromptKind::String), ("models[].name", "Model Name:", true, PromptKind::String),

@ -1,8 +1,8 @@
use super::claude::{claude_build_body, claude_extract_completion}; use super::claude::{claude_build_body, claude_extract_completion};
use super::{ use super::{
catch_error, generate_prompt, BedrockClient, Client, CompletionDetails, ExtraConfig, Model, catch_error, generate_prompt, BedrockClient, Client, CompletionDetails, ExtraConfig, Model,
ModelConfig, PromptFormat, PromptKind, PromptType, SendData, SseHandler, LLAMA2_PROMPT_FORMAT, ModelConfig, PromptAction, PromptFormat, PromptKind, SendData, SseHandler,
LLAMA3_PROMPT_FORMAT, LLAMA2_PROMPT_FORMAT, LLAMA3_PROMPT_FORMAT,
}; };
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256}; use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256};
@ -65,7 +65,7 @@ impl BedrockClient {
config_get_fn!(secret_access_key, get_secret_access_key); config_get_fn!(secret_access_key, get_secret_access_key);
config_get_fn!(region, get_region); config_get_fn!(region, get_region);
pub const PROMPTS: [PromptType<'static>; 3] = [ pub const PROMPTS: [PromptAction<'static>; 3] = [
( (
"access_key_id", "access_key_id",
"AWS Access Key ID", "AWS Access Key ID",

@ -1,6 +1,6 @@
use super::{ use super::{
catch_error, extract_system_message, sse_stream, ClaudeClient, CompletionDetails, ExtraConfig, catch_error, extract_system_message, sse_stream, ClaudeClient, CompletionDetails, ExtraConfig,
ImageUrl, MessageContent, MessageContentPart, Model, ModelConfig, PromptKind, PromptType, ImageUrl, MessageContent, MessageContentPart, Model, ModelConfig, PromptAction, PromptKind,
SendData, SsMmessage, SseHandler, SendData, SsMmessage, SseHandler,
}; };
@ -23,7 +23,7 @@ pub struct ClaudeConfig {
impl ClaudeClient { impl ClaudeClient {
config_get_fn!(api_key, get_api_key); config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 1] = pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)]; [("api_key", "API Key:", true, PromptKind::String)];
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> { fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {

@ -1,6 +1,6 @@
use super::{ use super::{
catch_error, sse_stream, CloudflareClient, CompletionDetails, ExtraConfig, Model, ModelConfig, catch_error, sse_stream, CloudflareClient, CompletionDetails, ExtraConfig, Model, ModelConfig,
PromptKind, PromptType, SendData, SsMmessage, SseHandler, PromptAction, PromptKind, SendData, SsMmessage, SseHandler,
}; };
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
@ -24,7 +24,7 @@ impl CloudflareClient {
config_get_fn!(account_id, get_account_id); config_get_fn!(account_id, get_account_id);
config_get_fn!(api_key, get_api_key); config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 2] = [ pub const PROMPTS: [PromptAction<'static>; 2] = [
("account_id", "Account ID:", true, PromptKind::String), ("account_id", "Account ID:", true, PromptKind::String),
("api_key", "API Key:", true, PromptKind::String), ("api_key", "API Key:", true, PromptKind::String),
]; ];

@ -1,6 +1,6 @@
use super::{ use super::{
catch_error, extract_system_message, json_stream, message::*, CohereClient, CompletionDetails, catch_error, extract_system_message, json_stream, message::*, CohereClient, CompletionDetails,
ExtraConfig, Model, ModelConfig, PromptKind, PromptType, SendData, SseHandler, ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData, SseHandler,
}; };
use anyhow::{anyhow, bail, Result}; use anyhow::{anyhow, bail, Result};
@ -22,7 +22,7 @@ pub struct CohereConfig {
impl CohereClient { impl CohereClient {
config_get_fn!(api_key, get_api_key); config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 1] = pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)]; [("api_key", "API Key:", true, PromptKind::String)];
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> { fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {

@ -1,4 +1,4 @@
use super::{openai::OpenAIConfig, ClientConfig, ClientModel, Message, Model, SseHandler}; use super::{openai::OpenAIConfig, BuiltinModels, ClientConfig, Message, Model, SseHandler};
use crate::{ use crate::{
config::{GlobalConfig, Input}, config::{GlobalConfig, Input},
@ -20,7 +20,8 @@ use tokio::{sync::mpsc::unbounded_channel, time::sleep};
const MODELS_YAML: &str = include_str!("../../models.yaml"); const MODELS_YAML: &str = include_str!("../../models.yaml");
lazy_static! { lazy_static! {
pub static ref CLIENT_MODELS: Vec<ClientModel> = serde_yaml::from_str(MODELS_YAML).unwrap(); pub static ref ALL_CLIENT_MODELS: Vec<BuiltinModels> =
serde_yaml::from_str(MODELS_YAML).unwrap();
} }
#[macro_export] #[macro_export]
@ -90,13 +91,10 @@ macro_rules! register_client {
pub fn list_models(local_config: &$config) -> Vec<Model> { pub fn list_models(local_config: &$config) -> Vec<Model> {
let client_name = Self::name(local_config); let client_name = Self::name(local_config);
if local_config.models.is_empty() { if local_config.models.is_empty() {
for model in $crate::client::CLIENT_MODELS.iter() { if let Some(client_models) = $crate::client::ALL_CLIENT_MODELS.iter().find(|v| {
match model { v.platform == $name || ($name == "openai-compatible" && local_config.name.as_deref() == Some(&v.platform))
$crate::client::ClientModel::$config { models } => { }) {
return Model::from_config(client_name, models); return Model::from_config(client_name, &client_models.models);
}
_ => {}
}
} }
vec![] vec![]
} else { } else {
@ -135,7 +133,7 @@ macro_rules! register_client {
pub fn list_client_types() -> Vec<&'static str> { pub fn list_client_types() -> Vec<&'static str> {
let mut client_types: Vec<_> = vec![$($client::NAME,)+]; let mut client_types: Vec<_> = vec![$($client::NAME,)+];
client_types.extend($crate::client::KNOWN_OPENAI_COMPATIBLE_PLATFORMS.iter().map(|(name, _)| *name)); client_types.extend($crate::client::OPENAI_COMPATIBLE_PLATFORMS.iter().map(|(name, _)| *name));
client_types client_types
} }
@ -170,69 +168,6 @@ macro_rules! register_client {
}; };
} }
#[macro_export]
macro_rules! openai_compatible_client {
(
$config:ident,
$client:ident,
$api_base:literal,
) => {
use $crate::client::openai::openai_build_body;
use $crate::client::{$client, ExtraConfig, Model, ModelConfig, PromptType, SendData};
use $crate::utils::PromptKind;
use anyhow::Result;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
const API_BASE: &str = $api_base;
#[derive(Debug, Clone, Deserialize)]
pub struct $config {
pub name: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
}
impl_client_trait!(
$client,
$crate::client::openai::openai_send_message,
$crate::client::openai::openai_send_message_streaming
);
impl $client {
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
fn request_builder(
&self,
client: &ReqwestClient,
data: SendData,
) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();
let body = openai_build_body(data, &self.model);
let url = format!("{API_BASE}/chat/completions");
debug!("Request: {url} {body}");
let mut builder = client.post(url).json(&body);
if let Some(api_key) = api_key {
builder = builder.bearer_auth(api_key);
}
Ok(builder)
}
}
};
}
#[macro_export] #[macro_export]
macro_rules! client_common_fns { macro_rules! client_common_fns {
() => { () => {
@ -437,36 +372,45 @@ pub struct CompletionDetails {
pub output_tokens: Option<u64>, pub output_tokens: Option<u64>,
} }
pub type PromptType<'a> = (&'a str, &'a str, bool, PromptKind); pub type PromptAction<'a> = (&'a str, &'a str, bool, PromptKind);
pub fn create_config(list: &[PromptType], client: &str) -> Result<(String, Value)> { pub fn create_config(prompts: &[PromptAction], client: &str) -> Result<(String, Value)> {
let mut config = json!({ let mut config = json!({
"type": client, "type": client,
}); });
let mut model = client.to_string(); let mut model = client.to_string();
set_client_config_values(list, &mut model, &mut config)?; set_client_config_values(prompts, &mut model, &mut config)?;
let clients = json!(vec![config]); let clients = json!(vec![config]);
Ok((model, clients)) Ok((model, clients))
} }
pub fn create_openai_compatible_client_config(client: &str) -> Result<Option<(String, Value)>> { pub fn create_openai_compatible_client_config(client: &str) -> Result<Option<(String, Value)>> {
match super::KNOWN_OPENAI_COMPATIBLE_PLATFORMS match super::OPENAI_COMPATIBLE_PLATFORMS
.iter() .iter()
.find(|(name, _)| client == *name) .find(|(name, _)| client == *name)
{ {
None => Ok(None), None => Ok(None),
Some((name, api_base)) => { Some((name, _)) => {
let mut config = json!({ let mut config = json!({
"type": "openai-compatible", "type": "openai-compatible",
"name": name, "name": name,
"api_base": api_base,
}); });
let prompts = if ALL_CLIENT_MODELS.iter().any(|v| &v.platform == name) {
vec![("api_key", "API Key:", false, PromptKind::String)]
} else {
vec![
("api_key", "API Key:", false, PromptKind::String),
("models[].name", "Model Name:", true, PromptKind::String),
(
"models[].max_input_tokens",
"Max Input Tokens:",
false,
PromptKind::Integer,
),
]
};
let mut model = client.to_string(); let mut model = client.to_string();
set_client_config_values( set_client_config_values(&prompts, &mut model, &mut config)?;
&super::KNOWN_OPENAI_COMPATIBLE_PROMPTS,
&mut model,
&mut config,
)?;
let clients = json!(vec![config]); let clients = json!(vec![config]);
Ok(Some((model, clients))) Ok(Some((model, clients)))
} }
@ -683,7 +627,7 @@ where
} }
fn set_client_config_values( fn set_client_config_values(
list: &[PromptType], list: &[PromptAction],
model: &mut String, model: &mut String,
client_config: &mut Value, client_config: &mut Value,
) -> Result<()> { ) -> Result<()> {

@ -1,6 +1,6 @@
use super::{ use super::{
maybe_catch_error, patch_system_message, sse_stream, Client, CompletionDetails, ErnieClient, maybe_catch_error, patch_system_message, sse_stream, Client, CompletionDetails, ErnieClient,
ExtraConfig, Model, ModelConfig, PromptKind, PromptType, SendData, SsMmessage, SseHandler, ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData, SsMmessage, SseHandler,
}; };
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
@ -27,7 +27,7 @@ pub struct ErnieConfig {
} }
impl ErnieClient { impl ErnieClient {
pub const PROMPTS: [PromptType<'static>; 2] = [ pub const PROMPTS: [PromptAction<'static>; 2] = [
("api_key", "API Key:", true, PromptKind::String), ("api_key", "API Key:", true, PromptKind::String),
("secret_key", "Secret Key:", true, PromptKind::String), ("secret_key", "Secret Key:", true, PromptKind::String),
]; ];

@ -1,5 +1,5 @@
use super::vertexai::gemini_build_body; use super::vertexai::gemini_build_body;
use super::{ExtraConfig, GeminiClient, Model, ModelConfig, PromptKind, PromptType, SendData}; use super::{ExtraConfig, GeminiClient, Model, ModelConfig, PromptAction, PromptKind, SendData};
use anyhow::Result; use anyhow::Result;
use reqwest::{Client as ReqwestClient, RequestBuilder}; use reqwest::{Client as ReqwestClient, RequestBuilder};
@ -20,7 +20,7 @@ pub struct GeminiConfig {
impl GeminiClient { impl GeminiClient {
config_get_fn!(api_key, get_api_key); config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 1] = pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)]; [("api_key", "API Key:", true, PromptKind::String)];
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> { fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {

@ -1 +0,0 @@
openai_compatible_client!(GroqConfig, GroqClient, "https://api.groq.com/openai/v1",);

@ -1 +0,0 @@
openai_compatible_client!(MistralConfig, MistralClient, "https://api.mistral.ai/v1",);

@ -14,12 +14,15 @@ pub use sse_handler::*;
register_client!( register_client!(
(openai, "openai", OpenAIConfig, OpenAIClient), (openai, "openai", OpenAIConfig, OpenAIClient),
(
openai_compatible,
"openai-compatible",
OpenAICompatibleConfig,
OpenAICompatibleClient
),
(gemini, "gemini", GeminiConfig, GeminiClient), (gemini, "gemini", GeminiConfig, GeminiClient),
(claude, "claude", ClaudeConfig, ClaudeClient), (claude, "claude", ClaudeConfig, ClaudeClient),
(mistral, "mistral", MistralConfig, MistralClient),
(cohere, "cohere", CohereConfig, CohereClient), (cohere, "cohere", CohereConfig, CohereClient),
(perplexity, "perplexity", PerplexityConfig, PerplexityClient),
(groq, "groq", GroqConfig, GroqClient),
(ollama, "ollama", OllamaConfig, OllamaClient), (ollama, "ollama", OllamaConfig, OllamaClient),
( (
azure_openai, azure_openai,
@ -33,30 +36,17 @@ register_client!(
(replicate, "replicate", ReplicateConfig, ReplicateClient), (replicate, "replicate", ReplicateConfig, ReplicateClient),
(ernie, "ernie", ErnieConfig, ErnieClient), (ernie, "ernie", ErnieConfig, ErnieClient),
(qianwen, "qianwen", QianwenConfig, QianwenClient), (qianwen, "qianwen", QianwenConfig, QianwenClient),
(moonshot, "moonshot", MoonshotConfig, MoonshotClient),
(
openai_compatible,
"openai-compatible",
OpenAICompatibleConfig,
OpenAICompatibleClient
),
); );
pub const KNOWN_OPENAI_COMPATIBLE_PLATFORMS: [(&str, &str); 5] = [ pub const OPENAI_COMPATIBLE_PLATFORMS: [(&str, &str); 10] = [
("anyscale", "https://api.endpoints.anyscale.com/v1"), ("anyscale", "https://api.endpoints.anyscale.com/v1"),
("deepinfra", "https://api.deepinfra.com/v1/openai"), ("deepinfra", "https://api.deepinfra.com/v1/openai"),
("fireworks", "https://api.fireworks.ai/inference/v1"), ("fireworks", "https://api.fireworks.ai/inference/v1"),
("groq", "https://api.groq.com/openai/v1"),
("mistral", "https://api.mistral.ai/v1"),
("moonshot", "https://api.moonshot.cn/v1"),
("openrouter", "https://openrouter.ai/api/v1"),
("octoai", "https://text.octoai.run/v1"), ("octoai", "https://text.octoai.run/v1"),
("perplexity", "https://api.perplexity.ai"),
("together", "https://api.together.xyz/v1"), ("together", "https://api.together.xyz/v1"),
]; ];
pub const KNOWN_OPENAI_COMPATIBLE_PROMPTS: [PromptType<'static>; 3] = [
("api_key", "API Key:", false, PromptKind::String),
("models[].name", "Model Name:", true, PromptKind::String),
(
"models[].max_input_tokens",
"Max Input Tokens:",
false,
PromptKind::Integer,
),
];

@ -242,6 +242,12 @@ pub struct ModelConfig {
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>, pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
} }
#[derive(Debug, Clone, Deserialize)]
pub struct BuiltinModels {
pub platform: String,
pub models: Vec<ModelConfig>,
}
bitflags::bitflags! { bitflags::bitflags! {
#[derive(Debug, Clone, Copy, PartialEq)] #[derive(Debug, Clone, Copy, PartialEq)]
pub struct ModelCapabilities: u32 { pub struct ModelCapabilities: u32 {

@ -1 +0,0 @@
openai_compatible_client!(MoonshotConfig, MoonshotClient, "https://api.moonshot.cn/v1",);

@ -1,6 +1,6 @@
use super::{ use super::{
catch_error, message::*, CompletionDetails, ExtraConfig, Model, ModelConfig, OllamaClient, catch_error, message::*, CompletionDetails, ExtraConfig, Model, ModelConfig, OllamaClient,
PromptKind, PromptType, SendData, SseHandler, PromptAction, PromptKind, SendData, SseHandler,
}; };
use anyhow::{anyhow, bail, Result}; use anyhow::{anyhow, bail, Result};
@ -23,7 +23,7 @@ impl OllamaClient {
config_get_fn!(api_base, get_api_base); config_get_fn!(api_base, get_api_base);
config_get_fn!(api_auth, get_api_auth); config_get_fn!(api_auth, get_api_auth);
pub const PROMPTS: [PromptType<'static>; 4] = [ pub const PROMPTS: [PromptAction<'static>; 4] = [
("api_base", "API Base:", true, PromptKind::String), ("api_base", "API Base:", true, PromptKind::String),
("api_auth", "API Auth:", false, PromptKind::String), ("api_auth", "API Auth:", false, PromptKind::String),
("models[].name", "Model Name:", true, PromptKind::String), ("models[].name", "Model Name:", true, PromptKind::String),

@ -1,6 +1,6 @@
use super::{ use super::{
catch_error, sse_stream, CompletionDetails, ExtraConfig, Model, ModelConfig, OpenAIClient, catch_error, sse_stream, CompletionDetails, ExtraConfig, Model, ModelConfig, OpenAIClient,
PromptKind, PromptType, SendData, SsMmessage, SseHandler, PromptAction, PromptKind, SendData, SsMmessage, SseHandler,
}; };
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
@ -25,7 +25,7 @@ impl OpenAIClient {
config_get_fn!(api_key, get_api_key); config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base); config_get_fn!(api_base, get_api_base);
pub const PROMPTS: [PromptType<'static>; 1] = pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)]; [("api_key", "API Key:", true, PromptKind::String)];
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> { fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {

@ -1,6 +1,8 @@
use crate::client::OPENAI_COMPATIBLE_PLATFORMS;
use super::openai::openai_build_body; use super::openai::openai_build_body;
use super::{ use super::{
ExtraConfig, Model, ModelConfig, OpenAICompatibleClient, PromptKind, PromptType, SendData, ExtraConfig, Model, ModelConfig, OpenAICompatibleClient, PromptAction, PromptKind, SendData,
}; };
use anyhow::Result; use anyhow::Result;
@ -13,6 +15,7 @@ pub struct OpenAICompatibleConfig {
pub api_base: Option<String>, pub api_base: Option<String>,
pub api_key: Option<String>, pub api_key: Option<String>,
pub chat_endpoint: Option<String>, pub chat_endpoint: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>, pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>, pub extra: Option<ExtraConfig>,
} }
@ -21,7 +24,7 @@ impl OpenAICompatibleClient {
config_get_fn!(api_base, get_api_base); config_get_fn!(api_base, get_api_base);
config_get_fn!(api_key, get_api_key); config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 5] = [ pub const PROMPTS: [PromptAction<'static>; 5] = [
("name", "Platform Name:", true, PromptKind::String), ("name", "Platform Name:", true, PromptKind::String),
("api_base", "API Base:", true, PromptKind::String), ("api_base", "API Base:", true, PromptKind::String),
("api_key", "API Key:", false, PromptKind::String), ("api_key", "API Key:", false, PromptKind::String),
@ -35,7 +38,23 @@ impl OpenAICompatibleClient {
]; ];
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> { fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_base = self.get_api_base()?; let api_base = match self.get_api_base() {
Ok(v) => v,
Err(err) => {
match OPENAI_COMPATIBLE_PLATFORMS
.into_iter()
.find_map(|(name, api_base)| {
if name == self.model.client_name {
Some(api_base.to_string())
} else {
None
}
}) {
Some(v) => v,
None => return Err(err),
}
}
};
let api_key = self.get_api_key().ok(); let api_key = self.get_api_key().ok();
let mut body = openai_build_body(data, &self.model); let mut body = openai_build_body(data, &self.model);

@ -1,5 +0,0 @@
openai_compatible_client!(
PerplexityConfig,
PerplexityClient,
"https://api.perplexity.ai",
);

@ -1,6 +1,6 @@
use super::{ use super::{
maybe_catch_error, message::*, sse_stream, Client, CompletionDetails, ExtraConfig, Model, maybe_catch_error, message::*, sse_stream, Client, CompletionDetails, ExtraConfig, Model,
ModelConfig, PromptKind, PromptType, QianwenClient, SendData, SsMmessage, SseHandler, ModelConfig, PromptAction, PromptKind, QianwenClient, SendData, SsMmessage, SseHandler,
}; };
use crate::utils::{base64_decode, sha256}; use crate::utils::{base64_decode, sha256};
@ -33,7 +33,7 @@ pub struct QianwenConfig {
impl QianwenClient { impl QianwenClient {
config_get_fn!(api_key, get_api_key); config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 1] = pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)]; [("api_key", "API Key:", true, PromptKind::String)];
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> { fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {

@ -2,8 +2,8 @@ use std::time::Duration;
use super::{ use super::{
catch_error, generate_prompt, smart_prompt_format, sse_stream, Client, CompletionDetails, catch_error, generate_prompt, smart_prompt_format, sse_stream, Client, CompletionDetails,
ExtraConfig, Model, ModelConfig, PromptKind, PromptType, ReplicateClient, SendData, SsMmessage, ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, ReplicateClient, SendData,
SseHandler, SsMmessage, SseHandler,
}; };
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
@ -26,7 +26,7 @@ pub struct ReplicateConfig {
impl ReplicateClient { impl ReplicateClient {
config_get_fn!(api_key, get_api_key); config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 1] = pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)]; [("api_key", "API Key:", true, PromptKind::String)];
fn request_builder( fn request_builder(

@ -1,7 +1,8 @@
use super::claude::{claude_build_body, claude_send_message, claude_send_message_streaming}; use super::claude::{claude_build_body, claude_send_message, claude_send_message_streaming};
use super::{ use super::{
catch_error, json_stream, message::*, patch_system_message, Client, CompletionDetails, catch_error, json_stream, message::*, patch_system_message, Client, CompletionDetails,
ExtraConfig, Model, ModelConfig, PromptKind, PromptType, SendData, SseHandler, VertexAIClient, ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData, SseHandler,
VertexAIClient,
}; };
use anyhow::{anyhow, bail, Context, Result}; use anyhow::{anyhow, bail, Context, Result};
@ -30,7 +31,7 @@ impl VertexAIClient {
config_get_fn!(project_id, get_project_id); config_get_fn!(project_id, get_project_id);
config_get_fn!(location, get_location); config_get_fn!(location, get_location);
pub const PROMPTS: [PromptType<'static>; 2] = [ pub const PROMPTS: [PromptAction<'static>; 2] = [
("project_id", "Project ID", true, PromptKind::String), ("project_id", "Project ID", true, PromptKind::String),
("location", "Location", true, PromptKind::String), ("location", "Location", true, PromptKind::String),
]; ];

@ -9,6 +9,7 @@ use self::session::{Session, TEMP_SESSION_NAME};
use crate::client::{ use crate::client::{
create_client_config, list_client_types, list_models, ClientConfig, Message, Model, SendData, create_client_config, list_client_types, list_models, ClientConfig, Message, Model, SendData,
OPENAI_COMPATIBLE_PLATFORMS,
}; };
use crate::render::{MarkdownRender, RenderOptions}; use crate::render::{MarkdownRender, RenderOptions};
use crate::utils::{ use crate::utils::{
@ -21,6 +22,7 @@ use inquire::{Confirm, Select, Text};
use is_terminal::IsTerminal; use is_terminal::IsTerminal;
use parking_lot::RwLock; use parking_lot::RwLock;
use serde::Deserialize; use serde::Deserialize;
use serde_json::json;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::{ use std::{
env, env,
@ -126,12 +128,12 @@ impl Config {
pub fn init(working_mode: WorkingMode) -> Result<Self> { pub fn init(working_mode: WorkingMode) -> Result<Self> {
let config_path = Self::config_file()?; let config_path = Self::config_file()?;
let client_type = env::var(get_env_name("client_type")).ok(); let platform = env::var(get_env_name("platform")).ok();
if working_mode != WorkingMode::Command && client_type.is_none() && !config_path.exists() { if working_mode != WorkingMode::Command && platform.is_none() && !config_path.exists() {
create_config_file(&config_path)?; create_config_file(&config_path)?;
} }
let mut config = if client_type.is_some() { let mut config = if platform.is_some() {
Self::load_config_env(&client_type.unwrap())? Self::load_config_env(&platform.unwrap())?
} else { } else {
Self::load_config_file(&config_path)? Self::load_config_file(&config_path)?
}; };
@ -926,37 +928,45 @@ impl Config {
fn load_config_file(config_path: &Path) -> Result<Self> { fn load_config_file(config_path: &Path) -> Result<Self> {
let ctx = || format!("Failed to load config at {}", config_path.display()); let ctx = || format!("Failed to load config at {}", config_path.display());
let content = read_to_string(config_path).with_context(ctx)?; let content = read_to_string(config_path).with_context(ctx)?;
let config = Self::load_config(&content).with_context(ctx)?; let config: Self = serde_yaml::from_str(&content).map_err(|err| {
let err_msg = err.to_string();
let err_msg = if err_msg.starts_with(&format!("{}: ", CLIENTS_FIELD)) {
// location is incorrect, get rid of it
err_msg
.split_once(" at line")
.map(|(v, _)| {
format!("{v} (Sorry for being unable to provide an exact location)")
})
.unwrap_or_else(|| "clients: invalid value".into())
} else {
err_msg
};
anyhow!("{err_msg}")
})?;
Ok(config) Ok(config)
} }
fn load_config_env(client_type: &str) -> Result<Self> { fn load_config_env(platform: &str) -> Result<Self> {
let model_id = match env::var(get_env_name("model_name")) { let model_id = match env::var(get_env_name("model_name")) {
Ok(model_name) => format!("{client_type}:{model_name}"), Ok(model_name) => format!("{platform}:{model_name}"),
Err(_) => client_type.to_string(), Err(_) => platform.to_string(),
}; };
let content = format!( let is_openai_compatible = OPENAI_COMPATIBLE_PLATFORMS
r#" .into_iter()
model: {model_id} .any(|(name, _)| platform == name);
save: false let client = if is_openai_compatible {
clients: json!({ "type": "openai-compatible", "name": platform })
- type: {client_type} } else {
"# json!({ "type": platform })
); };
let config = Self::load_config(&content).with_context(|| "Failed to load config")?; let config = json!({
Ok(config) "model": model_id,
} "save": false,
"clients": vec![client],
fn load_config(content: &str) -> Result<Self> { });
let config: Self = serde_yaml::from_str(content).map_err(|err| { let config =
let err_msg = err.to_string(); serde_json::from_value(config).with_context(|| "Failed to load config from env")?;
if err_msg.starts_with(&format!("{}: ", CLIENTS_FIELD)) {
anyhow!("clients: invalid value")
} else {
anyhow!("{err_msg}")
}
})?;
Ok(config) Ok(config)
} }

Loading…
Cancel
Save