diff --git a/src/client/common.rs b/src/client/common.rs index 842741c..85255ee 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -147,17 +147,22 @@ macro_rules! register_client { anyhow::bail!("Unknown client '{}'", client) } - pub fn list_models(config: &$crate::config::Config) -> Vec<$crate::client::Model> { - config - .clients - .iter() - .flat_map(|v| match v { - $(ClientConfig::$config(c) => $client::list_models(c),)+ - ClientConfig::Unknown => vec![], - }) - .collect() + static mut ALL_CLIENTS: Option> = None; + + pub fn list_models(config: &$crate::config::Config) -> Vec<&$crate::client::Model> { + if unsafe { ALL_CLIENTS.is_none() } { + let models: Vec<_> = config + .clients + .iter() + .flat_map(|v| match v { + $(ClientConfig::$config(c) => $client::list_models(c),)+ + ClientConfig::Unknown => vec![], + }) + .collect(); + unsafe { ALL_CLIENTS = Some(models) }; + } + unsafe { ALL_CLIENTS.as_ref().unwrap().iter().collect() } } - }; } diff --git a/src/client/model.rs b/src/client/model.rs index 459d94e..aface38 100644 --- a/src/client/model.rs +++ b/src/client/model.rs @@ -1,6 +1,6 @@ use super::message::{Message, MessageContent}; -use crate::utils::count_tokens; +use crate::utils::{count_tokens, format_option_value}; use anyhow::{bail, Result}; use serde::Deserialize; @@ -14,6 +14,9 @@ pub struct Model { pub name: String, pub max_input_tokens: Option, pub max_output_tokens: Option, + pub ref_max_output_tokens: Option, + pub input_price: Option, + pub output_price: Option, pub extra_fields: Option>, pub capabilities: ModelCapabilities, } @@ -32,6 +35,9 @@ impl Model { extra_fields: None, max_input_tokens: None, max_output_tokens: None, + ref_max_output_tokens: None, + input_price: None, + output_price: None, capabilities: ModelCapabilities::Text, } } @@ -43,13 +49,16 @@ impl Model { Model::new(client_name, &v.name) .set_max_input_tokens(v.max_input_tokens) .set_max_output_tokens(v.max_output_tokens) + .set_ref_max_output_tokens(v.ref_max_output_tokens) + .set_input_price(v.input_price) + .set_output_price(v.output_price) .set_supports_vision(v.supports_vision) .set_extra_fields(&v.extra_fields) }) .collect() } - pub fn find(models: &[Self], value: &str) -> Option { + pub fn find(models: &[&Self], value: &str) -> Option { let mut model = None; let (client_name, model_name) = match value.split_once(':') { Some((client_name, model_name)) => { @@ -64,16 +73,16 @@ impl Model { match model_name { Some(model_name) => { if let Some(found) = models.iter().find(|v| v.id() == value) { - model = Some(found.clone()); + model = Some((*found).clone()); } else if let Some(found) = models.iter().find(|v| v.client_name == client_name) { - let mut found = found.clone(); + let mut found = (*found).clone(); found.name = model_name.to_string(); model = Some(found) } } None => { if let Some(found) = models.iter().find(|v| v.client_name == client_name) { - model = Some(found.clone()); + model = Some((*found).clone()); } } } @@ -84,6 +93,23 @@ impl Model { format!("{}:{}", self.client_name, self.name) } + pub fn description(&self) -> String { + let max_input_tokens = format_option_value(&self.max_input_tokens); + let max_output_tokens = + format_option_value(&self.max_output_tokens.or(self.ref_max_output_tokens)); + let input_price = format_option_value(&self.input_price); + let output_price = format_option_value(&self.output_price); + let vision = if self.capabilities.contains(ModelCapabilities::Vision) { + "👁" + } else { + "" + }; + format!( + "{:>8} / {:>8} | {:>6} / {:>6} {}", + max_input_tokens, max_output_tokens, input_price, output_price, vision + ) + } + pub fn set_max_input_tokens(mut self, max_input_tokens: Option) -> Self { match max_input_tokens { None | Some(0) => self.max_input_tokens = None, @@ -100,6 +126,30 @@ impl Model { self } + pub fn set_ref_max_output_tokens(mut self, ref_max_output_tokens: Option) -> Self { + match ref_max_output_tokens { + None | Some(0) => self.ref_max_output_tokens = None, + _ => self.ref_max_output_tokens = ref_max_output_tokens, + } + self + } + + pub fn set_input_price(mut self, input_price: Option) -> Self { + match input_price { + None => self.input_price = None, + _ => self.input_price = input_price, + } + self + } + + pub fn set_output_price(mut self, output_price: Option) -> Self { + match output_price { + None => self.output_price = None, + _ => self.output_price = output_price, + } + self + } + pub fn set_supports_vision(mut self, supports_vision: bool) -> Self { if supports_vision { self.capabilities |= ModelCapabilities::Vision; @@ -178,6 +228,8 @@ pub struct ModelConfig { pub name: String, pub max_input_tokens: Option, pub max_output_tokens: Option, + #[serde(rename = "max_output_tokens?")] + pub ref_max_output_tokens: Option, pub input_price: Option, pub output_price: Option, #[serde(default)] diff --git a/src/config/mod.rs b/src/config/mod.rs index 56f3122..85c3056 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -11,7 +11,10 @@ use crate::client::{ create_client_config, list_client_types, list_models, ClientConfig, Message, Model, SendData, }; use crate::render::{MarkdownRender, RenderOptions}; -use crate::utils::{get_env_name, light_theme_from_colorfgbg, now, render_prompt, set_text}; +use crate::utils::{ + format_option_value, fuzzy_match, get_env_name, light_theme_from_colorfgbg, now, render_prompt, + set_text, +}; use anyhow::{anyhow, bail, Context, Result}; use inquire::{Confirm, Select, Text}; @@ -415,18 +418,18 @@ impl Config { .map_or_else(|| String::from("no"), |v| v.to_string()); let items = vec![ ("model", self.model.id()), - ("temperature", format_option(&self.temperature)), - ("top_p", format_option(&self.top_p)), + ("temperature", format_option_value(&self.temperature)), + ("top_p", format_option_value(&self.top_p)), ("dry_run", self.dry_run.to_string()), ("save", self.save.to_string()), - ("save_session", format_option(&self.save_session)), + ("save_session", format_option_value(&self.save_session)), ("highlight", self.highlight.to_string()), ("light_theme", self.light_theme.to_string()), ("wrap", wrap), ("wrap_code", self.wrap_code.to_string()), ("auto_copy", self.auto_copy.to_string()), ("keybindings", self.keybindings.stringify().into()), - ("prelude", format_option(&self.prelude)), + ("prelude", format_option_value(&self.prelude)), ("compress_threshold", self.compress_threshold.to_string()), ("config_file", display_path(&Self::config_file()?)), ("roles_file", display_path(&Self::roles_file()?)), @@ -476,12 +479,23 @@ impl Config { .unwrap_or_default() } - pub fn repl_complete(&self, cmd: &str, args: &[&str]) -> Vec { + pub fn repl_complete(&self, cmd: &str, args: &[&str]) -> Vec<(String, String)> { let (values, filter) = if args.len() == 1 { let values = match cmd { - ".role" => self.roles.iter().map(|v| v.name.clone()).collect(), - ".model" => list_models(self).into_iter().map(|v| v.id()).collect(), - ".session" => self.list_sessions(), + ".role" => self + .roles + .iter() + .map(|v| (v.name.clone(), String::new())) + .collect(), + ".model" => list_models(self) + .into_iter() + .map(|v| (v.id(), v.description())) + .collect(), + ".session" => self + .list_sessions() + .into_iter() + .map(|v| (v.clone(), String::new())) + .collect(), ".set" => vec![ "temperature ", "top_p ", @@ -493,7 +507,7 @@ impl Config { "auto_copy ", ] .into_iter() - .map(|v| v.to_string()) + .map(|v| (v.to_string(), String::new())) .collect(), _ => vec![], }; @@ -514,13 +528,16 @@ impl Config { "auto_copy" => complete_bool(self.auto_copy), _ => vec![], }; - (values, args[1]) + ( + values.into_iter().map(|v| (v, String::new())).collect(), + args[1], + ) } else { return vec![]; }; values .into_iter() - .filter(|v| v.starts_with(filter)) + .filter(|(value, _)| fuzzy_match(value, filter)) .collect() } @@ -1136,16 +1153,6 @@ where Ok(value) } -fn format_option(value: &Option) -> String -where - T: std::fmt::Display, -{ - match value { - Some(value) => value.to_string(), - None => "-".to_string(), - } -} - fn complete_bool(value: bool) -> Vec { vec![(!value).to_string()] } diff --git a/src/repl/completer.rs b/src/repl/completer.rs index aea1ee6..2e7d79c 100644 --- a/src/repl/completer.rs +++ b/src/repl/completer.rs @@ -54,7 +54,7 @@ impl Completer for ReplCompleter { .read() .repl_complete(cmd, &args) .iter() - .map(|name| create_suggestion(name.clone(), None, span)), + .map(|(value, description)| create_suggestion(value, description, span)), ) } @@ -69,7 +69,7 @@ impl Completer for ReplCompleter { } else { format!("{name} ") }; - create_suggestion(name, Some(description.to_string()), span) + create_suggestion(&name, description, span) })) } suggestions @@ -105,9 +105,14 @@ impl ReplCompleter { } } -fn create_suggestion(value: String, description: Option, span: Span) -> Suggestion { +fn create_suggestion(value: &str, description: &str, span: Span) -> Suggestion { + let description = if description.is_empty() { + None + } else { + Some(description.to_string()) + }; Suggestion { - value, + value: value.to_string(), description, style: None, extra: None, diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 38cf766..771baf3 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -166,6 +166,33 @@ pub fn extract_block(input: &str) -> String { } } +pub fn format_option_value(value: &Option) -> String +where + T: std::fmt::Display, +{ + match value { + Some(value) => value.to_string(), + None => "-".to_string(), + } +} + +pub fn fuzzy_match(text: &str, pattern: &str) -> bool { + let text_chars: Vec = text.chars().collect(); + let pattern_chars: Vec = pattern.chars().collect(); + + let mut pattern_index = 0; + let mut text_index = 0; + + while pattern_index < pattern_chars.len() && text_index < text_chars.len() { + if pattern_chars[pattern_index] == text_chars[text_index] { + pattern_index += 1; + } + text_index += 1; + } + + pattern_index == pattern_chars.len() +} + #[cfg(test)] mod tests { use super::*; @@ -180,4 +207,11 @@ mod tests { fn test_count_tokens() { assert_eq!(count_tokens("😊 hello world"), 4); } + + #[test] + fn test_fuzzy_match() { + assert!(fuzzy_match("openai:gpt-4-turbo", "gpt4")); + assert!(fuzzy_match("openai:gpt-4-turbo", "oai4")); + assert!(!fuzzy_match("openai:gpt-4-turbo", "4gpt")); + } }