refactor: rename model's `max_tokens` to `max_input_tokens` (#339)

BREAKING CHANGE: rename model's `max_tokens` to `max_input_tokens`
pull/340/head
sigoden 3 months ago committed by GitHub
parent be4e5e569a
commit 8e5d4e55b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -97,7 +97,7 @@ clients:
api_base: http://localhost:8080/v1
models:
- name: gpt4all-j
max_tokens: 8192
max_input_tokens: 8192
```
Take a look at the [config.example.yaml](config.example.yaml) for the complete configuration details.

@ -52,11 +52,11 @@ clients:
chat_endpoint: /chat/completions # Optional field
models:
- name: llama2
max_tokens: 8192
max_input_tokens: 8192
extra_fields: # Optional field, set custom parameters
key: value
- name: llava
max_tokens: 8192
max_input_tokens: 8192
capabilities: text,vision # Optional field, possible values: text, vision
# See https://github.com/jmorganca/ollama
@ -66,7 +66,7 @@ clients:
chat_endpoint: /chat # Optional field
models:
- name: mistral
max_tokens: 8192
max_input_tokens: 8192
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart
- type: azure-openai
@ -74,7 +74,7 @@ clients:
api_key: xxx
models:
- name: MyGPT4 # Model deployment name
max_tokens: 8192
max_input_tokens: 8192
# See https://cloud.google.com/vertex-ai
- type: vertexai

@ -28,8 +28,8 @@ impl AzureOpenAIClient {
("api_key", "API Key:", true, PromptKind::String),
("models[].name", "Model Name:", true, PromptKind::String),
(
"models[].max_tokens",
"Max Tokens:",
"models[].max_input_tokens",
"Max Input Tokens:",
true,
PromptKind::Integer,
),
@ -43,7 +43,7 @@ impl AzureOpenAIClient {
.iter()
.map(|v| {
Model::new(client_name, &v.name)
.set_max_tokens(v.max_tokens)
.set_max_input_tokens(v.max_input_tokens)
.set_capabilities(v.capabilities)
.set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS)
})

@ -20,11 +20,12 @@ use serde_json::{json, Value};
const API_BASE: &str = "https://api.anthropic.com/v1/messages";
const MODELS: [(&str, usize, &str); 5] = [
("claude-3-opus-20240229", 204096, "text,vision"),
("claude-3-sonnet-20240229", 204096, "text,vision"),
("claude-2.1", 204096, "text"),
("claude-2.0", 104096, "text"),
("claude-instant-1.2", 104096, "text"),
// https://docs.anthropic.com/claude/docs/models-overview
("claude-3-opus-20240229", 200000, "text,vision"),
("claude-3-sonnet-20240229", 200000, "text,vision"),
("claude-2.1", 200000, "text"),
("claude-2.0", 100000, "text"),
("claude-instant-1.2", 100000, "text"),
];
const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2);
@ -66,10 +67,10 @@ impl ClaudeClient {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_tokens, capabilities)| {
.map(|(name, max_input_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_tokens(Some(max_tokens))
.set_max_input_tokens(Some(max_input_tokens))
.set_tokens_count_factors(TOKENS_COUNT_FACTORS)
})
.collect()

@ -11,8 +11,9 @@ use serde::Deserialize;
const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models/";
const MODELS: [(&str, usize, &str); 2] = [
("gemini-pro", 32768, "text"),
("gemini-pro-vision", 16384, "vision"),
// https://ai.google.dev/models/gemini
("gemini-pro", 30720, "text"),
("gemini-pro-vision", 12288, "vision"),
];
const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2);
@ -54,10 +55,10 @@ impl GeminiClient {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_tokens, capabilities)| {
.map(|(name, max_input_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_tokens(Some(max_tokens))
.set_max_input_tokens(Some(max_input_tokens))
.set_tokens_count_factors(TOKENS_COUNT_FACTORS)
})
.collect()

@ -28,8 +28,8 @@ impl LocalAIClient {
("api_key", "API Key:", false, PromptKind::String),
("models[].name", "Model Name:", true, PromptKind::String),
(
"models[].max_tokens",
"Max Tokens:",
"models[].max_input_tokens",
"Max Input Tokens:",
false,
PromptKind::Integer,
),
@ -44,7 +44,7 @@ impl LocalAIClient {
.map(|v| {
Model::new(client_name, &v.name)
.set_capabilities(v.capabilities)
.set_max_tokens(v.max_tokens)
.set_max_input_tokens(v.max_input_tokens)
.set_extra_fields(v.extra_fields.clone())
.set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS)
})

@ -11,6 +11,7 @@ use serde::Deserialize;
const API_URL: &str = "https://api.mistral.ai/v1/chat/completions";
const MODELS: [(&str, usize, &str); 5] = [
// https://docs.mistral.ai/platform/endpoints/
("mistral-small-latest", 32000, "text"),
("mistral-medium-latest", 32000, "text"),
("mistral-larget-latest", 32000, "text"),
@ -39,10 +40,10 @@ impl MistralClient {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_tokens, capabilities)| {
.map(|(name, max_input_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_tokens(Some(max_tokens))
.set_max_input_tokens(Some(max_input_tokens))
.set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS)
})
.collect()

@ -11,7 +11,7 @@ pub type TokensCountFactors = (usize, usize); // (per-messages, bias)
pub struct Model {
pub client_name: String,
pub name: String,
pub max_tokens: Option<usize>,
pub max_input_tokens: Option<usize>,
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
pub tokens_count_factors: TokensCountFactors,
pub capabilities: ModelCapabilities,
@ -29,7 +29,7 @@ impl Model {
client_name: client_name.into(),
name: name.into(),
extra_fields: None,
max_tokens: None,
max_input_tokens: None,
tokens_count_factors: Default::default(),
capabilities: ModelCapabilities::Text,
}
@ -83,10 +83,10 @@ impl Model {
self
}
pub fn set_max_tokens(mut self, max_tokens: Option<usize>) -> Self {
match max_tokens {
None | Some(0) => self.max_tokens = None,
_ => self.max_tokens = max_tokens,
pub fn set_max_input_tokens(mut self, max_input_tokens: Option<usize>) -> Self {
match max_input_tokens {
None | Some(0) => self.max_input_tokens = None,
_ => self.max_input_tokens = max_input_tokens,
}
self
}
@ -122,12 +122,12 @@ impl Model {
}
}
pub fn max_tokens_limit(&self, messages: &[Message]) -> Result<()> {
pub fn max_input_tokens_limit(&self, messages: &[Message]) -> Result<()> {
let (_, bias) = self.tokens_count_factors;
let total_tokens = self.total_tokens(messages) + bias;
if let Some(max_tokens) = self.max_tokens {
if total_tokens >= max_tokens {
bail!("Exceed max tokens limit")
if let Some(max_input_tokens) = self.max_input_tokens {
if total_tokens >= max_input_tokens {
bail!("Exceed max input tokens limit")
}
}
Ok(())
@ -147,7 +147,7 @@ impl Model {
#[derive(Debug, Clone, Deserialize)]
pub struct ModelConfig {
pub name: String,
pub max_tokens: Option<usize>,
pub max_input_tokens: Option<usize>,
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
#[serde(deserialize_with = "deserialize_capabilities")]
#[serde(default = "default_capabilities")]

@ -52,8 +52,8 @@ impl OllamaClient {
("api_key", "API Key:", false, PromptKind::String),
("models[].name", "Model Name:", true, PromptKind::String),
(
"models[].max_tokens",
"Max Tokens:",
"models[].max_input_tokens",
"Max Input Tokens:",
false,
PromptKind::Integer,
),
@ -68,7 +68,7 @@ impl OllamaClient {
.map(|v| {
Model::new(client_name, &v.name)
.set_capabilities(v.capabilities)
.set_max_tokens(v.max_tokens)
.set_max_input_tokens(v.max_input_tokens)
.set_extra_fields(v.extra_fields.clone())
.set_tokens_count_factors(TOKENS_COUNT_FACTORS)
})

@ -12,14 +12,14 @@ use serde_json::{json, Value};
const API_BASE: &str = "https://api.openai.com/v1";
const MODELS: [(&str, usize, &str); 7] = [
const MODELS: [(&str, usize, &str); 5] = [
// https://platform.openai.com/docs/models/gpt-3-5-turbo
("gpt-3.5-turbo", 16385, "text"),
("gpt-3.5-turbo-1106", 16385, "text"),
// https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
("gpt-4-turbo-preview", 128000, "text"),
("gpt-4-vision-preview", 128000, "text,vision"),
("gpt-4-1106-preview", 128000, "text"),
("gpt-4", 8192, "text"),
("gpt-4-32k", 32768, "text"),
];
pub const OPENAI_TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2);
@ -46,10 +46,10 @@ impl OpenAIClient {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_tokens, capabilities)| {
.map(|(name, max_input_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_tokens(Some(max_tokens))
.set_max_input_tokens(Some(max_input_tokens))
.set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS)
})
.collect()

@ -25,10 +25,12 @@ const API_URL_VL: &str =
"https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation";
const MODELS: [(&str, usize, &str); 6] = [
("qwen-turbo", 8192, "text"),
("qwen-plus", 32768, "text"),
("qwen-max", 8192, "text"),
("qwen-max-longcontext", 30720, "text"),
// https://help.aliyun.com/zh/dashscope/developer-reference/api-details
("qwen-turbo", 6000, "text"),
("qwen-plus", 30000, "text"),
("qwen-max", 6000, "text"),
("qwen-max-longcontext", 28000, "text"),
// https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-vl-plus-api
("qwen-vl-plus", 0, "text,vision"),
("qwen-vl-max", 0, "text,vision"),
];
@ -78,10 +80,10 @@ impl QianwenClient {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_tokens, capabilities)| {
.map(|(name, max_input_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_tokens(Some(max_tokens))
.set_max_input_tokens(Some(max_input_tokens))
})
.collect()
}

@ -14,9 +14,10 @@ use serde::Deserialize;
use serde_json::{json, Value};
use std::path::PathBuf;
// https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models
const MODELS: [(&str, usize, &str); 5] = [
("gemini-1.0-pro", 32760, "text"),
("gemini.1.0-pro-vision", 16384, "text,vision"),
("gemini-1.0-pro", 24568, "text"),
("gemini.1.0-pro-vision", 14336, "text,vision"),
("gemini-1.0-ultra", 8192, "text"),
("gemini.1.0-ultra-vision", 8192, "text,vision"),
("gemini-1.5-pro", 1000000, "text"),
@ -66,10 +67,10 @@ impl VertexAIClient {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_tokens, capabilities)| {
.map(|(name, max_input_tokens, capabilities)| {
Model::new(client_name, name)
.set_capabilities(capabilities.into())
.set_max_tokens(Some(max_tokens))
.set_max_input_tokens(Some(max_input_tokens))
.set_tokens_count_factors(TOKENS_COUNT_FACTORS)
})
.collect()

@ -750,7 +750,7 @@ impl Config {
pub fn prepare_send_data(&self, input: &Input, stream: bool) -> Result<SendData> {
let messages = self.build_messages(input)?;
self.model.max_tokens_limit(&messages)?;
self.model.max_input_tokens_limit(&messages)?;
Ok(SendData {
messages,
temperature: self.get_temperature(),
@ -773,8 +773,8 @@ impl Config {
output.insert("client_name", self.model.client_name.clone());
output.insert("model_name", self.model.name.clone());
output.insert(
"max_tokens",
self.model.max_tokens.unwrap_or_default().to_string(),
"max_input_tokens",
self.model.max_input_tokens.unwrap_or_default().to_string(),
);
if let Some(temperature) = self.temperature {
if temperature != 0.0 {

@ -108,8 +108,8 @@ impl Session {
data["temperature"] = temperature.into();
}
data["total_tokens"] = tokens.into();
if let Some(max_tokens) = self.model.max_tokens {
data["max_tokens"] = max_tokens.into();
if let Some(conext_window) = self.model.max_input_tokens {
data["max_input_tokens"] = conext_window.into();
}
if percent != 0.0 {
data["total/max"] = format!("{}%", percent).into();
@ -138,8 +138,8 @@ impl Session {
items.push(("compress_threshold", compress_threshold.to_string()));
}
if let Some(max_tokens) = self.model.max_tokens {
items.push(("max_tokens", max_tokens.to_string()));
if let Some(max_input_tokens) = self.model.max_input_tokens {
items.push(("max_input_tokens", max_input_tokens.to_string()));
}
let mut lines: Vec<String> = items
@ -179,11 +179,11 @@ impl Session {
pub fn tokens_and_percent(&self) -> (usize, f32) {
let tokens = self.tokens();
let max_tokens = self.model.max_tokens.unwrap_or_default();
let percent = if max_tokens == 0 {
let max_input_tokens = self.model.max_input_tokens.unwrap_or_default();
let percent = if max_input_tokens == 0 {
0.0
} else {
let percent = tokens as f32 / max_tokens as f32 * 100.0;
let percent = tokens as f32 / max_input_tokens as f32 * 100.0;
(percent * 100.0).round() / 100.0
};
(tokens, percent)

Loading…
Cancel
Save