refactor: remove role field from session struct (#356)

pull/357/head
sigoden 3 months ago committed by GitHub
parent 992d570041
commit f3210d622a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -316,7 +316,8 @@ impl Config {
pub fn set_role_obj(&mut self, role: Role) -> Result<()> {
if let Some(session) = self.session.as_mut() {
session.update_role(Some(role.clone()))?;
session.guard_empty()?;
session.set_temperature(role.temperature);
}
self.temperature = role.temperature;
self.role = Some(role);
@ -324,9 +325,6 @@ impl Config {
}
pub fn clear_role(&mut self) -> Result<()> {
if let Some(session) = self.session.as_mut() {
session.update_role(None)?;
}
self.temperature = self.default_temperature;
self.role = None;
Ok(())
@ -335,7 +333,7 @@ impl Config {
pub fn get_state(&self) -> State {
if let Some(session) = &self.session {
if session.is_empty() {
if session.role.is_some() {
if self.role.is_some() {
State::EmptySessionWithRole
} else {
State::EmptySession
@ -592,13 +590,13 @@ impl Config {
self.session = Some(Session::new(
TEMP_SESSION_NAME,
self.model.clone(),
self.role.clone(),
self.temperature,
));
}
Some(name) => {
let session_path = Self::session_file(name)?;
if !session_path.exists() {
self.session = Some(Session::new(name, self.model.clone(), self.role.clone()));
self.session = Some(Session::new(name, self.model.clone(), self.temperature));
} else {
let session = Session::load(name, &session_path)?;
let model = session.model().to_string();

@ -1,5 +1,4 @@
use super::input::resolve_data_url;
use super::role::Role;
use super::{Input, Model};
use crate::client::{Message, MessageContent, MessageRole};
@ -34,14 +33,11 @@ pub struct Session {
#[serde(skip)]
pub compressing: bool,
#[serde(skip)]
pub role: Option<Role>,
#[serde(skip)]
pub model: Model,
}
impl Session {
pub fn new(name: &str, model: Model, role: Option<Role>) -> Self {
let temperature = role.as_ref().and_then(|v| v.temperature);
pub fn new(name: &str, model: Model, temperature: Option<f64>) -> Self {
Self {
model_id: model.id(),
temperature,
@ -53,7 +49,6 @@ impl Session {
path: None,
dirty: false,
compressing: false,
role,
model,
}
}
@ -189,13 +184,6 @@ impl Session {
(tokens, percent)
}
pub fn update_role(&mut self, role: Option<Role>) -> Result<()> {
self.guard_empty()?;
self.temperature = role.as_ref().and_then(|v| v.temperature);
self.role = role;
Ok(())
}
pub fn set_temperature(&mut self, value: Option<f64>) {
self.temperature = value;
}
@ -216,7 +204,6 @@ impl Session {
role: MessageRole::System,
content: MessageContent::Text(prompt),
});
self.role = None;
self.dirty = true;
}
@ -260,13 +247,13 @@ impl Session {
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
self.messages.is_empty() && self.compressed_messages.is_empty()
}
pub fn add_message(&mut self, input: &Input, output: &str) -> Result<()> {
let mut need_add_msg = true;
if self.messages.is_empty() {
if let Some(role) = self.role.as_ref() {
if let Some(role) = input.role() {
self.messages.extend(role.build_messages(input));
need_add_msg = false;
}
@ -282,7 +269,6 @@ impl Session {
role: MessageRole::Assistant,
content: MessageContent::Text(output.to_string()),
});
self.role = None;
self.dirty = true;
Ok(())
}
@ -304,7 +290,7 @@ impl Session {
let mut need_add_msg = true;
let len = messages.len();
if len == 0 {
if let Some(role) = self.role.as_ref() {
if let Some(role) = input.role() {
messages = role.build_messages(input);
need_add_msg = false;
}

Loading…
Cancel
Save