|
|
|
@ -11,7 +11,7 @@ pub type TokensCountFactors = (usize, usize); // (per-messages, bias)
|
|
|
|
|
pub struct Model {
|
|
|
|
|
pub client_name: String,
|
|
|
|
|
pub name: String,
|
|
|
|
|
pub max_tokens: Option<usize>,
|
|
|
|
|
pub max_input_tokens: Option<usize>,
|
|
|
|
|
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
|
|
|
|
|
pub tokens_count_factors: TokensCountFactors,
|
|
|
|
|
pub capabilities: ModelCapabilities,
|
|
|
|
@ -29,7 +29,7 @@ impl Model {
|
|
|
|
|
client_name: client_name.into(),
|
|
|
|
|
name: name.into(),
|
|
|
|
|
extra_fields: None,
|
|
|
|
|
max_tokens: None,
|
|
|
|
|
max_input_tokens: None,
|
|
|
|
|
tokens_count_factors: Default::default(),
|
|
|
|
|
capabilities: ModelCapabilities::Text,
|
|
|
|
|
}
|
|
|
|
@ -83,10 +83,10 @@ impl Model {
|
|
|
|
|
self
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn set_max_tokens(mut self, max_tokens: Option<usize>) -> Self {
|
|
|
|
|
match max_tokens {
|
|
|
|
|
None | Some(0) => self.max_tokens = None,
|
|
|
|
|
_ => self.max_tokens = max_tokens,
|
|
|
|
|
pub fn set_max_input_tokens(mut self, max_input_tokens: Option<usize>) -> Self {
|
|
|
|
|
match max_input_tokens {
|
|
|
|
|
None | Some(0) => self.max_input_tokens = None,
|
|
|
|
|
_ => self.max_input_tokens = max_input_tokens,
|
|
|
|
|
}
|
|
|
|
|
self
|
|
|
|
|
}
|
|
|
|
@ -122,12 +122,12 @@ impl Model {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn max_tokens_limit(&self, messages: &[Message]) -> Result<()> {
|
|
|
|
|
pub fn max_input_tokens_limit(&self, messages: &[Message]) -> Result<()> {
|
|
|
|
|
let (_, bias) = self.tokens_count_factors;
|
|
|
|
|
let total_tokens = self.total_tokens(messages) + bias;
|
|
|
|
|
if let Some(max_tokens) = self.max_tokens {
|
|
|
|
|
if total_tokens >= max_tokens {
|
|
|
|
|
bail!("Exceed max tokens limit")
|
|
|
|
|
if let Some(max_input_tokens) = self.max_input_tokens {
|
|
|
|
|
if total_tokens >= max_input_tokens {
|
|
|
|
|
bail!("Exceed max input tokens limit")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
Ok(())
|
|
|
|
@ -147,7 +147,7 @@ impl Model {
|
|
|
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
|
|
|
pub struct ModelConfig {
|
|
|
|
|
pub name: String,
|
|
|
|
|
pub max_tokens: Option<usize>,
|
|
|
|
|
pub max_input_tokens: Option<usize>,
|
|
|
|
|
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
|
|
|
|
|
#[serde(deserialize_with = "deserialize_capabilities")]
|
|
|
|
|
#[serde(default = "default_capabilities")]
|
|
|
|
|