refactor: seperate temperature from role/session (#378)

pull/379/head
sigoden 3 months ago committed by GitHub
parent 8da9fa5f4c
commit 3180d1d485
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -47,8 +47,7 @@ pub struct Config {
#[serde(rename(serialize = "model", deserialize = "model"))]
pub model_id: Option<String>,
/// LLM temperature
#[serde(rename(serialize = "temperature", deserialize = "temperature"))]
pub default_temperature: Option<f64>,
pub temperature: Option<f64>,
/// Dry-run flag
pub dry_run: bool,
/// Whether to save the message
@ -94,15 +93,13 @@ pub struct Config {
pub model: Model,
#[serde(skip)]
pub last_message: Option<(Input, String)>,
#[serde(skip)]
pub temperature: Option<f64>,
}
impl Default for Config {
fn default() -> Self {
Self {
model_id: None,
default_temperature: None,
temperature: None,
save: true,
highlight: true,
dry_run: false,
@ -125,7 +122,6 @@ impl Default for Config {
session: None,
model: Default::default(),
last_message: None,
temperature: None,
}
}
}
@ -157,8 +153,6 @@ impl Config {
config.set_wrap(&wrap)?;
}
config.temperature = config.default_temperature;
config.load_roles()?;
config.setup_model()?;
@ -322,13 +316,11 @@ impl Config {
session.guard_empty()?;
session.set_temperature(role.temperature);
}
self.temperature = role.temperature;
self.role = Some(role);
Ok(())
}
pub fn clear_role(&mut self) -> Result<()> {
self.temperature = self.default_temperature;
self.role = None;
Ok(())
}
@ -351,14 +343,13 @@ impl Config {
}
}
pub fn get_temperature(&self) -> Option<f64> {
self.temperature
}
pub fn set_temperature(&mut self, value: Option<f64>) {
self.temperature = value;
if let Some(session) = self.session.as_mut() {
session.set_temperature(value);
} else if let Some(role) = self.role.as_mut() {
role.set_temperature(value);
} else {
self.temperature = value;
}
}
@ -621,7 +612,6 @@ impl Config {
} else {
let session = Session::load(name, &session_path)?;
let model = session.model().to_string();
self.temperature = session.temperature();
self.session = Some(session);
self.set_model(&model)?;
}
@ -649,7 +639,6 @@ impl Config {
pub fn end_session(&mut self, interactive: bool) -> Result<()> {
if let Some(mut session) = self.session.take() {
self.last_message = None;
self.temperature = self.default_temperature;
if session.dirty {
// If it's a temporary session, we'll always prompt to save on exit
// If it's named, we'll save automatically if they've set the save flag and prompt if they haven't
@ -786,10 +775,17 @@ impl Config {
pub fn prepare_send_data(&self, input: &Input, stream: bool) -> Result<SendData> {
let messages = self.build_messages(input)?;
let temperature = if let Some(session) = input.session(&self.session) {
session.temperature()
} else if let Some(role) = input.role() {
role.temperature
} else {
self.temperature
};
self.model.max_input_tokens_limit(&messages)?;
Ok(SendData {
messages,
temperature: self.get_temperature(),
temperature,
stream,
})
}

@ -89,6 +89,10 @@ For example if the prompt is "Hello world Python", you should return "print('Hel
self.prompt.contains(INPUT_PLACEHOLDER)
}
pub fn set_temperature(&mut self, value: Option<f64>) {
self.temperature = value;
}
pub fn complete_prompt_args(&mut self, name: &str) {
self.name = name.to_string();
self.prompt = complete_prompt_args(&self.prompt, &self.name);

Loading…
Cancel
Save