feat: builtin models can be overwrited by models config (#429)

pull/430/head
sigoden 4 weeks ago committed by GitHub
parent d1aafa1115
commit 9c6c9f10a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,5 +1,5 @@
use super::openai::openai_build_body;
use super::{convert_models, AzureOpenAIClient, ExtraConfig, Model, ModelConfig, PromptType, SendData};
use super::{AzureOpenAIClient, ExtraConfig, Model, ModelConfig, PromptType, SendData};
use crate::utils::PromptKind;
@ -20,6 +20,7 @@ pub struct AzureOpenAIConfig {
openai_compatible_client!(AzureOpenAIClient);
impl AzureOpenAIClient {
list_models_fn!(AzureOpenAIConfig);
config_get_fn!(api_base, get_api_base);
config_get_fn!(api_key, get_api_key);
@ -35,11 +36,6 @@ impl AzureOpenAIClient {
),
];
pub fn list_models(local_config: &AzureOpenAIConfig) -> Vec<Model> {
let client_name = Self::name(local_config);
convert_models(client_name, &local_config.models)
}
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_base = self.get_api_base()?;
let api_key = self.get_api_key()?;

@ -1,6 +1,6 @@
use super::{
patch_system_message, ClaudeClient, Client, ExtraConfig, ImageUrl, MessageContent,
MessageContentPart, Model, PromptType, ReplyHandler, SendData,
MessageContentPart, Model, ModelConfig, PromptType, ReplyHandler, SendData,
};
use crate::utils::PromptKind;
@ -15,17 +15,19 @@ use serde_json::{json, Value};
const API_BASE: &str = "https://api.anthropic.com/v1/messages";
const MODELS: [(&str, usize, isize, &str); 3] = [
const MODELS: [(&str, usize, &str); 3] = [
// https://docs.anthropic.com/claude/docs/models-overview
("claude-3-opus-20240229", 200000, 4096, "text,vision"),
("claude-3-sonnet-20240229", 200000, 4096, "text,vision"),
("claude-3-haiku-20240307", 200000, 4096, "text,vision"),
("claude-3-opus-20240229", 200000, "text,vision"),
("claude-3-sonnet-20240229", 200000, "text,vision"),
("claude-3-haiku-20240307", 200000, "text,vision"),
];
#[derive(Debug, Clone, Deserialize)]
pub struct ClaudeConfig {
pub name: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
}
@ -50,26 +52,12 @@ impl Client for ClaudeClient {
}
impl ClaudeClient {
list_models_fn!(ClaudeConfig, &MODELS);
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 1] =
[("api_key", "API Key:", false, PromptKind::String)];
pub fn list_models(local_config: &ClaudeConfig) -> Vec<Model> {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(
|(name, max_input_tokens, max_output_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_input_tokens(Some(max_input_tokens))
.set_max_output_tokens(Some(max_output_tokens))
},
)
.collect()
}
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();

@ -1,6 +1,6 @@
use super::{
json_stream, message::*, patch_system_message, Client, CohereClient, ExtraConfig, Model,
PromptType, ReplyHandler, SendData,
ModelConfig, PromptType, ReplyHandler, SendData,
};
use crate::utils::PromptKind;
@ -23,6 +23,8 @@ const MODELS: [(&str, usize, &str); 2] = [
pub struct CohereConfig {
pub name: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
}
@ -47,23 +49,12 @@ impl Client for CohereClient {
}
impl CohereClient {
list_models_fn!(CohereConfig, &MODELS);
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 1] =
[("api_key", "API Key:", false, PromptKind::String)];
pub fn list_models(local_config: &CohereConfig) -> Vec<Model> {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_input_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_input_tokens(Some(max_input_tokens))
})
.collect()
}
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();
@ -182,7 +173,7 @@ pub(crate) fn build_body(data: SendData, model: &Model) -> Result<Value> {
"model": &model.name,
"message": message,
});
if let Some(max_tokens) = model.max_output_tokens {
body["max_tokens"] = max_tokens.into();
}

@ -86,7 +86,7 @@ macro_rules! register_client {
pub fn ensure_model_capabilities(client: &mut dyn Client, capabilities: $crate::client::ModelCapabilities) -> anyhow::Result<()> {
if !client.model().capabilities.contains(capabilities) {
let models = client.models();
let models = client.list_models();
if let Some(model) = models.into_iter().find(|v| v.capabilities.contains(capabilities)) {
client.set_model(model);
} else {
@ -137,7 +137,7 @@ macro_rules! client_common_fns {
(&self.global_config, &self.config.extra)
}
fn models(&self) -> Vec<Model> {
fn list_models(&self) -> Vec<Model> {
Self::list_models(&self.config)
}
@ -197,11 +197,31 @@ macro_rules! config_get_fn {
};
}
#[macro_export]
macro_rules! list_models_fn {
($config:ident) => {
pub fn list_models(local_config: &$config) -> Vec<Model> {
let client_name = Self::name(local_config);
Model::from_config(client_name, &local_config.models)
}
};
($config:ident, $models:expr) => {
pub fn list_models(local_config: &$config) -> Vec<Model> {
let client_name = Self::name(local_config);
if local_config.models.is_empty() {
Model::from_static(client_name, $models)
} else {
Model::from_config(client_name, &local_config.models)
}
}
};
}
#[async_trait]
pub trait Client: Sync + Send {
fn config(&self) -> (&GlobalConfig, &Option<ExtraConfig>);
fn models(&self) -> Vec<Model>;
fn list_models(&self) -> Vec<Model>;
fn model(&self) -> &Model;

@ -1,6 +1,6 @@
use super::{
patch_system_message, Client, ErnieClient, ExtraConfig, Model, PromptType, ReplyHandler,
SendData,
patch_system_message, Client, ErnieClient, ExtraConfig, Model, ModelConfig, PromptType,
ReplyHandler, SendData,
};
use crate::utils::PromptKind;
@ -73,6 +73,8 @@ pub struct ErnieConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub secret_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
}
@ -106,14 +108,18 @@ impl ErnieClient {
pub fn list_models(local_config: &ErnieConfig) -> Vec<Model> {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, _, max_input_tokens, max_output_tokens)| {
Model::new(client_name, name)
.set_max_input_tokens(Some(max_input_tokens))
.set_max_output_tokens(Some(max_output_tokens))
}) // ERNIE tokenizer is different from cl100k_base
.collect()
if local_config.models.is_empty() {
MODELS
.into_iter()
.map(|(name, _, max_input_tokens, max_output_tokens)| {
Model::new(client_name, name)
.set_max_input_tokens(Some(max_input_tokens))
.set_max_output_tokens(Some(max_output_tokens))
}) // ERNIE tokenizer is different from cl100k_base
.collect()
} else {
Model::from_config(client_name, &local_config.models)
}
}
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {

@ -1,5 +1,7 @@
use super::vertexai::{build_body, send_message, send_message_streaming};
use super::{Client, ExtraConfig, GeminiClient, Model, PromptType, ReplyHandler, SendData};
use super::{
Client, ExtraConfig, GeminiClient, Model, ModelConfig, PromptType, ReplyHandler, SendData,
};
use crate::utils::PromptKind;
@ -22,6 +24,8 @@ pub struct GeminiConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub block_threshold: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
}
@ -46,23 +50,12 @@ impl Client for GeminiClient {
}
impl GeminiClient {
list_models_fn!(GeminiConfig, &MODELS);
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
pub fn list_models(local_config: &GeminiConfig) -> Vec<Model> {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_input_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_input_tokens(Some(max_input_tokens))
})
.collect()
}
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key()?;

@ -1,5 +1,5 @@
use super::openai::openai_build_body;
use super::{ExtraConfig, MistralClient, Model, PromptType, SendData};
use super::{ExtraConfig, MistralClient, Model, ModelConfig, PromptType, SendData};
use crate::utils::PromptKind;
@ -19,34 +19,23 @@ const MODELS: [(&str, usize, &str); 5] = [
("mistral-large-latest", 32000, "text"),
];
#[derive(Debug, Clone, Deserialize)]
pub struct MistralConfig {
pub name: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
}
openai_compatible_client!(MistralClient);
impl MistralClient {
list_models_fn!(MistralConfig, &MODELS);
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 1] = [
("api_key", "API Key:", false, PromptKind::String),
];
pub fn list_models(local_config: &MistralConfig) -> Vec<Model> {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_input_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_input_tokens(Some(max_input_tokens))
})
.collect()
}
pub const PROMPTS: [PromptType<'static>; 1] =
[("api_key", "API Key:", false, PromptKind::String)];
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();

@ -36,6 +36,30 @@ impl Model {
}
}
pub fn from_config(client_name: &str, models: &[ModelConfig]) -> Vec<Self> {
models
.iter()
.map(|v| {
Model::new(client_name, &v.name)
.set_capabilities(v.capabilities)
.set_max_input_tokens(v.max_input_tokens)
.set_max_output_tokens(v.max_output_tokens)
.set_extra_fields(v.extra_fields.clone())
})
.collect()
}
pub fn from_static(client_name: &str, models: &[(&str, usize, &str)]) -> Vec<Self> {
models
.iter()
.map(|(name, max_input_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities((*capabilities).into())
.set_max_input_tokens(Some(*max_input_tokens))
})
.collect()
}
pub fn find(models: &[Self], value: &str) -> Option<Self> {
let mut model = None;
let (client_name, model_name) = match value.split_once(':') {
@ -156,19 +180,6 @@ impl Model {
}
}
pub fn convert_models(client_name: &str, models: &[ModelConfig]) -> Vec<Model> {
models
.iter()
.map(|v| {
Model::new(client_name, &v.name)
.set_capabilities(v.capabilities)
.set_max_input_tokens(v.max_input_tokens)
.set_max_output_tokens(v.max_output_tokens)
.set_extra_fields(v.extra_fields.clone())
})
.collect()
}
#[derive(Debug, Clone, Deserialize)]
pub struct ModelConfig {
pub name: String,

@ -1,6 +1,6 @@
use super::{
convert_models, message::*, patch_system_message, Client, ExtraConfig, Model, ModelConfig,
OllamaClient, PromptType, ReplyHandler, SendData,
message::*, patch_system_message, Client, ExtraConfig, Model, ModelConfig, OllamaClient,
PromptType, ReplyHandler, SendData,
};
use crate::utils::PromptKind;
@ -43,6 +43,7 @@ impl Client for OllamaClient {
}
impl OllamaClient {
list_models_fn!(OllamaConfig);
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 4] = [
@ -57,11 +58,6 @@ impl OllamaClient {
),
];
pub fn list_models(local_config: &OllamaConfig) -> Vec<Model> {
let client_name = Self::name(local_config);
convert_models(client_name, &local_config.models)
}
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();

@ -1,4 +1,4 @@
use super::{ExtraConfig, Model, OpenAIClient, PromptType, ReplyHandler, SendData};
use super::{ExtraConfig, Model, ModelConfig, OpenAIClient, PromptType, ReplyHandler, SendData};
use crate::utils::PromptKind;
@ -30,30 +30,21 @@ pub struct OpenAIConfig {
pub api_key: Option<String>,
pub api_base: Option<String>,
pub organization_id: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
}
openai_compatible_client!(OpenAIClient);
impl OpenAIClient {
list_models_fn!(OpenAIConfig, &MODELS);
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);
pub const PROMPTS: [PromptType<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
pub fn list_models(local_config: &OpenAIConfig) -> Vec<Model> {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_input_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_input_tokens(Some(max_input_tokens))
})
.collect()
}
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key()?;
let api_base = self.get_api_base().unwrap_or_else(|_| API_BASE.to_string());

@ -1,5 +1,5 @@
use super::openai::openai_build_body;
use super::{convert_models, ExtraConfig, Model, ModelConfig, OpenAICompatibleClient, PromptType, SendData};
use super::{ExtraConfig, Model, ModelConfig, OpenAICompatibleClient, PromptType, SendData};
use crate::utils::PromptKind;
@ -21,6 +21,7 @@ pub struct OpenAICompatibleConfig {
openai_compatible_client!(OpenAICompatibleClient);
impl OpenAICompatibleClient {
list_models_fn!(OpenAICompatibleConfig);
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 5] = [
@ -36,11 +37,6 @@ impl OpenAICompatibleClient {
),
];
pub fn list_models(local_config: &OpenAICompatibleConfig) -> Vec<Model> {
let client_name = Self::name(local_config);
convert_models(client_name, &local_config.models)
}
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();

@ -1,5 +1,6 @@
use super::{
message::*, Client, ExtraConfig, Model, PromptType, QianwenClient, ReplyHandler, SendData,
message::*, Client, ExtraConfig, Model, ModelConfig, PromptType, QianwenClient, ReplyHandler,
SendData,
};
use crate::utils::{sha256sum, PromptKind};
@ -38,6 +39,8 @@ const MODELS: [(&str, usize, &str); 6] = [
pub struct QianwenConfig {
pub name: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
}
@ -70,23 +73,12 @@ impl Client for QianwenClient {
}
impl QianwenClient {
list_models_fn!(QianwenConfig, &MODELS);
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
pub fn list_models(local_config: &QianwenConfig) -> Vec<Model> {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_input_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_input_tokens(Some(max_input_tokens))
})
.collect()
}
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key()?;

@ -1,6 +1,6 @@
use super::{
json_stream, message::*, patch_system_message, Client, ExtraConfig, Model, PromptType,
ReplyHandler, SendData, VertexAIClient,
json_stream, message::*, patch_system_message, Client, ExtraConfig, Model, ModelConfig,
PromptType, ReplyHandler, SendData, VertexAIClient,
};
use crate::utils::PromptKind;
@ -28,6 +28,8 @@ pub struct VertexAIConfig {
pub api_base: Option<String>,
pub adc_file: Option<String>,
pub block_threshold: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
}
@ -54,23 +56,12 @@ impl Client for VertexAIClient {
}
impl VertexAIClient {
list_models_fn!(VertexAIConfig, &MODELS);
config_get_fn!(api_base, get_api_base);
pub const PROMPTS: [PromptType<'static>; 1] =
[("api_base", "API Base:", true, PromptKind::String)];
pub fn list_models(local_config: &VertexAIConfig) -> Vec<Model> {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_input_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_input_tokens(Some(max_input_tokens))
})
.collect()
}
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_base = self.get_api_base()?;

Loading…
Cancel
Save