refactor message data structure and make claude client supporting function calling

pull/514/head
sigoden 2 weeks ago
parent f71aa4f5dd
commit bac9b447ca

3
.gitignore vendored

@ -1,2 +1,3 @@
/target
/tmp
/tmp
*.log

@ -24,6 +24,34 @@ test-platform-env() {
cargo run -- "$@"
}
# @cmd Test function calling
# @option --model[`_choice_model`]
# @option --preset[=default|weather|multi-weathers]
# @flag -S --no-stream
# @arg text~
test-function-calling() {
args=(--role %functions%)
if [[ -n "$argc_model" ]]; then
args+=("--model" "$argc_model")
fi
if [[ -n "$argc_no_stream" ]]; then
args+=("-S")
fi
if [[ -z "$argc_text" ]]; then
case "$argc_preset" in
multi-weathers)
text="what is the weather in London and Pairs?"
;;
weather|*)
text="what is the weather in London?"
;;
esac
else
text="${argc_text[*]}"
fi
cargo run -- "${args[@]}" "$text"
}
# @cmd Test clients
# @arg clients+[`_choice_client`]
test-clients() {
@ -174,6 +202,7 @@ chat-claude() {
-X POST \
-H 'content-type: application/json' \
-H 'anthropic-version: 2023-06-01' \
-H 'anthropic-beta: tools-2024-05-16' \
-H "x-api-key: $CLAUDE_API_KEY" \
-d "$(_build_body claude "$@")"
}

@ -1,6 +1,7 @@
# notes:
# - do not submit pull requests to add new models; this list will be updated in batches with new releases.
# - do not add any open-source LLMs except for the following: Mixtral, LLama-3, Gemma, Qwen, Phi-3, DeepSeek, Command-R, dbrx, Yi.
# - only model that supports parallel fucntion calling can have the `supports_function_calling` property set to true.
- platform: openai
# docs:
@ -79,7 +80,6 @@
max_output_tokens: 2048
input_price: 0.5
output_price: 1.5
supports_function_calling: true
- name: gemini-1.0-pro-vision-latest
max_input_tokens: 12288
max_output_tokens: 4096
@ -92,7 +92,6 @@
input_price: 0.35
output_price: 0.53
supports_vision: true
supports_function_calling: true
- name: gemini-1.5-pro-latest
max_input_tokens: 1048576
max_output_tokens: 8192
@ -115,6 +114,7 @@
input_price: 15
output_price: 75
supports_vision: true
supports_function_calling: true
- name: claude-3-sonnet-20240229
max_input_tokens: 200000
max_output_tokens: 4096
@ -150,17 +150,14 @@
max_input_tokens: 64000
input_price: 2
output_price: 6
supports_function_calling: true
- name: mistral-small-latest
max_input_tokens: 32000
input_price: 2
output_price: 6
supports_function_calling: true
- name: mistral-large-latest
max_input_tokens: 32000
input_price: 8
output_price: 24
supports_function_calling: true
- platform: cohere
# docs:
@ -261,7 +258,6 @@
max_output_tokens: 8192
input_price: 0.125
output_price: 0.375
supports_function_calling: true
- name: gemini-1.0-pro-vision
max_input_tokens: 14336
max_output_tokens: 2048

@ -1,10 +1,10 @@
use super::{
catch_error, extract_system_message, sse_stream, ClaudeClient, CompletionOutput, ExtraConfig,
ImageUrl, MessageContent, MessageContentPart, Model, ModelData, PromptAction, PromptKind,
SendData, SsMmessage, SseHandler,
catch_error, extract_system_message, message::*, sse_stream, ClaudeClient, CompletionOutput,
ExtraConfig, ImageUrl, MessageContent, MessageContentPart, Model, ModelData, PromptAction,
PromptKind, SendData, SsMmessage, SseHandler, ToolCall,
};
use anyhow::{anyhow, bail, Result};
use anyhow::{bail, Context, Result};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
@ -36,7 +36,9 @@ impl ClaudeClient {
debug!("Claude Request: {url} {body}");
let mut builder = client.post(url).json(&body);
builder = builder.header("anthropic-version", "2023-06-01");
builder = builder
.header("anthropic-version", "2023-06-01")
.header("anthropic-beta", "tools-2024-05-16");
if let Some(api_key) = api_key {
builder = builder.header("x-api-key", api_key)
}
@ -66,14 +68,58 @@ pub async fn claude_send_message_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
) -> Result<()> {
let mut function_name = String::new();
let mut function_arguments = String::new();
let mut function_id = String::new();
let handle = |message: SsMmessage| -> Result<bool> {
let data: Value = serde_json::from_str(&message.data)?;
debug!("stream-data: {data}");
if let Some(typ) = data["type"].as_str() {
if typ == "content_block_delta" {
if let Some(text) = data["delta"]["text"].as_str() {
handler.text(text)?;
match typ {
"content_block_start" => {
if let (Some("tool_use"), Some(name), Some(id)) = (
data["content_block"]["type"].as_str(),
data["content_block"]["name"].as_str(),
data["content_block"]["id"].as_str(),
) {
if !function_name.is_empty() {
let arguments: Value = function_arguments
.parse()
.context("Invalid call arguments: must be json")?;
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
Some(function_id.clone()),
))?;
}
function_name = name.into();
function_arguments.clear();
function_id = id.into();
}
}
"content_block_delta" => {
if let Some(text) = data["delta"]["text"].as_str() {
handler.text(text)?;
} else if let (true, Some(partial_json)) = (
!function_name.is_empty(),
data["delta"]["partial_json"].as_str(),
) {
function_arguments.push_str(partial_json);
}
}
"content_block_stop" => {
if !function_name.is_empty() {
let arguments: Value = function_arguments
.parse()
.context("Invalid call arguments: must be json")?;
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
Some(function_id.clone()),
))?;
}
}
_ => {}
}
}
Ok(false)
@ -87,61 +133,94 @@ pub fn claude_build_body(data: SendData, model: &Model) -> Result<Value> {
mut messages,
temperature,
top_p,
functions: _,
functions,
stream,
} = data;
let system_message = extract_system_message(&mut messages);
let mut is_tool_call = false;
let mut network_image_urls = vec![];
let messages: Vec<Value> = messages
.into_iter()
.map(|message| {
let role = message.role;
let content = match message.content {
MessageContent::Text(text) => vec![json!({"type": "text", "text": text})],
MessageContent::Array(list) => list
.into_iter()
.map(|item| match item {
MessageContentPart::Text { text } => json!({"type": "text", "text": text}),
MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} => {
if let Some((mime_type, data)) = url
.strip_prefix("data:")
.and_then(|v| v.split_once(";base64,"))
{
json!({
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": data,
}
})
} else {
network_image_urls.push(url.clone());
json!({ "url": url })
.flat_map(|message| {
let Message { role, content } = message;
match content {
MessageContent::Text(text) => vec![json!({
"role": role,
"content": text,
})],
MessageContent::Array(list) => {
let content: Vec<_> = list
.into_iter()
.map(|item| match item {
MessageContentPart::Text { text } => {
json!({"type": "text", "text": text})
}
}
})
.collect(),
MessageContent::ToolCall(_) => {
is_tool_call = true;
vec![]
},
};
json!({ "role": role, "content": content })
MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} => {
if let Some((mime_type, data)) = url
.strip_prefix("data:")
.and_then(|v| v.split_once(";base64,"))
{
json!({
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": data,
}
})
} else {
network_image_urls.push(url.clone());
json!({ "url": url })
}
}
})
.collect();
vec![json!({
"role": role,
"content": content,
})]
}
MessageContent::ToolResults((tool_call_results, text)) => {
let mut tool_call = vec![];
let mut tool_result = vec![];
if !text.is_empty() {
tool_call.push(json!({
"type": "text",
"text": text,
}))
}
for tool_call_result in tool_call_results {
tool_call.push(json!({
"type": "tool_use",
"id": tool_call_result.call.id,
"name": tool_call_result.call.name,
"input": tool_call_result.call.arguments,
}));
tool_result.push(json!({
"type": "tool_result",
"tool_use_id": tool_call_result.call.id,
"content": tool_call_result.output.to_string(),
}));
}
vec![
json!({
"role": "assistant",
"content": tool_call,
}),
json!({
"role": "user",
"content": tool_result,
}),
]
}
}
})
.collect();
if is_tool_call {
bail!("The client does not support function calling",);
}
if !network_image_urls.is_empty() {
bail!(
"The model does not support network images: {:?}",
@ -168,17 +247,58 @@ pub fn claude_build_body(data: SendData, model: &Model) -> Result<Value> {
if stream {
body["stream"] = true.into();
}
if let Some(functions) = functions {
body["tools"] = functions
.iter()
.map(|v| {
json!({
"name": v.name,
"description": v.description,
"input_schema": v.parameters,
})
})
.collect();
}
Ok(body)
}
pub fn claude_extract_completion(data: &Value) -> Result<CompletionOutput> {
let text = data["content"][0]["text"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let text = data["content"][0]["text"].as_str().unwrap_or_default();
let mut tool_calls = vec![];
if let Some(calls) = data["content"].as_array().map(|content| {
content
.iter()
.filter(|content| matches!(content["type"].as_str(), Some("tool_use")))
.collect::<Vec<&Value>>()
}) {
tool_calls = calls
.into_iter()
.filter_map(|call| {
if let (Some(name), Some(input), Some(id)) = (
call["name"].as_str(),
call.get("input"),
call["id"].as_str(),
) {
Some(ToolCall::new(
name.to_string(),
input.clone(),
Some(id.to_string()),
))
} else {
None
}
})
.collect();
};
if text.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}");
}
let output = CompletionOutput {
text: text.to_string(),
tool_calls: vec![],
tool_calls,
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["input_tokens"].as_u64(),
output_tokens: data["usage"]["output_tokens"].as_u64(),

@ -1,7 +1,6 @@
use super::{
catch_error, extract_system_message, json_stream, message::*, CohereClient, CompletionOutput,
ExtraConfig, Model, ModelData, PromptAction, PromptKind, SendData, SseHandler, ToolCall,
ToolCallResult,
};
use anyhow::{bail, Result};
@ -103,67 +102,44 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
let system_message = extract_system_message(&mut messages);
let mut image_urls = vec![];
let mut tool_calls: Vec<MessageToolCall> = vec![];
let mut tool_call_results: Vec<ToolCallResult> = vec![];
let mut tool_results = None;
let mut messages: Vec<Value> = messages
.into_iter()
.filter_map(|message| {
if message.role == MessageRole::Tool {
if let MessageContent::ToolCall(result) = message.content {
tool_call_results.push(result);
let Message { role, content } = message;
let role = match role {
MessageRole::User => "USER",
_ => "CHATBOT",
};
match content {
MessageContent::Text(text) => Some(json!({
"role": role,
"message": text,
})),
MessageContent::Array(list) => {
let list: Vec<String> = list
.into_iter()
.filter_map(|item| match item {
MessageContentPart::Text { text } => Some(text),
MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} => {
image_urls.push(url.clone());
None
}
})
.collect();
Some(json!({ "role": role, "message": list.join("\n\n") }))
}
None
} else if !message.tool_calls.is_empty() {
tool_calls = message.tool_calls;
None
} else {
let role = match message.role {
MessageRole::User => "USER",
_ => "CHATBOT",
};
match message.content {
MessageContent::Text(text) => Some(json!({
"role": role,
"message": text,
})),
MessageContent::Array(list) => {
let list: Vec<String> = list
.into_iter()
.filter_map(|item| match item {
MessageContentPart::Text { text } => Some(text),
MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} => {
image_urls.push(url.clone());
None
}
})
.collect();
Some(json!({ "role": role, "message": list.join("\n\n") }))
}
MessageContent::ToolCall(_) => None,
MessageContent::ToolResults((tool_call_results, _)) => {
tool_results = Some(tool_call_results);
None
}
}
})
.collect();
let tool_results: Vec<Value> = tool_calls
.into_iter()
.zip(tool_call_results)
.map(|(tool_call, tool_call_result)| {
json!({
"call": {
"name": tool_call.function.name,
"parameters": tool_call.function.arguments,
},
"outputs": [
tool_call_result.output,
]
})
})
.collect();
if !image_urls.is_empty() {
bail!("The model does not support images: {:?}", image_urls);
}
@ -175,6 +151,25 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
"message": message,
});
if let Some(tool_results) = tool_results {
let tool_results: Vec<_> = tool_results
.into_iter()
.map(|tool_call_result| {
json!({
"call": {
"name": tool_call_result.call.name,
"parameters": tool_call_result.call.arguments,
},
"outputs": [
tool_call_result.output,
]
})
})
.collect();
body["tool_results"] = json!(tool_results);
}
if let Some(v) = system_message {
body["preamble"] = v.into();
}
@ -219,17 +214,15 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
})
.collect();
}
if !tool_results.is_empty() {
body["tool_results"] = json!(tool_results);
}
Ok(body)
}
fn extract_completion(data: &Value) -> Result<CompletionOutput> {
let text = data["text"].as_str().unwrap_or_default();
let tool_calls = if let Some(tool_calls) = data["tool_calls"].as_array() {
tool_calls
let mut tool_calls = vec![];
if let Some(calls) = data["tool_calls"].as_array() {
tool_calls = calls
.iter()
.filter_map(|call| {
if let (Some(name), Some(parameters)) =
@ -241,9 +234,8 @@ fn extract_completion(data: &Value) -> Result<CompletionOutput> {
}
})
.collect()
} else {
vec![]
};
}
if text.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}");
}

