feat: allow use of temporary role in a session (#348)

pull/350/head
sigoden 3 months ago committed by GitHub
parent 8f14498969
commit aed243c3aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -102,11 +102,16 @@ pub struct ImageUrl {
#[cfg(test)]
mod tests {
use super::*;
use crate::config::InputContext;
#[test]
fn test_serde() {
assert_eq!(
serde_json::to_string(&Message::new(&Input::from_str("Hello World"))).unwrap(),
serde_json::to_string(&Message::new(&Input::from_str(
"Hello World",
InputContext::default()
)))
.unwrap(),
"{\"role\":\"user\",\"content\":\"Hello World\"}"
);
}

@ -1,3 +1,6 @@
use super::role::Role;
use super::session::Session;
use crate::client::{ImageUrl, MessageContent, MessageContentPart, ModelCapabilities};
use crate::utils::sha256sum;
@ -25,18 +28,20 @@ pub struct Input {
text: String,
medias: Vec<String>,
data_urls: HashMap<String, String>,
context: InputContext,
}
impl Input {
pub fn from_str(text: &str) -> Self {
pub fn from_str(text: &str, context: InputContext) -> Self {
Self {
text: text.to_string(),
medias: Default::default(),
data_urls: Default::default(),
context,
}
}
pub fn new(text: &str, files: Vec<String>) -> Result<Self> {
pub fn new(text: &str, files: Vec<String>, context: InputContext) -> Result<Self> {
let mut texts = vec![text.to_string()];
let mut medias = vec![];
let mut data_urls = HashMap::new();
@ -72,13 +77,38 @@ impl Input {
text: texts.join("\n"),
medias,
data_urls,
context,
})
}
pub fn is_empty(&self) -> bool {
self.text.is_empty() && self.medias.is_empty()
}
pub fn data_urls(&self) -> HashMap<String, String> {
self.data_urls.clone()
}
pub fn role(&self) -> Option<&Role> {
self.context.role.as_ref()
}
pub fn session<'a>(&self, session: &'a Option<Session>) -> Option<&'a Session> {
if self.context.in_session {
session.as_ref()
} else {
None
}
}
pub fn session_mut<'a>(&self, session: &'a mut Option<Session>) -> Option<&'a mut Session> {
if self.context.in_session {
session.as_mut()
} else {
None
}
}
pub fn summary(&self) -> String {
let text: String = self
.text
@ -154,6 +184,18 @@ impl Input {
}
}
#[derive(Debug, Clone, Default)]
pub struct InputContext {
role: Option<Role>,
in_session: bool,
}
impl InputContext {
pub fn new(role: Option<Role>, in_session: bool) -> Self {
Self { role, in_session }
}
}
pub fn resolve_data_url(data_urls: &HashMap<String, String>, data_url: String) -> String {
if data_url.starts_with("data:") {
let hash = sha256sum(&data_url);

@ -2,7 +2,7 @@ mod input;
mod role;
mod session;
pub use self::input::Input;
pub use self::input::{Input, InputContext};
use self::role::Role;
use self::session::{Session, TEMP_SESSION_NAME};
@ -226,7 +226,7 @@ impl Config {
return Ok(());
}
if let Some(session) = self.session.as_mut() {
if let Some(session) = input.session_mut(&mut self.session) {
session.add_message(&input, output)?;
return Ok(());
}
@ -241,7 +241,7 @@ impl Config {
let timestamp = now();
let summary = input.summary();
let input_markdown = input.render();
let output = match self.role.as_ref() {
let output = match input.role() {
None => {
format!("# CHAT: {summary} [{timestamp}]\n{input_markdown}\n--------\n{output}\n--------\n\n",)
}
@ -369,9 +369,9 @@ impl Config {
}
pub fn echo_messages(&self, input: &Input) -> String {
if let Some(session) = self.session.as_ref() {
if let Some(session) = input.session(&self.session) {
session.echo_messages(input)
} else if let Some(role) = self.role.as_ref() {
} else if let Some(role) = input.role() {
role.echo_messages(input)
} else {
input.render()
@ -379,9 +379,9 @@ impl Config {
}
pub fn build_messages(&self, input: &Input) -> Result<Vec<Message>> {
let messages = if let Some(session) = self.session.as_ref() {
let messages = if let Some(session) = input.session(&self.session) {
session.build_emssages(input)
} else if let Some(role) = self.role.as_ref() {
} else if let Some(role) = input.role() {
role.build_messages(input)
} else {
let message = Message::new(input);
@ -762,6 +762,10 @@ impl Config {
})
}
pub fn input_context(&self) -> InputContext {
InputContext::new(self.role.clone(), self.has_session())
}
pub fn maybe_print_send_tokens(&self, input: &Input) {
if self.dry_run {
if let Ok(messages) = self.build_messages(input) {

@ -114,7 +114,11 @@ fn start_directive(
if let Some(session) = &config.read().session {
session.guard_save()?;
}
let input = Input::new(text, include.unwrap_or_default())?;
let input = Input::new(
text,
include.unwrap_or_default(),
config.read().input_context(),
)?;
let mut client = init_client(config)?;
ensure_model_capabilities(client.as_mut(), input.required_capabilities())?;
config.read().maybe_print_send_tokens(&input);
@ -148,7 +152,7 @@ fn start_interactive(config: &GlobalConfig) -> Result<()> {
}
fn execute(config: &GlobalConfig, text: &str) -> Result<()> {
let input = Input::from_str(text);
let input = Input::from_str(text, config.read().input_context());
let client = init_client(config)?;
config.read().maybe_print_send_tokens(&input);
let mut eval_str = client.send_message(input.clone())?;
@ -192,7 +196,7 @@ fn execute(config: &GlobalConfig, text: &str) -> Result<()> {
if !describe {
config.write().set_describe_command_role()?;
}
let input = Input::from_str(&eval_str);
let input = Input::from_str(&eval_str, config.read().input_context());
let abort = create_abort_signal();
render_stream(&input, client.as_ref(), config, abort)?;
describe = true;

@ -7,7 +7,7 @@ use self::highlighter::ReplHighlighter;
use self::prompt::ReplPrompt;
use crate::client::{ensure_model_capabilities, init_client};
use crate::config::{GlobalConfig, Input, State};
use crate::config::{GlobalConfig, Input, InputContext, State};
use crate::render::{render_error, render_stream};
use crate::utils::{create_abort_signal, set_text, AbortSignal};
@ -174,21 +174,10 @@ impl Repl {
".role" => match args {
Some(args) => match args.split_once(|c| c == '\n' || c == ' ') {
Some((name, text)) => {
if self.config.read().has_session() {
bail!(r#"Cannot perform this action in a session"#);
} else {
let name = name.trim();
let text = text.trim();
let old_role =
self.config.read().role.as_ref().map(|v| v.name.to_string());
self.config.write().set_role(name)?;
let ask_ret = self.ask(text, vec![]);
match old_role {
Some(old_role) => self.config.write().set_role(&old_role)?,
None => self.config.write().clear_role()?,
}
ask_ret?;
}
let role = self.config.read().retrieve_role(name.trim())?;
let input =
Input::from_str(text.trim(), InputContext::new(Some(role), false));
self.ask(input)?;
}
None => {
self.config.write().set_role(args)?;
@ -219,7 +208,8 @@ impl Repl {
None => (args, ""),
};
let files = shell_words::split(files).with_context(|| "Invalid args")?;
self.ask(text, files)?;
let input = Input::new(text, files, self.config.read().input_context())?;
self.ask(input)?;
}
None => println!("Usage: .file <files>...[ -- <text>...]"),
},
@ -250,7 +240,8 @@ impl Repl {
_ => unknown_command()?,
},
None => {
self.ask(line, vec![])?;
let input = Input::from_str(line, self.config.read().input_context());
self.ask(input)?;
}
}
@ -259,18 +250,13 @@ impl Repl {
Ok(false)
}
fn ask(&self, text: &str, files: Vec<String>) -> Result<()> {
if text.is_empty() && files.is_empty() {
fn ask(&self, input: Input) -> Result<()> {
if input.is_empty() {
return Ok(());
}
while self.config.read().is_compressing_session() {
std::thread::sleep(std::time::Duration::from_millis(100));
}
let input = if files.is_empty() {
Input::from_str(text)
} else {
Input::new(text, files)?
};
self.config.read().maybe_print_send_tokens(&input);
let mut client = init_client(&self.config)?;
ensure_model_capabilities(client.as_mut(), input.required_capabilities())?;
@ -435,7 +421,10 @@ fn parse_command(line: &str) -> Option<(&str, Option<&str>)> {
}
fn compress_session(config: &GlobalConfig) -> Result<()> {
let input = Input::from_str(&config.read().summarize_prompt);
let input = Input::from_str(
&config.read().summarize_prompt,
config.read().input_context(),
);
let mut client = init_client(config)?;
ensure_model_capabilities(client.as_mut(), input.required_capabilities())?;
let summary = client.send_message(input)?;

Loading…
Cancel
Save