feat: support function calling (#514)

* feat: support function calling

* fix on Windows OS

* implement multi-steps function calling

* fix on Windows OS

* add error for client not support function calling

* refactor message data structure and make claude client supporting function calling

* support reuse previous call results

* improve error handling for function calling

* use prefix `may_`  as indicator for `execute` type fucntions
pull/518/head
sigoden 2 weeks ago committed by GitHub
parent 1348a62e5f
commit b4a40e3fed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

3
.gitignore vendored

@ -1,2 +1,3 @@
/target
/tmp
/tmp
*.log

@ -24,6 +24,34 @@ test-platform-env() {
cargo run -- "$@"
}
# @cmd Test function calling
# @option --model[?`_choice_model`]
# @option --preset[=default|weather|multi-weathers]
# @flag -S --no-stream
# @arg text~
test-function-calling() {
args=(--role %functions%)
if [[ -n "$argc_model" ]]; then
args+=("--model" "$argc_model")
fi
if [[ -n "$argc_no_stream" ]]; then
args+=("-S")
fi
if [[ -z "$argc_text" ]]; then
case "$argc_preset" in
multi-weathers)
text="what is the weather in London and Pairs?"
;;
weather|*)
text="what is the weather in London?"
;;
esac
else
text="${argc_text[*]}"
fi
cargo run -- "${args[@]}" "$text"
}
# @cmd Test clients
# @arg clients+[`_choice_client`]
test-clients() {
@ -36,7 +64,7 @@ test-clients() {
}
# @cmd Test proxy server
# @option -m --model[`_choice_model`]
# @option -m --model[?`_choice_model`]
# @flag -S --no-stream
# @arg text~
test-server() {
@ -153,10 +181,7 @@ chat-gemini() {
_wrapper curl -i "https://generativelanguage.googleapis.com/v1beta/models/${argc_model}:${method}?key=${GEMINI_API_KEY}" \
-i -X POST \
-H 'Content-Type: application/json' \
-d '{
"safetySettings":[{"category":"HARM_CATEGORY_HARASSMENT","threshold":"BLOCK_ONLY_HIGH"},{"category":"HARM_CATEGORY_HATE_SPEECH","threshold":"BLOCK_ONLY_HIGH"},{"category":"HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold":"BLOCK_ONLY_HIGH"},{"category":"HARM_CATEGORY_DANGEROUS_CONTENT","threshold":"BLOCK_ONLY_HIGH"}],
"contents": '"$(_build_msg_gemini $*)"'
}'
-d "$(_build_body gemini "$@")"
}
# @cmd List gemini models
@ -177,14 +202,9 @@ chat-claude() {
-X POST \
-H 'content-type: application/json' \
-H 'anthropic-version: 2023-06-01' \
-H 'anthropic-beta: tools-2024-05-16' \
-H "x-api-key: $CLAUDE_API_KEY" \
-d '{
"model": "'$argc_model'",
"messages": '"$(_build_msg $*)"',
"max_tokens": 4096,
"stream": '$stream'
}
'
-d "$(_build_body claude "$@")"
}
# @cmd Chat with cohere api
@ -221,11 +241,7 @@ chat-ollama() {
_wrapper curl -i http://localhost:11434/api/chat \
-X POST \
-H 'Content-Type: application/json' \
-d '{
"model": "'$argc_model'",
"stream": '$stream',
"messages": '"$(_build_msg $*)"'
}'
-d "$(_build_body ollama "$@")"
}
# @cmd Chat with vertexai api
@ -246,10 +262,7 @@ chat-vertexai() {
-X POST \
-H "Authorization: Bearer $api_key" \
-H 'Content-Type: application/json' \
-d '{
"contents": '"$(_build_msg_gemini $*)"',
"generationConfig": {}
}'
-d "$(_build_body vertexai "$@")"
}
# @cmd Chat with vertexai-claude api
@ -266,12 +279,7 @@ chat-vertexai-claude() {
-X POST \
-H "Authorization: Bearer $api_key" \
-H 'Content-Type: application/json' \
-d '{
"anthropic_version": "vertex-2023-10-16",
"messages": '"$(_build_msg $*)"',
"max_tokens": 4096,
"stream": '$stream'
}'
-d "$(_build_body vertexai-claude "$@")"
}
# @cmd Chat with bedrock api
@ -285,11 +293,7 @@ chat-bedrock() {
body='{"prompt":"'"$*"'"}'
;;
anthropic.*)
body='{
"anthropic_version": "vertex-2023-10-16",
"messages": '"$(_build_msg $*)"',
"max_tokens": 4096
}'
body="$(_build_body bedrock-claude "$@")"
;;
*)
_die "Invalid model: $argc_model"
@ -314,10 +318,7 @@ chat-cloudflare() {
_wrapper curl -i "$url" \
-X POST \
-H "Authorization: Bearer $CLOUDFLARE_API_KEY" \
-d '{
"messages": '"$(_build_msg $*)"',
"stream": '$stream'
}'
-d "$(_build_body cloudflare "$@")"
}
# @cmd Chat with replicate api
@ -331,12 +332,8 @@ chat-replicate() {
-X POST \
-H "Authorization: Bearer $REPLICATE_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"stream": '$stream',
"input": {
"prompt": "'"$*"'"
}
}')"
-d "$(_build_body replicate "$@")" \
)"
echo "$res"
if [[ -n "$argc_no_stream" ]]; then
prediction_url="$(echo "$res" | jq -r '.urls.get')"
@ -373,10 +370,7 @@ chat-ernie() {
url="https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/$argc_model?access_token=$ACCESS_TOKEN"
_wrapper curl -i "$url" \
-X POST \
-d '{
"messages": '"$(_build_msg $*)"',
"stream": '$stream'
}'
-d "$(_build_body ernie "$@")"
}
@ -397,13 +391,7 @@ chat-qianwen() {
-X POST \
-H "Authorization: Bearer $QIANWEN_API_KEY" \
-H 'Content-Type: application/json' $stream_args \
-d '{
"model": "'$argc_model'",
"parameters": '"$parameters_args"',
"input":{
"messages": '"$(_build_msg $*)"'
}
}'
-d "$(_build_body qianwen "$@")"
}
_argc_before() {
@ -420,12 +408,7 @@ _openai_chat() {
-X POST \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $api_key" \
--data '{
"model": "'$argc_model'",
"messages": '"$(_build_msg $*)"',
"stream": '$stream'
}
'
-d "$(_build_body openai "$@")"
}
_openai_models() {
@ -460,35 +443,112 @@ _choice_openai_compatible_platform() {
done
}
_build_msg() {
if [[ $# -eq 0 ]]; then
cat tmp/messages.json
else
echo '
[
{
"role": "user",
"content": "'"$*"'"
}
]
'
fi
}
_build_body() {
kind="$1"
if [[ "$#" -eq 1 ]]; then
file="${BODY_FILE:-"tmp/body/$1.json"}"
if [[ -f "$file" ]]; then
cat "$file" | \
sed \
-e 's/"model": ".*"/"model": "'"$argc_model"'"/' \
-e 's/"stream": \(true\|false\)/"stream": '$stream'/' \
_build_msg_gemini() {
if [[ $# -eq 0 ]]; then
cat tmp/messages.gemini.json
fi
else
echo '
[{
"role": "user",
"parts": [
shift
case "$kind" in
openai|ollama)
echo '{
"model": "'$argc_model'",
"messages": [
{
"text": "'"$*"'"
"role": "user",
"content": "'"$*"'"
}
]
}]
'
],
"stream": '$stream'
}'
;;
claude)
echo '{
"model": "'$argc_model'",
"messages": [
{
"role": "user",
"content": "'"$*"'"
}
],
"max_tokens": 4096,
"stream": '$stream'
}'
;;
vertexai-claude|bedrock-claude)
echo '{
"anthropic_version": "vertex-2023-10-16",
"messages": [
{
"role": "user",
"content": "'"$*"'"
}
],
"max_tokens": 4096,
"stream": '$stream'
}'
;;
gemini|vertexai)
echo '{
"contents": [{
"role": "user",
"parts": [
{
"text": "'"$*"'"
}
]
}],
"safetySettings":[{"category":"HARM_CATEGORY_HARASSMENT","threshold":"BLOCK_ONLY_HIGH"},{"category":"HARM_CATEGORY_HATE_SPEECH","threshold":"BLOCK_ONLY_HIGH"},{"category":"HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold":"BLOCK_ONLY_HIGH"},{"category":"HARM_CATEGORY_DANGEROUS_CONTENT","threshold":"BLOCK_ONLY_HIGH"}]
}'
;;
ernie|cloudflare)
echo '{
"messages": [
{
"role": "user",
"content": "'"$*"'"
}
],
"stream": '$stream'
}'
;;
replicate)
echo '{
"stream": '$stream',
"input": {
"prompt": "'"$*"'"
}
}'
;;
qianwen)
echo '{
"model": "'$argc_model'",
"parameters": '"$parameters_args"',
"input":{
"messages": [
{
"role": "user",
"content": "'"$*"'"
}
]
}
}'
;;
*)
_die "Unsupported build body for $kind"
;;
esac
fi
}

13
Cargo.lock generated

