|
|
|
@ -1,21 +1,84 @@
|
|
|
|
|
use crate::{config::GlobalConfig, utils::exec_command};
|
|
|
|
|
use crate::{
|
|
|
|
|
client::{MessageToolCall, MessageToolCallFunction},
|
|
|
|
|
config::GlobalConfig,
|
|
|
|
|
utils::{dimmed_text, error_text, exec_command, spawn_command},
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
use anyhow::{anyhow, bail, Context, Result};
|
|
|
|
|
use fancy_regex::Regex;
|
|
|
|
|
use indexmap::{IndexMap, IndexSet};
|
|
|
|
|
use inquire::Confirm;
|
|
|
|
|
use is_terminal::IsTerminal;
|
|
|
|
|
use lazy_static::lazy_static;
|
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
|
use serde_json::Value;
|
|
|
|
|
use std::{collections::HashMap, fs, path::Path};
|
|
|
|
|
use serde_json::{json, Value};
|
|
|
|
|
use std::{collections::HashMap, fs, io::stdout, path::Path, sync::mpsc::channel};
|
|
|
|
|
use threadpool::ThreadPool;
|
|
|
|
|
|
|
|
|
|
const BIN_DIR_NAME: &str = "bin";
|
|
|
|
|
const DECLARATIONS_FILE_PATH: &str = "functions.json";
|
|
|
|
|
|
|
|
|
|
pub fn run_tool_calls(config: &GlobalConfig, calls: &[ToolCall]) -> Result<()> {
|
|
|
|
|
for call in calls {
|
|
|
|
|
call.run(config)?;
|
|
|
|
|
lazy_static! {
|
|
|
|
|
static ref THREAD_POOL: ThreadPool = ThreadPool::new(num_cpus::get());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn run_tool_calls(config: &GlobalConfig, calls: Vec<ToolCall>) -> Result<Vec<ToolCallResult>> {
|
|
|
|
|
let mut output = vec![];
|
|
|
|
|
if calls.is_empty() {
|
|
|
|
|
return Ok(output);
|
|
|
|
|
}
|
|
|
|
|
let parallel = calls.len() > 1 && calls.iter().all(|v| !v.is_execute());
|
|
|
|
|
if parallel {
|
|
|
|
|
let (tx, rx) = channel();
|
|
|
|
|
let calls_len = calls.len();
|
|
|
|
|
for (index, call) in calls.into_iter().enumerate() {
|
|
|
|
|
let tx = tx.clone();
|
|
|
|
|
let config = config.clone();
|
|
|
|
|
THREAD_POOL.execute(move || {
|
|
|
|
|
let result = call.run(&config);
|
|
|
|
|
let _ = tx.send((index, call, result));
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
let mut list: Vec<(usize, ToolCall, Result<Option<Value>>)> =
|
|
|
|
|
rx.iter().take(calls_len).collect();
|
|
|
|
|
list.sort_by_key(|v| v.0);
|
|
|
|
|
for (_, call, result) in list {
|
|
|
|
|
let result = result?;
|
|
|
|
|
if let Some(result) = result {
|
|
|
|
|
output.push(ToolCallResult::new(call, result));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for call in calls {
|
|
|
|
|
if let Some(result) = call.run(config)? {
|
|
|
|
|
output.push(ToolCallResult::new(call, result));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
Ok(output)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
|
|
|
pub struct ToolCallResult {
|
|
|
|
|
pub call: ToolCall,
|
|
|
|
|
pub output: Value,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl ToolCallResult {
|
|
|
|
|
pub fn new(call: ToolCall, output: Value) -> Self {
|
|
|
|
|
Self { call, output }
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn build_message(&self) -> MessageToolCall {
|
|
|
|
|
MessageToolCall {
|
|
|
|
|
id: self.call.id.clone(),
|
|
|
|
|
typ: "function".into(),
|
|
|
|
|
function: MessageToolCallFunction {
|
|
|
|
|
name: self.call.name.clone(),
|
|
|
|
|
arguments: self.call.arguments.clone(),
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Default)]
|
|
|
|
@ -62,19 +125,19 @@ impl Function {
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn filtered_declarations(&self, filters: &[String]) -> Vec<FunctionDeclaration> {
|
|
|
|
|
if filters.is_empty() {
|
|
|
|
|
vec![]
|
|
|
|
|
} else if filters.len() == 1 && filters[0] == "*" {
|
|
|
|
|
self.declarations.clone()
|
|
|
|
|
} else if let Ok(re) = Regex::new(&filters.join("|")) {
|
|
|
|
|
self.declarations
|
|
|
|
|
.iter()
|
|
|
|
|
.filter(|v| re.is_match(&v.name).unwrap_or_default())
|
|
|
|
|
.cloned()
|
|
|
|
|
.collect()
|
|
|
|
|
pub fn filtered_declarations(&self, filter: Option<&str>) -> Option<Vec<FunctionDeclaration>> {
|
|
|
|
|
let filter = filter?;
|
|
|
|
|
let regex = Regex::new(&format!("^({filter})$")).ok()?;
|
|
|
|
|
let output: Vec<FunctionDeclaration> = self
|
|
|
|
|
.declarations
|
|
|
|
|
.iter()
|
|
|
|
|
.filter(|v| regex.is_match(&v.name).unwrap_or_default())
|
|
|
|
|
.cloned()
|
|
|
|
|
.collect();
|
|
|
|
|
if output.is_empty() {
|
|
|
|
|
None
|
|
|
|
|
} else {
|
|
|
|
|
vec![]
|
|
|
|
|
Some(output)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -107,37 +170,43 @@ pub struct JsonSchema {
|
|
|
|
|
pub required: Option<Vec<String>>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Default)]
|
|
|
|
|
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
|
|
|
|
pub struct ToolCall {
|
|
|
|
|
pub name: String,
|
|
|
|
|
pub args: Value,
|
|
|
|
|
pub arguments: Value,
|
|
|
|
|
pub id: Option<String>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl ToolCall {
|
|
|
|
|
pub fn new(name: String, args: Value) -> Self {
|
|
|
|
|
Self { name, args }
|
|
|
|
|
pub fn new(name: String, arguments: Value, id: Option<String>) -> Self {
|
|
|
|
|
Self {
|
|
|
|
|
name,
|
|
|
|
|
arguments,
|
|
|
|
|
id,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn run(&self, config: &GlobalConfig) -> Result<()> {
|
|
|
|
|
pub fn run(&self, config: &GlobalConfig) -> Result<Option<Value>> {
|
|
|
|
|
let name = self.name.clone();
|
|
|
|
|
if !config.read().function.names.contains(&name) {
|
|
|
|
|
bail!("Invalid call: {name} {}", self.args);
|
|
|
|
|
bail!("Unexpected call: {name} {}", self.arguments);
|
|
|
|
|
}
|
|
|
|
|
let args = if self.args.is_object() {
|
|
|
|
|
self.args.clone()
|
|
|
|
|
} else if let Some(args) = self.args.as_str() {
|
|
|
|
|
let args: Value =
|
|
|
|
|
serde_json::from_str(args).map_err(|_| anyhow!("Invalid call args: {args}"))?;
|
|
|
|
|
let arguments = if self.arguments.is_object() {
|
|
|
|
|
self.arguments.clone()
|
|
|
|
|
} else if let Some(arguments) = self.arguments.as_str() {
|
|
|
|
|
let args: Value = serde_json::from_str(arguments)
|
|
|
|
|
.map_err(|_| anyhow!("Invalid call arguments: {arguments}"))?;
|
|
|
|
|
args
|
|
|
|
|
} else {
|
|
|
|
|
bail!("Invalid call args: {}", self.args);
|
|
|
|
|
bail!("Invalid call arguments: {}", self.arguments);
|
|
|
|
|
};
|
|
|
|
|
let args = convert_args(&args);
|
|
|
|
|
let arguments = convert_arguments(&arguments);
|
|
|
|
|
|
|
|
|
|
let prompt_text = format!(
|
|
|
|
|
"call {} {}",
|
|
|
|
|
"Call {} {}",
|
|
|
|
|
name,
|
|
|
|
|
args.iter()
|
|
|
|
|
arguments
|
|
|
|
|
.iter()
|
|
|
|
|
.map(|v| shell_words::quote(v).to_string())
|
|
|
|
|
.collect::<Vec<String>>()
|
|
|
|
|
.join(" ")
|
|
|
|
@ -150,32 +219,45 @@ impl ToolCall {
|
|
|
|
|
} else {
|
|
|
|
|
None
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let ans = Confirm::new(&prompt_text).with_default(true).prompt()?;
|
|
|
|
|
if ans {
|
|
|
|
|
#[cfg(windows)]
|
|
|
|
|
let name = {
|
|
|
|
|
let mut name = name;
|
|
|
|
|
let bin_dir = config.read().function.bin_dir.clone();
|
|
|
|
|
if let Ok(exts) = std::env::var("PATHEXT") {
|
|
|
|
|
if let Some(cmd_path) = exts
|
|
|
|
|
.split(';')
|
|
|
|
|
.map(|ext| bin_dir.join(format!("{}{}", self.name, ext)))
|
|
|
|
|
.find(|path| path.exists())
|
|
|
|
|
{
|
|
|
|
|
name = cmd_path.display().to_string();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
name
|
|
|
|
|
let output = if self.is_execute() {
|
|
|
|
|
let proceed = if stdout().is_terminal() {
|
|
|
|
|
Confirm::new(&prompt_text).with_default(true).prompt()?
|
|
|
|
|
} else {
|
|
|
|
|
println!("{}", dimmed_text(&prompt_text));
|
|
|
|
|
true
|
|
|
|
|
};
|
|
|
|
|
exec_command(&name, &args, envs)?;
|
|
|
|
|
}
|
|
|
|
|
if proceed {
|
|
|
|
|
#[cfg(windows)]
|
|
|
|
|
let name = polyfill_cmd_name(name, &config.read().function.bin_dir);
|
|
|
|
|
spawn_command(&name, &arguments, envs)?;
|
|
|
|
|
}
|
|
|
|
|
None
|
|
|
|
|
} else {
|
|
|
|
|
println!("{}", dimmed_text(&prompt_text));
|
|
|
|
|
#[cfg(windows)]
|
|
|
|
|
let name = polyfill_cmd_name(name, &config.read().function.bin_dir);
|
|
|
|
|
let (success, stdout, stderr) = exec_command(&name, &arguments, envs)?;
|
|
|
|
|
if stderr.is_empty() {
|
|
|
|
|
eprintln!("{}", error_text(&stderr));
|
|
|
|
|
}
|
|
|
|
|
if success && !stdout.is_empty() {
|
|
|
|
|
serde_json::from_str(&stdout)
|
|
|
|
|
.ok()
|
|
|
|
|
.or_else(|| Some(json!({"output": stdout})))
|
|
|
|
|
} else {
|
|
|
|
|
None
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Ok(output)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
|
pub fn is_execute(&self) -> bool {
|
|
|
|
|
self.name.starts_with("execute_") || self.name.contains("__execute_")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn convert_args(args: &Value) -> Vec<String> {
|
|
|
|
|
fn convert_arguments(args: &Value) -> Vec<String> {
|
|
|
|
|
let mut options: Vec<String> = Vec::new();
|
|
|
|
|
|
|
|
|
|
if let Value::Object(map) = args {
|
|
|
|
@ -215,6 +297,21 @@ fn prepend_env_path(bin_dir: &Path) -> Result<String> {
|
|
|
|
|
Ok(new_path)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[cfg(windows)]
|
|
|
|
|
fn polyfill_cmd_name(name: &str, bin_dir: &std::path::Path) -> String {
|
|
|
|
|
let mut name = name.to_string();
|
|
|
|
|
if let Ok(exts) = std::env::var("PATHEXT") {
|
|
|
|
|
if let Some(cmd_path) = exts
|
|
|
|
|
.split(';')
|
|
|
|
|
.map(|ext| bin_dir.join(format!("{}{}", name, ext)))
|
|
|
|
|
.find(|path| path.exists())
|
|
|
|
|
{
|
|
|
|
|
name = cmd_path.display().to_string();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
name
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
|
mod tests {
|
|
|
|
|
|
|
|
|
@ -228,7 +325,7 @@ mod tests {
|
|
|
|
|
"baz": ["v1", "v2"]
|
|
|
|
|
});
|
|
|
|
|
assert_eq!(
|
|
|
|
|
convert_args(&args),
|
|
|
|
|
convert_arguments(&args),
|
|
|
|
|
vec!["--foo", "--bar", "val", "--baz", "v1", "--baz", "v2"]
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|