feat: allow binding model to the role (#505)

pull/506/head
sigoden 3 weeks ago committed by GitHub
parent 5284a18248
commit 79d0bba640
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -70,8 +70,7 @@ macro_rules! register_client {
impl $client {
pub const NAME: &'static str = $name;
pub fn init(global_config: &$crate::config::GlobalConfig) -> Option<Box<dyn Client>> {
let model = global_config.read().model.clone();
pub fn init(global_config: &$crate::config::GlobalConfig, model: &$crate::client::Model) -> Option<Box<dyn Client>> {
let config = global_config.read().clients.iter().find_map(|client_config| {
if let ClientConfig::$config(c) = client_config {
if Self::name(c) == &model.client_name {
@ -84,7 +83,7 @@ macro_rules! register_client {
Some(Box::new(Self {
global_config: global_config.clone(),
config,
model,
model: model.clone(),
}))
}
@ -109,11 +108,12 @@ macro_rules! register_client {
)+
pub fn init_client(config: &$crate::config::GlobalConfig) -> anyhow::Result<Box<dyn Client>> {
pub fn init_client(config: &$crate::config::GlobalConfig, model: Option<$crate::client::Model>) -> anyhow::Result<Box<dyn Client>> {
let model = model.unwrap_or_else(|| config.read().model.clone());
None
$(.or_else(|| $client::init(config)))+
$(.or_else(|| $client::init(config, &model)))+
.ok_or_else(|| {
anyhow::anyhow!("Unknown client '{}'", &config.read().model.client_name)
anyhow::anyhow!("Unknown client '{}'", model.client_name)
})
}

@ -1,8 +1,8 @@
use super::{role::Role, session::Session, GlobalConfig};
use crate::client::{
init_client, Client, ImageUrl, Message, MessageContent, MessageContentPart, ModelCapabilities,
SendData,
init_client, list_models, Client, ImageUrl, Message, MessageContent, MessageContentPart, Model,
ModelCapabilities, SendData,
};
use crate::utils::{base64_encode, sha256};
@ -111,8 +111,23 @@ impl Input {
self.text = text;
}
pub fn model(&self) -> Model {
let model = self.config.read().model.clone();
if let Some(model_id) = self.role().and_then(|v| v.model_id.clone()) {
if model.id() != model_id {
if let Some(model) = list_models(&self.config.read())
.into_iter()
.find(|v| v.id() == model_id)
{
return model.clone();
}
}
};
model
}
pub fn create_client(&self) -> Result<Box<dyn Client>> {
init_client(&self.config)
init_client(&self.config, Some(self.model()))
}
pub fn prepare_send_data(&self, stream: bool) -> Result<SendData> {

@ -54,7 +54,8 @@ const RIGHT_PROMPT: &str = "{color.purple}{?session {?consume_tokens {consume_to
#[serde(default)]
pub struct Config {
#[serde(rename(serialize = "model", deserialize = "model"))]
pub model_id: Option<String>,
#[serde(default)]
pub model_id: String,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub dry_run: bool,
@ -91,7 +92,7 @@ pub struct Config {
impl Default for Config {
fn default() -> Self {
Self {
model_id: None,
model_id: Default::default(),
temperature: None,
top_p: None,
save: false,
@ -296,12 +297,16 @@ impl Config {
session.set_temperature(role.temperature);
session.set_top_p(role.top_p);
}
if let Some(model_id) = &role.model_id {
self.set_model(model_id)?;
}
self.role = Some(role);
Ok(())
}
pub fn clear_role(&mut self) -> Result<()> {
self.role = None;
self.restore_model()?;
Ok(())
}
@ -381,6 +386,8 @@ impl Config {
Some(model) => {
if let Some(session) = self.session.as_mut() {
session.set_model(&model);
} else if let Some(role) = self.role.as_mut() {
role.set_model(&model);
}
self.model = model;
Ok(())
@ -388,12 +395,28 @@ impl Config {
}
}
pub fn set_model_id(&mut self) {
self.model_id = self.model.id()
}
pub fn restore_model(&mut self) -> Result<()> {
let origin_model_id = self.model_id.clone();
self.set_model(&origin_model_id)
}
pub fn system_info(&self) -> Result<String> {
let display_path = |path: &Path| path.display().to_string();
let wrap = self
.wrap
.clone()
.map_or_else(|| String::from("no"), |v| v.to_string());
let (temperature, top_p) = if let Some(session) = &self.session {
(session.temperature(), session.top_p())
} else if let Some(role) = &self.role {
(role.temperature, role.top_p)
} else {
(self.temperature, self.top_p)
};
let items = vec![
("model", self.model.id()),
(
@ -403,8 +426,8 @@ impl Config {
.map(|v| format!("{v} (current model)"))
.unwrap_or_else(|| "-".into()),
),
("temperature", format_option_value(&self.temperature)),
("top_p", format_option_value(&self.top_p)),
("temperature", format_option_value(&temperature)),
("top_p", format_option_value(&top_p)),
("dry_run", self.dry_run.to_string()),
("save", self.save.to_string()),
("save_session", format_option_value(&self.save_session)),
@ -645,6 +668,7 @@ impl Config {
}
Self::save_session_to_file(&mut session)?;
}
self.restore_model()?;
}
Ok(())
}
@ -926,18 +950,19 @@ impl Config {
}
fn setup_model(&mut self) -> Result<()> {
let model = match &self.model_id {
Some(v) => v.clone(),
None => {
let models = list_models(self);
if models.is_empty() {
bail!("No available model");
}
models[0].id()
let model_id = if self.model_id.is_empty() {
let models = list_models(self);
if models.is_empty() {
bail!("No available model");
}
let model_id = models[0].id();
self.model_id.clone_from(&model_id);
model_id
} else {
self.model_id.clone()
};
self.set_model(&model)?;
self.set_model(&model_id)?;
Ok(())
}
@ -1046,6 +1071,10 @@ impl State {
pub fn in_role() -> Vec<Self> {
vec![Self::Role, Self::EmptySessionWithRole]
}
pub fn is_normal(&self) -> bool {
self == &Self::Normal
}
}
fn create_config_file(config_path: &Path) -> Result<()> {

@ -1,6 +1,6 @@
use super::Input;
use crate::{
client::{Message, MessageContent, MessageRole},
client::{Message, MessageContent, MessageRole, Model},
utils::{detect_os, detect_shell},
};
@ -18,6 +18,8 @@ pub const INPUT_PLACEHOLDER: &str = "__INPUT__";
pub struct Role {
pub name: String,
pub prompt: String,
#[serde(rename(serialize = "model", deserialize = "model"))]
pub model_id: Option<String>,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
}
@ -28,6 +30,7 @@ impl Role {
name: TEMP_ROLE.into(),
prompt: prompt.into(),
temperature: None,
model_id: None,
top_p: None,
}
}
@ -62,6 +65,7 @@ async function timeout(ms) {
.map(|(name, prompt)| Self {
name: name.into(),
prompt,
model_id: None,
temperature: None,
top_p: None,
})
@ -78,6 +82,10 @@ async function timeout(ms) {
self.prompt.contains(INPUT_PLACEHOLDER)
}
pub fn set_model(&mut self, model: &Model) {
self.model_id = Some(model.id());
}
pub fn set_temperature(&mut self, value: Option<f64>) {
self.temperature = value;
}

@ -97,6 +97,7 @@ async fn main() -> Result<()> {
}
if let Some(model) = &cli.model {
config.write().set_model(model)?;
config.write().set_model_id();
}
if cli.save_session {
config.write().set_save_session(Some(true));

@ -163,6 +163,9 @@ impl Repl {
".model" => match args {
Some(name) => {
self.config.write().set_model(name)?;
if self.config.read().state().is_normal() {
self.config.write().set_model_id();
}
}
None => println!("Usage: .model <name>"),
},

@ -249,7 +249,7 @@ impl Server {
config.write().set_model(&model_name)?;
}
let mut client = init_client(&config)?;
let mut client = init_client(&config, None)?;
if max_tokens.is_some() {
client.model_mut().set_max_tokens(max_tokens, true);
}

Loading…
Cancel
Save