feat: add `.set max_output_tokens` (#468)

pull/469/head
sigoden 3 weeks ago committed by GitHub
parent e7fa6c5a20
commit 8a65337d59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -115,8 +115,7 @@ macro_rules! register_client {
None
$(.or_else(|| $client::init(config)))+
.ok_or_else(|| {
let model = config.read().model.clone();
anyhow::anyhow!("Unknown client '{}'", &model.client_name)
anyhow::anyhow!("Unknown client '{}'", &config.read().model.client_name)
})
}
@ -254,6 +253,10 @@ macro_rules! client_common_fns {
&self.model
}
fn model_mut(&mut self) -> &mut Model {
&mut self.model
}
fn set_model(&mut self, model: Model) {
self.model = model;
}
@ -323,6 +326,8 @@ pub trait Client: Sync + Send {
fn model(&self) -> &Model;
fn model_mut(&mut self) -> &mut Model;
fn set_model(&mut self, model: Model);
fn build_client(&self) -> Result<ReqwestClient> {

@ -46,14 +46,16 @@ impl Model {
models
.iter()
.map(|v| {
Model::new(client_name, &v.name)
let mut model = Model::new(client_name, &v.name);
model
.set_max_input_tokens(v.max_input_tokens)
.set_max_output_tokens(v.max_output_tokens)
.set_ref_max_output_tokens(v.ref_max_output_tokens)
.set_input_price(v.input_price)
.set_output_price(v.output_price)
.set_supports_vision(v.supports_vision)
.set_extra_fields(&v.extra_fields)
.set_extra_fields(&v.extra_fields);
model
})
.collect()
}
@ -95,8 +97,7 @@ impl Model {
pub fn description(&self) -> String {
let max_input_tokens = format_option_value(&self.max_input_tokens);
let max_output_tokens =
format_option_value(&self.max_output_tokens.or(self.ref_max_output_tokens));
let max_output_tokens = format_option_value(&self.show_max_output_tokens());
let input_price = format_option_value(&self.input_price);
let output_price = format_option_value(&self.output_price);
let vision = if self.capabilities.contains(ModelCapabilities::Vision) {
@ -110,7 +111,11 @@ impl Model {
)
}
pub fn set_max_input_tokens(mut self, max_input_tokens: Option<usize>) -> Self {
pub fn show_max_output_tokens(&self) -> Option<isize> {
self.max_output_tokens.or(self.ref_max_output_tokens)
}
pub fn set_max_input_tokens(&mut self, max_input_tokens: Option<usize>) -> &mut Self {
match max_input_tokens {
None | Some(0) => self.max_input_tokens = None,
_ => self.max_input_tokens = max_input_tokens,
@ -118,7 +123,7 @@ impl Model {
self
}
pub fn set_max_output_tokens(mut self, max_output_tokens: Option<isize>) -> Self {
pub fn set_max_output_tokens(&mut self, max_output_tokens: Option<isize>) -> &mut Self {
match max_output_tokens {
None | Some(0) => self.max_output_tokens = None,
_ => self.max_output_tokens = max_output_tokens,
@ -126,7 +131,7 @@ impl Model {
self
}
pub fn set_ref_max_output_tokens(mut self, ref_max_output_tokens: Option<isize>) -> Self {
pub fn set_ref_max_output_tokens(&mut self, ref_max_output_tokens: Option<isize>) -> &mut Self {
match ref_max_output_tokens {
None | Some(0) => self.ref_max_output_tokens = None,
_ => self.ref_max_output_tokens = ref_max_output_tokens,
@ -134,7 +139,7 @@ impl Model {
self
}
pub fn set_input_price(mut self, input_price: Option<f64>) -> Self {
pub fn set_input_price(&mut self, input_price: Option<f64>) -> &mut Self {
match input_price {
None => self.input_price = None,
_ => self.input_price = input_price,
@ -142,7 +147,7 @@ impl Model {
self
}
pub fn set_output_price(mut self, output_price: Option<f64>) -> Self {
pub fn set_output_price(&mut self, output_price: Option<f64>) -> &mut Self {
match output_price {
None => self.output_price = None,
_ => self.output_price = output_price,
@ -150,7 +155,7 @@ impl Model {
self
}
pub fn set_supports_vision(mut self, supports_vision: bool) -> Self {
pub fn set_supports_vision(&mut self, supports_vision: bool) -> &mut Self {
if supports_vision {
self.capabilities |= ModelCapabilities::Vision;
} else {
@ -160,9 +165,9 @@ impl Model {
}
pub fn set_extra_fields(
mut self,
&mut self,
extra_fields: &Option<serde_json::Map<String, serde_json::Value>>,
) -> Self {
) -> &mut Self {
self.extra_fields = extra_fields.clone();
self
}

@ -418,6 +418,13 @@ impl Config {
.map_or_else(|| String::from("no"), |v| v.to_string());
let items = vec![
("model", self.model.id()),
(
"max_output_tokens",
self.model
.max_output_tokens
.map(|v| format!("{v} (current model)"))
.unwrap_or_else(|| "-".into()),
),
("temperature", format_option_value(&self.temperature)),
("top_p", format_option_value(&self.top_p)),
("dry_run", self.dry_run.to_string()),
@ -497,23 +504,28 @@ impl Config {
.map(|v| (v.clone(), String::new()))
.collect(),
".set" => vec![
"temperature ",
"top_p ",
"max_output_tokens",
"temperature",
"top_p",
"compress_threshold",
"save ",
"save_session ",
"highlight ",
"dry_run ",
"auto_copy ",
"save",
"save_session",
"highlight",
"dry_run",
"auto_copy",
]
.into_iter()
.map(|v| (v.to_string(), String::new()))
.map(|v| (format!("{v} "), String::new()))
.collect(),
_ => vec![],
};
(values, args[0])
} else if args.len() == 2 {
let values = match args[0] {
"max_output_tokens" => match self.model.show_max_output_tokens() {
Some(v) => vec![v.to_string()],
None => vec![],
},
"save" => complete_bool(self.save),
"save_session" => {
let save_session = if let Some(session) = &self.session {
@ -549,6 +561,10 @@ impl Config {
let key = parts[0];
let value = parts[1];
match key {
"max_output_tokens" => {
let value = parse_value(value)?;
self.model.set_max_output_tokens(value);
}
"temperature" => {
let value = parse_value(value)?;
self.set_temperature(value);

@ -164,7 +164,7 @@ impl Server {
let mut client = init_client(&config)?;
if max_tokens.is_some() {
client.set_model(client.model().clone().set_max_output_tokens(max_tokens));
client.model_mut().set_max_output_tokens(max_tokens);
}
let abort = create_abort_signal();
let http_client = client.build_client()?;

Loading…
Cancel
Save