feat: add `.set max_output_tokens` (#468)

pull/469/head
sigoden 1 month ago committed by GitHub
parent e7fa6c5a20
commit 8a65337d59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -115,8 +115,7 @@ macro_rules! register_client {
None None
$(.or_else(|| $client::init(config)))+ $(.or_else(|| $client::init(config)))+
.ok_or_else(|| { .ok_or_else(|| {
let model = config.read().model.clone(); anyhow::anyhow!("Unknown client '{}'", &config.read().model.client_name)
anyhow::anyhow!("Unknown client '{}'", &model.client_name)
}) })
} }
@ -254,6 +253,10 @@ macro_rules! client_common_fns {
&self.model &self.model
} }
fn model_mut(&mut self) -> &mut Model {
&mut self.model
}
fn set_model(&mut self, model: Model) { fn set_model(&mut self, model: Model) {
self.model = model; self.model = model;
} }
@ -323,6 +326,8 @@ pub trait Client: Sync + Send {
fn model(&self) -> &Model; fn model(&self) -> &Model;
fn model_mut(&mut self) -> &mut Model;
fn set_model(&mut self, model: Model); fn set_model(&mut self, model: Model);
fn build_client(&self) -> Result<ReqwestClient> { fn build_client(&self) -> Result<ReqwestClient> {

@ -46,14 +46,16 @@ impl Model {
models models
.iter() .iter()
.map(|v| { .map(|v| {
Model::new(client_name, &v.name) let mut model = Model::new(client_name, &v.name);
model
.set_max_input_tokens(v.max_input_tokens) .set_max_input_tokens(v.max_input_tokens)
.set_max_output_tokens(v.max_output_tokens) .set_max_output_tokens(v.max_output_tokens)
.set_ref_max_output_tokens(v.ref_max_output_tokens) .set_ref_max_output_tokens(v.ref_max_output_tokens)
.set_input_price(v.input_price) .set_input_price(v.input_price)
.set_output_price(v.output_price) .set_output_price(v.output_price)
.set_supports_vision(v.supports_vision) .set_supports_vision(v.supports_vision)
.set_extra_fields(&v.extra_fields) .set_extra_fields(&v.extra_fields);
model
}) })
.collect() .collect()
} }
@ -95,8 +97,7 @@ impl Model {
pub fn description(&self) -> String { pub fn description(&self) -> String {
let max_input_tokens = format_option_value(&self.max_input_tokens); let max_input_tokens = format_option_value(&self.max_input_tokens);
let max_output_tokens = let max_output_tokens = format_option_value(&self.show_max_output_tokens());
format_option_value(&self.max_output_tokens.or(self.ref_max_output_tokens));
let input_price = format_option_value(&self.input_price); let input_price = format_option_value(&self.input_price);
let output_price = format_option_value(&self.output_price); let output_price = format_option_value(&self.output_price);
let vision = if self.capabilities.contains(ModelCapabilities::Vision) { let vision = if self.capabilities.contains(ModelCapabilities::Vision) {
@ -110,7 +111,11 @@ impl Model {
) )
} }
pub fn set_max_input_tokens(mut self, max_input_tokens: Option<usize>) -> Self { pub fn show_max_output_tokens(&self) -> Option<isize> {
self.max_output_tokens.or(self.ref_max_output_tokens)
}
pub fn set_max_input_tokens(&mut self, max_input_tokens: Option<usize>) -> &mut Self {
match max_input_tokens { match max_input_tokens {
None | Some(0) => self.max_input_tokens = None, None | Some(0) => self.max_input_tokens = None,
_ => self.max_input_tokens = max_input_tokens, _ => self.max_input_tokens = max_input_tokens,
@ -118,7 +123,7 @@ impl Model {
self self
} }
pub fn set_max_output_tokens(mut self, max_output_tokens: Option<isize>) -> Self { pub fn set_max_output_tokens(&mut self, max_output_tokens: Option<isize>) -> &mut Self {
match max_output_tokens { match max_output_tokens {
None | Some(0) => self.max_output_tokens = None, None | Some(0) => self.max_output_tokens = None,
_ => self.max_output_tokens = max_output_tokens, _ => self.max_output_tokens = max_output_tokens,
@ -126,7 +131,7 @@ impl Model {
self self
} }
pub fn set_ref_max_output_tokens(mut self, ref_max_output_tokens: Option<isize>) -> Self { pub fn set_ref_max_output_tokens(&mut self, ref_max_output_tokens: Option<isize>) -> &mut Self {
match ref_max_output_tokens { match ref_max_output_tokens {
None | Some(0) => self.ref_max_output_tokens = None, None | Some(0) => self.ref_max_output_tokens = None,
_ => self.ref_max_output_tokens = ref_max_output_tokens, _ => self.ref_max_output_tokens = ref_max_output_tokens,
@ -134,7 +139,7 @@ impl Model {
self self
} }
pub fn set_input_price(mut self, input_price: Option<f64>) -> Self { pub fn set_input_price(&mut self, input_price: Option<f64>) -> &mut Self {
match input_price { match input_price {
None => self.input_price = None, None => self.input_price = None,
_ => self.input_price = input_price, _ => self.input_price = input_price,
@ -142,7 +147,7 @@ impl Model {
self self
} }
pub fn set_output_price(mut self, output_price: Option<f64>) -> Self { pub fn set_output_price(&mut self, output_price: Option<f64>) -> &mut Self {
match output_price { match output_price {
None => self.output_price = None, None => self.output_price = None,
_ => self.output_price = output_price, _ => self.output_price = output_price,
@ -150,7 +155,7 @@ impl Model {
self self
} }
pub fn set_supports_vision(mut self, supports_vision: bool) -> Self { pub fn set_supports_vision(&mut self, supports_vision: bool) -> &mut Self {
if supports_vision { if supports_vision {
self.capabilities |= ModelCapabilities::Vision; self.capabilities |= ModelCapabilities::Vision;
} else { } else {
@ -160,9 +165,9 @@ impl Model {
} }
pub fn set_extra_fields( pub fn set_extra_fields(
mut self, &mut self,
extra_fields: &Option<serde_json::Map<String, serde_json::Value>>, extra_fields: &Option<serde_json::Map<String, serde_json::Value>>,
) -> Self { ) -> &mut Self {
self.extra_fields = extra_fields.clone(); self.extra_fields = extra_fields.clone();
self self
} }

@ -418,6 +418,13 @@ impl Config {
.map_or_else(|| String::from("no"), |v| v.to_string()); .map_or_else(|| String::from("no"), |v| v.to_string());
let items = vec![ let items = vec![
("model", self.model.id()), ("model", self.model.id()),
(
"max_output_tokens",
self.model
.max_output_tokens
.map(|v| format!("{v} (current model)"))
.unwrap_or_else(|| "-".into()),
),
("temperature", format_option_value(&self.temperature)), ("temperature", format_option_value(&self.temperature)),
("top_p", format_option_value(&self.top_p)), ("top_p", format_option_value(&self.top_p)),
("dry_run", self.dry_run.to_string()), ("dry_run", self.dry_run.to_string()),
@ -497,23 +504,28 @@ impl Config {
.map(|v| (v.clone(), String::new())) .map(|v| (v.clone(), String::new()))
.collect(), .collect(),
".set" => vec![ ".set" => vec![
"temperature ", "max_output_tokens",
"top_p ", "temperature",
"top_p",
"compress_threshold", "compress_threshold",
"save ", "save",
"save_session ", "save_session",
"highlight ", "highlight",
"dry_run ", "dry_run",
"auto_copy ", "auto_copy",
] ]
.into_iter() .into_iter()
.map(|v| (v.to_string(), String::new())) .map(|v| (format!("{v} "), String::new()))
.collect(), .collect(),
_ => vec![], _ => vec![],
}; };
(values, args[0]) (values, args[0])
} else if args.len() == 2 { } else if args.len() == 2 {
let values = match args[0] { let values = match args[0] {
"max_output_tokens" => match self.model.show_max_output_tokens() {
Some(v) => vec![v.to_string()],
None => vec![],
},
"save" => complete_bool(self.save), "save" => complete_bool(self.save),
"save_session" => { "save_session" => {
let save_session = if let Some(session) = &self.session { let save_session = if let Some(session) = &self.session {
@ -549,6 +561,10 @@ impl Config {
let key = parts[0]; let key = parts[0];
let value = parts[1]; let value = parts[1];
match key { match key {
"max_output_tokens" => {
let value = parse_value(value)?;
self.model.set_max_output_tokens(value);
}
"temperature" => { "temperature" => {
let value = parse_value(value)?; let value = parse_value(value)?;
self.set_temperature(value); self.set_temperature(value);

@ -164,7 +164,7 @@ impl Server {
let mut client = init_client(&config)?; let mut client = init_client(&config)?;
if max_tokens.is_some() { if max_tokens.is_some() {
client.set_model(client.model().clone().set_max_output_tokens(max_tokens)); client.model_mut().set_max_output_tokens(max_tokens);
} }
let abort = create_abort_signal(); let abort = create_abort_signal();
let http_client = client.build_client()?; let http_client = client.build_client()?;

Loading…
Cancel
Save