feat: customize model's max_output_tokens (#428)

pull/429/head
sigoden 1 month ago committed by GitHub
parent 1cc89eff51
commit d1aafa1115
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -28,7 +28,14 @@ right_prompt: '{color.purple}{?session {?consume_tokens {consume_tokens}({consum
clients:
# All clients have the following configuration:
# - type: xxxx
# name: nova # Only use it to distinguish clients with the same client type. Optional
# name: xxxx # Only use it to distinguish clients with the same client type. Optional
# models:
# - name: xxxx # The model name
# max_input_tokens: 100000 # Optional field
# max_output_tokens: 4096 # Optional field
# capabilities: text,vision # Optional field, supported capabilities: text, vision
# extra_fields: # Optional field, set custom parameters, will merge with the body json
# key: value
# extra:
# proxy: socks5://127.0.0.1:1080 # Specify https/socks5 proxy server. Note HTTPS_PROXY/ALL_PROXY also works.
# connect_timeout: 10 # Set a timeout in seconds for connect to server

@ -1,5 +1,5 @@
use super::openai::openai_build_body;
use super::{AzureOpenAIClient, ExtraConfig, Model, ModelConfig, PromptType, SendData};
use super::{convert_models, AzureOpenAIClient, ExtraConfig, Model, ModelConfig, PromptType, SendData};
use crate::utils::PromptKind;
@ -37,23 +37,14 @@ impl AzureOpenAIClient {
pub fn list_models(local_config: &AzureOpenAIConfig) -> Vec<Model> {
let client_name = Self::name(local_config);
local_config
.models
.iter()
.map(|v| {
Model::new(client_name, &v.name)
.set_max_input_tokens(v.max_input_tokens)
.set_capabilities(v.capabilities)
})
.collect()
convert_models(client_name, &local_config.models)
}
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_base = self.get_api_base()?;
let api_key = self.get_api_key()?;
let mut body = openai_build_body(data, self.model.name.clone());
let mut body = openai_build_body(data, &self.model);
self.model.merge_extra_fields(&mut body);
let url = format!(

@ -15,11 +15,11 @@ use serde_json::{json, Value};
const API_BASE: &str = "https://api.anthropic.com/v1/messages";
const MODELS: [(&str, usize, &str); 3] = [
const MODELS: [(&str, usize, isize, &str); 3] = [
// https://docs.anthropic.com/claude/docs/models-overview
("claude-3-opus-20240229", 200000, "text,vision"),
("claude-3-sonnet-20240229", 200000, "text,vision"),
("claude-3-haiku-20240307", 200000, "text,vision"),
("claude-3-opus-20240229", 200000, 4096, "text,vision"),
("claude-3-sonnet-20240229", 200000, 4096, "text,vision"),
("claude-3-haiku-20240307", 200000, 4096, "text,vision"),
];
#[derive(Debug, Clone, Deserialize)]
@ -59,18 +59,21 @@ impl ClaudeClient {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_input_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_input_tokens(Some(max_input_tokens))
})
.map(
|(name, max_input_tokens, max_output_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_input_tokens(Some(max_input_tokens))
.set_max_output_tokens(Some(max_output_tokens))
},
)
.collect()
}
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();
let body = build_body(data, self.model.name.clone())?;
let body = build_body(data, &self.model)?;
let url = API_BASE;
@ -143,7 +146,7 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHand
Ok(())
}
fn build_body(data: SendData, model: String) -> Result<Value> {
fn build_body(data: SendData, model: &Model) -> Result<Value> {
let SendData {
mut messages,
temperature,
@ -197,9 +200,11 @@ fn build_body(data: SendData, model: String) -> Result<Value> {
);
}
let max_tokens = model.max_output_tokens.unwrap_or(4096);
let mut body = json!({
"model": model,
"max_tokens": 4096,
"model": &model.name,
"max_tokens": max_tokens,
"messages": messages,
});

@ -67,7 +67,7 @@ impl CohereClient {
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();
let body = build_body(data, self.model.name.clone())?;
let body = build_body(data, &self.model)?;
let url = API_URL;
@ -131,7 +131,7 @@ fn check_error(data: &Value) -> Result<()> {
}
}
pub(crate) fn build_body(data: SendData, model: String) -> Result<Value> {
pub(crate) fn build_body(data: SendData, model: &Model) -> Result<Value> {
let SendData {
mut messages,
temperature,
@ -179,9 +179,13 @@ pub(crate) fn build_body(data: SendData, model: String) -> Result<Value> {
let message = message["message"].as_str().unwrap_or_default();
let mut body = json!({
"model": model,
"model": &model.name,
"message": message,
});
if let Some(max_tokens) = model.max_output_tokens {
body["max_tokens"] = max_tokens.into();
}
if !messages.is_empty() {
body["chat_history"] = messages.into();

@ -18,27 +18,50 @@ use std::{env, sync::Mutex};
const API_BASE: &str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1";
const ACCESS_TOKEN_URL: &str = "https://aip.baidubce.com/oauth/2.0/token";
const MODELS: [(&str, usize, &str); 7] = [
const MODELS: [(&str, &str, usize, isize); 7] = [
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
("ernie-4.0-8k", 5120, "/wenxinworkshop/chat/completions_pro"),
(
"ernie-3.5-8k",
"ernie-4.0-8k",
"/wenxinworkshop/chat/completions_pro",
5120,
2048,
),
(
"ernie-3.5-8k",
"/wenxinworkshop/chat/ernie-3.5-8k-0205",
5120,
2048,
),
(
"ernie-3.5-4k",
2048,
"/wenxinworkshop/chat/ernie-3.5-4k-0205",
2048,
2048,
),
(
"ernie-speed-8k",
"/wenxinworkshop/chat/ernie_speed",
7168,
2048,
),
("ernie-speed-8k", 7168, "/wenxinworkshop/chat/ernie_speed"),
(
"ernie-speed-128k",
124000,
"/wenxinworkshop/chat/ernie-speed-128k",
124000,
4096,
),
(
"ernie-lite-8k",
"/wenxinworkshop/chat/ernie-lite-8k",
7168,
2048,
),
(
"ernie-tiny-8k",
"/wenxinworkshop/chat/ernie-tiny-8k",
7168,
2048,
),
("ernie-lite-8k", 7168, "/wenxinworkshop/chat/ernie-lite-8k"),
("ernie-tiny-8k", 7168, "/wenxinworkshop/chat/ernie-tiny-8k"),
];
lazy_static! {
@ -85,17 +108,21 @@ impl ErnieClient {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, _, _)| Model::new(client_name, name)) // ERNIE tokenizer is different from cl100k_base
.map(|(name, _, max_input_tokens, max_output_tokens)| {
Model::new(client_name, name)
.set_max_input_tokens(Some(max_input_tokens))
.set_max_output_tokens(Some(max_output_tokens))
}) // ERNIE tokenizer is different from cl100k_base
.collect()
}
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let body = build_body(data, self.model.name.clone());
let body = build_body(data, &self.model);
let model = self.model.name.clone();
let (_, _, chat_endpoint) = MODELS
let model = &self.model.name;
let (_, chat_endpoint, _, _) = MODELS
.iter()
.find(|(v, _, _)| v == &model)
.find(|(v, _, _, _)| v == model)
.ok_or_else(|| anyhow!("Miss Model '{}'", self.model.id()))?;
let access_token = ACCESS_TOKEN
@ -207,7 +234,7 @@ fn check_error(data: &Value) -> Result<()> {
Ok(())
}
fn build_body(data: SendData, _model: String) -> Value {
fn build_body(data: SendData, model: &Model) -> Value {
let SendData {
mut messages,
temperature,
@ -223,6 +250,11 @@ fn build_body(data: SendData, _model: String) -> Value {
if let Some(temperature) = temperature {
body["temperature"] = temperature.into();
}
if let Some(max_output_tokens) = model.max_output_tokens {
body["max_output_tokens"] = max_output_tokens.into();
}
if stream {
body["stream"] = true.into();
}

@ -73,9 +73,9 @@ impl GeminiClient {
let block_threshold = self.config.block_threshold.clone();
let body = build_body(data, self.model.name.clone(), block_threshold)?;
let body = build_body(data, &self.model, block_threshold)?;
let model = self.model.name.clone();
let model = &self.model.name;
let url = format!("{API_BASE}{}:{}?key={}", model, func, api_key);

@ -51,7 +51,7 @@ impl MistralClient {
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();
let body = openai_build_body(data, self.model.name.clone());
let body = openai_build_body(data, &self.model);
let url = API_URL;

@ -13,6 +13,7 @@ pub struct Model {
pub client_name: String,
pub name: String,
pub max_input_tokens: Option<usize>,
pub max_output_tokens: Option<isize>,
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
pub capabilities: ModelCapabilities,
}
@ -30,6 +31,7 @@ impl Model {
name: name.into(),
extra_fields: None,
max_input_tokens: None,
max_output_tokens: None,
capabilities: ModelCapabilities::Text,
}
}
@ -90,6 +92,14 @@ impl Model {
self
}
pub fn set_max_output_tokens(mut self, max_output_tokens: Option<isize>) -> Self {
match max_output_tokens {
None | Some(0) => self.max_output_tokens = None,
_ => self.max_output_tokens = max_output_tokens,
}
self
}
pub fn messages_tokens(&self, messages: &[Message]) -> usize {
messages
.iter()
@ -127,19 +137,43 @@ impl Model {
pub fn merge_extra_fields(&self, body: &mut serde_json::Value) {
if let (Some(body), Some(extra_fields)) = (body.as_object_mut(), &self.extra_fields) {
for (k, v) in extra_fields {
if !body.contains_key(k) {
body.insert(k.clone(), v.clone());
for (key, extra_field) in extra_fields {
if body.contains_key(key) {
if let (Some(sub_body), Some(extra_field)) =
(body[key].as_object_mut(), extra_field.as_object())
{
for (subkey, sub_field) in extra_field {
if !sub_body.contains_key(subkey) {
sub_body.insert(subkey.clone(), sub_field.clone());
}
}
}
} else {
body.insert(key.clone(), extra_field.clone());
}
}
}
}
}
pub fn convert_models(client_name: &str, models: &[ModelConfig]) -> Vec<Model> {
models
.iter()
.map(|v| {
Model::new(client_name, &v.name)
.set_capabilities(v.capabilities)
.set_max_input_tokens(v.max_input_tokens)
.set_max_output_tokens(v.max_output_tokens)
.set_extra_fields(v.extra_fields.clone())
})
.collect()
}
#[derive(Debug, Clone, Deserialize)]
pub struct ModelConfig {
pub name: String,
pub max_input_tokens: Option<usize>,
pub max_output_tokens: Option<isize>,
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
#[serde(deserialize_with = "deserialize_capabilities")]
#[serde(default = "default_capabilities")]

@ -1,6 +1,6 @@
use super::{
message::*, patch_system_message, Client, ExtraConfig, Model, ModelConfig, OllamaClient,
PromptType, ReplyHandler, SendData,
convert_models, message::*, patch_system_message, Client, ExtraConfig, Model, ModelConfig,
OllamaClient, PromptType, ReplyHandler, SendData,
};
use crate::utils::PromptKind;
@ -59,23 +59,13 @@ impl OllamaClient {
pub fn list_models(local_config: &OllamaConfig) -> Vec<Model> {
let client_name = Self::name(local_config);
local_config
.models
.iter()
.map(|v| {
Model::new(client_name, &v.name)
.set_capabilities(v.capabilities)
.set_max_input_tokens(v.max_input_tokens)
.set_extra_fields(v.extra_fields.clone())
})
.collect()
convert_models(client_name, &local_config.models)
}
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();
let mut body = build_body(data, self.model.name.clone())?;
let mut body = build_body(data, &self.model)?;
self.model.merge_extra_fields(&mut body);
let chat_endpoint = self.config.chat_endpoint.as_deref().unwrap_or("/api/chat");
@ -133,7 +123,7 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHand
Ok(())
}
fn build_body(data: SendData, model: String) -> Result<Value> {
fn build_body(data: SendData, model: &Model) -> Result<Value> {
let SendData {
mut messages,
temperature,
@ -189,15 +179,18 @@ fn build_body(data: SendData, model: String) -> Result<Value> {
}
let mut body = json!({
"model": model,
"model": &model.name,
"messages": messages,
"stream": stream,
"options": {},
});
if let Some(num_predict) = model.max_output_tokens {
body["options"]["num_predict"] = num_predict.into();
}
if let Some(temperature) = temperature {
body["options"] = json!({
"temperature": temperature,
});
body["options"]["temperature"] = temperature.into();
}
Ok(body)

@ -58,7 +58,7 @@ impl OpenAIClient {
let api_key = self.get_api_key()?;
let api_base = self.get_api_base().unwrap_or_else(|_| API_BASE.to_string());
let body = openai_build_body(data, self.model.name.clone());
let body = openai_build_body(data, &self.model);
let url = format!("{api_base}/chat/completions");
@ -139,7 +139,7 @@ pub async fn openai_send_message_streaming(
Ok(())
}
pub fn openai_build_body(data: SendData, model: String) -> Value {
pub fn openai_build_body(data: SendData, model: &Model) -> Value {
let SendData {
messages,
temperature,
@ -147,15 +147,16 @@ pub fn openai_build_body(data: SendData, model: String) -> Value {
} = data;
let mut body = json!({
"model": model,
"model": &model.name,
"messages": messages,
});
// The default max_tokens of gpt-4-vision-preview is only 16, we need to make it larger
if model == "gpt-4-vision-preview" {
if let Some(max_tokens) = model.max_output_tokens {
body["max_tokens"] = json!(max_tokens);
} else if model.name == "gpt-4-vision-preview" {
// The default max_tokens of gpt-4-vision-preview is only 16, we need to make it larger
body["max_tokens"] = json!(4096);
}
if let Some(v) = temperature {
body["temperature"] = v.into();
}

@ -1,5 +1,5 @@
use super::openai::openai_build_body;
use super::{ExtraConfig, Model, ModelConfig, OpenAICompatibleClient, PromptType, SendData};
use super::{convert_models, ExtraConfig, Model, ModelConfig, OpenAICompatibleClient, PromptType, SendData};
use crate::utils::PromptKind;
@ -38,23 +38,13 @@ impl OpenAICompatibleClient {
pub fn list_models(local_config: &OpenAICompatibleConfig) -> Vec<Model> {
let client_name = Self::name(local_config);
local_config
.models
.iter()
.map(|v| {
Model::new(client_name, &v.name)
.set_capabilities(v.capabilities)
.set_max_input_tokens(v.max_input_tokens)
.set_extra_fields(v.extra_fields.clone())
})
.collect()
convert_models(client_name, &local_config.models)
}
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();
let mut body = openai_build_body(data, self.model.name.clone());
let mut body = openai_build_body(data, &self.model);
self.model.merge_extra_fields(&mut body);
let chat_endpoint = self

@ -97,7 +97,7 @@ impl QianwenClient {
true => API_URL_VL,
false => API_URL,
};
let (body, has_upload) = build_body(data, self.model.name.clone(), is_vl)?;
let (body, has_upload) = build_body(data, &self.model, is_vl)?;
debug!("Qianwen Request: {url} {body}");
@ -180,7 +180,7 @@ fn check_error(data: &Value) -> Result<()> {
Ok(())
}
fn build_body(data: SendData, model: String, is_vl: bool) -> Result<(Value, bool)> {
fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool)> {
let SendData {
messages,
temperature,
@ -233,6 +233,10 @@ fn build_body(data: SendData, model: String, is_vl: bool) -> Result<(Value, bool
parameters["incremental_output"] = true.into();
}
if let Some(max_tokens) = model.max_output_tokens {
parameters["max_tokens"] = max_tokens.into();
}
if let Some(v) = temperature {
parameters["temperature"] = v.into();
}
@ -240,7 +244,7 @@ fn build_body(data: SendData, model: String, is_vl: bool) -> Result<(Value, bool
};
let body = json!({
"model": model,
"model": &model.name,
"input": input,
"parameters": parameters
});

@ -81,9 +81,9 @@ impl VertexAIClient {
let block_threshold = self.config.block_threshold.clone();
let body = build_body(data, self.model.name.clone(), block_threshold)?;
let body = build_body(data, &self.model, block_threshold)?;
let model = self.model.name.clone();
let model = &self.model.name;
let url = format!("{api_base}/{}:{}", model, func);
@ -176,7 +176,7 @@ fn check_error(data: &Value) -> Result<()> {
pub(crate) fn build_body(
data: SendData,
_model: String,
model: &Model,
block_threshold: Option<String>,
) -> Result<Value> {
let SendData {
@ -228,7 +228,7 @@ pub(crate) fn build_body(
);
}
let mut body = json!({ "contents": contents });
let mut body = json!({ "contents": contents, "generationConfig": {} });
if let Some(block_threshold) = block_threshold {
body["safetySettings"] = json!([
@ -239,10 +239,12 @@ pub(crate) fn build_body(
]);
}
if let Some(max_output_tokens) = model.max_output_tokens {
body["generationConfig"]["maxOutputTokens"] = max_output_tokens.into();
}
if let Some(temperature) = temperature {
body["generationConfig"] = json!({
"temperature": temperature,
});
body["generationConfig"]["temperature"] = temperature.into();
}
Ok(body)

Loading…
Cancel
Save