@ -38,7 +38,6 @@ dependencies = [
"aws-smithy-eventstream",
"base64 0.22.1",
"bincode",
"bitflags 2.5.0",
"bstr",
"bytes",
"chrono",
@ -59,6 +58,7 @@ dependencies = [
"log",
"mime_guess",
"nu-ansi-term 0.50.0",
"num_cpus",
"parking_lot",
"reedline",
"reqwest",
@ -72,6 +72,7 @@ dependencies = [
"simplelog",
"syntect",
"textwrap",
"threadpool",
"time",
"tokio",
"tokio-graceful",
@ -1071,6 +1072,7 @@ checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26"
dependencies = [
"equivalent",
"hashbrown",
"serde",
]
[[package]]
@ -2229,6 +2231,15 @@ dependencies = [
"once_cell",
]
[[package]]
name = "threadpool"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d050e60b33d41c19108b32cea32164033a9013fe3b46cbd4457559bfbf77afaa"
dependencies = [
"num_cpus",
]
[[package]]
name = "time"
version = "0.3.36"

@ -44,19 +44,20 @@ log = "0.4.20"
shell-words = "1.1.0"
mime_guess = "2.0.4"
sha2 = "0.10.8"
bitflags = "2.4.1"
unicode-width = "0.1.11"
async-recursion = "1.1.0"
async-recursion = "1.1.1"
http = "1.1.0"
http-body-util = "0.1"
hyper = { version = "1.0", features = ["full"] }
hyper-util = { version = "0.1", features = ["server-auto", "client-legacy"] }
time = { version = "0.3.36", features = ["macros"] }
indexmap = "2.2.6"
indexmap = { version = "2.2.6", features = ["serde"] }
hmac = "0.12.1"
aws-smithy-eventstream = "0.60.4"
urlencoding = "2.1.3"
unicode-segmentation = "1.11.0"
num_cpus = "1.16.0"
threadpool = "1.8.1"
[dependencies.reqwest]
version = "0.12.0"

@ -15,6 +15,9 @@ prelude: null # Set a default role or session to start with (
# if unset fallback to $EDITOR and $VISUAL
buffer_editor: null
# Controls the function calling feature. For setup instructions, visit https://github.com/sigoden/llm-functions.
function_calling: false
# Compress session when token count reaches or exceeds this threshold (must be at least 1000)
compress_threshold: 1000
# Text prompt used for creating a concise summary of session message

@ -15,33 +15,39 @@
max_output_tokens: 4096
input_price: 0.5
output_price: 1.5
supports_function_calling: true
- name: gpt-3.5-turbo-1106
max_input_tokens: 16385
max_output_tokens: 4096
input_price: 1
output_price: 2
supports_function_calling: true
- name: gpt-4o
max_input_tokens: 128000
max_output_tokens: 4096
input_price: 5
output_price: 15
supports_vision: true
supports_function_calling: true
- name: gpt-4-turbo
max_input_tokens: 128000
max_output_tokens: 4096
input_price: 10
output_price: 30
supports_vision: true
supports_function_calling: true
- name: gpt-4-turbo-preview
max_input_tokens: 128000
max_output_tokens: 4096
input_price: 10
output_price: 30
supports_function_calling: true
- name: gpt-4-1106-preview
max_input_tokens: 128000
max_output_tokens: 4096
input_price: 10
output_price: 30
supports_function_calling: true
- name: gpt-4-vision-preview
max_input_tokens: 128000
max_output_tokens: 4096
@ -73,6 +79,7 @@
max_output_tokens: 2048
input_price: 0.5
output_price: 1.5
supports_function_calling: true
- name: gemini-1.0-pro-vision-latest
max_input_tokens: 12288
max_output_tokens: 4096
@ -85,12 +92,14 @@
input_price: 0.35
output_price: 0.53
supports_vision: true
supports_function_calling: true
- name: gemini-1.5-pro-latest
max_input_tokens: 1048576
max_output_tokens: 8192
input_price: 3.5
output_price: 10.5
supports_vision: true
supports_function_calling: true
- platform: claude
# docs:
@ -106,6 +115,7 @@
input_price: 15
output_price: 75
supports_vision: true
supports_function_calling: true
- name: claude-3-sonnet-20240229
max_input_tokens: 200000
max_output_tokens: 4096
@ -113,6 +123,7 @@
input_price: 3
output_price: 15
supports_vision: true
supports_function_calling: true
- name: claude-3-haiku-20240307
max_input_tokens: 200000
max_output_tokens: 4096
@ -120,6 +131,7 @@
input_price: 0.25
output_price: 1.25
supports_vision: true
supports_function_calling: true
- platform: mistral
# docs:
@ -149,6 +161,7 @@
max_input_tokens: 32000
input_price: 8
output_price: 24
supports_function_calling: true
- platform: cohere
# docs:
@ -163,11 +176,13 @@
max_output_tokens: 4000
input_price: 0.5
output_price: 1.5
supports_function_calling: true
- name: command-r-plus
max_input_tokens: 128000
max_output_tokens: 4000
input_price: 3
output_price: 15
supports_function_calling: true
- platform: perplexity
# docs:
@ -242,12 +257,13 @@
# notes:
# - get max_output_tokens info from models doc
models:
- name: gemini-1.0-pro
- name: gemini-1.0-pro-002
max_input_tokens: 24568
max_output_tokens: 8192
input_price: 0.125
output_price: 0.375
- name: gemini-1.0-pro-vision
supports_function_calling: true
- name: gemini-1.0-pro-vision-001
max_input_tokens: 14336
max_output_tokens: 2048
input_price: 0.125
@ -387,6 +403,7 @@
# docs:
# - https://replicate.com/docs
# - https://replicate.com/pricing
# - https://replicate.com/docs/reference/http
# notes:
# - max_output_tokens is required but unknown
models:
@ -695,20 +712,24 @@
max_input_tokens: 16385
input_price: 0.5
output_price: 1.5
supports_function_calling: true
- name: openai/gpt-4o
max_input_tokens: 128000
input_price: 5
output_price: 15
supports_vision: true
supports_function_calling: true
- name: openai/gpt-4-turbo
max_input_tokens: 128000
input_price: 10
output_price: 30
supports_vision: true
supports_function_calling: true
- name: openai/gpt-4-turbo-preview
max_input_tokens: 128000
input_price: 10
output_price: 30
supports_function_calling: true
- name: openai/gpt-4-vision-preview
max_input_tokens: 128000
max_output_tokens: 4096

@ -1,7 +1,5 @@
use super::openai::openai_build_body;
use super::{
AzureOpenAIClient, ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData,
};
use super::{AzureOpenAIClient, ExtraConfig, Model, ModelData, PromptAction, PromptKind, SendData};
use anyhow::Result;
use reqwest::{Client as ReqwestClient, RequestBuilder};
@ -12,7 +10,7 @@ pub struct AzureOpenAIConfig {
pub name: Option<String>,
pub api_base: Option<String>,
pub api_key: Option<String>,
pub models: Vec<ModelConfig>,
pub models: Vec<ModelData>,
pub extra: Option<ExtraConfig>,
}
@ -41,7 +39,8 @@ impl AzureOpenAIClient {
let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version=2024-02-01",
&api_base, self.model.name
&api_base,
self.model.name()
);
debug!("AzureOpenAI Request: {url} {body}");

@ -1,8 +1,8 @@
use super::claude::{claude_build_body, claude_extract_completion};
use super::{
catch_error, generate_prompt, BedrockClient, Client, CompletionDetails, ExtraConfig, Model,
ModelConfig, PromptAction, PromptFormat, PromptKind, SendData, SseHandler,
LLAMA3_PROMPT_FORMAT, MISTRAL_PROMPT_FORMAT,
catch_error, generate_prompt, BedrockClient, Client, CompletionOutput, ExtraConfig, Model,
ModelData, PromptAction, PromptFormat, PromptKind, SendData, SseHandler, LLAMA3_PROMPT_FORMAT,
MISTRAL_PROMPT_FORMAT,
};
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256};
@ -30,7 +30,7 @@ pub struct BedrockConfig {
pub secret_access_key: Option<String>,
pub region: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub models: Vec<ModelData>,
pub extra: Option<ExtraConfig>,
}
@ -42,8 +42,8 @@ impl Client for BedrockClient {
&self,
client: &ReqwestClient,
data: SendData,
) -> Result<(String, CompletionDetails)> {
let model_category = ModelCategory::from_str(&self.model.name)?;
) -> Result<CompletionOutput> {
let model_category = ModelCategory::from_str(self.model.name())?;
let builder = self.request_builder(client, data, &model_category)?;
send_message(builder, &model_category).await
}
@ -54,7 +54,7 @@ impl Client for BedrockClient {
handler: &mut SseHandler,
data: SendData,
) -> Result<()> {
let model_category = ModelCategory::from_str(&self.model.name)?;
let model_category = ModelCategory::from_str(self.model.name())?;
let builder = self.request_builder(client, data, &model_category)?;
send_message_streaming(builder, handler, &model_category).await
}
@ -91,7 +91,7 @@ impl BedrockClient {
let secret_access_key = self.get_secret_access_key()?;
let region = self.get_region()?;
let model_name = &self.model.name;
let model_name = &self.model.name();
let uri = if data.stream {
format!("/model/{model_name}/invoke-with-response-stream")
} else {
@ -129,7 +129,7 @@ impl BedrockClient {
async fn send_message(
builder: RequestBuilder,
model_category: &ModelCategory,
) -> Result<(String, CompletionDetails)> {
) -> Result<CompletionOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -138,6 +138,7 @@ async fn send_message(
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
match model_category {
ModelCategory::Anthropic => claude_extract_completion(&data),
ModelCategory::MetaLlama3 => llama_extract_completion(&data),
@ -172,7 +173,7 @@ async fn send_message_streaming(
let data: Value = decode_chunk(message.payload()).ok_or_else(|| {
anyhow!("Invalid chunk data: {}", hex_encode(message.payload()))
})?;
// debug!("bedrock chunk: {data}");
debug!("stream-data: {data}");
match model_category {
ModelCategory::Anthropic => {
if let Some(typ) = data["type"].as_str() {
@ -230,6 +231,7 @@ fn meta_llama_build_body(data: SendData, model: &Model, pt: PromptFormat) -> Res
messages,
temperature,
top_p,
functions: _,
stream: _,
} = data;
let prompt = generate_prompt(&messages, pt)?;
@ -253,6 +255,7 @@ fn mistral_build_body(data: SendData, model: &Model) -> Result<Value> {
messages,
temperature,
top_p,
functions: _,
stream: _,
} = data;
let prompt = generate_prompt(&messages, MISTRAL_PROMPT_FORMAT)?;
@ -271,23 +274,25 @@ fn mistral_build_body(data: SendData, model: &Model) -> Result<Value> {
Ok(body)
}
fn llama_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
fn llama_extract_completion(data: &Value) -> Result<CompletionOutput> {
let text = data["generation"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let details = CompletionDetails {
let output = CompletionOutput {
text: text.to_string(),
tool_calls: vec![],
id: None,
input_tokens: data["prompt_token_count"].as_u64(),
output_tokens: data["generation_token_count"].as_u64(),
};
Ok((text.to_string(), details))
Ok(output)
}
fn mistral_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
fn mistral_extract_completion(data: &Value) -> Result<CompletionOutput> {
let text = data["outputs"][0]["text"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
Ok((text.to_string(), CompletionDetails::default()))
Ok(CompletionOutput::new(text))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]

@ -1,10 +1,10 @@
use super::{
catch_error, extract_system_message, sse_stream, ClaudeClient, CompletionDetails, ExtraConfig,
ImageUrl, MessageContent, MessageContentPart, Model, ModelConfig, PromptAction, PromptKind,
SendData, SsMmessage, SseHandler,
catch_error, extract_system_message, message::*, sse_stream, ClaudeClient, CompletionOutput,
ExtraConfig, ImageUrl, MessageContent, MessageContentPart, Model, ModelData, PromptAction,
PromptKind, SendData, SsMmessage, SseHandler, ToolCall,
};
use anyhow::{anyhow, bail, Result};
use anyhow::{bail, Context, Result};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
@ -16,7 +16,7 @@ pub struct ClaudeConfig {
pub name: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub models: Vec<ModelData>,
pub extra: Option<ExtraConfig>,
}
@ -36,7 +36,9 @@ impl ClaudeClient {
debug!("Claude Request: {url} {body}");
let mut builder = client.post(url).json(&body);
builder = builder.header("anthropic-version", "2023-06-01");
builder = builder
.header("anthropic-version", "2023-06-01")
.header("anthropic-beta", "tools-2024-05-16");
if let Some(api_key) = api_key {
builder = builder.header("x-api-key", api_key)
}
@ -51,13 +53,14 @@ impl_client_trait!(
claude_send_message_streaming
);
pub async fn claude_send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
pub async fn claude_send_message(builder: RequestBuilder) -> Result<CompletionOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
claude_extract_completion(&data)
}
@ -65,13 +68,59 @@ pub async fn claude_send_message_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
) -> Result<()> {
let mut function_name = String::new();
let mut function_arguments = String::new();
let mut function_id = String::new();
let handle = |message: SsMmessage| -> Result<bool> {
let data: Value = serde_json::from_str(&message.data)?;
debug!("stream-data: {data}");
if let Some(typ) = data["type"].as_str() {
if typ == "content_block_delta" {
if let Some(text) = data["delta"]["text"].as_str() {
handler.text(text)?;
match typ {
"content_block_start" => {
if let (Some("tool_use"), Some(name), Some(id)) = (
data["content_block"]["type"].as_str(),
data["content_block"]["name"].as_str(),
data["content_block"]["id"].as_str(),
) {
if !function_name.is_empty() {
let arguments: Value =
function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' is invalid: arguments must be in valid JSON format")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
Some(function_id.clone()),
))?;
}
function_name = name.into();
function_arguments.clear();
function_id = id.into();
}
}
"content_block_delta" => {
if let Some(text) = data["delta"]["text"].as_str() {
handler.text(text)?;
} else if let (true, Some(partial_json)) = (
!function_name.is_empty(),
data["delta"]["partial_json"].as_str(),
) {
function_arguments.push_str(partial_json);
}
}
"content_block_stop" => {
if !function_name.is_empty() {
let arguments: Value = function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' is invalid: arguments must be in valid JSON format")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
Some(function_id.clone()),
))?;
}
}
_ => {}
}
}
Ok(false)
@ -85,46 +134,91 @@ pub fn claude_build_body(data: SendData, model: &Model) -> Result<Value> {
mut messages,
temperature,
top_p,
functions,
stream,
} = data;
let system_message = extract_system_message(&mut messages);
let mut network_image_urls = vec![];
let messages: Vec<Value> = messages
.into_iter()
.map(|message| {
let role = message.role;
let content = match message.content {
MessageContent::Text(text) => vec![json!({"type": "text", "text": text})],
MessageContent::Array(list) => list
.into_iter()
.map(|item| match item {
MessageContentPart::Text { text } => json!({"type": "text", "text": text}),
MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} => {
if let Some((mime_type, data)) = url
.strip_prefix("data:")
.and_then(|v| v.split_once(";base64,"))
{
json!({
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": data,
}
})
} else {
network_image_urls.push(url.clone());
json!({ "url": url })
.flat_map(|message| {
let Message { role, content } = message;
match content {
MessageContent::Text(text) => vec![json!({
"role": role,
"content": text,
})],
MessageContent::Array(list) => {
let content: Vec<_> = list
.into_iter()
.map(|item| match item {
MessageContentPart::Text { text } => {
json!({"type": "text", "text": text})
}
}
})
.collect(),
};
json!({ "role": role, "content": content })
MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} => {
if let Some((mime_type, data)) = url
.strip_prefix("data:")
.and_then(|v| v.split_once(";base64,"))
{
json!({
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": data,
}
})
} else {
network_image_urls.push(url.clone());
json!({ "url": url })
}
}
})
.collect();
vec![json!({
"role": role,
"content": content,
})]
}
MessageContent::ToolResults((tool_call_results, text)) => {
let mut tool_call = vec![];
let mut tool_result = vec![];
if !text.is_empty() {
tool_call.push(json!({
"type": "text",
"text": text,
}))
}
for tool_call_result in tool_call_results {
tool_call.push(json!({
"type": "tool_use",
"id": tool_call_result.call.id,
"name": tool_call_result.call.name,
"input": tool_call_result.call.arguments,
}));
tool_result.push(json!({
"type": "tool_result",
"tool_use_id": tool_call_result.call.id,
"content": tool_call_result.output.to_string(),
}));
}
vec![
json!({
"role": "assistant",
"content": tool_call,
}),
json!({
"role": "user",
"content": tool_result,
}),
]
}
}
})
.collect();
@ -136,7 +230,7 @@ pub fn claude_build_body(data: SendData, model: &Model) -> Result<Value> {
}
let mut body = json!({
"model": &model.name,
"model": model.name(),
"messages": messages,
});
if let Some(v) = system_message {
@ -154,18 +248,61 @@ pub fn claude_build_body(data: SendData, model: &Model) -> Result<Value> {
if stream {
body["stream"] = true.into();
}
if let Some(functions) = functions {
body["tools"] = functions
.iter()
.map(|v| {
json!({
"name": v.name,
"description": v.description,
"input_schema": v.parameters,
})
})
.collect();
}
Ok(body)
}
pub fn claude_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
let text = data["content"][0]["text"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
pub fn claude_extract_completion(data: &Value) -> Result<CompletionOutput> {
let text = data["content"][0]["text"].as_str().unwrap_or_default();
let mut tool_calls = vec![];
if let Some(calls) = data["content"].as_array().map(|content| {
content
.iter()
.filter(|content| matches!(content["type"].as_str(), Some("tool_use")))
.collect::<Vec<&Value>>()
}) {
tool_calls = calls
.into_iter()
.filter_map(|call| {
if let (Some(name), Some(input), Some(id)) = (
call["name"].as_str(),
call.get("input"),
call["id"].as_str(),
) {
Some(ToolCall::new(
name.to_string(),
input.clone(),
Some(id.to_string()),
))
} else {
None
}
})
.collect();
};
if text.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}");
}
let details = CompletionDetails {
let output = CompletionOutput {
text: text.to_string(),
tool_calls,
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["input_tokens"].as_u64(),
output_tokens: data["usage"]["output_tokens"].as_u64(),
};
Ok((text.to_string(), details))
Ok(output)
}

@ -1,5 +1,5 @@
use super::{
catch_error, sse_stream, CloudflareClient, CompletionDetails, ExtraConfig, Model, ModelConfig,
catch_error, sse_stream, CloudflareClient, CompletionOutput, ExtraConfig, Model, ModelData,
PromptAction, PromptKind, SendData, SsMmessage, SseHandler,
};
@ -16,7 +16,7 @@ pub struct CloudflareConfig {
pub account_id: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub models: Vec<ModelData>,
pub extra: Option<ExtraConfig>,
}
@ -37,7 +37,7 @@ impl CloudflareClient {
let url = format!(
"{API_BASE}/accounts/{account_id}/ai/run/{}",
self.model.name
self.model.name()
);
debug!("Cloudflare Request: {url} {body}");
@ -50,7 +50,7 @@ impl CloudflareClient {
impl_client_trait!(CloudflareClient, send_message, send_message_streaming);
async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
async fn send_message(builder: RequestBuilder) -> Result<CompletionOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -58,6 +58,7 @@ async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDeta
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
extract_completion(&data)
}
@ -67,6 +68,7 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle
return Ok(true);
}
let data: Value = serde_json::from_str(&message.data)?;
debug!("stream-data: {data}");
if let Some(text) = data["response"].as_str() {
handler.text(text)?;
}
@ -80,11 +82,12 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
messages,
temperature,
top_p,
functions: _,
stream,
} = data;
let mut body = json!({
"model": &model.name,
"model": &model.name(),
"messages": messages,
});
@ -104,10 +107,10 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
Ok(body)
}
fn extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
fn extract_completion(data: &Value) -> Result<CompletionOutput> {
let text = data["result"]["response"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
Ok((text.to_string(), CompletionDetails::default()))
Ok(CompletionOutput::new(text))
}

@ -1,9 +1,9 @@
use super::{
catch_error, extract_system_message, json_stream, message::*, CohereClient, CompletionDetails,
ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData, SseHandler,
catch_error, extract_system_message, json_stream, message::*, CohereClient, CompletionOutput,
ExtraConfig, Model, ModelData, PromptAction, PromptKind, SendData, SseHandler, ToolCall,
};
use anyhow::{anyhow, bail, Result};
use anyhow::{bail, Result};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
@ -15,7 +15,7 @@ pub struct CohereConfig {
pub name: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub models: Vec<ModelData>,
pub extra: Option<ExtraConfig>,
}
@ -42,7 +42,7 @@ impl CohereClient {
impl_client_trait!(CohereClient, send_message, send_message_streaming);
async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
async fn send_message(builder: RequestBuilder) -> Result<CompletionOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -50,6 +50,7 @@ async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDeta
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
extract_completion(&data)
}
@ -62,10 +63,25 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle
} else {
let handle = |data: &str| -> Result<()> {
let data: Value = serde_json::from_str(data)?;
debug!("stream-data: {data}");
if let Some("text-generation") = data["event_type"].as_str() {
if let Some(text) = data["text"].as_str() {
handler.text(text)?;
}
} else if let Some("tool-calls-generation") = data["event_type"].as_str() {
if let Some(tool_calls) = data["tool_calls"].as_array() {
for call in tool_calls {
if let (Some(name), Some(args)) =
(call["name"].as_str(), call["parameters"].as_object())
{
handler.tool_call(ToolCall::new(
name.to_string(),
json!(args),
None,
))?;
}
}
}
}
Ok(())
};
@ -79,24 +95,28 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
mut messages,
temperature,
top_p,
functions,
stream,
} = data;
let system_message = extract_system_message(&mut messages);
let mut image_urls = vec![];
let mut tool_results = None;
let mut messages: Vec<Value> = messages
.into_iter()
.map(|message| {
let role = match message.role {
.filter_map(|message| {
let Message { role, content } = message;
let role = match role {
MessageRole::User => "USER",
_ => "CHATBOT",
};
match message.content {
MessageContent::Text(text) => json!({
match content {
MessageContent::Text(text) => Some(json!({
"role": role,
"message": text,
}),
})),
MessageContent::Array(list) => {
let list: Vec<String> = list
.into_iter()
@ -110,7 +130,11 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
}
})
.collect();
json!({ "role": role, "message": list.join("\n\n") })
Some(json!({ "role": role, "message": list.join("\n\n") }))
}
MessageContent::ToolResults((tool_call_results, _)) => {
tool_results = Some(tool_call_results);
None
}
}
})
@ -123,10 +147,29 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
let message = message["message"].as_str().unwrap_or_default();
let mut body = json!({
"model": &model.name,
"model": &model.name(),
"message": message,
});
if let Some(tool_results) = tool_results {
let tool_results: Vec<_> = tool_results
.into_iter()
.map(|tool_call_result| {
json!({
"call": {
"name": tool_call_result.call.name,
"parameters": tool_call_result.call.arguments,
},
"outputs": [
tool_call_result.output,
]
})
})
.collect();
body["tool_results"] = json!(tool_results);
}
if let Some(v) = system_message {
body["preamble"] = v.into();
}
@ -148,18 +191,60 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
body["stream"] = true.into();
}
if let Some(functions) = functions {
body["tools"] = functions
.iter()
.map(|v| {
let required = v.parameters.required.clone().unwrap_or_default();
let mut parameter_definitions = json!({});
if let Some(properties) = &v.parameters.properties {
for (key, value) in properties {
let mut value: Value = json!(value);
if value.is_object() && required.iter().any(|x| x == key) {
value["required"] = true.into();
}
parameter_definitions[key] = value;
}
}
json!({
"name": v.name,
"description": v.description,
"parameter_definitions": parameter_definitions,
})
})
.collect();
}
Ok(body)
}
fn extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
let text = data["text"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
fn extract_completion(data: &Value) -> Result<CompletionOutput> {
let text = data["text"].as_str().unwrap_or_default();
let mut tool_calls = vec![];
if let Some(calls) = data["tool_calls"].as_array() {
tool_calls = calls
.iter()
.filter_map(|call| {
if let (Some(name), Some(parameters)) =
(call["name"].as_str(), call["parameters"].as_object())
{
Some(ToolCall::new(name.to_string(), json!(parameters), None))
} else {
None
}
})
.collect()
}
let details = CompletionDetails {
if text.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}");
}
let output = CompletionOutput {
text: text.to_string(),
tool_calls,
id: data["generation_id"].as_str().map(|v| v.to_string()),
input_tokens: data["meta"]["billed_units"]["input_tokens"].as_u64(),
output_tokens: data["meta"]["billed_units"]["output_tokens"].as_u64(),
};
Ok((text.to_string(), details))
Ok(output)
}

@ -2,6 +2,7 @@ use super::{openai::OpenAIConfig, BuiltinModels, ClientConfig, Message, Model, S
use crate::{
config::{GlobalConfig, Input},
function::{eval_tool_calls, FunctionDeclaration, ToolCall, ToolCallResult},
render::{render_error, render_stream},
utils::{prompt_input_integer, prompt_input_string, tokenize, AbortSignal, PromptKind},
};
@ -52,7 +53,7 @@ macro_rules! register_client {
pub enum ClientModel {
$(
#[serde(rename = $name)]
$config { models: Vec<ModelConfig> },
$config { models: Vec<ModelData> },
)+
#[serde(other)]
Unknown,
@ -73,7 +74,7 @@ macro_rules! register_client {
pub fn init(global_config: &$crate::config::GlobalConfig, model: &$crate::client::Model) -> Option<Box<dyn Client>> {
let config = global_config.read().clients.iter().find_map(|client_config| {
if let ClientConfig::$config(c) = client_config {
if Self::name(c) == &model.client_name {
if Self::name(c) == model.client_name() {
return Some(c.clone())
}
}
@ -113,24 +114,10 @@ macro_rules! register_client {
None
$(.or_else(|| $client::init(config, &model)))+
.ok_or_else(|| {
anyhow::anyhow!("Unknown client '{}'", model.client_name)
anyhow::anyhow!("Unknown client '{}'", model.client_name())
})
}
pub fn ensure_model_capabilities(client: &mut dyn Client, capabilities: $crate::client::ModelCapabilities) -> anyhow::Result<()> {
if !client.model().capabilities.contains(capabilities) {
let models = client.list_models();
if let Some(model) = models.into_iter().find(|v| v.capabilities.contains(capabilities)) {
client.set_model(model);
} else {
anyhow::bail!(
"The current model is incapable of doing that."
);
}
}
Ok(())
}
pub fn list_client_types() -> Vec<&'static str> {
let mut client_types: Vec<_> = vec![$($client::NAME,)+];
client_types.extend($crate::client::OPENAI_COMPATIBLE_PLATFORMS.iter().map(|(name, _)| *name));
@ -213,7 +200,7 @@ macro_rules! impl_client_trait {
&self,
client: &reqwest::Client,
data: $crate::client::SendData,
) -> anyhow::Result<(String, $crate::client::CompletionDetails)> {
) -> anyhow::Result<$crate::client::CompletionOutput> {
let builder = self.request_builder(client, data)?;
$send_message(builder).await
}
@ -261,14 +248,16 @@ macro_rules! unsupported_model {
pub trait Client: Sync + Send {
fn config(&self) -> (&GlobalConfig, &Option<ExtraConfig>);
fn list_models(&self) -> Vec<Model>;
fn name(&self) -> &str;
#[allow(unused)]
fn list_models(&self) -> Vec<Model>;
fn model(&self) -> &Model;
fn model_mut(&mut self) -> &mut Model;
#[allow(unused)]
fn set_model(&mut self, model: Model);
fn build_client(&self) -> Result<ReqwestClient> {
@ -287,14 +276,15 @@ pub trait Client: Sync + Send {
Ok(client)
}
async fn send_message(&self, input: Input) -> Result<(String, CompletionDetails)> {
async fn send_message(&self, input: Input) -> Result<CompletionOutput> {
let global_config = self.config().0;
if global_config.read().dry_run {
let content = input.echo_messages();
return Ok((content, CompletionDetails::default()));
return Ok(CompletionOutput::new(&content));
}
let client = self.build_client()?;
let data = input.prepare_send_data(false)?;
let data = input.prepare_send_data(self.model(), false)?;
self.send_message_inner(&client, data)
.await
.with_context(|| "Failed to get answer")
@ -324,7 +314,7 @@ pub trait Client: Sync + Send {
return Ok(());
}
let client = self.build_client()?;
let data = input.prepare_send_data(true)?;
let data = input.prepare_send_data(self.model(), true)?;
self.send_message_streaming_inner(&client, handler, data).await
} => {
handler.done()?;
@ -341,7 +331,7 @@ pub trait Client: Sync + Send {
&self,
client: &ReqwestClient,
data: SendData,
) -> Result<(String, CompletionDetails)>;
) -> Result<CompletionOutput>;
async fn send_message_streaming_inner(
&self,
@ -368,16 +358,28 @@ pub struct SendData {
pub messages: Vec<Message>,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub functions: Option<Vec<FunctionDeclaration>>,
pub stream: bool,
}
#[derive(Debug, Clone, Default)]
pub struct CompletionDetails {
pub struct CompletionOutput {
pub text: String,
pub tool_calls: Vec<ToolCall>,
pub id: Option<String>,
pub input_tokens: Option<u64>,
pub output_tokens: Option<u64>,
}
impl CompletionOutput {
pub fn new(text: &str) -> Self {
Self {
text: text.to_string(),
..Default::default()
}
}
}
pub type PromptAction<'a> = (&'a str, &'a str, bool, PromptKind);
pub fn create_config(prompts: &[PromptAction], client: &str) -> Result<(String, Value)> {
@ -429,22 +431,24 @@ pub async fn send_stream(
client: &dyn Client,
config: &GlobalConfig,
abort: AbortSignal,
) -> Result<String> {
) -> Result<(String, Vec<ToolCallResult>)> {
let (tx, rx) = unbounded_channel();
let mut stream_handler = SseHandler::new(tx, abort.clone());
let mut handler = SseHandler::new(tx, abort.clone());
let (send_ret, rend_ret) = tokio::join!(
client.send_message_streaming(input, &mut stream_handler),
client.send_message_streaming(input, &mut handler),
render_stream(rx, config, abort.clone()),
);
if let Err(err) = rend_ret {
render_error(err, config.read().highlight);
}
let output = stream_handler.get_buffer().to_string();
let (output, calls) = handler.take();
match send_ret {
Ok(_) => {
println!();
Ok(output)
if !output.is_empty() && !output.ends_with('\n') {
println!();
}
Ok((output, eval_tool_calls(config, calls)?))
}
Err(err) => {
if !output.is_empty() {

@ -1,7 +1,7 @@
use super::access_token::*;
use super::{
maybe_catch_error, patch_system_message, sse_stream, Client, CompletionDetails, ErnieClient,
ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData, SsMmessage, SseHandler,
maybe_catch_error, patch_system_message, sse_stream, Client, CompletionOutput, ErnieClient,
ExtraConfig, Model, ModelData, PromptAction, PromptKind, SendData, SsMmessage, SseHandler,
};
use anyhow::{anyhow, Context, Result};
@ -20,7 +20,7 @@ pub struct ErnieConfig {
pub api_key: Option<String>,
pub secret_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub models: Vec<ModelData>,
pub extra: Option<ExtraConfig>,
}
@ -36,7 +36,7 @@ impl ErnieClient {
let url = format!(
"{API_BASE}/wenxinworkshop/chat/{}?access_token={access_token}",
&self.model.name,
&self.model.name(),
);
debug!("Ernie Request: {url} {body}");
@ -78,7 +78,7 @@ impl Client for ErnieClient {
&self,
client: &ReqwestClient,
data: SendData,
) -> Result<(String, CompletionDetails)> {
) -> Result<CompletionOutput> {
self.prepare_access_token().await?;
let builder = self.request_builder(client, data)?;
send_message(builder).await
@ -96,15 +96,17 @@ impl Client for ErnieClient {
}
}
async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
async fn send_message(builder: RequestBuilder) -> Result<CompletionOutput> {
let data: Value = builder.send().await?.json().await?;
maybe_catch_error(&data)?;
debug!("non-stream-data: {data}");
extract_completion_text(&data)
}
async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> {
let handle = |message: SsMmessage| -> Result<bool> {
let data: Value = serde_json::from_str(&message.data)?;
debug!("stream-data: {data}");
if let Some(text) = data["result"].as_str() {
handler.text(text)?;
}
@ -119,6 +121,7 @@ fn build_body(data: SendData, model: &Model) -> Value {
mut messages,
temperature,
top_p,
functions: _,
stream,
} = data;
@ -145,16 +148,18 @@ fn build_body(data: SendData, model: &Model) -> Value {
body
}
fn extract_completion_text(data: &Value) -> Result<(String, CompletionDetails)> {
fn extract_completion_text(data: &Value) -> Result<CompletionOutput> {
let text = data["result"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let details = CompletionDetails {
let output = CompletionOutput {
text: text.to_string(),
tool_calls: vec![],
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["prompt_tokens"].as_u64(),
output_tokens: data["usage"]["completion_tokens"].as_u64(),
};
Ok((text.to_string(), details))
Ok(output)
}
async fn fetch_access_token(

@ -1,5 +1,5 @@
use super::vertexai::gemini_build_body;
use super::{ExtraConfig, GeminiClient, Model, ModelConfig, PromptAction, PromptKind, SendData};
use super::{ExtraConfig, GeminiClient, Model, ModelData, PromptAction, PromptKind, SendData};
use anyhow::Result;
use reqwest::{Client as ReqwestClient, RequestBuilder};
@ -14,7 +14,7 @@ pub struct GeminiConfig {
#[serde(rename = "safetySettings")]
pub safety_settings: Option<serde_json::Value>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub models: Vec<ModelData>,
pub extra: Option<ExtraConfig>,
}
@ -34,7 +34,7 @@ impl GeminiClient {
let body = gemini_build_body(data, &self.model, self.config.safety_settings.clone())?;
let model = &self.model.name;
let model = &self.model.name();
let url = format!("{API_BASE}{}:{}?key={}", model, func, api_key);

@ -1,4 +1,4 @@
use crate::config::Input;
use super::ToolResults;
use serde::{Deserialize, Serialize};
@ -8,15 +8,21 @@ pub struct Message {
pub content: MessageContent,
}
impl Message {
pub fn new(input: &Input) -> Self {
impl Default for Message {
fn default() -> Self {
Self {
role: MessageRole::User,
content: input.to_message_content(),
content: MessageContent::Text(String::new()),
}
}
}
impl Message {
pub fn new(role: MessageRole, content: MessageContent) -> Self {
Self { role, content }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum MessageRole {
@ -34,10 +40,6 @@ impl MessageRole {
pub fn is_user(&self) -> bool {
matches!(self, MessageRole::User)
}
pub fn is_assistant(&self) -> bool {
matches!(self, MessageRole::Assistant)
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@ -45,6 +47,8 @@ impl MessageRole {
pub enum MessageContent {
Text(String),
Array(Vec<MessageContentPart>),
// Note: This type is primarily for convenience and does not exist in OpenAI's API.
ToolResults(ToolResults),
}
impl MessageContent {
@ -68,6 +72,7 @@ impl MessageContent {
}
format!(".file {}{}", files.join(" "), concated_text)
}
MessageContent::ToolResults(_) => String::new(),
}
}
@ -83,6 +88,7 @@ impl MessageContent {
*text = replace_fn(text)
}
}
MessageContent::ToolResults(_) => {}
}
}
@ -98,6 +104,7 @@ impl MessageContent {
}
parts.join("\n\n")
}
MessageContent::ToolResults(_) => String::new(),
}
}
}

@ -6,6 +6,7 @@ mod model;
mod prompt_format;
mod sse_handler;
pub use crate::function::{ToolCall, ToolResults};
pub use crate::utils::PromptKind;
pub use common::*;
pub use message::*;

@ -10,15 +10,8 @@ const BASIS_TOKENS: usize = 2;
#[derive(Debug, Clone)]
pub struct Model {
pub client_name: String,
pub name: String,
pub max_input_tokens: Option<usize>,
pub max_output_tokens: Option<isize>,
pub pass_max_tokens: bool,
pub input_price: Option<f64>,
pub output_price: Option<f64>,
pub capabilities: ModelCapabilities,
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
client_name: String,
data: ModelData,
}
impl Default for Model {
@ -31,30 +24,16 @@ impl Model {
pub fn new(client_name: &str, name: &str) -> Self {
Self {
client_name: client_name.into(),
name: name.into(),
max_input_tokens: None,
max_output_tokens: None,
pass_max_tokens: false,
input_price: None,
output_price: None,
capabilities: ModelCapabilities::Text,
extra_fields: None,
data: ModelData::new(name),
}
}
pub fn from_config(client_name: &str, models: &[ModelConfig]) -> Vec<Self> {
pub fn from_config(client_name: &str, models: &[ModelData]) -> Vec<Self> {
models
.iter()
.map(|v| {
let mut model = Model::new(client_name, &v.name);
model
.set_max_input_tokens(v.max_input_tokens)
.set_max_tokens(v.max_output_tokens, v.pass_max_tokens)
.set_input_price(v.input_price)
.set_output_price(v.output_price)
.set_supports_vision(v.supports_vision)
.set_extra_fields(&v.extra_fields);
model
.map(|v| Model {
client_name: client_name.to_string(),
data: v.clone(),
})
.collect()
}
@ -77,7 +56,7 @@ impl Model {
model = Some((*found).clone());
} else if let Some(found) = models.iter().find(|v| v.client_name == client_name) {
let mut found = (*found).clone();
found.name = model_name.to_string();
found.data.name = model_name.to_string();
model = Some(found)
}
}
@ -91,99 +70,101 @@ impl Model {
}
pub fn id(&self) -> String {
format!("{}:{}", self.client_name, self.name)
format!("{}:{}", self.client_name, self.data.name)
}
pub fn client_name(&self) -> &str {
&self.client_name
}
pub fn name(&self) -> &str {
&self.data.name
}
pub fn data(&self) -> &ModelData {
&self.data
}
pub fn data_mut(&mut self) -> &mut ModelData {
&mut self.data
}
pub fn description(&self) -> String {
let max_input_tokens = format_option_value(&self.max_input_tokens);
let max_output_tokens = format_option_value(&self.max_output_tokens);
let input_price = format_option_value(&self.input_price);
let output_price = format_option_value(&self.output_price);
let vision = if self.capabilities.contains(ModelCapabilities::Vision) {
"👁"
} else {
""
let ModelData {
max_input_tokens,
max_output_tokens,
input_price,
output_price,
supports_vision,
supports_function_calling,
..
} = &self.data;
let max_input_tokens = format_option_value(max_input_tokens);
let max_output_tokens = format_option_value(max_output_tokens);
let input_price = format_option_value(input_price);
let output_price = format_option_value(output_price);
let mut capabilities = vec![];
if *supports_vision {
capabilities.push('👁');
};
if *supports_function_calling {
capabilities.push('⚒');
};
let capabilities: String = capabilities
.into_iter()
.map(|v| format!("{v} "))
.collect::<Vec<String>>()
.join("");
format!(
"{:>8} / {:>8} | {:>6} / {:>6} {}",
max_input_tokens, max_output_tokens, input_price, output_price, vision
"{:>8} / {:>8} | {:>6} / {:>6} {:>6}",
max_input_tokens, max_output_tokens, input_price, output_price, capabilities
)
}
pub fn max_input_tokens(&self) -> Option<usize> {
self.data.max_input_tokens
}
pub fn max_output_tokens(&self) -> Option<isize> {
self.data.max_output_tokens
}
pub fn supports_vision(&self) -> bool {
self.capabilities.contains(ModelCapabilities::Vision)
self.data.supports_vision
}
pub fn supports_function_calling(&self) -> bool {
self.data.supports_function_calling
}
pub fn max_tokens_param(&self) -> Option<isize> {
if self.pass_max_tokens {
self.max_output_tokens
if self.data.pass_max_tokens {
self.data.max_output_tokens
} else {
None
}
}
pub fn set_max_input_tokens(&mut self, max_input_tokens: Option<usize>) -> &mut Self {
match max_input_tokens {
None | Some(0) => self.max_input_tokens = None,
_ => self.max_input_tokens = max_input_tokens,
}
self
}
pub fn set_max_tokens(
&mut self,
max_output_tokens: Option<isize>,
pass_max_tokens: bool,
) -> &mut Self {
match max_output_tokens {
None | Some(0) => self.max_output_tokens = None,
_ => self.max_output_tokens = max_output_tokens,
}
self.pass_max_tokens = pass_max_tokens;
self
}
pub fn set_input_price(&mut self, input_price: Option<f64>) -> &mut Self {
match input_price {
None => self.input_price = None,
_ => self.input_price = input_price,
}
self
}
pub fn set_output_price(&mut self, output_price: Option<f64>) -> &mut Self {
match output_price {
None => self.output_price = None,
_ => self.output_price = output_price,
None | Some(0) => self.data.max_output_tokens = None,
_ => self.data.max_output_tokens = max_output_tokens,
}
self
}
pub fn set_supports_vision(&mut self, supports_vision: bool) -> &mut Self {
if supports_vision {
self.capabilities |= ModelCapabilities::Vision;
} else {
self.capabilities &= !ModelCapabilities::Vision;
}
self
}
pub fn set_extra_fields(
&mut self,
extra_fields: &Option<serde_json::Map<String, serde_json::Value>>,
) -> &mut Self {
self.extra_fields.clone_from(extra_fields);
self.data.pass_max_tokens = pass_max_tokens;
self
}
pub fn messages_tokens(&self, messages: &[Message]) -> usize {
messages
.iter()
.map(|v| {
match &v.content {
MessageContent::Text(text) => estimate_token_length(text),
MessageContent::Array(_) => 0, // TODO
}
.map(|v| match &v.content {
MessageContent::Text(text) => estimate_token_length(text),
MessageContent::Array(_) => 0,
MessageContent::ToolResults(_) => 0,
})
.sum()
}
@ -203,7 +184,7 @@ impl Model {
pub fn max_input_tokens_limit(&self, messages: &[Message]) -> Result<()> {
let total_tokens = self.total_tokens(messages) + BASIS_TOKENS;
if let Some(max_input_tokens) = self.max_input_tokens {
if let Some(max_input_tokens) = self.data.max_input_tokens {
if total_tokens >= max_input_tokens {
bail!("Exceed max input tokens limit")
}
@ -212,7 +193,7 @@ impl Model {
}
pub fn merge_extra_fields(&self, body: &mut serde_json::Value) {
if let (Some(body), Some(extra_fields)) = (body.as_object_mut(), &self.extra_fields) {
if let (Some(body), Some(extra_fields)) = (body.as_object_mut(), &self.data.extra_fields) {
for (key, extra_field) in extra_fields {
if body.contains_key(key) {
if let (Some(sub_body), Some(extra_field)) =
@ -232,30 +213,33 @@ impl Model {
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct ModelConfig {
#[derive(Debug, Clone, Default, Deserialize)]
pub struct ModelData {
pub name: String,
pub max_input_tokens: Option<usize>,
pub max_output_tokens: Option<isize>,
#[serde(default)]
pub pass_max_tokens: bool,
pub input_price: Option<f64>,
pub output_price: Option<f64>,
#[serde(default)]
pub supports_vision: bool,
#[serde(default)]
pub pass_max_tokens: bool,
pub supports_function_calling: bool,
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
}
impl ModelData {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
..Default::default()
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct BuiltinModels {
pub platform: String,
pub models: Vec<ModelConfig>,
}
bitflags::bitflags! {
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ModelCapabilities: u32 {
const Text = 0b00000001;
const Vision = 0b00000010;
}
pub models: Vec<ModelData>,
}

@ -1,5 +1,5 @@
use super::{
catch_error, message::*, CompletionDetails, ExtraConfig, Model, ModelConfig, OllamaClient,
catch_error, message::*, CompletionOutput, ExtraConfig, Model, ModelData, OllamaClient,
PromptAction, PromptKind, SendData, SseHandler,
};
@ -15,7 +15,7 @@ pub struct OllamaConfig {
pub api_base: Option<String>,
pub api_auth: Option<String>,
pub chat_endpoint: Option<String>,
pub models: Vec<ModelConfig>,
pub models: Vec<ModelData>,
pub extra: Option<ExtraConfig>,
}
@ -59,17 +59,18 @@ impl OllamaClient {
impl_client_trait!(OllamaClient, send_message, send_message_streaming);
async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
async fn send_message(builder: RequestBuilder) -> Result<CompletionOutput> {
let res = builder.send().await?;
let status = res.status();
let data = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
let text = data["message"]["content"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
Ok((text.to_string(), CompletionDetails::default()))
Ok(CompletionOutput::new(text))
}
async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> {
@ -86,6 +87,7 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle
continue;
}
let data: Value = serde_json::from_slice(&chunk)?;
debug!("stream-data: {data}");
if data["done"].is_boolean() {
if let Some(text) = data["message"]["content"].as_str() {
handler.text(text)?;
@ -103,10 +105,13 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
messages,
temperature,
top_p,
functions: _,
stream,
} = data;
let mut is_tool_call = false;
let mut network_image_urls = vec![];
let messages: Vec<Value> = messages
.into_iter()
.map(|message| {
@ -141,10 +146,18 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
let content = content.join("\n\n");
json!({ "role": role, "content": content, "images": images })
}
MessageContent::ToolResults(_) => {
is_tool_call = true;
json!({ "role": role })
}
}
})
.collect();
if is_tool_call {
bail!("The client does not support function calling",);
}
if !network_image_urls.is_empty() {
bail!(
"The model does not support network images: {:?}",
@ -153,7 +166,7 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
}
let mut body = json!({
"model": &model.name,
"model": &model.name(),
"messages": messages,
"stream": stream,
"options": {},

@ -1,9 +1,9 @@
use super::{
catch_error, sse_stream, CompletionDetails, ExtraConfig, Model, ModelConfig, OpenAIClient,
PromptAction, PromptKind, SendData, SsMmessage, SseHandler,
catch_error, message::*, sse_stream, CompletionOutput, ExtraConfig, Model, ModelData,
OpenAIClient, PromptAction, PromptKind, SendData, SsMmessage, SseHandler, ToolCall,
};
use anyhow::{anyhow, Result};
use anyhow::{bail, Result};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
@ -17,7 +17,7 @@ pub struct OpenAIConfig {
pub api_base: Option<String>,
pub organization_id: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub models: Vec<ModelData>,
pub extra: Option<ExtraConfig>,
}
@ -48,7 +48,7 @@ impl OpenAIClient {
}
}
pub async fn openai_send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
pub async fn openai_send_message(builder: RequestBuilder) -> Result<CompletionOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -56,6 +56,7 @@ pub async fn openai_send_message(builder: RequestBuilder) -> Result<(String, Com
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
openai_extract_completion(&data)
}
@ -63,13 +64,53 @@ pub async fn openai_send_message_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
) -> Result<()> {
let mut function_index = 0;
let mut function_name = String::new();
let mut function_arguments = String::new();
let mut function_id = String::new();
let handle = |message: SsMmessage| -> Result<bool> {
if message.data == "[DONE]" {
if !function_name.is_empty() {
handler.tool_call(ToolCall::new(
function_name.clone(),
json!(function_arguments),
Some(function_id.clone()),
))?;
}
return Ok(true);
}
let data: Value = serde_json::from_str(&message.data)?;
debug!("stream-data: {data}");
if let Some(text) = data["choices"][0]["delta"]["content"].as_str() {
handler.text(text)?;
} else if let (Some(function), index, id) = (
data["choices"][0]["delta"]["tool_calls"][0]["function"].as_object(),
data["choices"][0]["delta"]["tool_calls"][0]["index"].as_u64(),
data["choices"][0]["delta"]["tool_calls"][0]["id"].as_str(),
) {
let index = index.unwrap_or_default();
if index != function_index {
if !function_name.is_empty() {
handler.tool_call(ToolCall::new(
function_name.clone(),
json!(function_arguments),
Some(function_id.clone()),
))?;
}
function_name.clear();
function_arguments.clear();
function_id.clear();
function_index = index;
}
if let Some(name) = function.get("name").and_then(|v| v.as_str()) {
function_name = name.to_string();
}
if let Some(arguments) = function.get("arguments").and_then(|v| v.as_str()) {
function_arguments.push_str(arguments);
}
if let Some(id) = id {
function_id = id.to_string();
}
}
Ok(false)
};
@ -82,11 +123,47 @@ pub fn openai_build_body(data: SendData, model: &Model) -> Value {
messages,
temperature,
top_p,
functions,
stream,
} = data;
let messages: Vec<Value> = messages
.into_iter()
.flat_map(|message| {
let Message { role, content } = message;
match content {
MessageContent::ToolResults((tool_call_results, text)) => {
let tool_calls: Vec<_> = tool_call_results.iter().map(|tool_call_result| {
json!({
"id": tool_call_result.call.id,
"type": "function",
"function": {
"name": tool_call_result.call.name,
"arguments": tool_call_result.call.arguments,
},
})
}).collect();
let mut messages = vec![
json!({ "role": MessageRole::Assistant, "content": text, "tool_calls": tool_calls })
];
for tool_call_result in tool_call_results {
messages.push(
json!({
"role": "tool",
"content": tool_call_result.output.to_string(),
"tool_call_id": tool_call_result.call.id,
})
);
}
messages
},
_ => vec![json!({ "role": role, "content": content })]
}
})
.collect();
let mut body = json!({
"model": &model.name,
"model": &model.name(),
"messages": messages,
});
@ -102,19 +179,59 @@ pub fn openai_build_body(data: SendData, model: &Model) -> Value {
if stream {
body["stream"] = true.into();
}
if let Some(functions) = functions {
body["tools"] = functions
.iter()
.map(|v| {
json!({
"type": "function",
"function": v,
})
})
.collect();
body["tool_choice"] = "auto".into();
}
body
}
pub fn openai_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
pub fn openai_extract_completion(data: &Value) -> Result<CompletionOutput> {
let text = data["choices"][0]["message"]["content"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let details = CompletionDetails {
.unwrap_or_default();
let mut tool_calls = vec![];
if let Some(tools_call) = data["choices"][0]["message"]["tool_calls"].as_array() {
tool_calls = tools_call
.iter()
.filter_map(|call| {
if let (Some(name), Some(arguments), Some(id)) = (
call["function"]["name"].as_str(),
call["function"]["arguments"].as_str(),
call["id"].as_str(),
) {
Some(ToolCall::new(
name.to_string(),
json!(arguments),
Some(id.to_string()),
))
} else {
None
}
})
.collect()
};
if text.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}");
}
let output = CompletionOutput {
text: text.to_string(),
tool_calls,
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["prompt_tokens"].as_u64(),
output_tokens: data["usage"]["completion_tokens"].as_u64(),
};
Ok((text.to_string(), details))
Ok(output)
}
impl_client_trait!(

@ -2,7 +2,7 @@ use crate::client::OPENAI_COMPATIBLE_PLATFORMS;
use super::openai::openai_build_body;
use super::{
ExtraConfig, Model, ModelConfig, OpenAICompatibleClient, PromptAction, PromptKind, SendData,
ExtraConfig, Model, ModelData, OpenAICompatibleClient, PromptAction, PromptKind, SendData,
};
use anyhow::Result;
@ -16,7 +16,7 @@ pub struct OpenAICompatibleConfig {
pub api_key: Option<String>,
pub chat_endpoint: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub models: Vec<ModelData>,
pub extra: Option<ExtraConfig>,
}
@ -44,7 +44,7 @@ impl OpenAICompatibleClient {
match OPENAI_COMPATIBLE_PLATFORMS
.into_iter()
.find_map(|(name, api_base)| {
if name == self.model.client_name {
if name == self.model.client_name() {
Some(api_base.to_string())
} else {
None

@ -108,6 +108,7 @@ pub fn generate_prompt(messages: &[Message], format: PromptFormat) -> anyhow::Re
}
parts.join("\n\n")
}
MessageContent::ToolResults(_) => String::new(),
};
match role {
MessageRole::System => prompt.push_str(&format!(

@ -1,6 +1,6 @@
use super::{
maybe_catch_error, message::*, sse_stream, Client, CompletionDetails, ExtraConfig, Model,
ModelConfig, PromptAction, PromptKind, QianwenClient, SendData, SsMmessage, SseHandler,
maybe_catch_error, message::*, sse_stream, Client, CompletionOutput, ExtraConfig, Model,
ModelData, PromptAction, PromptKind, QianwenClient, SendData, SsMmessage, SseHandler,
};
use crate::utils::{base64_decode, sha256};
@ -26,7 +26,7 @@ pub struct QianwenConfig {
pub name: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub models: Vec<ModelData>,
pub extra: Option<ExtraConfig>,
}
@ -62,7 +62,7 @@ impl QianwenClient {
}
fn is_vl(&self) -> bool {
self.model.name.starts_with("qwen-vl")
self.model.name().starts_with("qwen-vl")
}
}
@ -74,9 +74,9 @@ impl Client for QianwenClient {
&self,
client: &ReqwestClient,
mut data: SendData,
) -> Result<(String, CompletionDetails)> {
) -> Result<CompletionOutput> {
let api_key = self.get_api_key()?;
patch_messages(&self.model.name, &api_key, &mut data.messages).await?;
patch_messages(self.model.name(), &api_key, &mut data.messages).await?;
let builder = self.request_builder(client, data)?;
send_message(builder, self.is_vl()).await
}
@ -88,16 +88,17 @@ impl Client for QianwenClient {
mut data: SendData,
) -> Result<()> {
let api_key = self.get_api_key()?;
patch_messages(&self.model.name, &api_key, &mut data.messages).await?;
patch_messages(self.model.name(), &api_key, &mut data.messages).await?;
let builder = self.request_builder(client, data)?;
send_message_streaming(builder, handler, self.is_vl()).await
}
}
async fn send_message(builder: RequestBuilder, is_vl: bool) -> Result<(String, CompletionDetails)> {
async fn send_message(builder: RequestBuilder, is_vl: bool) -> Result<CompletionOutput> {
let data: Value = builder.send().await?.json().await?;
maybe_catch_error(&data)?;
debug!("non-stream-data: {data}");
extract_completion_text(&data, is_vl)
}
@ -109,6 +110,7 @@ async fn send_message_streaming(
let handle = |message: SsMmessage| -> Result<bool> {
let data: Value = serde_json::from_str(&message.data)?;
maybe_catch_error(&data)?;
debug!("stream-data: {data}");
if is_vl {
if let Some(text) =
data["output"]["choices"][0]["message"]["content"][0]["text"].as_str()
@ -129,10 +131,12 @@ fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool
messages,
temperature,
top_p,
functions: _,
stream,
} = data;
let mut has_upload = false;
let mut is_tool_call = false;
let input = if is_vl {
let messages: Vec<Value> = messages
.into_iter()
@ -154,6 +158,10 @@ fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool
}
})
.collect(),
MessageContent::ToolResults(_) => {
is_tool_call = true;
vec![]
}
};
json!({ "role": role, "content": content })
})
@ -167,6 +175,9 @@ fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool
"messages": messages,
})
};
if is_tool_call {
bail!("The client does not support function calling",);
}
let mut parameters = json!({});
if stream {
@ -184,7 +195,7 @@ fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool
}
let body = json!({
"model": &model.name,
"model": &model.name(),
"input": input,
"parameters": parameters
});
@ -192,7 +203,7 @@ fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool
Ok((body, has_upload))
}
fn extract_completion_text(data: &Value, is_vl: bool) -> Result<(String, CompletionDetails)> {
fn extract_completion_text(data: &Value, is_vl: bool) -> Result<CompletionOutput> {
let err = || anyhow!("Invalid response data: {data}");
let text = if is_vl {
data["output"]["choices"][0]["message"]["content"][0]["text"]
@ -201,13 +212,15 @@ fn extract_completion_text(data: &Value, is_vl: bool) -> Result<(String, Complet
} else {
data["output"]["text"].as_str().ok_or_else(err)?
};
let details = CompletionDetails {
let output = CompletionOutput {
text: text.to_string(),
tool_calls: vec![],
id: data["request_id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["input_tokens"].as_u64(),
output_tokens: data["usage"]["output_tokens"].as_u64(),
};
Ok((text.to_string(), details))
Ok(output)
}
/// Patch messages, upload embedded images to oss

@ -1,9 +1,9 @@
use std::time::Duration;
use super::{
catch_error, generate_prompt, smart_prompt_format, sse_stream, Client, CompletionDetails,
ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, ReplicateClient, SendData,
SsMmessage, SseHandler,
catch_error, generate_prompt, smart_prompt_format, sse_stream, Client, CompletionOutput,
ExtraConfig, Model, ModelData, PromptAction, PromptKind, ReplicateClient, SendData, SsMmessage,
SseHandler,
};
use anyhow::{anyhow, Result};
@ -19,7 +19,7 @@ pub struct ReplicateConfig {
pub name: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub models: Vec<ModelData>,
pub extra: Option<ExtraConfig>,
}
@ -37,7 +37,7 @@ impl ReplicateClient {
) -> Result<RequestBuilder> {
let body = build_body(data, &self.model)?;
let url = format!("{API_BASE}/models/{}/predictions", self.model.name);
let url = format!("{API_BASE}/models/{}/predictions", self.model.name());
debug!("Replicate Request: {url} {body}");
@ -55,7 +55,7 @@ impl Client for ReplicateClient {
&self,
client: &ReqwestClient,
data: SendData,
) -> Result<(String, CompletionDetails)> {
) -> Result<CompletionOutput> {
let api_key = self.get_api_key()?;
let builder = self.request_builder(client, data, &api_key)?;
send_message(client, builder, &api_key).await
@ -77,7 +77,7 @@ async fn send_message(
client: &ReqwestClient,
builder: RequestBuilder,
api_key: &str,
) -> Result<(String, CompletionDetails)> {
) -> Result<CompletionOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -96,6 +96,7 @@ async fn send_message(
.await?
.json()
.await?;
debug!("non-stream-data: {prediction_data}");
let err = || anyhow!("Invalid response data: {prediction_data}");
let status = prediction_data["status"].as_str().ok_or_else(err)?;
if status == "succeeded" {
@ -138,10 +139,11 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
messages,
temperature,
top_p,
functions: _,
stream,
} = data;
let prompt = generate_prompt(&messages, smart_prompt_format(&model.name))?;
let prompt = generate_prompt(&messages, smart_prompt_format(model.name()))?;
let mut input = json!({
"prompt": prompt,
@ -170,7 +172,7 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
Ok(body)
}
fn extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
fn extract_completion(data: &Value) -> Result<CompletionOutput> {
let text = data["output"]
.as_array()
.map(|parts| {
@ -182,11 +184,13 @@ fn extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
})
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let details = CompletionDetails {
let output = CompletionOutput {
text: text.to_string(),
tool_calls: vec![],
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["metrics"]["input_token_count"].as_u64(),
output_tokens: data["metrics"]["output_token_count"].as_u64(),
};
Ok((text.to_string(), details))
Ok(output)
}

@ -3,10 +3,13 @@ use crate::utils::AbortSignal;
use anyhow::{Context, Result};
use tokio::sync::mpsc::UnboundedSender;
use super::ToolCall;
pub struct SseHandler {
sender: UnboundedSender<SseEvent>,
buffer: String,
abort: AbortSignal,
buffer: String,
tool_calls: Vec<ToolCall>,
}
impl SseHandler {
@ -15,11 +18,12 @@ impl SseHandler {
sender,
abort,
buffer: String::new(),
tool_calls: Vec::new(),
}
}
pub fn text(&mut self, text: &str) -> Result<()> {
// debug!("ReplyText: {}", text);
// debug!("HandleText: {}", text);
if text.is_empty() {
return Ok(());
}
@ -33,7 +37,7 @@ impl SseHandler {
}
pub fn done(&mut self) -> Result<()> {
// debug!("ReplyDone");
// debug!("HandleDone");
let ret = self
.sender
.send(SseEvent::Done)
@ -42,14 +46,23 @@ impl SseHandler {
Ok(())
}
pub fn get_buffer(&self) -> &str {
&self.buffer
pub fn tool_call(&mut self, call: ToolCall) -> Result<()> {
// debug!("HandleCall: {:?}", call);
self.tool_calls.push(call);
Ok(())
}
pub fn get_abort(&self) -> AbortSignal {
self.abort.clone()
}
pub fn take(self) -> (String, Vec<ToolCall>) {
let Self {
buffer, tool_calls, ..
} = self;
(buffer, tool_calls)
}
fn safe_ret(&self, ret: Result<()>) -> Result<()> {
if ret.is_err() && self.abort.aborted() {
return Ok(());

@ -1,8 +1,7 @@
use super::access_token::*;
use super::{
catch_error, json_stream, message::*, patch_system_message, Client, CompletionDetails,
ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData, SseHandler,
VertexAIClient,
access_token::*, catch_error, json_stream, message::*, patch_system_message, Client,
CompletionOutput, ExtraConfig, Model, ModelData, PromptAction, PromptKind, SendData,
SseHandler, ToolCall, VertexAIClient,
};
use anyhow::{anyhow, bail, Context, Result};
@ -22,7 +21,7 @@ pub struct VertexAIConfig {
#[serde(rename = "safetySettings")]
pub safety_settings: Option<Value>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub models: Vec<ModelData>,
pub extra: Option<ExtraConfig>,
}
@ -46,7 +45,7 @@ impl VertexAIClient {
true => "streamGenerateContent",
false => "generateContent",
};
let url = format!("{base_url}/google/models/{}:{func}", self.model.name);
let url = format!("{base_url}/google/models/{}:{func}", self.model.name());
let body = gemini_build_body(data, &self.model, self.config.safety_settings.clone())?;
@ -66,7 +65,7 @@ impl Client for VertexAIClient {
&self,
client: &ReqwestClient,
data: SendData,
) -> Result<(String, CompletionDetails)> {
) -> Result<CompletionOutput> {
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
let builder = self.request_builder(client, data)?;
gemini_send_message(builder).await
@ -84,13 +83,14 @@ impl Client for VertexAIClient {
}
}
pub async fn gemini_send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
pub async fn gemini_send_message(builder: RequestBuilder) -> Result<CompletionOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
gemini_extract_completion_text(&data)
}
@ -105,8 +105,28 @@ pub async fn gemini_send_message_streaming(
catch_error(&data, status.as_u16())?;
} else {
let handle = |value: &str| -> Result<()> {
let value: Value = serde_json::from_str(value)?;
handler.text(gemini_extract_text(&value)?)?;
let data: Value = serde_json::from_str(value)?;
debug!("stream-data: {data}");
if let Some(text) = data["candidates"][0]["content"]["parts"][0]["text"].as_str() {
if !text.is_empty() {
handler.text(text)?;
}
} else if let Some("SAFETY") = data["promptFeedback"]["blockReason"]
.as_str()
.or_else(|| data["candidates"][0]["finishReason"].as_str())
{
bail!("Content Blocked")
} else if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() {
for part in parts {
if let (Some(name), Some(args)) = (
part["functionCall"]["name"].as_str(),
part["functionCall"]["args"].as_object(),
) {
handler.tool_call(ToolCall::new(name.to_string(), json!(args), None))?;
}
}
}
Ok(())
};
json_stream(res.bytes_stream(), handle).await?;
@ -114,30 +134,45 @@ pub async fn gemini_send_message_streaming(
Ok(())
}
fn gemini_extract_completion_text(data: &Value) -> Result<(String, CompletionDetails)> {
let text = gemini_extract_text(data)?;
let details = CompletionDetails {
fn gemini_extract_completion_text(data: &Value) -> Result<CompletionOutput> {
let text = data["candidates"][0]["content"]["parts"][0]["text"]
.as_str()
.unwrap_or_default();
let mut tool_calls = vec![];
if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() {
tool_calls = parts
.iter()
.filter_map(|part| {
if let (Some(name), Some(args)) = (
part["functionCall"]["name"].as_str(),
part["functionCall"]["args"].as_object(),
) {
Some(ToolCall::new(name.to_string(), json!(args), None))
} else {
None
}
})
.collect()
}
if text.is_empty() && tool_calls.is_empty() {
if let Some("SAFETY") = data["promptFeedback"]["blockReason"]
.as_str()
.or_else(|| data["candidates"][0]["finishReason"].as_str())
{
bail!("Content Blocked")
} else {
bail!("Invalid response data: {data}");
}
}
let output = CompletionOutput {
text: text.to_string(),
tool_calls,
id: None,
input_tokens: data["usageMetadata"]["promptTokenCount"].as_u64(),
output_tokens: data["usageMetadata"]["candidatesTokenCount"].as_u64(),
};
Ok((text.to_string(), details))
}
fn gemini_extract_text(data: &Value) -> Result<&str> {
match data["candidates"][0]["content"]["parts"][0]["text"].as_str() {
Some(text) => Ok(text),
None => {
if let Some("SAFETY") = data["promptFeedback"]["blockReason"]
.as_str()
.or_else(|| data["candidates"][0]["finishReason"].as_str())
{
bail!("Blocked by safety settingsconsider adjusting `safetySettings` in the client configuration")
} else {
bail!("Invalid response data: {data}")
}
}
}
Ok(output)
}
pub(crate) fn gemini_build_body(
@ -149,6 +184,7 @@ pub(crate) fn gemini_build_body(
mut messages,
temperature,
top_p,
functions,
stream: _,
} = data;
@ -157,34 +193,60 @@ pub(crate) fn gemini_build_body(
let mut network_image_urls = vec![];
let contents: Vec<Value> = messages
.into_iter()
.map(|message| {
let role = match message.role {
.flat_map(|message| {
let Message { role, content } = message;
let role = match role {
MessageRole::User => "user",
_ => "model",
};
match message.content {
MessageContent::Text(text) => json!({
"role": role,
"parts": [{ "text": text }]
}),
MessageContent::Array(list) => {
let list: Vec<Value> = list
.into_iter()
.map(|item| match item {
MessageContentPart::Text { text } => json!({"text": text}),
MessageContentPart::ImageUrl { image_url: ImageUrl { url } } => {
if let Some((mime_type, data)) = url.strip_prefix("data:").and_then(|v| v.split_once(";base64,")) {
json!({ "inline_data": { "mime_type": mime_type, "data": data } })
} else {
network_image_urls.push(url.clone());
json!({ "url": url })
match content {
MessageContent::Text(text) => vec![json!({
"role": role,
"parts": [{ "text": text }]
})],
MessageContent::Array(list) => {
let parts: Vec<Value> = list
.into_iter()
.map(|item| match item {
MessageContentPart::Text { text } => json!({"text": text}),
MessageContentPart::ImageUrl { image_url: ImageUrl { url } } => {
if let Some((mime_type, data)) = url.strip_prefix("data:").and_then(|v| v.split_once(";base64,")) {
json!({ "inline_data": { "mime_type": mime_type, "data": data } })
} else {
network_image_urls.push(url.clone());
json!({ "url": url })
}
},
})
.collect();
vec![json!({ "role": role, "parts": parts })]
},
MessageContent::ToolResults((tool_call_results, _)) => {
let function_call_parts: Vec<Value> = tool_call_results.iter().map(|tool_call_result| {
json!({
"functionCall": {
"name": tool_call_result.call.name,
"args": tool_call_result.call.arguments,
}
})
}).collect();
let function_response_parts: Vec<Value> = tool_call_results.into_iter().map(|tool_call_result| {
json!({
"functionResponse": {
"name": tool_call_result.call.name,
"response": {
"name": tool_call_result.call.name,
"content": tool_call_result.output,
}
}
},
})
.collect();
json!({ "role": role, "parts": list })
})
}).collect();
vec![
json!({ "role": "model", "parts": function_call_parts }),
json!({ "role": "function", "parts": function_response_parts }),
]
}
}
}
})
.collect();
@ -211,6 +273,10 @@ pub(crate) fn gemini_build_body(
body["generationConfig"]["topP"] = v.into();
}
if let Some(functions) = functions {
body["tools"] = json!([{ "functionDeclarations": *functions }]);
}
Ok(body)
}

@ -2,7 +2,7 @@ use super::access_token::*;
use super::claude::{claude_build_body, claude_send_message, claude_send_message_streaming};
use super::vertexai::prepare_gcloud_access_token;
use super::{
Client, CompletionDetails, ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData,
Client, CompletionOutput, ExtraConfig, Model, ModelData, PromptAction, PromptKind, SendData,
SseHandler, VertexAIClaudeClient,
};
@ -18,7 +18,7 @@ pub struct VertexAIClaudeConfig {
pub location: Option<String>,
pub adc_file: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub models: Vec<ModelData>,
pub extra: Option<ExtraConfig>,
}
@ -39,7 +39,7 @@ impl VertexAIClaudeClient {
let base_url = format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers");
let url = format!(
"{base_url}/anthropic/models/{}:streamRawPredict",
self.model.name
self.model.name()
);
let mut body = claude_build_body(data, &self.model)?;
@ -64,7 +64,7 @@ impl Client for VertexAIClaudeClient {
&self,
client: &ReqwestClient,
data: SendData,
) -> Result<(String, CompletionDetails)> {
) -> Result<CompletionOutput> {
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
let builder = self.request_builder(client, data)?;
claude_send_message(builder).await

@ -1,9 +1,10 @@
use super::{role::Role, session::Session, GlobalConfig};
use crate::client::{
init_client, list_models, Client, ImageUrl, Message, MessageContent, MessageContentPart, Model,
ModelCapabilities, SendData,
init_client, list_models, Client, ImageUrl, Message, MessageContent, MessageContentPart,
MessageRole, Model, SendData,
};
use crate::function::{ToolCallResult, ToolResults};
use crate::utils::{base64_encode, sha256};
use anyhow::{bail, Context, Result};
@ -30,6 +31,7 @@ pub struct Input {
text: String,
medias: Vec<String>,
data_urls: HashMap<String, String>,
tool_call: Option<ToolResults>,
context: InputContext,
}
@ -40,6 +42,7 @@ impl Input {
text: text.to_string(),
medias: Default::default(),
data_urls: Default::default(),
tool_call: None,
context: context.unwrap_or_else(|| InputContext::from_config(config)),
}
}
@ -91,6 +94,7 @@ impl Input {
text: texts.join("\n"),
medias,
data_urls,
tool_call: Default::default(),
context: context.unwrap_or_else(|| InputContext::from_config(config)),
})
}
@ -111,6 +115,21 @@ impl Input {
self.text = text;
}
pub fn merge_tool_call(
mut self,
output: String,
tool_call_results: Vec<ToolCallResult>,
) -> Self {
match self.tool_call.as_mut() {
Some(exist_tool_call_results) => {
exist_tool_call_results.0.extend(tool_call_results);
exist_tool_call_results.1 = output;
}
None => self.tool_call = Some((tool_call_results, output)),
}
self
}
pub fn model(&self) -> Model {
let model = self.config.read().model.clone();
if let Some(model_id) = self.role().and_then(|v| v.model_id.clone()) {
@ -130,7 +149,10 @@ impl Input {
init_client(&self.config, Some(self.model()))
}
pub fn prepare_send_data(&self, stream: bool) -> Result<SendData> {
pub fn prepare_send_data(&self, model: &Model, stream: bool) -> Result<SendData> {
if !self.medias.is_empty() && !model.supports_vision() {
bail!("The current model does not support vision.");
}
let messages = self.build_messages()?;
self.config.read().model.max_input_tokens_limit(&messages)?;
let (temperature, top_p) = if let Some(session) = self.session(&self.config.read().session)
@ -142,23 +164,41 @@ impl Input {
let config = self.config.read();
(config.temperature, config.top_p)
};
let mut functions = None;
if self.config.read().function_calling && model.supports_function_calling() {
let config = self.config.read();
let function_filter = if let Some(session) = self.session(&config.session) {
session.function_filter()
} else if let Some(role) = self.role() {
role.function_filter.as_deref()
} else {
None
};
functions = config.function.filtered_declarations(function_filter);
};
Ok(SendData {
messages,
temperature,
top_p,
functions,
stream,
})
}
pub fn build_messages(&self) -> Result<Vec<Message>> {
let messages = if let Some(session) = self.session(&self.config.read().session) {
let mut messages = if let Some(session) = self.session(&self.config.read().session) {
session.build_messages(self)
} else if let Some(role) = self.role() {
role.build_messages(self)
} else {
let message = Message::new(self);
vec![message]
vec![Message::new(MessageRole::User, self.message_content())]
};
if let Some(tool_results) = &self.tool_call {
messages.push(Message::new(
MessageRole::Assistant,
MessageContent::ToolResults(tool_results.clone()),
))
}
Ok(messages)
}
@ -234,7 +274,7 @@ impl Input {
format!(".file {}{}", files.join(" "), text)
}
pub fn to_message_content(&self) -> MessageContent {
pub fn message_content(&self) -> MessageContent {
if self.medias.is_empty() {
MessageContent::Text(self.text.clone())
} else {
@ -257,14 +297,6 @@ impl Input {
MessageContent::Array(list)
}
}
pub fn required_capabilities(&self) -> ModelCapabilities {
if !self.medias.is_empty() {
ModelCapabilities::Vision
} else {
ModelCapabilities::Text
}
}
}
#[derive(Debug, Clone, Default)]

@ -10,6 +10,7 @@ use crate::client::{
create_client_config, list_client_types, list_models, ClientConfig, Model,
OPENAI_COMPATIBLE_PLATFORMS,
};
use crate::function::{Function, ToolCallResult};
use crate::render::{MarkdownRender, RenderOptions};
use crate::utils::{
format_option_value, fuzzy_match, get_env_name, light_theme_from_colorfgbg, now, render_prompt,
@ -41,6 +42,7 @@ const CONFIG_FILE_NAME: &str = "config.yaml";
const ROLES_FILE_NAME: &str = "roles.yaml";
const MESSAGES_FILE_NAME: &str = "messages.md";
const SESSIONS_DIR_NAME: &str = "sessions";
const FUNCTIONS_DIR_NAME: &str = "functions";
const CLIENTS_FIELD: &str = "clients";
@ -69,6 +71,7 @@ pub struct Config {
pub keybindings: Keybindings,
pub prelude: Option<String>,
pub buffer_editor: Option<String>,
pub function_calling: bool,
pub compress_threshold: usize,
pub summarize_prompt: Option<String>,
pub summary_prompt: Option<String>,
@ -84,6 +87,8 @@ pub struct Config {
#[serde(skip)]
pub model: Model,
#[serde(skip)]
pub function: Function,
#[serde(skip)]
pub working_mode: WorkingMode,
#[serde(skip)]
pub last_message: Option<(Input, String)>,
@ -106,6 +111,7 @@ impl Default for Config {
keybindings: Default::default(),
prelude: None,
buffer_editor: None,
function_calling: false,
compress_threshold: 2000,
summarize_prompt: None,
summary_prompt: None,
@ -116,6 +122,7 @@ impl Default for Config {
role: None,
session: None,
model: Default::default(),
function: Default::default(),
working_mode: WorkingMode::Command,
last_message: None,
}
@ -142,6 +149,8 @@ impl Config {
config.set_wrap(&wrap)?;
}
config.function = Function::init(&Self::functions_dir()?)?;
config.working_mode = working_mode;
config.load_roles()?;
@ -212,15 +221,20 @@ impl Config {
Ok(path)
}
pub fn save_message(&mut self, input: Input, output: &str) -> Result<()> {
pub fn save_message(
&mut self,
input: &Input,
output: &str,
tool_call_results: &[ToolCallResult],
) -> Result<()> {
self.last_message = Some((input.clone(), output.to_string()));
if self.dry_run {
if self.dry_run || output.is_empty() || !tool_call_results.is_empty() {
return Ok(());
}
if let Some(session) = input.session_mut(&mut self.session) {
session.add_message(&input, output)?;
session.add_message(input, output)?;
return Ok(());
}
@ -275,6 +289,10 @@ impl Config {
Self::local_path(SESSIONS_DIR_NAME)
}
pub fn functions_dir() -> Result<PathBuf> {
Self::local_path(FUNCTIONS_DIR_NAME)
}
pub fn session_file(name: &str) -> Result<PathBuf> {
let mut path = Self::sessions_dir()?;
path.push(&format!("{name}.yaml"));
@ -294,8 +312,7 @@ impl Config {
pub fn set_role_obj(&mut self, role: Role) -> Result<()> {
if let Some(session) = self.session.as_mut() {
session.guard_empty()?;
session.set_temperature(role.temperature);
session.set_top_p(role.top_p);
session.set_role_properties(&role);
}
if let Some(model_id) = &role.model_id {
self.set_model(model_id)?;
@ -428,6 +445,8 @@ impl Config {
),
("temperature", format_option_value(&temperature)),
("top_p", format_option_value(&top_p)),
("function_calling", self.function_calling.to_string()),
("compress_threshold", self.compress_threshold.to_string()),
("dry_run", self.dry_run.to_string()),
("save", self.save.to_string()),
("save_session", format_option_value(&self.save_session)),
@ -438,11 +457,11 @@ impl Config {
("auto_copy", self.auto_copy.to_string()),
("keybindings", self.keybindings.stringify().into()),
("prelude", format_option_value(&self.prelude)),
("compress_threshold", self.compress_threshold.to_string()),
("config_file", display_path(&Self::config_file()?)),
("roles_file", display_path(&Self::roles_file()?)),
("messages_file", display_path(&Self::messages_file()?)),
("sessions_dir", display_path(&Self::sessions_dir()?)),
("functions_dir", display_path(&Self::functions_dir()?)),
];
let output = items
.iter()
@ -508,6 +527,7 @@ impl Config {
"max_output_tokens",
"temperature",
"top_p",
"function_calling",
"compress_threshold",
"save",
"save_session",
@ -523,10 +543,11 @@ impl Config {
(values, args[0])
} else if args.len() == 2 {
let values = match args[0] {
"max_output_tokens" => match self.model.max_output_tokens {
"max_output_tokens" => match self.model.max_output_tokens() {
Some(v) => vec![v.to_string()],
None => vec![],
},
"function_calling" => complete_bool(self.function_calling),
"save" => complete_bool(self.save),
"save_session" => {
let save_session = if let Some(session) = &self.session {
@ -574,6 +595,10 @@ impl Config {
let value = parse_value(value)?;
self.set_top_p(value);
}
"function_calling" => {
let value = value.parse().with_context(|| "Invalid value")?;
self.function_calling = value;
}
"compress_threshold" => {
let value = parse_value(value)?;
self.set_compress_threshold(value);
@ -792,11 +817,14 @@ impl Config {
fn generate_prompt_context(&self) -> HashMap<&str, String> {
let mut output = HashMap::new();
output.insert("model", self.model.id());
output.insert("client_name", self.model.client_name.clone());
output.insert("model_name", self.model.name.clone());
output.insert("client_name", self.model.client_name().to_string());
output.insert("model_name", self.model.name().to_string());
output.insert(
"max_input_tokens",
self.model.max_input_tokens.unwrap_or_default().to_string(),
self.model
.max_input_tokens()
.unwrap_or_default()
.to_string(),
);
if let Some(temperature) = self.temperature {
if temperature != 0.0 {
@ -884,8 +912,8 @@ impl Config {
}
fn load_config_file(config_path: &Path) -> Result<Self> {
let ctx = || format!("Failed to load config at {}", config_path.display());
let content = read_to_string(config_path).with_context(ctx)?;
let content = read_to_string(config_path)
.with_context(|| format!("Failed to load config at {}", config_path.display()))?;
let config: Self = serde_yaml::from_str(&content).map_err(|err| {
let err_msg = err.to_string();
let err_msg = if err_msg.starts_with(&format!("{}: ", CLIENTS_FIELD)) {

@ -1,4 +1,5 @@
use super::Input;
use crate::{
client::{Message, MessageContent, MessageRole, Model},
utils::{detect_os, detect_shell},
@ -18,10 +19,17 @@ pub const INPUT_PLACEHOLDER: &str = "__INPUT__";
pub struct Role {
pub name: String,
pub prompt: String,
#[serde(rename(serialize = "model", deserialize = "model"))]
#[serde(
rename(serialize = "model", deserialize = "model"),
skip_serializing_if = "Option::is_none"
)]
pub model_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_filter: Option<String>,
}
impl Role {
@ -32,12 +40,13 @@ impl Role {
temperature: None,
model_id: None,
top_p: None,
function_filter: None,
}
}
pub fn builtin() -> Vec<Role> {
[
(SHELL_ROLE, shell_prompt()),
(SHELL_ROLE, shell_prompt(), None),
(
EXPLAIN_SHELL_ROLE,
r#"Provide a terse, single sentence description of the given shell command.
@ -45,6 +54,7 @@ Describe each argument and option of the command.
Provide short responses in about 80 words.
APPLY MARKDOWN formatting when possible."#
.into(),
None,
),
(
CODE_ROLE,
@ -59,15 +69,18 @@ async function timeout(ms) {
```
"#
.into(),
None,
),
("%functions%", String::new(), Some(".*".into())),
]
.into_iter()
.map(|(name, prompt)| Self {
.map(|(name, prompt, function_filter)| Self {
name: name.into(),
prompt,
model_id: None,
temperature: None,
top_p: None,
function_filter,
})
.collect()
}
@ -78,7 +91,11 @@ async function timeout(ms) {
Ok(output.trim_end().to_string())
}
pub fn embedded(&self) -> bool {
pub fn empty_prompt(&self) -> bool {
self.prompt.is_empty()
}
pub fn embedded_prompt(&self) -> bool {
self.prompt.contains(INPUT_PLACEHOLDER)
}
@ -111,7 +128,9 @@ async function timeout(ms) {
pub fn echo_messages(&self, input: &Input) -> String {
let input_markdown = input.render();
if self.embedded() {
if self.empty_prompt() {
input_markdown
} else if self.embedded_prompt() {
self.prompt.replace(INPUT_PLACEHOLDER, &input_markdown)
} else {
format!("{}\n\n{}", self.prompt, input.render())
@ -119,41 +138,30 @@ async function timeout(ms) {
}
pub fn build_messages(&self, input: &Input) -> Vec<Message> {
let mut content = input.to_message_content();
if self.embedded() {
let mut content = input.message_content();
if self.empty_prompt() {
vec![Message::new(MessageRole::User, content)]
} else if self.embedded_prompt() {
content.merge_prompt(|v: &str| self.prompt.replace(INPUT_PLACEHOLDER, v));
vec![Message {
role: MessageRole::User,
content,
}]
vec![Message::new(MessageRole::User, content)]
} else {
let mut messages = vec![];
let (system, cases) = parse_structure_prompt(&self.prompt);
if !system.is_empty() {
messages.push(Message {
role: MessageRole::System,
content: MessageContent::Text(system.to_string()),
})
messages.push(Message::new(
MessageRole::System,
MessageContent::Text(system.to_string()),
));
}
if !cases.is_empty() {
messages.extend(cases.into_iter().flat_map(|(i, o)| {
vec![
Message {
role: MessageRole::User,
content: MessageContent::Text(i.to_string()),
},
Message {
role: MessageRole::Assistant,
content: MessageContent::Text(o.to_string()),
},
Message::new(MessageRole::User, MessageContent::Text(i.to_string())),
Message::new(MessageRole::Assistant, MessageContent::Text(o.to_string())),
]
}));
}
messages.push(Message {
role: MessageRole::User,
content,
});
messages.push(Message::new(MessageRole::User, content));
messages
}
}

@ -1,5 +1,5 @@
use super::input::resolve_data_url;
use super::{Config, Input, Model};
use super::{Config, Input, Model, Role};
use crate::client::{Message, MessageContent, MessageRole};
use crate::render::MarkdownRender;
@ -17,15 +17,20 @@ pub const TEMP_SESSION_NAME: &str = "temp";
pub struct Session {
#[serde(rename(serialize = "model", deserialize = "model"))]
model_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f64>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
function_filter: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
save_session: Option<bool>,
messages: Vec<Message>,
#[serde(default)]
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
data_urls: HashMap<String, String>,
#[serde(default)]
#[serde(default, skip_serializing_if = "Vec::is_empty")]
compressed_messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
compress_threshold: Option<usize>,
#[serde(skip)]
pub name: String,
@ -41,13 +46,14 @@ pub struct Session {
impl Session {
pub fn new(config: &Config, name: &str) -> Self {
Self {
let mut session = Self {
model_id: config.model.id(),
temperature: config.temperature,
top_p: config.top_p,
function_filter: None,
save_session: config.save_session,
messages: vec![],
compressed_messages: vec![],
messages: Default::default(),
compressed_messages: Default::default(),
compress_threshold: None,
data_urls: Default::default(),
name: name.to_string(),
@ -55,7 +61,11 @@ impl Session {
dirty: false,
compressing: false,
model: config.model.clone(),
};
if let Some(role) = &config.role {
session.set_role_properties(role);
}
session
}
pub fn load(name: &str, path: &Path) -> Result<Self> {
@ -86,6 +96,10 @@ impl Session {
self.top_p
}
pub fn function_filter(&self) -> Option<&str> {
self.function_filter.as_deref()
}
pub fn save_session(&self) -> Option<bool> {
self.save_session
}
@ -120,12 +134,15 @@ impl Session {
if let Some(top_p) = self.top_p() {
data["top_p"] = top_p.into();
}
if let Some(function_filter) = self.function_filter() {
data["function_filter"] = function_filter.into();
}
if let Some(save_session) = self.save_session() {
data["save_session"] = save_session.into();
}
data["total_tokens"] = tokens.into();
if let Some(context_window) = self.model.max_input_tokens {
data["max_input_tokens"] = context_window.into();
if let Some(max_input_tokens) = self.model.max_input_tokens() {
data["max_input_tokens"] = max_input_tokens.into();
}
if percent != 0.0 {
data["total/max"] = format!("{}%", percent).into();
@ -153,6 +170,10 @@ impl Session {
items.push(("top_p", top_p.to_string()));
}
if let Some(function_filter) = self.function_filter() {
items.push(("function_filter", function_filter.into()));
}
if let Some(save_session) = self.save_session() {
items.push(("save_session", save_session.to_string()));
}
@ -161,7 +182,7 @@ impl Session {
items.push(("compress_threshold", compress_threshold.to_string()));
}
if let Some(max_input_tokens) = self.model.max_input_tokens {
if let Some(max_input_tokens) = self.model.max_input_tokens() {
items.push(("max_input_tokens", max_input_tokens.to_string()));
}
@ -202,7 +223,7 @@ impl Session {
pub fn tokens_and_percent(&self) -> (usize, f32) {
let tokens = self.tokens();
let max_input_tokens = self.model.max_input_tokens.unwrap_or_default();
let max_input_tokens = self.model.max_input_tokens().unwrap_or_default();
let percent = if max_input_tokens == 0 {
0.0
} else {
@ -226,6 +247,16 @@ impl Session {
}
}
pub fn set_functions(&mut self, function_filter: Option<&str>) {
self.function_filter = function_filter.map(|v| v.to_string());
}
pub fn set_role_properties(&mut self, role: &Role) {
self.set_temperature(role.temperature);
self.set_top_p(role.top_p);
self.set_functions(role.function_filter.as_deref());
}
pub fn set_save_session(&mut self, value: Option<bool>) {
if self.save_session != value {
self.save_session = value;
@ -251,10 +282,10 @@ impl Session {
pub fn compress(&mut self, prompt: String) {
self.compressed_messages.append(&mut self.messages);
self.messages.push(Message {
role: MessageRole::System,
content: MessageContent::Text(prompt),
});
self.messages.push(Message::new(
MessageRole::System,
MessageContent::Text(prompt),
));
self.dirty = true;
}
@ -300,16 +331,14 @@ impl Session {
}
}
if need_add_msg {
self.messages.push(Message {
role: MessageRole::User,
content: input.to_message_content(),
});
self.messages
.push(Message::new(MessageRole::User, input.message_content()));
}
self.data_urls.extend(input.data_urls());
self.messages.push(Message {
role: MessageRole::Assistant,
content: MessageContent::Text(output.to_string()),
});
self.messages.push(Message::new(
MessageRole::Assistant,
MessageContent::Text(output.to_string()),
));
self.dirty = true;
Ok(())
}
@ -340,10 +369,7 @@ impl Session {
.extend(self.compressed_messages[self.compressed_messages.len() - 2..].to_vec());
}
if need_add_msg {
messages.push(Message {
role: MessageRole::User,
content: input.to_message_content(),
});
messages.push(Message::new(MessageRole::User, input.message_content()));
}
messages
}

@ -0,0 +1,367 @@
use crate::{
config::GlobalConfig,
utils::{dimmed_text, indent_text, run_command, run_command_with_output, warning_text},
};
use anyhow::{anyhow, bail, Context, Result};
use fancy_regex::Regex;
use indexmap::{IndexMap, IndexSet};
use inquire::Confirm;
use is_terminal::IsTerminal;
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::{
collections::{HashMap, HashSet},
fs,
io::stdout,
path::Path,
sync::mpsc::channel,
};
use threadpool::ThreadPool;
const BIN_DIR_NAME: &str = "bin";
const DECLARATIONS_FILE_PATH: &str = "functions.json";
lazy_static! {
static ref THREAD_POOL: ThreadPool = ThreadPool::new(num_cpus::get());
}
pub type ToolResults = (Vec<ToolCallResult>, String);
pub fn eval_tool_calls(
config: &GlobalConfig,
mut calls: Vec<ToolCall>,
) -> Result<Vec<ToolCallResult>> {
let mut output = vec![];
if calls.is_empty() {
return Ok(output);
}
calls = ToolCall::dedup(calls);
let parallel = calls.len() > 1 && calls.iter().all(|v| !v.is_execute_type());
if parallel {
let (tx, rx) = channel();
let calls_len = calls.len();
for (index, call) in calls.into_iter().enumerate() {
let tx = tx.clone();
let config = config.clone();
THREAD_POOL.execute(move || {
let result = call.eval(&config);
let _ = tx.send((index, call, result));
});
}
let mut list: Vec<(usize, ToolCall, Result<Value>)> = rx.iter().take(calls_len).collect();
list.sort_by_key(|v| v.0);
for (_, call, result) in list {
output.push(ToolCallResult::new(call, result?));
}
} else {
for call in calls {
let result = call.eval(config)?;
output.push(ToolCallResult::new(call, result));
}
}
Ok(output)
}
pub fn need_send_call_results(arr: &[ToolCallResult]) -> bool {
arr.iter().any(|v| !v.output.is_null())
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolCallResult {
pub call: ToolCall,
pub output: Value,
}
impl ToolCallResult {
pub fn new(call: ToolCall, output: Value) -> Self {
Self { call, output }
}
}
#[derive(Debug, Clone, Default)]
pub struct Function {
names: IndexSet<String>,
declarations: Vec<FunctionDeclaration>,
#[cfg(windows)]
bin_dir: std::path::PathBuf,
env_path: Option<String>,
}
impl Function {
pub fn init(functions_dir: &Path) -> Result<Self> {
let bin_dir = functions_dir.join(BIN_DIR_NAME);
let env_path = if bin_dir.exists() {
prepend_env_path(&bin_dir).ok()
} else {
None
};
let declarations_file = functions_dir.join(DECLARATIONS_FILE_PATH);
let declarations: Vec<FunctionDeclaration> = if declarations_file.exists() {
let ctx = || {
format!(
"Failed to load function declarations at {}",
declarations_file.display()
)
};
let content = fs::read_to_string(&declarations_file).with_context(ctx)?;
serde_json::from_str(&content).with_context(ctx)?
} else {
vec![]
};
let func_names = declarations.iter().map(|v| v.name.clone()).collect();
Ok(Self {
names: func_names,
declarations,
#[cfg(windows)]
bin_dir,
env_path,
})
}
pub fn filtered_declarations(&self, filter: Option<&str>) -> Option<Vec<FunctionDeclaration>> {
let filter = filter?;
let regex = Regex::new(&format!("^({filter})$")).ok()?;
let output: Vec<FunctionDeclaration> = self
.declarations
.iter()
.filter(|v| regex.is_match(&v.name).unwrap_or_default())
.cloned()
.collect();
if output.is_empty() {
None
} else {
Some(output)
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct FunctionConfig {
pub enable: bool,
pub declarations_file: String,
pub functions_dir: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDeclaration {
pub name: String,
pub description: String,
pub parameters: JsonSchema,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonSchema {
#[serde(rename = "type")]
pub type_value: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub properties: Option<IndexMap<String, JsonSchema>>,
#[serde(rename = "enum", skip_serializing_if = "Option::is_none")]
pub enum_value: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<Vec<String>>,
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct ToolCall {
pub name: String,
pub arguments: Value,
pub id: Option<String>,
}
impl ToolCall {
pub fn dedup(calls: Vec<Self>) -> Vec<Self> {
let mut new_calls = vec![];
let mut seen_ids = HashSet::new();
for call in calls.into_iter().rev() {
if let Some(id) = &call.id {
if !seen_ids.contains(id) {
seen_ids.insert(id.clone());
new_calls.push(call);
}
} else {
new_calls.push(call);
}
}
new_calls.reverse();
new_calls
}
pub fn new(name: String, arguments: Value, id: Option<String>) -> Self {
Self {
name,
arguments,
id,
}
}
pub fn eval(&self, config: &GlobalConfig) -> Result<Value> {
let name = self.name.clone();
if !config.read().function.names.contains(&name) {
bail!("Unexpected call: {name} {}", self.arguments);
}
let arguments = if self.arguments.is_object() {
self.arguments.clone()
} else if let Some(arguments) = self.arguments.as_str() {
let args: Value = serde_json::from_str(arguments)
.map_err(|_| anyhow!("The {name} call has invalid arguments: {arguments}"))?;
args
} else {
bail!("The {name} call has invalid arguments: {}", self.arguments);
};
let arguments = convert_arguments(&arguments);
let prompt_text = format!(
"Call {} {}",
name,
arguments
.iter()
.map(|v| shell_words::quote(v).to_string())
.collect::<Vec<String>>()
.join(" ")
);
let envs = if let Some(env_path) = config.read().function.env_path.clone() {
let mut envs = HashMap::new();
envs.insert("PATH".into(), env_path);
Some(envs)
} else {
None
};
let output = if self.is_execute_type() {
let proceed = if stdout().is_terminal() {
Confirm::new(&prompt_text).with_default(true).prompt()?
} else {
println!("{}", dimmed_text(&prompt_text));
true
};
if proceed {
#[cfg(windows)]
let name = polyfill_cmd_name(&name, &config.read().function.bin_dir);
run_command(&name, &arguments, envs)?;
}
Value::Null
} else {
println!("{}", dimmed_text(&prompt_text));
#[cfg(windows)]
let name = polyfill_cmd_name(&name, &config.read().function.bin_dir);
let (success, stdout, stderr) = run_command_with_output(&name, &arguments, envs)?;
if success {
if !stderr.is_empty() {
eprintln!(
"{}",
warning_text(&format!("{prompt_text}:\n{}", indent_text(&stderr, 4)))
);
}
if !stdout.is_empty() {
serde_json::from_str(&stdout)
.ok()
.unwrap_or_else(|| json!({"output": stdout}))
} else {
Value::Null
}
} else {
let err = if stderr.is_empty() {
if stdout.is_empty() {
"Something wrong"
} else {
&stdout
}
} else {
&stderr
};
bail!("{}", &format!("{prompt_text}:\n{}", indent_text(err, 4)));
}
};
Ok(output)
}
pub fn is_execute_type(&self) -> bool {
self.name.starts_with("may_") || self.name.contains("__may_")
}
}
fn convert_arguments(args: &Value) -> Vec<String> {
let mut options: Vec<String> = Vec::new();
if let Value::Object(map) = args {
for (key, value) in map {
let key = key.replace('_', "-");
match value {
Value::Bool(true) => {
options.push(format!("--{key}"));
}
Value::String(s) => {
options.push(format!("--{key}"));
options.push(s.to_string());
}
Value::Array(arr) => {
for item in arr {
if let Value::String(s) = item {
options.push(format!("--{key}"));
options.push(s.to_string());
}
}
}
_ => {} // Ignore other types
}
}
}
options
}
fn prepend_env_path(bin_dir: &Path) -> Result<String> {
let current_path = std::env::var("PATH").context("No PATH environment variable")?;
let new_path = if cfg!(target_os = "windows") {
format!("{};{}", bin_dir.display(), current_path)
} else {
format!("{}:{}", bin_dir.display(), current_path)
};
Ok(new_path)
}
#[cfg(windows)]
fn polyfill_cmd_name(name: &str, bin_dir: &std::path::Path) -> String {
let mut name = name.to_string();
if let Ok(exts) = std::env::var("PATHEXT") {
if let Some(cmd_path) = exts
.split(';')
.map(|ext| bin_dir.join(format!("{}{}", name, ext)))
.find(|path| path.exists())
{
name = cmd_path.display().to_string();
}
}
name
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_convert_args() {
let args = serde_json::json!({
"foo": true,
"bar": "val",
"baz": ["v1", "v2"]
});
assert_eq!(
convert_arguments(&args),
vec!["--foo", "--bar", "val", "--baz", "v1", "--baz", "v2"]
);
}
}

@ -1,6 +1,7 @@
mod cli;
mod client;
mod config;
mod function;
mod logger;
mod render;
mod repl;
@ -12,17 +13,20 @@ mod utils;
extern crate log;
use crate::cli::Cli;
use crate::client::{ensure_model_capabilities, list_models, send_stream};
use crate::client::{list_models, send_stream, CompletionOutput};
use crate::config::{
Config, GlobalConfig, Input, InputContext, WorkingMode, CODE_ROLE, EXPLAIN_SHELL_ROLE,
SHELL_ROLE,
};
use crate::function::eval_tool_calls;
use crate::render::{render_error, MarkdownRender};
use crate::repl::Repl;
use crate::utils::{create_abort_signal, extract_block, run_command, run_spinner, CODE_BLOCK_RE};
use anyhow::{bail, Result};
use async_recursion::async_recursion;
use clap::Parser;
use function::need_send_call_results;
use inquire::{Select, Text};
use is_terminal::IsTerminal;
use parking_lot::RwLock;
@ -30,6 +34,7 @@ use std::io::{stderr, stdin, stdout, Read};
use std::process;
use std::sync::Arc;
use tokio::sync::oneshot;
use utils::detect_shell;
#[tokio::main]
async fn main() -> Result<()> {
@ -113,7 +118,8 @@ async fn main() -> Result<()> {
bail!("No input");
}
let input = create_input(&config, text, file)?;
execute(&config, input).await?;
let (_, shell, shell_arg) = detect_shell();
shell_execute(&config, &shell, shell_arg, input).await?;
return Ok(());
}
config.write().apply_prelude()?;
@ -130,39 +136,56 @@ async fn main() -> Result<()> {
Ok(())
}
#[async_recursion]
async fn start_directive(
config: &GlobalConfig,
input: Input,
no_stream: bool,
code_mode: bool,
) -> Result<()> {
let mut client = input.create_client()?;
ensure_model_capabilities(client.as_mut(), input.required_capabilities())?;
let client = input.create_client()?;
let is_terminal_stdout = stdout().is_terminal();
let extract_code = !is_terminal_stdout && code_mode;
let output = if no_stream || extract_code {
let (output, _) = client.send_message(input.clone()).await?;
let output = if extract_code && output.trim_start().starts_with("```") {
extract_block(&output)
} else {
output.clone()
};
if is_terminal_stdout {
let render_options = config.read().get_render_options()?;
let mut markdown_render = MarkdownRender::init(render_options)?;
println!("{}", markdown_render.render(&output).trim());
let (output, tool_call_results) = if no_stream || extract_code {
let CompletionOutput {
text, tool_calls, ..
} = client.send_message(input.clone()).await?;
if !tool_calls.is_empty() {
(String::new(), eval_tool_calls(config, tool_calls)?)
} else {
println!("{}", output);
let text = if extract_code && text.trim_start().starts_with("```") {
extract_block(&text)
} else {
text.clone()
};
if is_terminal_stdout {
let render_options = config.read().get_render_options()?;
let mut markdown_render = MarkdownRender::init(render_options)?;
println!("{}", markdown_render.render(&text).trim());
} else {
println!("{}", text);
}
(text, vec![])
}
output
} else {
let abort = create_abort_signal();
send_stream(&input, client.as_ref(), config, abort).await?
};
// Save the message/session
config.write().save_message(input, &output)?;
config
.write()
.save_message(&input, &output, &tool_call_results)?;
config.write().end_session()?;
Ok(())
if need_send_call_results(&tool_call_results) {
start_directive(
config,
input.merge_tool_call(output, tool_call_results),
no_stream,
code_mode,
)
.await
} else {
Ok(())
}
}
async fn start_interactive(config: &GlobalConfig) -> Result<()> {
@ -171,7 +194,12 @@ async fn start_interactive(config: &GlobalConfig) -> Result<()> {
}
#[async_recursion::async_recursion]
async fn execute(config: &GlobalConfig, mut input: Input) -> Result<()> {
async fn shell_execute(
config: &GlobalConfig,
shell: &str,
shell_arg: &str,
mut input: Input,
) -> Result<()> {
let client = input.create_client()?;
let is_terminal_stdout = stdout().is_terminal();
let ret = if is_terminal_stdout {
@ -183,11 +211,11 @@ async fn execute(config: &GlobalConfig, mut input: Input) -> Result<()> {
} else {
client.send_message(input.clone()).await
};
let (mut eval_str, _) = ret?;
let mut eval_str = ret?.text;
if let Ok(true) = CODE_BLOCK_RE.is_match(&eval_str) {
eval_str = extract_block(&eval_str);
}
config.write().save_message(input.clone(), &eval_str)?;
config.write().save_message(&input, &eval_str, &[])?;
config.read().maybe_copy(&eval_str);
let render_options = config.read().get_render_options()?;
let mut markdown_render = MarkdownRender::init(render_options)?;
@ -205,7 +233,8 @@ async fn execute(config: &GlobalConfig, mut input: Input) -> Result<()> {
match answer {
"✅ Execute" => {
let code = run_command(&eval_str)?;
debug!("{} {:?}", shell, &[shell_arg, &eval_str]);
let code = run_command(shell, &[shell_arg, &eval_str], None)?;
if code != 0 {
process::exit(code);
}
@ -214,7 +243,7 @@ async fn execute(config: &GlobalConfig, mut input: Input) -> Result<()> {
let revision = Text::new("Enter your revision:").prompt()?;
let text = format!("{}\n{revision}", input.text());
input.set_text(text);
return execute(config, input).await;
return shell_execute(config, shell, shell_arg, input).await;
}
"📙 Explain" => {
let role = config.read().retrieve_role(EXPLAIN_SHELL_ROLE)?;

@ -4,12 +4,11 @@ mod stream;
pub use self::markdown::{MarkdownRender, RenderOptions};
use self::stream::{markdown_stream, raw_stream};
use crate::utils::AbortSignal;
use crate::utils::{error_text, AbortSignal};
use crate::{client::SseEvent, config::GlobalConfig};
use anyhow::Result;
use is_terminal::IsTerminal;
use nu_ansi_term::{Color, Style};
use std::io::stdout;
use tokio::sync::mpsc::UnboundedReceiver;
@ -30,8 +29,7 @@ pub async fn render_stream(
pub fn render_error(err: anyhow::Error, highlight: bool) {
let err = format!("{err:?}");
if highlight {
let style = Style::new().fg(Color::Red);
eprintln!("{}", style.paint(err));
eprintln!("{}", error_text(&err));
} else {
eprintln!("{err}");
}

@ -6,12 +6,14 @@ use self::completer::ReplCompleter;
use self::highlighter::ReplHighlighter;
use self::prompt::ReplPrompt;
use crate::client::{ensure_model_capabilities, send_stream};
use crate::client::send_stream;
use crate::config::{GlobalConfig, Input, InputContext, State};
use crate::function::need_send_call_results;
use crate::render::render_error;
use crate::utils::{create_abort_signal, set_text, AbortSignal};
use anyhow::{bail, Context, Result};
use async_recursion::async_recursion;
use fancy_regex::Regex;
use lazy_static::lazy_static;
use nu_ansi_term::Color;
@ -184,7 +186,7 @@ impl Repl {
text.trim(),
Some(InputContext::role(role)),
);
self.ask(input).await?;
ask(&self.config, self.abort.clone(), input).await?;
}
None => {
self.config.write().set_role(args)?;
@ -226,7 +228,7 @@ impl Repl {
let (files, text) = split_files_text(args);
let files = shell_words::split(files).with_context(|| "Invalid args")?;
let input = Input::new(&self.config, text, files, None)?;
self.ask(input).await?;
ask(&self.config, self.abort.clone(), input).await?;
}
None => println!("Usage: .file <files>... [-- <text>...]"),
},
@ -252,7 +254,7 @@ impl Repl {
},
None => {
let input = Input::from_str(&self.config, line, None);
self.ask(input).await?;
ask(&self.config, self.abort.clone(), input).await?;
}
}
@ -261,41 +263,6 @@ impl Repl {
Ok(false)
}
async fn ask(&self, input: Input) -> Result<()> {
if input.is_empty() {
return Ok(());
}
while self.config.read().is_compressing_session() {
std::thread::sleep(std::time::Duration::from_millis(100));
}
let mut client = input.create_client()?;
ensure_model_capabilities(client.as_mut(), input.required_capabilities())?;
let output = send_stream(&input, client.as_ref(), &self.config, self.abort.clone()).await?;
self.config.write().save_message(input, &output)?;
self.config.read().maybe_copy(&output);
if self.config.write().should_compress_session() {
let config = self.config.clone();
let color = if config.read().light_theme {
Color::LightGray
} else {
Color::DarkGray
};
print!(
"\n📢 {}{}{}\n",
color.normal().paint(
"Session compression is being activated because the current tokens exceed `"
),
color.italic().paint("compress_threshold"),
color.normal().paint("`."),
);
tokio::spawn(async move {
let _ = compress_session(&config).await;
config.write().end_compressing_session();
});
}
Ok(())
}
fn banner(&self) {
let version = env!("CARGO_PKG_VERSION");
print!(
@ -416,6 +383,53 @@ impl Validator for ReplValidator {
}
}
#[async_recursion]
async fn ask(config: &GlobalConfig, abort: AbortSignal, input: Input) -> Result<()> {
if input.is_empty() {
return Ok(());
}
while config.read().is_compressing_session() {
std::thread::sleep(std::time::Duration::from_millis(100));
}
let client = input.create_client()?;
let (output, tool_call_results) =
send_stream(&input, client.as_ref(), config, abort.clone()).await?;
config
.write()
.save_message(&input, &output, &tool_call_results)?;
config.read().maybe_copy(&output);
if config.write().should_compress_session() {
let config = config.clone();
let color = if config.read().light_theme {
Color::LightGray
} else {
Color::DarkGray
};
print!(
"\n📢 {}{}{}\n",
color.normal().paint(
"Session compression is being activated because the current tokens exceed `"
),
color.italic().paint("compress_threshold"),
color.normal().paint("`."),
);
tokio::spawn(async move {
let _ = compress_session(&config).await;
config.write().end_compressing_session();
});
}
if need_send_call_results(&tool_call_results) {
ask(
config,
abort,
input.merge_tool_call(output, tool_call_results),
)
.await
} else {
Ok(())
}
}
fn unknown_command() -> Result<()> {
bail!(r#"Unknown command. Type ".help" for additional help."#);
}
@ -449,9 +463,8 @@ fn parse_command(line: &str) -> Option<(&str, Option<&str>)> {
async fn compress_session(config: &GlobalConfig) -> Result<()> {
let input = Input::from_str(config, config.read().summarize_prompt(), None);
let mut client = input.create_client()?;
ensure_model_capabilities(client.as_mut(), input.required_capabilities())?;
let (summary, _) = client.send_message(input).await?;
let client = input.create_client()?;
let summary = client.send_message(input).await?.text;
config.write().compress_session(&summary);
Ok(())
}

@ -1,7 +1,7 @@
use crate::{
client::{
init_client, list_models, ClientConfig, CompletionDetails, Message, Model, SendData,
SseEvent, SseHandler,
init_client, list_models, ClientConfig, CompletionOutput, Message, Model, ModelData,
SendData, SseEvent, SseHandler,
},
config::{Config, GlobalConfig, Role},
utils::create_abort_signal,
@ -78,7 +78,7 @@ impl Server {
let roles = config.roles.clone();
let mut models = list_models(&config);
let mut default_model = model.clone();
default_model.name = DEFAULT_MODEL_NAME.into();
default_model.data_mut().name = DEFAULT_MODEL_NAME.into();
models.insert(0, &default_model);
let models: Vec<Value> = models
.into_iter()
@ -89,14 +89,25 @@ impl Server {
} else {
model.id()
};
let ModelData {
max_input_tokens,
max_output_tokens,
pass_max_tokens,
input_price,
output_price,
supports_vision,
supports_function_calling,
..
} = model.data();
json!({
"id": id,
"max_input_tokens": model.max_input_tokens,
"max_output_tokens": model.max_output_tokens,
"pass_max_tokens": model.pass_max_tokens,
"input_price": model.input_price,
"output_price": model.output_price,
"supports_vision": model.supports_vision(),
"max_input_tokens": max_input_tokens,
"max_output_tokens": max_output_tokens,
"pass_max_tokens": pass_max_tokens,
"input_price": input_price,
"output_price": output_price,
"supports_vision": supports_vision,
"supports_function_calling": supports_function_calling,
})
})
.collect();
@ -263,6 +274,7 @@ impl Server {
messages,
temperature,
top_p,
functions: None,
stream,
};
@ -338,7 +350,7 @@ impl Server {
.body(BodyExt::boxed(StreamBody::new(stream)))?;
Ok(res)
} else {
let (content, details) = client.send_message_inner(&http_client, send_data).await?;
let output = client.send_message_inner(&http_client, send_data).await?;
let res = Response::builder()
.header("Content-Type", "application/json")
.body(
@ -346,8 +358,7 @@ impl Server {
&completion_id,
&model_name,
created,
&content,
&details,
&output,
))
.boxed(),
)?;
@ -439,16 +450,10 @@ fn create_frame(id: &str, model: &str, created: i64, content: &str, done: bool)
Frame::data(Bytes::from(output))
}
fn ret_non_stream(
id: &str,
model: &str,
created: i64,
content: &str,
details: &CompletionDetails,
) -> Bytes {
let id = details.id.as_deref().unwrap_or(id);
let input_tokens = details.input_tokens.unwrap_or_default();
let output_tokens = details.output_tokens.unwrap_or_default();
fn ret_non_stream(id: &str, model: &str, created: i64, output: &CompletionOutput) -> Bytes {
let id = output.id.as_deref().unwrap_or(id);
let input_tokens = output.input_tokens.unwrap_or_default();
let output_tokens = output.output_tokens.unwrap_or_default();
let total_tokens = input_tokens + output_tokens;
let res_body = json!({
"id": id,
@ -460,7 +465,7 @@ fn ret_non_stream(
"index": 0,
"message": {
"role": "assistant",
"content": content,
"content": output.text,
},
"logprobs": null,
"finish_reason": "stop",

@ -0,0 +1,82 @@
use std::{collections::HashMap, env, ffi::OsStr, process::Command};
use anyhow::{Context, Result};
pub fn detect_os() -> String {
let os = env::consts::OS;
if os == "linux" {
if let Ok(contents) = std::fs::read_to_string("/etc/os-release") {
for line in contents.lines() {
if let Some(id) = line.strip_prefix("ID=") {
return format!("{os}/{id}");
}
}
}
}
os.to_string()
}
pub fn detect_shell() -> (String, String, &'static str) {
let os = env::consts::OS;
if os == "windows" {
if env::var("NU_VERSION").is_ok() {
("nushell".into(), "nu.exe".into(), "-c")
} else if let Some(ret) = env::var("PSModulePath").ok().and_then(|v| {
let v = v.to_lowercase();
if v.split(';').count() >= 3 {
if v.contains("powershell\\7\\") {
Some(("pwsh".into(), "pwsh.exe".into(), "-c"))
} else {
Some(("powershell".into(), "powershell.exe".into(), "-Command"))
}
} else {
None
}
}) {
ret
} else {
("cmd".into(), "cmd.exe".into(), "/C")
}
} else if env::var("NU_VERSION").is_ok() {
("nushell".into(), "nu".into(), "-c")
} else {
let shell_cmd = env::var("SHELL").unwrap_or_else(|_| "/bin/sh".to_string());
let shell_name = match shell_cmd.rsplit_once('/') {
Some((_, name)) => name.to_string(),
None => shell_cmd.clone(),
};
let shell_name = if shell_name == "nu" {
"nushell".into()
} else {
shell_name
};
(shell_name, shell_cmd, "-c")
}
}
pub fn run_command<T: AsRef<OsStr>>(
cmd: &str,
args: &[T],
envs: Option<HashMap<String, String>>,
) -> Result<i32> {
let status = Command::new(cmd)
.args(args.iter())
.envs(envs.unwrap_or_default())
.status()?;
Ok(status.code().unwrap_or_default())
}
pub fn run_command_with_output<T: AsRef<OsStr>>(
cmd: &str,
args: &[T],
envs: Option<HashMap<String, String>>,
) -> Result<(bool, String, String)> {
let output = Command::new(cmd)
.args(args.iter())
.envs(envs.unwrap_or_default())
.output()?;
let status = output.status;
let stdout = std::str::from_utf8(&output.stdout).context("Invalid UTF-8 in stdout")?;
let stderr = std::str::from_utf8(&output.stderr).context("Invalid UTF-8 in stderr")?;
Ok((status.success(), stdout.to_string(), stderr.to_string()))
}

@ -1,5 +1,6 @@
mod abort_signal;
mod clipboard;
mod command;
mod crypto;
mod prompt_input;
mod render_prompt;
@ -7,6 +8,7 @@ mod spinner;
pub use self::abort_signal::{create_abort_signal, AbortSignal};
pub use self::clipboard::set_text;
pub use self::command::*;
pub use self::crypto::*;
pub use self::prompt_input::*;
pub use self::render_prompt::render_prompt;
@ -15,7 +17,6 @@ pub use self::spinner::run_spinner;
use fancy_regex::Regex;
use lazy_static::lazy_static;
use std::env;
use std::process::Command;
lazy_static! {
pub static ref CODE_BLOCK_RE: Regex = Regex::new(r"(?ms)```\w*(.*)```").unwrap();
@ -79,67 +80,6 @@ pub fn light_theme_from_colorfgbg(colorfgbg: &str) -> Option<bool> {
Some(light)
}
pub fn detect_os() -> String {
let os = env::consts::OS;
if os == "linux" {
if let Ok(contents) = std::fs::read_to_string("/etc/os-release") {
for line in contents.lines() {
if let Some(id) = line.strip_prefix("ID=") {
return format!("{os}/{id}");
}
}
}
}
os.to_string()
}
pub fn detect_shell() -> (String, String, &'static str) {
let os = env::consts::OS;
if os == "windows" {
if env::var("NU_VERSION").is_ok() {
("nushell".into(), "nu.exe".into(), "-c")
} else if let Some(ret) = env::var("PSModulePath").ok().and_then(|v| {
let v = v.to_lowercase();
if v.split(';').count() >= 3 {
if v.contains("powershell\\7\\") {
Some(("pwsh".into(), "pwsh.exe".into(), "-c"))
} else {
Some(("powershell".into(), "powershell.exe".into(), "-Command"))
}
} else {
None
}
}) {
ret
} else {
("cmd".into(), "cmd.exe".into(), "/C")
}
} else if env::var("NU_VERSION").is_ok() {
("nushell".into(), "nu".into(), "-c")
} else {
let shell_cmd = env::var("SHELL").unwrap_or_else(|_| "/bin/sh".to_string());
let shell_name = match shell_cmd.rsplit_once('/') {
Some((_, name)) => name.to_string(),
None => shell_cmd.clone(),
};
let shell_name = if shell_name == "nu" {
"nushell".into()
} else {
shell_name
};
(shell_name, shell_cmd, "-c")
}
}
pub fn run_command(eval_str: &str) -> anyhow::Result<i32> {
let (_shell_name, shell_cmd, shell_arg) = detect_shell();
let status = Command::new(shell_cmd)
.arg(shell_arg)
.arg(eval_str)
.status()?;
Ok(status.code().unwrap_or_default())
}
pub fn extract_block(input: &str) -> String {
let output: String = CODE_BLOCK_RE
.captures_iter(input)
@ -183,6 +123,32 @@ pub fn fuzzy_match(text: &str, pattern: &str) -> bool {
pattern_index == pattern_chars.len()
}
pub fn error_text(input: &str) -> String {
nu_ansi_term::Style::new()
.fg(nu_ansi_term::Color::Red)
.paint(input)
.to_string()
}
pub fn warning_text(input: &str) -> String {
nu_ansi_term::Style::new()
.fg(nu_ansi_term::Color::Yellow)
.paint(input)
.to_string()
}
pub fn dimmed_text(input: &str) -> String {
nu_ansi_term::Style::new().dimmed().paint(input).to_string()
}
pub fn indent_text(text: &str, spaces: usize) -> String {
let indent_size = " ".repeat(spaces);
text.lines()
.map(|line| format!("{}{}", indent_size, line))
.collect::<Vec<String>>()
.join("\n")
}
#[cfg(test)]
mod tests {
use super::*;

Loading…
Cancel
Save