@ -445,12 +445,10 @@ pub async fn send_stream(
let (output, calls) = handler.take();
match send_ret {
Ok(_) => {
let not_tool_call = calls.is_empty();
let tool_call_results = run_tool_calls(config, calls)?;
if not_tool_call && !output.ends_with('\n') {
if !output.is_empty() && !output.ends_with('\n') {
println!();
}
Ok((output, tool_call_results))
Ok((output, run_tool_calls(config, calls)?))
}
Err(err) => {
if !output.is_empty() {

@ -1,4 +1,4 @@
use crate::function::ToolCallResult;
use super::ToolResults;
use serde::{Deserialize, Serialize};
@ -6,12 +6,6 @@ use serde::{Deserialize, Serialize};
pub struct Message {
pub role: MessageRole,
pub content: MessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tool_calls: Vec<MessageToolCall>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
impl Default for Message {
@ -19,20 +13,13 @@ impl Default for Message {
Self {
role: MessageRole::User,
content: MessageContent::Text(String::new()),
name: None,
tool_calls: Default::default(),
tool_call_id: None,
}
}
}
impl Message {
pub fn new(role: MessageRole, content: MessageContent) -> Self {
Self {
role,
content,
..Default::default()
}
Self { role, content }
}
}
@ -42,7 +29,6 @@ pub enum MessageRole {
System,
Assistant,
User,
Tool,
}
#[allow(dead_code)]
@ -62,7 +48,7 @@ pub enum MessageContent {
Text(String),
Array(Vec<MessageContentPart>),
// Note: This type is primarily for convenience and does not exist in OpenAI's API.
ToolCall(ToolCallResult),
ToolResults(ToolResults),
}
impl MessageContent {
@ -86,7 +72,7 @@ impl MessageContent {
}
format!(".file {}{}", files.join(" "), concated_text)
}
MessageContent::ToolCall(_) => String::new(),
MessageContent::ToolResults(_) => String::new(),
}
}
@ -102,7 +88,7 @@ impl MessageContent {
*text = replace_fn(text)
}
}
MessageContent::ToolCall(_) => {}
MessageContent::ToolResults(_) => {}
}
}
@ -118,7 +104,7 @@ impl MessageContent {
}
parts.join("\n\n")
}
MessageContent::ToolCall(_) => String::new(),
MessageContent::ToolResults(_) => String::new(),
}
}
}
@ -135,20 +121,6 @@ pub struct ImageUrl {
pub url: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageToolCall {
pub id: Option<String>,
#[serde(rename = "type")]
pub typ: String,
pub function: MessageToolCallFunction,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageToolCallFunction {
pub name: String,
pub arguments: serde_json::Value,
}
pub fn patch_system_message(messages: &mut Vec<Message>) {
if messages[0].role.is_system() {
let system_message = messages.remove(0);

@ -6,7 +6,7 @@ mod model;
mod prompt_format;
mod sse_handler;
pub use crate::function::{ToolCall, ToolCallResult};
pub use crate::function::{ToolCall, ToolResults};
pub use crate::utils::PromptKind;
pub use common::*;
pub use message::*;

@ -164,7 +164,7 @@ impl Model {
.map(|v| match &v.content {
MessageContent::Text(text) => estimate_token_length(text),
MessageContent::Array(_) => 0,
MessageContent::ToolCall(_) => 0,
MessageContent::ToolResults(_) => 0,
})
.sum()
}

@ -145,8 +145,8 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
}
let content = content.join("\n\n");
json!({ "role": role, "content": content, "images": images })
},
MessageContent::ToolCall(_) => {
}
MessageContent::ToolResults(_) => {
is_tool_call = true;
json!({ "role": role })
}

@ -1,6 +1,6 @@
use super::{
catch_error, sse_stream, CompletionOutput, message::*, ExtraConfig, Model, ModelData, OpenAIClient,
PromptAction, PromptKind, SendData, SsMmessage, SseHandler, ToolCall,
catch_error, message::*, sse_stream, CompletionOutput, ExtraConfig, Model, ModelData,
OpenAIClient, PromptAction, PromptKind, SendData, SsMmessage, SseHandler, ToolCall,
};
use anyhow::{bail, Result};
@ -127,19 +127,38 @@ pub fn openai_build_body(data: SendData, model: &Model) -> Value {
stream,
} = data;
let messages: Vec<Value> = messages
.into_iter()
.map(|message| {
let mut new_message = json!(&message);
let content = match message.content {
MessageContent::ToolCall(result) => {
MessageContent::Text(json!(result.output).to_string())
.flat_map(|message| {
let Message { role, content } = message;
match content {
MessageContent::ToolResults((tool_call_results, text)) => {
let tool_calls: Vec<_> = tool_call_results.iter().map(|tool_call_result| {
json!({
"id": tool_call_result.call.id,
"type": "function",
"function": {
"name": tool_call_result.call.name,
"arguments": tool_call_result.call.arguments,
},
})
}).collect();
let mut messages = vec![
json!({ "role": MessageRole::Assistant, "content": text, "tool_calls": tool_calls })
];
for tool_call_result in tool_call_results {
messages.push(
json!({
"role": "tool",
"content": tool_call_result.output.to_string(),
"tool_call_id": tool_call_result.call.id,
})
);
}
messages
},
_ => message.content,
};
new_message["content"] = json!(content);
new_message
_ => vec![json!({ "role": role, "content": content })]
}
})
.collect();
@ -180,29 +199,28 @@ pub fn openai_extract_completion(data: &Value) -> Result<CompletionOutput> {
.as_str()
.unwrap_or_default();
let tool_calls =
if let Some(tools_call) = data["choices"][0]["message"]["tool_calls"].as_array() {
tools_call
.iter()
.filter_map(|call| {
if let (Some(name), Some(arguments), Some(id)) = (
call["function"]["name"].as_str(),
call["function"]["arguments"].as_str(),
call["id"].as_str(),
) {
Some(ToolCall::new(
name.to_string(),
json!(arguments),
Some(id.to_string()),
))
} else {
None
}
})
.collect()
} else {
vec![]
};
let mut tool_calls = vec![];
if let Some(tools_call) = data["choices"][0]["message"]["tool_calls"].as_array() {
tool_calls = tools_call
.iter()
.filter_map(|call| {
if let (Some(name), Some(arguments), Some(id)) = (
call["function"]["name"].as_str(),
call["function"]["arguments"].as_str(),
call["id"].as_str(),
) {
Some(ToolCall::new(
name.to_string(),
json!(arguments),
Some(id.to_string()),
))
} else {
None
}
})
.collect()
};
if text.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}");
}

@ -108,7 +108,7 @@ pub fn generate_prompt(messages: &[Message], format: PromptFormat) -> anyhow::Re
}
parts.join("\n\n")
}
MessageContent::ToolCall(_) => String::new(),
MessageContent::ToolResults(_) => String::new(),
};
match role {
MessageRole::System => prompt.push_str(&format!(
@ -120,7 +120,6 @@ pub fn generate_prompt(messages: &[Message], format: PromptFormat) -> anyhow::Re
MessageRole::User => {
prompt.push_str(&format!("{user_pre_message}{content}{user_post_message}"))
}
MessageRole::Tool => {}
}
}
if !image_urls.is_empty() {

@ -158,7 +158,7 @@ fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool
}
})
.collect(),
MessageContent::ToolCall(_) => {
MessageContent::ToolResults(_) => {
is_tool_call = true;
vec![]
}

@ -139,8 +139,9 @@ fn gemini_extract_completion_text(data: &Value) -> Result<CompletionOutput> {
.as_str()
.unwrap_or_default();
let tool_calls = if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() {
parts
let mut tool_calls = vec![];
if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() {
tool_calls = parts
.iter()
.filter_map(|part| {
if let (Some(name), Some(args)) = (
@ -153,9 +154,7 @@ fn gemini_extract_completion_text(data: &Value) -> Result<CompletionOutput> {
}
})
.collect()
} else {
vec![]
};
}
if text.is_empty() && tool_calls.is_empty() {
if let Some("SAFETY") = data["promptFeedback"]["blockReason"]
.as_str()
@ -194,27 +193,17 @@ pub(crate) fn gemini_build_body(
let mut network_image_urls = vec![];
let contents: Vec<Value> = messages
.into_iter()
.map(|message| {
let role = match message.role {
.flat_map(|message| {
let Message { role, content } = message;
let role = match role {
MessageRole::User => "user",
_ => "model",
};
if !message.tool_calls.is_empty() {
let parts: Vec<Value> = message.tool_calls.iter().map(|tool_call| {
json!({
"functionCall": {
"name": tool_call.function.name,
"args": tool_call.function.arguments,
}
})
}).collect();
json!({ "role": role, "parts": parts })
} else {
match message.content {
MessageContent::Text(text) => json!({
match content {
MessageContent::Text(text) => vec![json!({
"role": role,
"parts": [{ "text": text }]
}),
})],
MessageContent::Array(list) => {
let parts: Vec<Value> = list
.into_iter()
@ -230,24 +219,34 @@ pub(crate) fn gemini_build_body(
},
})
.collect();
json!({ "role": role, "parts": parts })
vec![json!({ "role": role, "parts": parts })]
},
MessageContent::ToolCall(result) => {
let parts = vec![
MessageContent::ToolResults((tool_call_results, _)) => {
let function_call_parts: Vec<Value> = tool_call_results.iter().map(|tool_call_result| {
json!({
"functionCall": {
"name": tool_call_result.call.name,
"args": tool_call_result.call.arguments,
}
})
}).collect();
let function_response_parts: Vec<Value> = tool_call_results.into_iter().map(|tool_call_result| {
json!({
"functionResponse": {
"name": result.call.name,
"name": tool_call_result.call.name,
"response": {
"name": result.call.name,
"content": result.output,
"name": tool_call_result.call.name,
"content": tool_call_result.output,
}
}
})
];
json!({ "role": role, "parts": parts })
}).collect();
vec![
json!({ "role": "model", "parts": function_call_parts }),
json!({ "role": "function", "parts": function_response_parts }),
]
}
}
}
})
.collect();

@ -4,7 +4,7 @@ use crate::client::{
init_client, list_models, Client, ImageUrl, Message, MessageContent, MessageContentPart,
MessageRole, Model, SendData,
};
use crate::function::ToolCallResult;
use crate::function::{ToolCallResult, ToolResults};
use crate::utils::{base64_encode, sha256};
use anyhow::{bail, Context, Result};
@ -31,7 +31,7 @@ pub struct Input {
text: String,
medias: Vec<String>,
data_urls: HashMap<String, String>,
tool_call_results: Vec<ToolCallResult>,
tool_call: Option<ToolResults>,
context: InputContext,
}
@ -42,7 +42,7 @@ impl Input {
text: text.to_string(),
medias: Default::default(),
data_urls: Default::default(),
tool_call_results: Default::default(),
tool_call: None,
context: context.unwrap_or_else(|| InputContext::from_config(config)),
}
}
@ -94,7 +94,7 @@ impl Input {
text: texts.join("\n"),
medias,
data_urls,
tool_call_results: Default::default(),
tool_call: Default::default(),
context: context.unwrap_or_else(|| InputContext::from_config(config)),
})
}
@ -115,15 +115,15 @@ impl Input {
self.text = text;
}
pub fn tool_call(mut self, tool_call_results: Vec<ToolCallResult>) -> Self {
self.tool_call_results = tool_call_results;
pub fn merge_tool_call(
mut self,
output: String,
tool_call_results: Vec<ToolCallResult>,
) -> Self {
self.tool_call = Some((tool_call_results, output));
self
}
pub fn is_tool_call(&self) -> bool {
!self.tool_call_results.is_empty()
}
pub fn model(&self) -> Model {
let model = self.config.read().model.clone();
if let Some(model_id) = self.role().and_then(|v| v.model_id.clone()) {
@ -185,14 +185,14 @@ impl Input {
} else if let Some(role) = self.role() {
role.build_messages(self)
} else {
let message = Message {
role: MessageRole::User,
content: self.message_content(),
..Default::default()
};
vec![message]
vec![Message::new(MessageRole::User, self.message_content())]
};
messages.extend(self.tool_messages());
if let Some(tool_results) = &self.tool_call {
messages.push(Message::new(
MessageRole::Assistant,
MessageContent::ToolResults(tool_results.clone()),
))
}
Ok(messages)
}
@ -291,30 +291,6 @@ impl Input {
MessageContent::Array(list)
}
}
pub fn tool_messages(&self) -> Vec<Message> {
if !self.is_tool_call() {
return vec![];
}
let mut messages = vec![Message {
role: MessageRole::Assistant,
content: MessageContent::Text(String::new()),
tool_calls: self
.tool_call_results
.iter()
.map(|v| v.build_message())
.collect(),
..Default::default()
}];
messages.extend(self.tool_call_results.iter().map(|tool_call| Message {
role: MessageRole::Tool,
content: MessageContent::ToolCall(tool_call.clone()),
name: Some(tool_call.call.name.clone()),
tool_calls: Default::default(),
tool_call_id: tool_call.call.id.clone(),
}));
messages
}
}
#[derive(Debug, Clone, Default)]

@ -213,7 +213,6 @@ impl Session {
message.content.render_input(resolve_url_fn)
));
}
MessageRole::Tool => {}
}
}
}

