mirror of https://github.com/sigoden/aichat
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
258 lines
8.3 KiB
Rust
258 lines
8.3 KiB
Rust
use super::message::{Message, MessageContent};
|
|
|
|
use crate::utils::{count_tokens, format_option_value};
|
|
|
|
use anyhow::{bail, Result};
|
|
use serde::Deserialize;
|
|
|
|
const PER_MESSAGES_TOKENS: usize = 5;
|
|
const BASIS_TOKENS: usize = 2;
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct Model {
|
|
pub client_name: String,
|
|
pub name: String,
|
|
pub max_input_tokens: Option<usize>,
|
|
pub max_output_tokens: Option<isize>,
|
|
pub ref_max_output_tokens: Option<isize>,
|
|
pub input_price: Option<f64>,
|
|
pub output_price: Option<f64>,
|
|
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
|
|
pub capabilities: ModelCapabilities,
|
|
}
|
|
|
|
impl Default for Model {
|
|
fn default() -> Self {
|
|
Model::new("", "")
|
|
}
|
|
}
|
|
|
|
impl Model {
|
|
pub fn new(client_name: &str, name: &str) -> Self {
|
|
Self {
|
|
client_name: client_name.into(),
|
|
name: name.into(),
|
|
extra_fields: None,
|
|
max_input_tokens: None,
|
|
max_output_tokens: None,
|
|
ref_max_output_tokens: None,
|
|
input_price: None,
|
|
output_price: None,
|
|
capabilities: ModelCapabilities::Text,
|
|
}
|
|
}
|
|
|
|
pub fn from_config(client_name: &str, models: &[ModelConfig]) -> Vec<Self> {
|
|
models
|
|
.iter()
|
|
.map(|v| {
|
|
let mut model = Model::new(client_name, &v.name);
|
|
model
|
|
.set_max_input_tokens(v.max_input_tokens)
|
|
.set_max_output_tokens(v.max_output_tokens)
|
|
.set_ref_max_output_tokens(v.ref_max_output_tokens)
|
|
.set_input_price(v.input_price)
|
|
.set_output_price(v.output_price)
|
|
.set_supports_vision(v.supports_vision)
|
|
.set_extra_fields(&v.extra_fields);
|
|
model
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
pub fn find(models: &[&Self], value: &str) -> Option<Self> {
|
|
let mut model = None;
|
|
let (client_name, model_name) = match value.split_once(':') {
|
|
Some((client_name, model_name)) => {
|
|
if model_name.is_empty() {
|
|
(client_name, None)
|
|
} else {
|
|
(client_name, Some(model_name))
|
|
}
|
|
}
|
|
None => (value, None),
|
|
};
|
|
match model_name {
|
|
Some(model_name) => {
|
|
if let Some(found) = models.iter().find(|v| v.id() == value) {
|
|
model = Some((*found).clone());
|
|
} else if let Some(found) = models.iter().find(|v| v.client_name == client_name) {
|
|
let mut found = (*found).clone();
|
|
found.name = model_name.to_string();
|
|
model = Some(found)
|
|
}
|
|
}
|
|
None => {
|
|
if let Some(found) = models.iter().find(|v| v.client_name == client_name) {
|
|
model = Some((*found).clone());
|
|
}
|
|
}
|
|
}
|
|
model
|
|
}
|
|
|
|
pub fn id(&self) -> String {
|
|
format!("{}:{}", self.client_name, self.name)
|
|
}
|
|
|
|
pub fn description(&self) -> String {
|
|
let max_input_tokens = format_option_value(&self.max_input_tokens);
|
|
let max_output_tokens = format_option_value(&self.show_max_output_tokens());
|
|
let input_price = format_option_value(&self.input_price);
|
|
let output_price = format_option_value(&self.output_price);
|
|
let vision = if self.capabilities.contains(ModelCapabilities::Vision) {
|
|
"👁"
|
|
} else {
|
|
""
|
|
};
|
|
format!(
|
|
"{:>8} / {:>8} | {:>6} / {:>6} {}",
|
|
max_input_tokens, max_output_tokens, input_price, output_price, vision
|
|
)
|
|
}
|
|
|
|
pub fn show_max_output_tokens(&self) -> Option<isize> {
|
|
self.max_output_tokens.or(self.ref_max_output_tokens)
|
|
}
|
|
|
|
pub fn set_max_input_tokens(&mut self, max_input_tokens: Option<usize>) -> &mut Self {
|
|
match max_input_tokens {
|
|
None | Some(0) => self.max_input_tokens = None,
|
|
_ => self.max_input_tokens = max_input_tokens,
|
|
}
|
|
self
|
|
}
|
|
|
|
pub fn set_max_output_tokens(&mut self, max_output_tokens: Option<isize>) -> &mut Self {
|
|
match max_output_tokens {
|
|
None | Some(0) => self.max_output_tokens = None,
|
|
_ => self.max_output_tokens = max_output_tokens,
|
|
}
|
|
self
|
|
}
|
|
|
|
pub fn set_ref_max_output_tokens(&mut self, ref_max_output_tokens: Option<isize>) -> &mut Self {
|
|
match ref_max_output_tokens {
|
|
None | Some(0) => self.ref_max_output_tokens = None,
|
|
_ => self.ref_max_output_tokens = ref_max_output_tokens,
|
|
}
|
|
self
|
|
}
|
|
|
|
pub fn set_input_price(&mut self, input_price: Option<f64>) -> &mut Self {
|
|
match input_price {
|
|
None => self.input_price = None,
|
|
_ => self.input_price = input_price,
|
|
}
|
|
self
|
|
}
|
|
|
|
pub fn set_output_price(&mut self, output_price: Option<f64>) -> &mut Self {
|
|
match output_price {
|
|
None => self.output_price = None,
|
|
_ => self.output_price = output_price,
|
|
}
|
|
self
|
|
}
|
|
|
|
pub fn set_supports_vision(&mut self, supports_vision: bool) -> &mut Self {
|
|
if supports_vision {
|
|
self.capabilities |= ModelCapabilities::Vision;
|
|
} else {
|
|
self.capabilities &= !ModelCapabilities::Vision;
|
|
}
|
|
self
|
|
}
|
|
|
|
pub fn set_extra_fields(
|
|
&mut self,
|
|
extra_fields: &Option<serde_json::Map<String, serde_json::Value>>,
|
|
) -> &mut Self {
|
|
self.extra_fields.clone_from(extra_fields);
|
|
self
|
|
}
|
|
|
|
pub fn messages_tokens(&self, messages: &[Message]) -> usize {
|
|
messages
|
|
.iter()
|
|
.map(|v| {
|
|
match &v.content {
|
|
MessageContent::Text(text) => count_tokens(text),
|
|
MessageContent::Array(_) => 0, // TODO
|
|
}
|
|
})
|
|
.sum()
|
|
}
|
|
|
|
pub fn total_tokens(&self, messages: &[Message]) -> usize {
|
|
if messages.is_empty() {
|
|
return 0;
|
|
}
|
|
let num_messages = messages.len();
|
|
let message_tokens = self.messages_tokens(messages);
|
|
if messages[num_messages - 1].role.is_user() {
|
|
num_messages * PER_MESSAGES_TOKENS + message_tokens
|
|
} else {
|
|
(num_messages - 1) * PER_MESSAGES_TOKENS + message_tokens
|
|
}
|
|
}
|
|
|
|
pub fn max_input_tokens_limit(&self, messages: &[Message]) -> Result<()> {
|
|
let total_tokens = self.total_tokens(messages) + BASIS_TOKENS;
|
|
if let Some(max_input_tokens) = self.max_input_tokens {
|
|
if total_tokens >= max_input_tokens {
|
|
bail!("Exceed max input tokens limit")
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub fn merge_extra_fields(&self, body: &mut serde_json::Value) {
|
|
if let (Some(body), Some(extra_fields)) = (body.as_object_mut(), &self.extra_fields) {
|
|
for (key, extra_field) in extra_fields {
|
|
if body.contains_key(key) {
|
|
if let (Some(sub_body), Some(extra_field)) =
|
|
(body[key].as_object_mut(), extra_field.as_object())
|
|
{
|
|
for (subkey, sub_field) in extra_field {
|
|
if !sub_body.contains_key(subkey) {
|
|
sub_body.insert(subkey.clone(), sub_field.clone());
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
body.insert(key.clone(), extra_field.clone());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
pub struct ModelConfig {
|
|
pub name: String,
|
|
pub max_input_tokens: Option<usize>,
|
|
pub max_output_tokens: Option<isize>,
|
|
#[serde(rename = "max_output_tokens?")]
|
|
pub ref_max_output_tokens: Option<isize>,
|
|
pub input_price: Option<f64>,
|
|
pub output_price: Option<f64>,
|
|
#[serde(default)]
|
|
pub supports_vision: bool,
|
|
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
pub struct BuiltinModels {
|
|
pub platform: String,
|
|
pub models: Vec<ModelConfig>,
|
|
}
|
|
|
|
bitflags::bitflags! {
|
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
|
pub struct ModelCapabilities: u32 {
|
|
const Text = 0b00000001;
|
|
const Vision = 0b00000010;
|
|
}
|
|
}
|