refactor: config::Input (#503)

pull/504/head
sigoden 3 weeks ago committed by GitHub
parent 154c1e0b4b
commit 5284a18248
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -290,11 +290,11 @@ pub trait Client: Sync + Send {
async fn send_message(&self, input: Input) -> Result<(String, CompletionDetails)> {
let global_config = self.config().0;
if global_config.read().dry_run {
let content = global_config.read().echo_messages(&input);
let content = input.echo_messages();
return Ok((content, CompletionDetails::default()));
}
let client = self.build_client()?;
let data = global_config.read().prepare_send_data(&input, false)?;
let data = input.prepare_send_data(false)?;
self.send_message_inner(&client, data)
.await
.with_context(|| "Failed to get answer")
@ -315,7 +315,7 @@ pub trait Client: Sync + Send {
ret = async {
let global_config = self.config().0;
if global_config.read().dry_run {
let content = global_config.read().echo_messages(&input);
let content = input.echo_messages();
let tokens = tokenize(&content);
for token in tokens {
tokio::time::sleep(Duration::from_millis(10)).await;
@ -324,7 +324,7 @@ pub trait Client: Sync + Send {
return Ok(());
}
let client = self.build_client()?;
let data = global_config.read().prepare_send_data(&input, true)?;
let data = input.prepare_send_data(true)?;
self.send_message_streaming_inner(&client, handler, data).await
} => {
handler.done()?;

@ -134,21 +134,3 @@ pub fn extract_system_message(messages: &mut Vec<Message>) -> Option<String> {
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::InputContext;
#[test]
fn test_serde() {
assert_eq!(
serde_json::to_string(&Message::new(&Input::from_str(
"Hello World",
InputContext::default()
)))
.unwrap(),
"{\"role\":\"user\",\"content\":\"Hello World\"}"
);
}
}

@ -1,7 +1,9 @@
use super::role::Role;
use super::session::Session;
use super::{role::Role, session::Session, GlobalConfig};
use crate::client::{ImageUrl, MessageContent, MessageContentPart, ModelCapabilities};
use crate::client::{
init_client, Client, ImageUrl, Message, MessageContent, MessageContentPart, ModelCapabilities,
SendData,
};
use crate::utils::{base64_encode, sha256};
use anyhow::{bail, Context, Result};
@ -24,6 +26,7 @@ lazy_static! {
#[derive(Debug, Clone)]
pub struct Input {
config: GlobalConfig,
text: String,
medias: Vec<String>,
data_urls: HashMap<String, String>,
@ -31,16 +34,22 @@ pub struct Input {
}
impl Input {
pub fn from_str(text: &str, context: InputContext) -> Self {
pub fn from_str(config: &GlobalConfig, text: &str, context: Option<InputContext>) -> Self {
Self {
config: config.clone(),
text: text.to_string(),
medias: Default::default(),
data_urls: Default::default(),
context,
context: context.unwrap_or_else(|| InputContext::from_config(config)),
}
}
pub fn new(text: &str, files: Vec<String>, context: InputContext) -> Result<Self> {
pub fn new(
config: &GlobalConfig,
text: &str,
files: Vec<String>,
context: Option<InputContext>,
) -> Result<Self> {
let mut texts = vec![text.to_string()];
let mut medias = vec![];
let mut data_urls = HashMap::new();
@ -78,10 +87,11 @@ impl Input {
}
Ok(Self {
config: config.clone(),
text: texts.join("\n"),
medias,
data_urls,
context,
context: context.unwrap_or_else(|| InputContext::from_config(config)),
})
}
@ -101,6 +111,61 @@ impl Input {
self.text = text;
}
pub fn create_client(&self) -> Result<Box<dyn Client>> {
init_client(&self.config)
}
pub fn prepare_send_data(&self, stream: bool) -> Result<SendData> {
let messages = self.build_messages()?;
self.config.read().model.max_input_tokens_limit(&messages)?;
let (temperature, top_p) = if let Some(session) = self.session(&self.config.read().session)
{
(session.temperature(), session.top_p())
} else if let Some(role) = self.role() {
(role.temperature, role.top_p)
} else {
let config = self.config.read();
(config.temperature, config.top_p)
};
Ok(SendData {
messages,
temperature,
top_p,
stream,
})
}
pub fn maybe_print_input_tokens(&self) {
if self.config.read().dry_run {
if let Ok(messages) = self.build_messages() {
let tokens = self.config.read().model.total_tokens(&messages);
println!(">>> This message consumes {tokens} tokens. <<<");
}
}
}
pub fn build_messages(&self) -> Result<Vec<Message>> {
let messages = if let Some(session) = self.session(&self.config.read().session) {
session.build_messages(self)
} else if let Some(role) = self.role() {
role.build_messages(self)
} else {
let message = Message::new(self);
vec![message]
};
Ok(messages)
}
pub fn echo_messages(&self) -> String {
if let Some(session) = self.session(&self.config.read().session) {
session.echo_messages(self)
} else if let Some(role) = self.role() {
role.echo_messages(self)
} else {
self.render()
}
}
pub fn role(&self) -> Option<&Role> {
self.context.role.as_ref()
}
@ -207,6 +272,11 @@ impl InputContext {
Self { role, session }
}
pub fn from_config(config: &GlobalConfig) -> Self {
let config = config.read();
InputContext::new(config.role.clone(), config.session.is_some())
}
pub fn role(role: Role) -> Self {
Self {
role: Some(role),

@ -7,7 +7,7 @@ pub use self::role::{Role, CODE_ROLE, EXPLAIN_SHELL_ROLE, SHELL_ROLE};
use self::session::{Session, TEMP_SESSION_NAME};
use crate::client::{
create_client_config, list_client_types, list_models, ClientConfig, Message, Model, SendData,
create_client_config, list_client_types, list_models, ClientConfig, Model,
OPENAI_COMPATIBLE_PLATFORMS,
};
use crate::render::{MarkdownRender, RenderOptions};
@ -305,7 +305,7 @@ impl Config {
Ok(())
}
pub fn get_state(&self) -> State {
pub fn state(&self) -> State {
if let Some(session) = &self.session {
if session.is_empty() {
if self.role.is_some() {
@ -359,28 +359,6 @@ impl Config {
}
}
pub fn echo_messages(&self, input: &Input) -> String {
if let Some(session) = input.session(&self.session) {
session.echo_messages(input)
} else if let Some(role) = input.role() {
role.echo_messages(input)
} else {
input.render()
}
}
pub fn build_messages(&self, input: &Input) -> Result<Vec<Message>> {
let messages = if let Some(session) = input.session(&self.session) {
session.build_messages(input)
} else if let Some(role) = input.role() {
role.build_messages(input)
} else {
let message = Message::new(input);
vec![message]
};
Ok(messages)
}
pub fn set_wrap(&mut self, value: &str) -> Result<()> {
if value == "no" {
self.wrap = None;
@ -402,7 +380,7 @@ impl Config {
None => bail!("No model '{}'", value),
Some(model) => {
if let Some(session) = self.session.as_mut() {
session.set_model(model.clone())?;
session.set_model(&model);
}
self.model = model;
Ok(())
@ -625,7 +603,7 @@ impl Config {
self.session = Some(Session::new(self, name));
} else {
let session = Session::load(name, &session_path)?;
let model_id = session.model().to_string();
let model_id = session.model_id().to_string();
self.session = Some(session);
self.set_model(&model_id)?;
}
@ -787,44 +765,6 @@ impl Config {
render_prompt(right_prompt, &variables)
}
pub fn prepare_send_data(&self, input: &Input, stream: bool) -> Result<SendData> {
let messages = self.build_messages(input)?;
let temperature = if let Some(session) = input.session(&self.session) {
session.temperature()
} else if let Some(role) = input.role() {
role.temperature
} else {
self.temperature
};
let top_p = if let Some(session) = input.session(&self.session) {
session.top_p()
} else if let Some(role) = input.role() {
role.top_p
} else {
self.top_p
};
self.model.max_input_tokens_limit(&messages)?;
Ok(SendData {
messages,
temperature,
top_p,
stream,
})
}
pub fn input_context(&self) -> InputContext {
InputContext::new(self.role.clone(), self.session.is_some())
}
pub fn maybe_print_send_tokens(&self, input: &Input) {
if self.dry_run {
if let Ok(messages) = self.build_messages(input) {
let tokens = self.model.total_tokens(&messages);
println!(">>> This message consumes {tokens} tokens. <<<");
}
}
}
fn generate_prompt_context(&self) -> HashMap<&str, String> {
let mut output = HashMap::new();
output.insert("model", self.model.id());

@ -74,7 +74,7 @@ impl Session {
&self.name
}
pub fn model(&self) -> &str {
pub fn model_id(&self) -> &str {
&self.model_id
}
@ -112,7 +112,7 @@ impl Session {
let (tokens, percent) = self.tokens_and_percent();
let mut data = json!({
"path": self.path,
"model": self.model(),
"model": self.model_id(),
});
if let Some(temperature) = self.temperature() {
data["temperature"] = temperature.into();
@ -240,14 +240,13 @@ impl Session {
}
}
pub fn set_model(&mut self, model: Model) -> Result<()> {
pub fn set_model(&mut self, model: &Model) {
let model_id = model.id();
if self.model_id != model_id {
self.model_id = model_id;
self.dirty = true;
}
self.model = model;
Ok(())
self.model = model.clone();
}
pub fn compress(&mut self, prompt: String) {

@ -12,7 +12,7 @@ mod utils;
extern crate log;
use crate::cli::Cli;
use crate::client::{ensure_model_capabilities, init_client, list_models, send_stream};
use crate::client::{ensure_model_capabilities, list_models, send_stream};
use crate::config::{
Config, GlobalConfig, Input, InputContext, WorkingMode, CODE_ROLE, EXPLAIN_SHELL_ROLE,
SHELL_ROLE,
@ -138,9 +138,9 @@ async fn start_directive(
no_stream: bool,
code_mode: bool,
) -> Result<()> {
let mut client = init_client(config)?;
let mut client = input.create_client()?;
ensure_model_capabilities(client.as_mut(), input.required_capabilities())?;
config.read().maybe_print_send_tokens(&input);
input.maybe_print_input_tokens();
let is_terminal_stdout = stdout().is_terminal();
let extract_code = !is_terminal_stdout && code_mode;
let output = if no_stream || extract_code {
@ -176,8 +176,8 @@ async fn start_interactive(config: &GlobalConfig) -> Result<()> {
#[async_recursion::async_recursion]
async fn execute(config: &GlobalConfig, mut input: Input) -> Result<()> {
let client = init_client(config)?;
config.read().maybe_print_send_tokens(&input);
let client = input.create_client()?;
input.maybe_print_input_tokens();
let is_terminal_stdout = stdout().is_terminal();
let ret = if is_terminal_stdout {
let (spinner_tx, spinner_rx) = oneshot::channel();
@ -223,7 +223,7 @@ async fn execute(config: &GlobalConfig, mut input: Input) -> Result<()> {
}
"📙 Explain" => {
let role = config.read().retrieve_role(EXPLAIN_SHELL_ROLE)?;
let input = Input::from_str(&eval_str, InputContext::role(role));
let input = Input::from_str(config, &eval_str, Some(InputContext::role(role)));
let abort = create_abort_signal();
send_stream(&input, client.as_ref(), config, abort).await?;
continue;
@ -254,11 +254,10 @@ fn aggregate_text(text: Option<String>) -> Result<Option<String>> {
}
fn create_input(config: &GlobalConfig, text: Option<String>, file: &[String]) -> Result<Input> {
let input_context = config.read().input_context();
let input = if file.is_empty() {
Input::from_str(&text.unwrap_or_default(), input_context)
Input::from_str(config, &text.unwrap_or_default(), None)
} else {
Input::new(&text.unwrap_or_default(), file.to_vec(), input_context)?
Input::new(config, &text.unwrap_or_default(), file.to_vec(), None)?
};
if input.is_empty() {
bail!("No input");

@ -27,7 +27,7 @@ impl Completer for ReplCompleter {
return suggestions;
}
let state = self.config.read().get_state();
let state = self.config.read().state();
let commands: Vec<_> = self
.commands

@ -6,7 +6,7 @@ use self::completer::ReplCompleter;
use self::highlighter::ReplHighlighter;
use self::prompt::ReplPrompt;
use crate::client::{ensure_model_capabilities, init_client, send_stream};
use crate::client::{ensure_model_capabilities, send_stream};
use crate::config::{GlobalConfig, Input, InputContext, State};
use crate::render::render_error;
use crate::utils::{create_abort_signal, set_text, AbortSignal};
@ -176,7 +176,11 @@ impl Repl {
Some(args) => match args.split_once(|c| c == '\n' || c == ' ') {
Some((name, text)) => {
let role = self.config.read().retrieve_role(name.trim())?;
let input = Input::from_str(text.trim(), InputContext::role(role));
let input = Input::from_str(
&self.config,
text.trim(),
Some(InputContext::role(role)),
);
self.ask(input).await?;
}
None => {
@ -218,7 +222,7 @@ impl Repl {
Some(args) => {
let (files, text) = split_files_text(args);
let files = shell_words::split(files).with_context(|| "Invalid args")?;
let input = Input::new(text, files, self.config.read().input_context())?;
let input = Input::new(&self.config, text, files, None)?;
self.ask(input).await?;
}
None => println!("Usage: .file <files>... [-- <text>...]"),
@ -244,7 +248,7 @@ impl Repl {
_ => unknown_command()?,
},
None => {
let input = Input::from_str(line, self.config.read().input_context());
let input = Input::from_str(&self.config, line, None);
self.ask(input).await?;
}
}
@ -261,8 +265,8 @@ impl Repl {
while self.config.read().is_compressing_session() {
std::thread::sleep(std::time::Duration::from_millis(100));
}
self.config.read().maybe_print_send_tokens(&input);
let mut client = init_client(&self.config)?;
input.maybe_print_input_tokens();
let mut client = input.create_client()?;
ensure_model_capabilities(client.as_mut(), input.required_capabilities())?;
let output = send_stream(&input, client.as_ref(), &self.config, self.abort.clone()).await?;
self.config.write().save_message(input, &output)?;
@ -442,11 +446,8 @@ fn parse_command(line: &str) -> Option<(&str, Option<&str>)> {
}
async fn compress_session(config: &GlobalConfig) -> Result<()> {
let input = Input::from_str(
config.read().summarize_prompt(),
config.read().input_context(),
);
let mut client = init_client(config)?;
let input = Input::from_str(config, config.read().summarize_prompt(), None);
let mut client = input.create_client()?;
ensure_model_capabilities(client.as_mut(), input.required_capabilities())?;
let (summary, _) = client.send_message(input).await?;
config.write().compress_session(&summary);

Loading…
Cancel
Save