@ -1,5 +1,4 @@
use crate::{
client::{MessageToolCall, MessageToolCallFunction},
config::GlobalConfig,
utils::{dimmed_text, error_text, exec_command, spawn_command},
};
@ -22,6 +21,8 @@ lazy_static! {
static ref THREAD_POOL: ThreadPool = ThreadPool::new(num_cpus::get());
}
pub type ToolResults = (Vec<ToolCallResult>, String);
pub fn run_tool_calls(config: &GlobalConfig, calls: Vec<ToolCall>) -> Result<Vec<ToolCallResult>> {
let mut output = vec![];
if calls.is_empty() {
@ -68,17 +69,6 @@ impl ToolCallResult {
pub fn new(call: ToolCall, output: Value) -> Self {
Self { call, output }
}
pub fn build_message(&self) -> MessageToolCall {
MessageToolCall {
id: self.call.id.clone(),
typ: "function".into(),
function: MessageToolCallFunction {
name: self.call.name.clone(),
arguments: self.call.arguments.clone(),
},
}
}
}
#[derive(Debug, Clone, Default)]

@ -177,7 +177,7 @@ async fn start_directive(
if !tool_call_results.is_empty() {
start_directive(
config,
Input::tool_call(input, tool_call_results),
input.merge_tool_call(output, tool_call_results),
no_stream,
code_mode,
)

@ -420,7 +420,12 @@ async fn ask(config: &GlobalConfig, abort: AbortSignal, input: Input) -> Result<
if tool_call_results.is_empty() {
Ok(())
} else {
ask(config, abort, input.tool_call(tool_call_results)).await
ask(
config,
abort,
input.merge_tool_call(output, tool_call_results),
)
.await
}
}

Loading…
Cancel
Save