use super::{openai::OpenAIConfig, ClientConfig, Message, Model, ReplyHandler}; use crate::{ config::{GlobalConfig, Input}, render::{render_error, render_stream}, utils::{prompt_input_integer, prompt_input_string, tokenize, AbortSignal, PromptKind}, }; use anyhow::{bail, Context, Result}; use async_trait::async_trait; use futures_util::{Stream, StreamExt}; use reqwest::{Client as ReqwestClient, ClientBuilder, Proxy, RequestBuilder}; use serde::Deserialize; use serde_json::{json, Value}; use std::{env, future::Future, time::Duration}; use tokio::{sync::mpsc::unbounded_channel, time::sleep}; #[macro_export] macro_rules! register_client { ( $(($module:ident, $name:literal, $config:ident, $client:ident),)+ ) => { $( mod $module; )+ $( use self::$module::$config; )+ #[derive(Debug, Clone, serde::Deserialize)] #[serde(tag = "type")] pub enum ClientConfig { $( #[serde(rename = $name)] $config($config), )+ #[serde(other)] Unknown, } $( #[derive(Debug)] pub struct $client { global_config: $crate::config::GlobalConfig, config: $config, model: $crate::client::Model, } impl $client { pub const NAME: &'static str = $name; pub fn init(global_config: &$crate::config::GlobalConfig) -> Option> { let model = global_config.read().model.clone(); let config = global_config.read().clients.iter().find_map(|client_config| { if let ClientConfig::$config(c) = client_config { if Self::name(c) == &model.client_name { return Some(c.clone()) } } None })?; Some(Box::new(Self { global_config: global_config.clone(), config, model, })) } pub fn name(config: &$config) -> &str { config.name.as_deref().unwrap_or(Self::NAME) } } )+ pub fn init_client(config: &$crate::config::GlobalConfig) -> anyhow::Result> { None $(.or_else(|| $client::init(config)))+ .ok_or_else(|| { let model = config.read().model.clone(); anyhow::anyhow!("Unknown client '{}'", &model.client_name) }) } pub fn ensure_model_capabilities(client: &mut dyn Client, capabilities: $crate::client::ModelCapabilities) -> anyhow::Result<()> { if !client.model().capabilities.contains(capabilities) { let models = client.list_models(); if let Some(model) = models.into_iter().find(|v| v.capabilities.contains(capabilities)) { client.set_model(model); } else { anyhow::bail!( "The current model lacks the corresponding capability." ); } } Ok(()) } pub fn list_client_types() -> Vec<&'static str> { vec![$($client::NAME,)+] } pub fn create_client_config(client: &str) -> anyhow::Result<(String, serde_json::Value)> { $( if client == $client::NAME { return create_config(&$client::PROMPTS, $client::NAME) } )+ anyhow::bail!("Unknown client '{}'", client) } pub fn list_models(config: &$crate::config::Config) -> Vec<$crate::client::Model> { config .clients .iter() .flat_map(|v| match v { $(ClientConfig::$config(c) => $client::list_models(c),)+ ClientConfig::Unknown => vec![], }) .collect() } }; } #[macro_export] macro_rules! openai_compatible_client { ( $config:ident, $client:ident, $api_base:literal, [$(($name:literal, $capabilities:literal, $max_input_tokens:literal $(, $max_output_tokens:literal)? )),+$(,)?] ) => { use $crate::client::openai::openai_build_body; use $crate::client::{ExtraConfig, $client, Model, ModelConfig, PromptType, SendData}; use $crate::utils::PromptKind; use anyhow::Result; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; const API_BASE: &str = $api_base; #[derive(Debug, Clone, Deserialize)] pub struct $config { pub name: Option, pub api_key: Option, #[serde(default)] pub models: Vec, pub extra: Option, } impl_client_trait!( $client, $crate::client::openai::openai_send_message, $crate::client::openai::openai_send_message_streaming ); impl $client { list_models_fn!( $config, [ $( ($name, $capabilities, $max_input_tokens $(, $max_output_tokens)?), )+ ] ); config_get_fn!(api_key, get_api_key); pub const PROMPTS: [PromptType<'static>; 1] = [("api_key", "API Key:", false, PromptKind::String)]; fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_key = self.get_api_key().ok(); let body = openai_build_body(data, &self.model); let url = format!("{API_BASE}/chat/completions"); debug!("Request: {url} {body}"); let mut builder = client.post(url).json(&body); if let Some(api_key) = api_key { builder = builder.bearer_auth(api_key); } Ok(builder) } } } } #[macro_export] macro_rules! client_common_fns { () => { fn config( &self, ) -> ( &$crate::config::GlobalConfig, &Option<$crate::client::ExtraConfig>, ) { (&self.global_config, &self.config.extra) } fn list_models(&self) -> Vec { Self::list_models(&self.config) } fn model(&self) -> &Model { &self.model } fn set_model(&mut self, model: Model) { self.model = model; } }; } #[macro_export] macro_rules! impl_client_trait { ($client:ident, $send_message:path, $send_message_streaming:path) => { #[async_trait::async_trait] impl $crate::client::Client for $crate::client::$client { client_common_fns!(); async fn send_message_inner( &self, client: &reqwest::Client, data: $crate::client::SendData, ) -> anyhow::Result { let builder = self.request_builder(client, data)?; $send_message(builder).await } async fn send_message_streaming_inner( &self, client: &reqwest::Client, handler: &mut $crate::client::ReplyHandler, data: $crate::client::SendData, ) -> Result<()> { let builder = self.request_builder(client, data)?; $send_message_streaming(builder, handler).await } } }; } #[macro_export] macro_rules! config_get_fn { ($field_name:ident, $fn_name:ident) => { fn $fn_name(&self) -> anyhow::Result { let api_key = self.config.$field_name.clone(); api_key .or_else(|| { let env_prefix = Self::name(&self.config); let env_name = format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase(); std::env::var(&env_name).ok() }) .ok_or_else(|| { anyhow::anyhow!("Miss '{}' in client configuration", stringify!($field_name)) }) } }; } #[macro_export] macro_rules! list_models_fn { ($config:ident) => { pub fn list_models(local_config: &$config) -> Vec { let client_name = Self::name(local_config); Model::from_config(client_name, &local_config.models) } }; ($config:ident, [$(($name:literal, $capabilities:literal, $max_input_tokens:literal $(, $max_output_tokens:literal)? )),+$(,)?]) => { pub fn list_models(local_config: &$config) -> Vec { let client_name = Self::name(local_config); if local_config.models.is_empty() { vec![ $( Model::new(client_name, $name) .set_capabilities($capabilities.into()) .set_max_input_tokens(Some($max_input_tokens)) $(.set_max_output_tokens(Some($max_output_tokens)))? ),+ ] } else { Model::from_config(client_name, &local_config.models) } } }; } #[macro_export] macro_rules! unsupported_model { ($name:expr) => { anyhow::bail!("Unsupported model '{}'", $name) }; } #[async_trait] pub trait Client: Sync + Send { fn config(&self) -> (&GlobalConfig, &Option); fn list_models(&self) -> Vec; fn model(&self) -> &Model; fn set_model(&mut self, model: Model); fn build_client(&self) -> Result { let mut builder = ReqwestClient::builder(); let options = self.config().1; let timeout = options .as_ref() .and_then(|v| v.connect_timeout) .unwrap_or(10); let proxy = options.as_ref().and_then(|v| v.proxy.clone()); builder = set_proxy(builder, &proxy)?; let client = builder .connect_timeout(Duration::from_secs(timeout)) .build() .with_context(|| "Failed to build client")?; Ok(client) } async fn send_message(&self, input: Input) -> Result { let global_config = self.config().0; if global_config.read().dry_run { let content = global_config.read().echo_messages(&input); return Ok(content); } let client = self.build_client()?; let data = global_config.read().prepare_send_data(&input, false)?; self.send_message_inner(&client, data) .await .with_context(|| "Failed to get answer") } async fn send_message_streaming( &self, input: &Input, handler: &mut ReplyHandler, ) -> Result<()> { async fn watch_abort(abort: AbortSignal) { loop { if abort.aborted() { break; } sleep(Duration::from_millis(100)).await; } } let abort = handler.get_abort(); let input = input.clone(); tokio::select! { ret = async { let global_config = self.config().0; if global_config.read().dry_run { let content = global_config.read().echo_messages(&input); let tokens = tokenize(&content); for token in tokens { tokio::time::sleep(Duration::from_millis(10)).await; handler.text(&token)?; } return Ok(()); } let client = self.build_client()?; let data = global_config.read().prepare_send_data(&input, true)?; self.send_message_streaming_inner(&client, handler, data).await } => { handler.done()?; ret.with_context(|| "Failed to get answer") } _ = watch_abort(abort.clone()) => { handler.done()?; Ok(()) }, } } async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result; async fn send_message_streaming_inner( &self, client: &ReqwestClient, handler: &mut ReplyHandler, data: SendData, ) -> Result<()>; } impl Default for ClientConfig { fn default() -> Self { Self::OpenAIConfig(OpenAIConfig::default()) } } #[derive(Debug, Clone, Deserialize, Default)] pub struct ExtraConfig { pub proxy: Option, pub connect_timeout: Option, } #[derive(Debug)] pub struct SendData { pub messages: Vec, pub temperature: Option, pub top_p: Option, pub stream: bool, } pub type PromptType<'a> = (&'a str, &'a str, bool, PromptKind); pub fn create_config(list: &[PromptType], client: &str) -> Result<(String, Value)> { let mut config = json!({ "type": client, }); let mut model = client.to_string(); for (path, desc, required, kind) in list { match kind { PromptKind::String => { let value = prompt_input_string(desc, *required)?; set_config_value(&mut config, path, kind, &value); if *path == "name" { model = value; } } PromptKind::Integer => { let value = prompt_input_integer(desc, *required)?; set_config_value(&mut config, path, kind, &value); } } } let clients = json!(vec![config]); Ok((model, clients)) } pub async fn send_stream( input: &Input, client: &dyn Client, config: &GlobalConfig, abort: AbortSignal, ) -> Result { let (tx, rx) = unbounded_channel(); let mut stream_handler = ReplyHandler::new(tx, abort.clone()); let (send_ret, rend_ret) = tokio::join!( client.send_message_streaming(input, &mut stream_handler), render_stream(rx, config, abort.clone()), ); if let Err(err) = rend_ret { render_error(err, config.read().highlight); } let output = stream_handler.get_buffer().to_string(); match send_ret { Ok(_) => { println!(); Ok(output) } Err(err) => { if !output.is_empty() { println!(); } Err(err) } } } #[allow(unused)] pub async fn send_message_as_streaming( builder: RequestBuilder, handler: &mut ReplyHandler, f: F, ) -> Result<()> where F: FnOnce(RequestBuilder) -> Fut, Fut: Future>, { let text = f(builder).await?; handler.text(&text)?; handler.done()?; Ok(()) } pub fn catch_error(data: &Value, status: u16) -> Result<()> { if (200..300).contains(&status) { return Ok(()); } debug!("Invalid response, status: {status}, data: {data}"); if let Some(error) = data["error"].as_object() { if let (Some(typ), Some(message)) = (error["type"].as_str(), error["message"].as_str()) { bail!("{message} (type: {typ})"); } } else if let Some(error) = data[0]["error"].as_object() { if let (Some(status), Some(message)) = (error["status"].as_str(), error["message"].as_str()) { bail!("{message} (status: {status})") } } else if let Some(error) = data["error"].as_str() { bail!("{error}"); } else if let Some(message) = data["message"].as_str() { bail!("{message}"); } bail!("Invalid response data: {data} (status: {status})"); } pub fn maybe_catch_error(data: &Value) -> Result<()> { if let (Some(code), Some(message)) = (data["code"].as_str(), data["message"].as_str()) { debug!("Invalid response: {}", data); bail!("{message} (code: {code})"); } else if let (Some(error_code), Some(error_msg)) = (data["error_code"].as_number(), data["error_msg"].as_str()) { debug!("Invalid response: {}", data); bail!("{error_msg} (error_code: {error_code})"); } Ok(()) } pub async fn json_stream(mut stream: S, mut handle: F) -> Result<()> where S: Stream> + Unpin, F: FnMut(&str) -> Result<()>, { let mut buffer = vec![]; let mut cursor = 0; let mut start = 0; let mut balances = vec![]; let mut quoting = false; let mut escape = false; while let Some(chunk) = stream.next().await { let chunk = chunk?; let chunk = std::str::from_utf8(&chunk)?; buffer.extend(chunk.chars()); for i in cursor..buffer.len() { let ch = buffer[i]; if quoting { if ch == '\\' { escape = !escape; } else { if !escape && ch == '"' { quoting = false; } escape = false; } continue; } match ch { '"' => { quoting = true; escape = false; } '{' => { if balances.is_empty() { start = i; } balances.push(ch); } '[' => { if start != 0 { balances.push(ch); } } '}' => { balances.pop(); if balances.is_empty() { let value: String = buffer[start..=i].iter().collect(); handle(&value)?; } } ']' => { balances.pop(); } _ => {} } } cursor = buffer.len(); } Ok(()) } fn set_config_value(json: &mut Value, path: &str, kind: &PromptKind, value: &str) { let segs: Vec<&str> = path.split('.').collect(); match segs.as_slice() { [name] => json[name] = to_json(kind, value), [scope, name] => match scope.split_once('[') { None => { if json.get(scope).is_none() { let mut obj = json!({}); obj[name] = to_json(kind, value); json[scope] = obj; } else { json[scope][name] = to_json(kind, value); } } Some((scope, _)) => { if json.get(scope).is_none() { let mut obj = json!({}); obj[name] = to_json(kind, value); json[scope] = json!([obj]); } else { json[scope][0][name] = to_json(kind, value); } } }, _ => {} } } fn to_json(kind: &PromptKind, value: &str) -> Value { if value.is_empty() { return Value::Null; } match kind { PromptKind::String => value.into(), PromptKind::Integer => match value.parse::() { Ok(value) => value.into(), Err(_) => value.into(), }, } } fn set_proxy(builder: ClientBuilder, proxy: &Option) -> Result { let proxy = if let Some(proxy) = proxy { if proxy.is_empty() || proxy == "false" || proxy == "-" { return Ok(builder); } proxy.clone() } else if let Ok(proxy) = env::var("HTTPS_PROXY").or_else(|_| env::var("ALL_PROXY")) { proxy } else { return Ok(builder); }; let builder = builder.proxy(Proxy::all(&proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?); Ok(builder) }