You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
aichat/src/config/input.rs

386 lines
12 KiB
Rust

use super::{role::Role, session::Session, GlobalConfig};
use crate::client::{
init_client, list_models, ChatCompletionsData, Client, ImageUrl, Message, MessageContent,
MessageContentPart, MessageRole, Model,
};
use crate::function::{ToolCallResult, ToolResults};
use crate::utils::{base64_encode, sha256};
use anyhow::{bail, Context, Result};
use fancy_regex::Regex;
use lazy_static::lazy_static;
use mime_guess::from_path;
use std::{
collections::HashMap,
fs::File,
io::Read,
path::{Path, PathBuf},
};
use unicode_width::{UnicodeWidthChar, UnicodeWidthStr};
const IMAGE_EXTS: [&str; 5] = ["png", "jpeg", "jpg", "webp", "gif"];
lazy_static! {
static ref URL_RE: Regex = Regex::new(r"^[A-Za-z0-9_-]{2,}:/").unwrap();
}
#[derive(Debug, Clone)]
pub struct Input {
config: GlobalConfig,
text: String,
medias: Vec<String>,
data_urls: HashMap<String, String>,
tool_call: Option<ToolResults>,
context: InputContext,
}
impl Input {
pub fn from_str(config: &GlobalConfig, text: &str, context: Option<InputContext>) -> Self {
Self {
config: config.clone(),
text: text.to_string(),
medias: Default::default(),
data_urls: Default::default(),
tool_call: None,
context: context.unwrap_or_else(|| InputContext::from_config(config)),
}
}
pub fn new(
config: &GlobalConfig,
text: &str,
files: Vec<String>,
context: Option<InputContext>,
) -> Result<Self> {
let mut texts = vec![text.to_string()];
let mut medias = vec![];
let mut data_urls = HashMap::new();
let files: Vec<_> = files
.iter()
.map(|f| (f, is_image_ext(Path::new(f))))
.collect();
let include_filepath = files.iter().filter(|(_, is_image)| !*is_image).count() > 1;
for (file_item, is_image) in files {
match resolve_local_file(file_item) {
Some(file_path) => {
if is_image {
let data_url = read_media_to_data_url(&file_path)
.with_context(|| format!("Unable to read media file '{file_item}'"))?;
data_urls.insert(sha256(&data_url), file_path.display().to_string());
medias.push(data_url)
} else {
let text = read_file(&file_path)
.with_context(|| format!("Unable to read file '{file_item}'"))?;
if include_filepath {
texts.push(format!("`{file_item}`:\n~~~~~~\n{text}\n~~~~~~"));
} else {
texts.push(text);
}
}
}
None => {
if is_image {
medias.push(file_item.to_string())
} else {
bail!("Unable to use remote file '{file_item}");
}
}
}
}
Ok(Self {
config: config.clone(),
text: texts.join("\n"),
medias,
data_urls,
tool_call: Default::default(),
context: context.unwrap_or_else(|| InputContext::from_config(config)),
})
}
pub fn is_empty(&self) -> bool {
self.text.is_empty() && self.medias.is_empty()
}
pub fn data_urls(&self) -> HashMap<String, String> {
self.data_urls.clone()
}
pub fn text(&self) -> String {
self.text.clone()
}
pub fn set_text(&mut self, text: String) {
self.text = text;
}
pub fn merge_tool_call(
mut self,
output: String,
tool_call_results: Vec<ToolCallResult>,
) -> Self {
match self.tool_call.as_mut() {
Some(exist_tool_call_results) => {
exist_tool_call_results.0.extend(tool_call_results);
exist_tool_call_results.1 = output;
}
None => self.tool_call = Some((tool_call_results, output)),
}
self
}
pub fn model(&self) -> Model {
let model = self.config.read().model.clone();
if let Some(model_id) = self.role().and_then(|v| v.model_id.clone()) {
if model.id() != model_id {
if let Some(model) = list_models(&self.config.read())
.into_iter()
.find(|v| v.id() == model_id)
{
return model.clone();
}
}
};
model
}
pub fn create_client(&self) -> Result<Box<dyn Client>> {
init_client(&self.config, Some(self.model()))
}
pub fn prepare_completion_data(
&self,
model: &Model,
stream: bool,
) -> Result<ChatCompletionsData> {
if !self.medias.is_empty() && !model.supports_vision() {
bail!("The current model does not support vision.");
}
let messages = self.build_messages()?;
self.config.read().model.max_input_tokens_limit(&messages)?;
let (temperature, top_p) = if let Some(session) = self.session(&self.config.read().session)
{
(session.temperature(), session.top_p())
} else if let Some(role) = self.role() {
(role.temperature, role.top_p)
} else {
let config = self.config.read();
(config.temperature, config.top_p)
};
let mut functions = None;
if self.config.read().function_calling && model.supports_function_calling() {
let config = self.config.read();
let function_matcher = if let Some(session) = self.session(&config.session) {
session.function_matcher()
} else if let Some(role) = self.role() {
role.function_matcher.as_deref()
} else {
None
};
functions = config.function.select(function_matcher);
};
Ok(ChatCompletionsData {
messages,
temperature,
top_p,
functions,
stream,
})
}
pub fn build_messages(&self) -> Result<Vec<Message>> {
let mut messages = if let Some(session) = self.session(&self.config.read().session) {
session.build_messages(self)
} else if let Some(role) = self.role() {
role.build_messages(self)
} else {
vec![Message::new(MessageRole::User, self.message_content())]
};
if let Some(tool_results) = &self.tool_call {
messages.push(Message::new(
MessageRole::Assistant,
MessageContent::ToolResults(tool_results.clone()),
))
}
Ok(messages)
}
pub fn echo_messages(&self) -> String {
if let Some(session) = self.session(&self.config.read().session) {
session.echo_messages(self)
} else if let Some(role) = self.role() {
role.echo_messages(self)
} else {
self.render()
}
}
pub fn role(&self) -> Option<&Role> {
self.context.role.as_ref()
}
pub fn session<'a>(&self, session: &'a Option<Session>) -> Option<&'a Session> {
if self.context.session {
session.as_ref()
} else {
None
}
}
pub fn session_mut<'a>(&self, session: &'a mut Option<Session>) -> Option<&'a mut Session> {
if self.context.session {
session.as_mut()
} else {
None
}
}
pub fn summary(&self) -> String {
let text: String = self
.text
.trim()
.chars()
.map(|c| if c.is_control() { ' ' } else { c })
.collect();
if text.width_cjk() > 70 {
let mut sum_width = 0;
let mut chars = vec![];
for c in text.chars() {
sum_width += c.width_cjk().unwrap_or(1);
if sum_width > 67 {
chars.extend(['.', '.', '.']);
break;
}
chars.push(c);
}
chars.into_iter().collect()
} else {
text
}
}
pub fn render(&self) -> String {
if self.medias.is_empty() {
return self.text.clone();
}
let text = if self.text.is_empty() {
self.text.to_string()
} else {
format!(" -- {}", self.text)
};
let files: Vec<String> = self
.medias
.iter()
.cloned()
.map(|url| resolve_data_url(&self.data_urls, url))
.collect();
format!(".file {}{}", files.join(" "), text)
}
pub fn message_content(&self) -> MessageContent {
if self.medias.is_empty() {
MessageContent::Text(self.text.clone())
} else {
let mut list: Vec<MessageContentPart> = self
.medias
.iter()
.cloned()
.map(|url| MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
})
.collect();
if !self.text.is_empty() {
list.insert(
0,
MessageContentPart::Text {
text: self.text.clone(),
},
);
}
MessageContent::Array(list)
}
}
}
#[derive(Debug, Clone, Default)]
pub struct InputContext {
role: Option<Role>,
session: bool,
}
impl InputContext {
pub fn new(role: Option<Role>, session: bool) -> Self {
Self { role, session }
}
pub fn from_config(config: &GlobalConfig) -> Self {
let config = config.read();
InputContext::new(config.role.clone(), config.session.is_some())
}
pub fn role(role: Role) -> Self {
Self {
role: Some(role),
session: false,
}
}
}
pub fn resolve_data_url(data_urls: &HashMap<String, String>, data_url: String) -> String {
if data_url.starts_with("data:") {
let hash = sha256(&data_url);
if let Some(path) = data_urls.get(&hash) {
return path.to_string();
}
data_url
} else {
data_url
}
}
fn resolve_local_file(file: &str) -> Option<PathBuf> {
if let Ok(true) = URL_RE.is_match(file) {
return None;
}
let path = if let (Some(file), Some(home)) = (file.strip_prefix("~/"), dirs::home_dir()) {
home.join(file)
} else {
std::env::current_dir().ok()?.join(file)
};
Some(path)
}
fn is_image_ext(path: &Path) -> bool {
path.extension()
.map(|v| {
IMAGE_EXTS
.iter()
.any(|ext| *ext == v.to_string_lossy().to_lowercase())
})
.unwrap_or_default()
}
fn read_media_to_data_url<P: AsRef<Path>>(image_path: P) -> Result<String> {
let image_path = image_path.as_ref();
let mime_type = from_path(image_path).first_or_octet_stream().to_string();
let mut file = File::open(image_path)?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)?;
let encoded_image = base64_encode(buffer);
let data_url = format!("data:{};base64,{}", mime_type, encoded_image);
Ok(data_url)
}
fn read_file<P: AsRef<Path>>(file_path: P) -> Result<String> {
let file_path = file_path.as_ref();
let mut text = String::new();
let mut file = File::open(file_path)?;
file.read_to_string(&mut text)?;
Ok(text)
}