From 8a65337d590729f96a5f3c0b35dc5a08fae5bf94 Mon Sep 17 00:00:00 2001 From: sigoden Date: Tue, 30 Apr 2024 08:57:22 +0800 Subject: [PATCH] feat: add `.set max_output_tokens` (#468) --- src/client/common.rs | 9 +++++++-- src/client/model.rs | 29 +++++++++++++++++------------ src/config/mod.rs | 32 ++++++++++++++++++++++++-------- src/serve.rs | 2 +- 4 files changed, 49 insertions(+), 23 deletions(-) diff --git a/src/client/common.rs b/src/client/common.rs index 5ddaf87..ee190d5 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -115,8 +115,7 @@ macro_rules! register_client { None $(.or_else(|| $client::init(config)))+ .ok_or_else(|| { - let model = config.read().model.clone(); - anyhow::anyhow!("Unknown client '{}'", &model.client_name) + anyhow::anyhow!("Unknown client '{}'", &config.read().model.client_name) }) } @@ -254,6 +253,10 @@ macro_rules! client_common_fns { &self.model } + fn model_mut(&mut self) -> &mut Model { + &mut self.model + } + fn set_model(&mut self, model: Model) { self.model = model; } @@ -323,6 +326,8 @@ pub trait Client: Sync + Send { fn model(&self) -> &Model; + fn model_mut(&mut self) -> &mut Model; + fn set_model(&mut self, model: Model); fn build_client(&self) -> Result { diff --git a/src/client/model.rs b/src/client/model.rs index aface38..42040fb 100644 --- a/src/client/model.rs +++ b/src/client/model.rs @@ -46,14 +46,16 @@ impl Model { models .iter() .map(|v| { - Model::new(client_name, &v.name) + let mut model = Model::new(client_name, &v.name); + model .set_max_input_tokens(v.max_input_tokens) .set_max_output_tokens(v.max_output_tokens) .set_ref_max_output_tokens(v.ref_max_output_tokens) .set_input_price(v.input_price) .set_output_price(v.output_price) .set_supports_vision(v.supports_vision) - .set_extra_fields(&v.extra_fields) + .set_extra_fields(&v.extra_fields); + model }) .collect() } @@ -95,8 +97,7 @@ impl Model { pub fn description(&self) -> String { let max_input_tokens = format_option_value(&self.max_input_tokens); - let max_output_tokens = - format_option_value(&self.max_output_tokens.or(self.ref_max_output_tokens)); + let max_output_tokens = format_option_value(&self.show_max_output_tokens()); let input_price = format_option_value(&self.input_price); let output_price = format_option_value(&self.output_price); let vision = if self.capabilities.contains(ModelCapabilities::Vision) { @@ -110,7 +111,11 @@ impl Model { ) } - pub fn set_max_input_tokens(mut self, max_input_tokens: Option) -> Self { + pub fn show_max_output_tokens(&self) -> Option { + self.max_output_tokens.or(self.ref_max_output_tokens) + } + + pub fn set_max_input_tokens(&mut self, max_input_tokens: Option) -> &mut Self { match max_input_tokens { None | Some(0) => self.max_input_tokens = None, _ => self.max_input_tokens = max_input_tokens, @@ -118,7 +123,7 @@ impl Model { self } - pub fn set_max_output_tokens(mut self, max_output_tokens: Option) -> Self { + pub fn set_max_output_tokens(&mut self, max_output_tokens: Option) -> &mut Self { match max_output_tokens { None | Some(0) => self.max_output_tokens = None, _ => self.max_output_tokens = max_output_tokens, @@ -126,7 +131,7 @@ impl Model { self } - pub fn set_ref_max_output_tokens(mut self, ref_max_output_tokens: Option) -> Self { + pub fn set_ref_max_output_tokens(&mut self, ref_max_output_tokens: Option) -> &mut Self { match ref_max_output_tokens { None | Some(0) => self.ref_max_output_tokens = None, _ => self.ref_max_output_tokens = ref_max_output_tokens, @@ -134,7 +139,7 @@ impl Model { self } - pub fn set_input_price(mut self, input_price: Option) -> Self { + pub fn set_input_price(&mut self, input_price: Option) -> &mut Self { match input_price { None => self.input_price = None, _ => self.input_price = input_price, @@ -142,7 +147,7 @@ impl Model { self } - pub fn set_output_price(mut self, output_price: Option) -> Self { + pub fn set_output_price(&mut self, output_price: Option) -> &mut Self { match output_price { None => self.output_price = None, _ => self.output_price = output_price, @@ -150,7 +155,7 @@ impl Model { self } - pub fn set_supports_vision(mut self, supports_vision: bool) -> Self { + pub fn set_supports_vision(&mut self, supports_vision: bool) -> &mut Self { if supports_vision { self.capabilities |= ModelCapabilities::Vision; } else { @@ -160,9 +165,9 @@ impl Model { } pub fn set_extra_fields( - mut self, + &mut self, extra_fields: &Option>, - ) -> Self { + ) -> &mut Self { self.extra_fields = extra_fields.clone(); self } diff --git a/src/config/mod.rs b/src/config/mod.rs index 85c3056..0c89324 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -418,6 +418,13 @@ impl Config { .map_or_else(|| String::from("no"), |v| v.to_string()); let items = vec![ ("model", self.model.id()), + ( + "max_output_tokens", + self.model + .max_output_tokens + .map(|v| format!("{v} (current model)")) + .unwrap_or_else(|| "-".into()), + ), ("temperature", format_option_value(&self.temperature)), ("top_p", format_option_value(&self.top_p)), ("dry_run", self.dry_run.to_string()), @@ -497,23 +504,28 @@ impl Config { .map(|v| (v.clone(), String::new())) .collect(), ".set" => vec![ - "temperature ", - "top_p ", + "max_output_tokens", + "temperature", + "top_p", "compress_threshold", - "save ", - "save_session ", - "highlight ", - "dry_run ", - "auto_copy ", + "save", + "save_session", + "highlight", + "dry_run", + "auto_copy", ] .into_iter() - .map(|v| (v.to_string(), String::new())) + .map(|v| (format!("{v} "), String::new())) .collect(), _ => vec![], }; (values, args[0]) } else if args.len() == 2 { let values = match args[0] { + "max_output_tokens" => match self.model.show_max_output_tokens() { + Some(v) => vec![v.to_string()], + None => vec![], + }, "save" => complete_bool(self.save), "save_session" => { let save_session = if let Some(session) = &self.session { @@ -549,6 +561,10 @@ impl Config { let key = parts[0]; let value = parts[1]; match key { + "max_output_tokens" => { + let value = parse_value(value)?; + self.model.set_max_output_tokens(value); + } "temperature" => { let value = parse_value(value)?; self.set_temperature(value); diff --git a/src/serve.rs b/src/serve.rs index c5f8fd3..d259126 100644 --- a/src/serve.rs +++ b/src/serve.rs @@ -164,7 +164,7 @@ impl Server { let mut client = init_client(&config)?; if max_tokens.is_some() { - client.set_model(client.model().clone().set_max_output_tokens(max_tokens)); + client.model_mut().set_max_output_tokens(max_tokens); } let abort = create_abort_signal(); let http_client = client.build_client()?;