diff --git a/src/client/azure_openai.rs b/src/client/azure_openai.rs index 1726bbe..9e96692 100644 --- a/src/client/azure_openai.rs +++ b/src/client/azure_openai.rs @@ -1,5 +1,5 @@ use super::openai::openai_build_body; -use super::{convert_models, AzureOpenAIClient, ExtraConfig, Model, ModelConfig, PromptType, SendData}; +use super::{AzureOpenAIClient, ExtraConfig, Model, ModelConfig, PromptType, SendData}; use crate::utils::PromptKind; @@ -20,6 +20,7 @@ pub struct AzureOpenAIConfig { openai_compatible_client!(AzureOpenAIClient); impl AzureOpenAIClient { + list_models_fn!(AzureOpenAIConfig); config_get_fn!(api_base, get_api_base); config_get_fn!(api_key, get_api_key); @@ -35,11 +36,6 @@ impl AzureOpenAIClient { ), ]; - pub fn list_models(local_config: &AzureOpenAIConfig) -> Vec { - let client_name = Self::name(local_config); - convert_models(client_name, &local_config.models) - } - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_base = self.get_api_base()?; let api_key = self.get_api_key()?; diff --git a/src/client/claude.rs b/src/client/claude.rs index d87e64a..24bffbf 100644 --- a/src/client/claude.rs +++ b/src/client/claude.rs @@ -1,6 +1,6 @@ use super::{ patch_system_message, ClaudeClient, Client, ExtraConfig, ImageUrl, MessageContent, - MessageContentPart, Model, PromptType, ReplyHandler, SendData, + MessageContentPart, Model, ModelConfig, PromptType, ReplyHandler, SendData, }; use crate::utils::PromptKind; @@ -15,17 +15,19 @@ use serde_json::{json, Value}; const API_BASE: &str = "https://api.anthropic.com/v1/messages"; -const MODELS: [(&str, usize, isize, &str); 3] = [ +const MODELS: [(&str, usize, &str); 3] = [ // https://docs.anthropic.com/claude/docs/models-overview - ("claude-3-opus-20240229", 200000, 4096, "text,vision"), - ("claude-3-sonnet-20240229", 200000, 4096, "text,vision"), - ("claude-3-haiku-20240307", 200000, 4096, "text,vision"), + ("claude-3-opus-20240229", 200000, "text,vision"), + ("claude-3-sonnet-20240229", 200000, "text,vision"), + ("claude-3-haiku-20240307", 200000, "text,vision"), ]; #[derive(Debug, Clone, Deserialize)] pub struct ClaudeConfig { pub name: Option, pub api_key: Option, + #[serde(default)] + pub models: Vec, pub extra: Option, } @@ -50,26 +52,12 @@ impl Client for ClaudeClient { } impl ClaudeClient { + list_models_fn!(ClaudeConfig, &MODELS); config_get_fn!(api_key, get_api_key); pub const PROMPTS: [PromptType<'static>; 1] = [("api_key", "API Key:", false, PromptKind::String)]; - pub fn list_models(local_config: &ClaudeConfig) -> Vec { - let client_name = Self::name(local_config); - MODELS - .into_iter() - .map( - |(name, max_input_tokens, max_output_tokens, capabilities)| { - Model::new(client_name, name) - .set_capabilities(capabilities.into()) - .set_max_input_tokens(Some(max_input_tokens)) - .set_max_output_tokens(Some(max_output_tokens)) - }, - ) - .collect() - } - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_key = self.get_api_key().ok(); diff --git a/src/client/cohere.rs b/src/client/cohere.rs index bfea105..445c145 100644 --- a/src/client/cohere.rs +++ b/src/client/cohere.rs @@ -1,6 +1,6 @@ use super::{ json_stream, message::*, patch_system_message, Client, CohereClient, ExtraConfig, Model, - PromptType, ReplyHandler, SendData, + ModelConfig, PromptType, ReplyHandler, SendData, }; use crate::utils::PromptKind; @@ -23,6 +23,8 @@ const MODELS: [(&str, usize, &str); 2] = [ pub struct CohereConfig { pub name: Option, pub api_key: Option, + #[serde(default)] + pub models: Vec, pub extra: Option, } @@ -47,23 +49,12 @@ impl Client for CohereClient { } impl CohereClient { + list_models_fn!(CohereConfig, &MODELS); config_get_fn!(api_key, get_api_key); pub const PROMPTS: [PromptType<'static>; 1] = [("api_key", "API Key:", false, PromptKind::String)]; - pub fn list_models(local_config: &CohereConfig) -> Vec { - let client_name = Self::name(local_config); - MODELS - .into_iter() - .map(|(name, max_input_tokens, capabilities)| { - Model::new(client_name, name) - .set_capabilities(capabilities.into()) - .set_max_input_tokens(Some(max_input_tokens)) - }) - .collect() - } - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_key = self.get_api_key().ok(); @@ -182,7 +173,7 @@ pub(crate) fn build_body(data: SendData, model: &Model) -> Result { "model": &model.name, "message": message, }); - + if let Some(max_tokens) = model.max_output_tokens { body["max_tokens"] = max_tokens.into(); } diff --git a/src/client/common.rs b/src/client/common.rs index 2206d21..b6ccce7 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -86,7 +86,7 @@ macro_rules! register_client { pub fn ensure_model_capabilities(client: &mut dyn Client, capabilities: $crate::client::ModelCapabilities) -> anyhow::Result<()> { if !client.model().capabilities.contains(capabilities) { - let models = client.models(); + let models = client.list_models(); if let Some(model) = models.into_iter().find(|v| v.capabilities.contains(capabilities)) { client.set_model(model); } else { @@ -137,7 +137,7 @@ macro_rules! client_common_fns { (&self.global_config, &self.config.extra) } - fn models(&self) -> Vec { + fn list_models(&self) -> Vec { Self::list_models(&self.config) } @@ -197,11 +197,31 @@ macro_rules! config_get_fn { }; } +#[macro_export] +macro_rules! list_models_fn { + ($config:ident) => { + pub fn list_models(local_config: &$config) -> Vec { + let client_name = Self::name(local_config); + Model::from_config(client_name, &local_config.models) + } + }; + ($config:ident, $models:expr) => { + pub fn list_models(local_config: &$config) -> Vec { + let client_name = Self::name(local_config); + if local_config.models.is_empty() { + Model::from_static(client_name, $models) + } else { + Model::from_config(client_name, &local_config.models) + } + } + }; +} + #[async_trait] pub trait Client: Sync + Send { fn config(&self) -> (&GlobalConfig, &Option); - fn models(&self) -> Vec; + fn list_models(&self) -> Vec; fn model(&self) -> &Model; diff --git a/src/client/ernie.rs b/src/client/ernie.rs index 0b10022..b0a0087 100644 --- a/src/client/ernie.rs +++ b/src/client/ernie.rs @@ -1,6 +1,6 @@ use super::{ - patch_system_message, Client, ErnieClient, ExtraConfig, Model, PromptType, ReplyHandler, - SendData, + patch_system_message, Client, ErnieClient, ExtraConfig, Model, ModelConfig, PromptType, + ReplyHandler, SendData, }; use crate::utils::PromptKind; @@ -73,6 +73,8 @@ pub struct ErnieConfig { pub name: Option, pub api_key: Option, pub secret_key: Option, + #[serde(default)] + pub models: Vec, pub extra: Option, } @@ -106,14 +108,18 @@ impl ErnieClient { pub fn list_models(local_config: &ErnieConfig) -> Vec { let client_name = Self::name(local_config); - MODELS - .into_iter() - .map(|(name, _, max_input_tokens, max_output_tokens)| { - Model::new(client_name, name) - .set_max_input_tokens(Some(max_input_tokens)) - .set_max_output_tokens(Some(max_output_tokens)) - }) // ERNIE tokenizer is different from cl100k_base - .collect() + if local_config.models.is_empty() { + MODELS + .into_iter() + .map(|(name, _, max_input_tokens, max_output_tokens)| { + Model::new(client_name, name) + .set_max_input_tokens(Some(max_input_tokens)) + .set_max_output_tokens(Some(max_output_tokens)) + }) // ERNIE tokenizer is different from cl100k_base + .collect() + } else { + Model::from_config(client_name, &local_config.models) + } } fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { diff --git a/src/client/gemini.rs b/src/client/gemini.rs index 5e42cd5..22f32c5 100644 --- a/src/client/gemini.rs +++ b/src/client/gemini.rs @@ -1,5 +1,7 @@ use super::vertexai::{build_body, send_message, send_message_streaming}; -use super::{Client, ExtraConfig, GeminiClient, Model, PromptType, ReplyHandler, SendData}; +use super::{ + Client, ExtraConfig, GeminiClient, Model, ModelConfig, PromptType, ReplyHandler, SendData, +}; use crate::utils::PromptKind; @@ -22,6 +24,8 @@ pub struct GeminiConfig { pub name: Option, pub api_key: Option, pub block_threshold: Option, + #[serde(default)] + pub models: Vec, pub extra: Option, } @@ -46,23 +50,12 @@ impl Client for GeminiClient { } impl GeminiClient { + list_models_fn!(GeminiConfig, &MODELS); config_get_fn!(api_key, get_api_key); pub const PROMPTS: [PromptType<'static>; 1] = [("api_key", "API Key:", true, PromptKind::String)]; - pub fn list_models(local_config: &GeminiConfig) -> Vec { - let client_name = Self::name(local_config); - MODELS - .into_iter() - .map(|(name, max_input_tokens, capabilities)| { - Model::new(client_name, name) - .set_capabilities(capabilities.into()) - .set_max_input_tokens(Some(max_input_tokens)) - }) - .collect() - } - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_key = self.get_api_key()?; diff --git a/src/client/mistral.rs b/src/client/mistral.rs index 4ff4787..e294e6a 100644 --- a/src/client/mistral.rs +++ b/src/client/mistral.rs @@ -1,5 +1,5 @@ use super::openai::openai_build_body; -use super::{ExtraConfig, MistralClient, Model, PromptType, SendData}; +use super::{ExtraConfig, MistralClient, Model, ModelConfig, PromptType, SendData}; use crate::utils::PromptKind; @@ -19,34 +19,23 @@ const MODELS: [(&str, usize, &str); 5] = [ ("mistral-large-latest", 32000, "text"), ]; - #[derive(Debug, Clone, Deserialize)] pub struct MistralConfig { pub name: Option, pub api_key: Option, + #[serde(default)] + pub models: Vec, pub extra: Option, } openai_compatible_client!(MistralClient); impl MistralClient { + list_models_fn!(MistralConfig, &MODELS); config_get_fn!(api_key, get_api_key); - pub const PROMPTS: [PromptType<'static>; 1] = [ - ("api_key", "API Key:", false, PromptKind::String), - ]; - - pub fn list_models(local_config: &MistralConfig) -> Vec { - let client_name = Self::name(local_config); - MODELS - .into_iter() - .map(|(name, max_input_tokens, capabilities)| { - Model::new(client_name, name) - .set_capabilities(capabilities.into()) - .set_max_input_tokens(Some(max_input_tokens)) - }) - .collect() - } + pub const PROMPTS: [PromptType<'static>; 1] = + [("api_key", "API Key:", false, PromptKind::String)]; fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_key = self.get_api_key().ok(); diff --git a/src/client/model.rs b/src/client/model.rs index 53d1834..b97c244 100644 --- a/src/client/model.rs +++ b/src/client/model.rs @@ -36,6 +36,30 @@ impl Model { } } + pub fn from_config(client_name: &str, models: &[ModelConfig]) -> Vec { + models + .iter() + .map(|v| { + Model::new(client_name, &v.name) + .set_capabilities(v.capabilities) + .set_max_input_tokens(v.max_input_tokens) + .set_max_output_tokens(v.max_output_tokens) + .set_extra_fields(v.extra_fields.clone()) + }) + .collect() + } + + pub fn from_static(client_name: &str, models: &[(&str, usize, &str)]) -> Vec { + models + .iter() + .map(|(name, max_input_tokens, capabilities)| { + Model::new(client_name, name) + .set_capabilities((*capabilities).into()) + .set_max_input_tokens(Some(*max_input_tokens)) + }) + .collect() + } + pub fn find(models: &[Self], value: &str) -> Option { let mut model = None; let (client_name, model_name) = match value.split_once(':') { @@ -156,19 +180,6 @@ impl Model { } } -pub fn convert_models(client_name: &str, models: &[ModelConfig]) -> Vec { - models - .iter() - .map(|v| { - Model::new(client_name, &v.name) - .set_capabilities(v.capabilities) - .set_max_input_tokens(v.max_input_tokens) - .set_max_output_tokens(v.max_output_tokens) - .set_extra_fields(v.extra_fields.clone()) - }) - .collect() -} - #[derive(Debug, Clone, Deserialize)] pub struct ModelConfig { pub name: String, diff --git a/src/client/ollama.rs b/src/client/ollama.rs index 1692942..de658c1 100644 --- a/src/client/ollama.rs +++ b/src/client/ollama.rs @@ -1,6 +1,6 @@ use super::{ - convert_models, message::*, patch_system_message, Client, ExtraConfig, Model, ModelConfig, - OllamaClient, PromptType, ReplyHandler, SendData, + message::*, patch_system_message, Client, ExtraConfig, Model, ModelConfig, OllamaClient, + PromptType, ReplyHandler, SendData, }; use crate::utils::PromptKind; @@ -43,6 +43,7 @@ impl Client for OllamaClient { } impl OllamaClient { + list_models_fn!(OllamaConfig); config_get_fn!(api_key, get_api_key); pub const PROMPTS: [PromptType<'static>; 4] = [ @@ -57,11 +58,6 @@ impl OllamaClient { ), ]; - pub fn list_models(local_config: &OllamaConfig) -> Vec { - let client_name = Self::name(local_config); - convert_models(client_name, &local_config.models) - } - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_key = self.get_api_key().ok(); diff --git a/src/client/openai.rs b/src/client/openai.rs index 535e739..7925ec6 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -1,4 +1,4 @@ -use super::{ExtraConfig, Model, OpenAIClient, PromptType, ReplyHandler, SendData}; +use super::{ExtraConfig, Model, ModelConfig, OpenAIClient, PromptType, ReplyHandler, SendData}; use crate::utils::PromptKind; @@ -30,30 +30,21 @@ pub struct OpenAIConfig { pub api_key: Option, pub api_base: Option, pub organization_id: Option, + #[serde(default)] + pub models: Vec, pub extra: Option, } openai_compatible_client!(OpenAIClient); impl OpenAIClient { + list_models_fn!(OpenAIConfig, &MODELS); config_get_fn!(api_key, get_api_key); config_get_fn!(api_base, get_api_base); pub const PROMPTS: [PromptType<'static>; 1] = [("api_key", "API Key:", true, PromptKind::String)]; - pub fn list_models(local_config: &OpenAIConfig) -> Vec { - let client_name = Self::name(local_config); - MODELS - .into_iter() - .map(|(name, max_input_tokens, capabilities)| { - Model::new(client_name, name) - .set_capabilities(capabilities.into()) - .set_max_input_tokens(Some(max_input_tokens)) - }) - .collect() - } - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_key = self.get_api_key()?; let api_base = self.get_api_base().unwrap_or_else(|_| API_BASE.to_string()); diff --git a/src/client/openai_compatible.rs b/src/client/openai_compatible.rs index f25d931..b4cb5f4 100644 --- a/src/client/openai_compatible.rs +++ b/src/client/openai_compatible.rs @@ -1,5 +1,5 @@ use super::openai::openai_build_body; -use super::{convert_models, ExtraConfig, Model, ModelConfig, OpenAICompatibleClient, PromptType, SendData}; +use super::{ExtraConfig, Model, ModelConfig, OpenAICompatibleClient, PromptType, SendData}; use crate::utils::PromptKind; @@ -21,6 +21,7 @@ pub struct OpenAICompatibleConfig { openai_compatible_client!(OpenAICompatibleClient); impl OpenAICompatibleClient { + list_models_fn!(OpenAICompatibleConfig); config_get_fn!(api_key, get_api_key); pub const PROMPTS: [PromptType<'static>; 5] = [ @@ -36,11 +37,6 @@ impl OpenAICompatibleClient { ), ]; - pub fn list_models(local_config: &OpenAICompatibleConfig) -> Vec { - let client_name = Self::name(local_config); - convert_models(client_name, &local_config.models) - } - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_key = self.get_api_key().ok(); diff --git a/src/client/qianwen.rs b/src/client/qianwen.rs index 6225b96..a3ee988 100644 --- a/src/client/qianwen.rs +++ b/src/client/qianwen.rs @@ -1,5 +1,6 @@ use super::{ - message::*, Client, ExtraConfig, Model, PromptType, QianwenClient, ReplyHandler, SendData, + message::*, Client, ExtraConfig, Model, ModelConfig, PromptType, QianwenClient, ReplyHandler, + SendData, }; use crate::utils::{sha256sum, PromptKind}; @@ -38,6 +39,8 @@ const MODELS: [(&str, usize, &str); 6] = [ pub struct QianwenConfig { pub name: Option, pub api_key: Option, + #[serde(default)] + pub models: Vec, pub extra: Option, } @@ -70,23 +73,12 @@ impl Client for QianwenClient { } impl QianwenClient { + list_models_fn!(QianwenConfig, &MODELS); config_get_fn!(api_key, get_api_key); pub const PROMPTS: [PromptType<'static>; 1] = [("api_key", "API Key:", true, PromptKind::String)]; - pub fn list_models(local_config: &QianwenConfig) -> Vec { - let client_name = Self::name(local_config); - MODELS - .into_iter() - .map(|(name, max_input_tokens, capabilities)| { - Model::new(client_name, name) - .set_capabilities(capabilities.into()) - .set_max_input_tokens(Some(max_input_tokens)) - }) - .collect() - } - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_key = self.get_api_key()?; diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index eceeb2c..66bb098 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -1,6 +1,6 @@ use super::{ - json_stream, message::*, patch_system_message, Client, ExtraConfig, Model, PromptType, - ReplyHandler, SendData, VertexAIClient, + json_stream, message::*, patch_system_message, Client, ExtraConfig, Model, ModelConfig, + PromptType, ReplyHandler, SendData, VertexAIClient, }; use crate::utils::PromptKind; @@ -28,6 +28,8 @@ pub struct VertexAIConfig { pub api_base: Option, pub adc_file: Option, pub block_threshold: Option, + #[serde(default)] + pub models: Vec, pub extra: Option, } @@ -54,23 +56,12 @@ impl Client for VertexAIClient { } impl VertexAIClient { + list_models_fn!(VertexAIConfig, &MODELS); config_get_fn!(api_base, get_api_base); pub const PROMPTS: [PromptType<'static>; 1] = [("api_base", "API Base:", true, PromptKind::String)]; - pub fn list_models(local_config: &VertexAIConfig) -> Vec { - let client_name = Self::name(local_config); - MODELS - .into_iter() - .map(|(name, max_input_tokens, capabilities)| { - Model::new(client_name, name) - .set_capabilities(capabilities.into()) - .set_max_input_tokens(Some(max_input_tokens)) - }) - .collect() - } - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_base = self.get_api_base()?;