feat: compress session automaticlly (#333)

* feat: compress session automaticlly

* non-block

* update field description

* set compress_threshold

* update session::clear_messages

* able to override session compress_threshold

* enable compress_threshold by default

* make session compress_threshold optional
pull/336/head
sigoden 3 months ago committed by GitHub
parent 9e15a3409e
commit 3f693ea060
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -87,6 +87,7 @@ wrap_code: false # Whether wrap code block
auto_copy: false # Automatically copy the last output to the clipboard
keybindings: emacs # REPL keybindings. values: emacs, vi
prelude: '' # Set a default role or session (role:<name>, session:<name>)
compress_threshold: 1000 # Compress session if tokens exceed this value (valid when >=1000)
clients:
- type: openai
@ -296,6 +297,7 @@ Usage: .file <file>... [-- text...]
> .set highlight false
> .set save false
> .set auto_copy true
> .set compress_threshold 1000
```
## Command

@ -9,6 +9,13 @@ auto_copy: false # Automatically copy the last output to the cli
keybindings: emacs # REPL keybindings. (emacs, vi)
prelude: '' # Set a default role or session (role:<name>, session:<name>)
# Compress session if tokens exceed this value (valid when >=1000)
compress_threshold: 1000
# The prompt for summarizing session messages
summarize_prompt: 'Summarize the discussion briefly in 200 words or less to use as a prompt for future context.'
# The prompt for the summary of the session
summary_prompt: 'This is a summary of the chat history as a recap: '
# Custom REPL prompt, see https://github.com/sigoden/aichat/wiki/Custom-REPL-Prompt
left_prompt: '{color.green}{?session {session}{?role /}}{role}{color.cyan}{?session )}{!session >}{color.reset} '
right_prompt: '{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}'

@ -17,7 +17,7 @@ impl Message {
}
}
#[derive(Debug, Clone, Copy, Deserialize, Serialize)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum MessageRole {
System,

@ -67,6 +67,12 @@ pub struct Config {
pub keybindings: Keybindings,
/// Set a default role or session (role:<name>, session:<name>)
pub prelude: String,
/// Compress session if tokens exceed this value (>=1000)
pub compress_threshold: usize,
/// The prompt for summarizing session messages
pub summarize_prompt: String,
// The prompt for the summary of the session
pub summary_prompt: String,
/// REPL left prompt
pub left_prompt: String,
/// REPL right prompt
@ -104,6 +110,9 @@ impl Default for Config {
auto_copy: false,
keybindings: Default::default(),
prelude: String::new(),
compress_threshold: 2000,
summarize_prompt: "Summarize the discussion briefly in 200 words or less to use as a prompt for future context.".to_string(),
summary_prompt: "This is a summary of the chat history as a recap: ".into(),
left_prompt: "{color.green}{?session {session}{?role /}}{role}{color.cyan}{?session )}{!session >}{color.reset} ".to_string(),
right_prompt: "{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}"
.to_string(),
@ -345,12 +354,18 @@ impl Config {
self.temperature
}
pub fn set_temperature(&mut self, value: Option<f64>) -> Result<()> {
pub fn set_temperature(&mut self, value: Option<f64>) {
self.temperature = value;
if let Some(session) = self.session.as_mut() {
session.set_temperature(value);
}
Ok(())
}
pub fn set_compress_threshold(&mut self, value: usize) {
self.compress_threshold = value;
if let Some(session) = self.session.as_mut() {
session.set_compress_threshold(value);
}
}
pub fn echo_messages(&self, input: &Input) -> String {
@ -430,6 +445,7 @@ impl Config {
("auto_copy", self.auto_copy.to_string()),
("keybindings", self.keybindings.stringify().into()),
("prelude", prelude),
("compress_threshold", self.compress_threshold.to_string()),
("config_file", display_path(&Self::config_file()?)),
("roles_file", display_path(&Self::roles_file()?)),
("messages_file", display_path(&Self::messages_file()?)),
@ -445,7 +461,7 @@ impl Config {
pub fn role_info(&self) -> Result<String> {
if let Some(role) = &self.role {
role.info()
role.export()
} else {
bail!("No role")
}
@ -455,7 +471,7 @@ impl Config {
if let Some(session) = &self.session {
let render_options = self.get_render_options()?;
let mut markdown_render = MarkdownRender::init(render_options)?;
session.render(&mut markdown_render)
session.info(&mut markdown_render)
} else {
bail!("No session")
}
@ -465,7 +481,7 @@ impl Config {
if let Some(session) = &self.session {
session.export()
} else if let Some(role) = &self.role {
role.info()
role.export()
} else {
self.sys_info()
}
@ -486,6 +502,7 @@ impl Config {
".session" => self.list_sessions(),
".set" => vec![
"temperature ",
"compress_threshold",
"save ",
"highlight ",
"dry_run ",
@ -532,7 +549,11 @@ impl Config {
let value = value.parse().with_context(|| "Invalid value")?;
Some(value)
};
self.set_temperature(value)?;
self.set_temperature(value);
}
"compress_threshold" => {
let value = value.parse().with_context(|| "Invalid value")?;
self.set_compress_threshold(value);
}
"save" => {
let value = value.parse().with_context(|| "Invalid value")?;
@ -608,7 +629,7 @@ impl Config {
if let Some(mut session) = self.session.take() {
self.last_message = None;
self.temperature = self.default_temperature;
if session.should_save() {
if session.dirty {
let ans = Confirm::new("Save session?").with_default(false).prompt()?;
if !ans {
return Ok(());
@ -634,7 +655,7 @@ impl Config {
pub fn clear_session_messages(&mut self) -> Result<()> {
if let Some(session) = self.session.as_mut() {
session.clear_messgaes();
session.clear_messages();
}
Ok(())
}
@ -660,6 +681,35 @@ impl Config {
}
}
pub fn should_compress_session(&mut self) -> bool {
if let Some(sesion) = self.session.as_mut() {
if sesion.need_compress(self.compress_threshold) {
sesion.compressing = true;
return true;
}
}
false
}
pub fn compress_session(&mut self, summary: &str) {
if let Some(session) = self.session.as_mut() {
session.compress(format!("{}{}", self.summary_prompt, summary));
}
}
pub fn is_compressing_session(&self) -> bool {
self.session
.as_ref()
.map(|v| v.compressing)
.unwrap_or_default()
}
pub fn end_compressing_session(&mut self) {
if let Some(session) = self.session.as_mut() {
session.compressing = false;
}
}
pub fn get_render_options(&self) -> Result<RenderOptions> {
let theme = if self.highlight {
let theme_mode = if self.light_theme { "light" } else { "dark" };

@ -72,7 +72,7 @@ For example if the prompt is "Hello world Python", you should return "print('Hel
}
}
pub fn info(&self) -> Result<String> {
pub fn export(&self) -> Result<String> {
let output = serde_yaml::to_string(&self)
.with_context(|| format!("Unable to show info about role {}", &self.name))?;
Ok(output.trim_end().to_string())

@ -22,6 +22,9 @@ pub struct Session {
messages: Vec<Message>,
#[serde(default)]
data_urls: HashMap<String, String>,
#[serde(default)]
compressed_messages: Vec<Message>,
compress_threshold: Option<usize>,
#[serde(skip)]
pub name: String,
#[serde(skip)]
@ -29,6 +32,8 @@ pub struct Session {
#[serde(skip)]
pub dirty: bool,
#[serde(skip)]
pub compressing: bool,
#[serde(skip)]
pub role: Option<Role>,
#[serde(skip)]
pub model: Model,
@ -41,10 +46,13 @@ impl Session {
model_id: model.id(),
temperature,
messages: vec![],
compressed_messages: vec![],
compress_threshold: None,
data_urls: Default::default(),
name: name.to_string(),
path: None,
dirty: false,
compressing: false,
role,
model,
}
@ -74,6 +82,13 @@ impl Session {
self.temperature
}
pub fn need_compress(&self, current_compress_threshold: usize) -> bool {
let threshold = self
.compress_threshold
.unwrap_or(current_compress_threshold);
threshold >= 1000 && self.tokens() > threshold
}
pub fn tokens(&self) -> usize {
self.model.total_tokens(&self.messages)
}
@ -106,7 +121,7 @@ impl Session {
Ok(output)
}
pub fn render(&self, render: &mut MarkdownRender) -> Result<String> {
pub fn info(&self, render: &mut MarkdownRender) -> Result<String> {
let mut items = vec![];
if let Some(path) = &self.path {
@ -119,6 +134,10 @@ impl Session {
items.push(("temperature", temperature.to_string()));
}
if let Some(compress_threshold) = self.compress_threshold {
items.push(("compress_threshold", compress_threshold.to_string()));
}
if let Some(max_tokens) = self.model.max_tokens {
items.push(("max_tokens", max_tokens.to_string()));
}
@ -135,7 +154,7 @@ impl Session {
for message in &self.messages {
match message.role {
MessageRole::System => {
continue;
lines.push(render.render(&message.content.render_input(resolve_url_fn)));
}
MessageRole::Assistant => {
if let MessageContent::Text(text) = &message.content {
@ -181,14 +200,28 @@ impl Session {
self.temperature = value;
}
pub fn set_compress_threshold(&mut self, value: usize) {
self.compress_threshold = Some(value);
}
pub fn set_model(&mut self, model: Model) -> Result<()> {
self.model_id = model.id();
self.model = model;
Ok(())
}
pub fn compress(&mut self, prompt: String) {
self.compressed_messages.append(&mut self.messages);
self.messages.push(Message {
role: MessageRole::System,
content: MessageContent::Text(prompt),
});
self.role = None;
self.dirty = true;
}
pub fn save(&mut self, session_path: &Path) -> Result<()> {
if !self.should_save() {
if !self.dirty {
return Ok(());
}
self.path = Some(session_path.display().to_string());
@ -208,10 +241,6 @@ impl Session {
Ok(())
}
pub fn should_save(&self) -> bool {
!self.is_empty() && self.dirty
}
pub fn guard_save(&self) -> Result<()> {
if self.path.is_none() {
bail!("Not found session '{}'", self.name)
@ -258,11 +287,9 @@ impl Session {
Ok(())
}
pub fn clear_messgaes(&mut self) {
if self.messages.is_empty() {
return;
}
pub fn clear_messages(&mut self) {
self.messages.clear();
self.compressed_messages.clear();
self.data_urls.clear();
self.dirty = true;
}
@ -275,12 +302,16 @@ impl Session {
pub fn build_emssages(&self, input: &Input) -> Vec<Message> {
let mut messages = self.messages.clone();
let mut need_add_msg = true;
if messages.is_empty() {
let len = messages.len();
if len == 0 {
if let Some(role) = self.role.as_ref() {
messages = role.build_messages(input);
need_add_msg = false;
}
};
} else if len == 1 && self.compressed_messages.len() >= 2 {
messages
.extend(self.compressed_messages[self.compressed_messages.len() - 2..].to_vec());
}
if need_add_msg {
messages.push(Message {
role: MessageRole::User,

@ -258,6 +258,9 @@ impl Repl {
if text.is_empty() && files.is_empty() {
return Ok(());
}
while self.config.read().is_compressing_session() {
std::thread::sleep(std::time::Duration::from_millis(100));
}
let input = if files.is_empty() {
Input::from_str(text)
} else {
@ -269,6 +272,14 @@ impl Repl {
let output = render_stream(&input, client.as_ref(), &self.config, self.abort.clone())?;
self.config.write().save_message(input, &output)?;
self.config.read().maybe_copy(&output);
if self.config.write().should_compress_session() {
let config = self.config.clone();
std::thread::spawn(move || -> anyhow::Result<()> {
let _ = compress_session(&config);
config.write().end_compressing_session();
Ok(())
});
}
Ok(())
}
@ -418,6 +429,15 @@ fn parse_command(line: &str) -> Option<(&str, Option<&str>)> {
}
}
fn compress_session(config: &GlobalConfig) -> Result<()> {
let input = Input::from_str(&config.read().summarize_prompt);
let mut client = init_client(config)?;
ensure_model_capabilities(client.as_mut(), input.required_capabilities())?;
let summary = client.send_message(input)?;
config.write().compress_session(&summary);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;

Loading…
Cancel
Save