diff --git a/src/client/azure_openai.rs b/src/client/azure_openai.rs index 9e96692..ab8c995 100644 --- a/src/client/azure_openai.rs +++ b/src/client/azure_openai.rs @@ -4,7 +4,6 @@ use super::{AzureOpenAIClient, ExtraConfig, Model, ModelConfig, PromptType, Send use crate::utils::PromptKind; use anyhow::Result; -use async_trait::async_trait; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; @@ -17,8 +16,6 @@ pub struct AzureOpenAIConfig { pub extra: Option, } -openai_compatible_client!(AzureOpenAIClient); - impl AzureOpenAIClient { list_models_fn!(AzureOpenAIConfig); config_get_fn!(api_base, get_api_base); @@ -55,3 +52,9 @@ impl AzureOpenAIClient { Ok(builder) } } + +impl_client_trait!( + AzureOpenAIClient, + crate::client::openai::openai_send_message, + crate::client::openai::openai_send_message_streaming +); diff --git a/src/client/claude.rs b/src/client/claude.rs index 1094348..d3fc68d 100644 --- a/src/client/claude.rs +++ b/src/client/claude.rs @@ -1,12 +1,11 @@ use super::{ - catch_error, extract_system_message, ClaudeClient, Client, ExtraConfig, ImageUrl, - MessageContent, MessageContentPart, Model, ModelConfig, PromptType, ReplyHandler, SendData, + catch_error, extract_system_message, ClaudeClient, ExtraConfig, ImageUrl, MessageContent, + MessageContentPart, Model, ModelConfig, PromptType, ReplyHandler, SendData, }; use crate::utils::PromptKind; use anyhow::{anyhow, bail, Result}; -use async_trait::async_trait; use futures_util::StreamExt; use reqwest::{Client as ReqwestClient, RequestBuilder}; use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt}; @@ -24,26 +23,6 @@ pub struct ClaudeConfig { pub extra: Option, } -#[async_trait] -impl Client for ClaudeClient { - client_common_fns!(); - - async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { - let builder = self.request_builder(client, data)?; - claude_send_message(builder).await - } - - async fn send_message_streaming_inner( - &self, - client: &ReqwestClient, - handler: &mut ReplyHandler, - data: SendData, - ) -> Result<()> { - let builder = self.request_builder(client, data)?; - claude_send_message_streaming(builder, handler).await - } -} - impl ClaudeClient { list_models_fn!( ClaudeConfig, @@ -79,6 +58,12 @@ impl ClaudeClient { } } +impl_client_trait!( + ClaudeClient, + claude_send_message, + claude_send_message_streaming +); + pub async fn claude_send_message(builder: RequestBuilder) -> Result { let res = builder.send().await?; let status = res.status(); diff --git a/src/client/cohere.rs b/src/client/cohere.rs index 27acb6c..3f97ece 100644 --- a/src/client/cohere.rs +++ b/src/client/cohere.rs @@ -1,12 +1,11 @@ use super::{ - catch_error, extract_system_message, json_stream, message::*, Client, CohereClient, + catch_error, extract_system_message, json_stream, message::*, CohereClient, ExtraConfig, Model, ModelConfig, PromptType, ReplyHandler, SendData, }; use crate::utils::PromptKind; use anyhow::{bail, Result}; -use async_trait::async_trait; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; use serde_json::{json, Value}; @@ -22,26 +21,6 @@ pub struct CohereConfig { pub extra: Option, } -#[async_trait] -impl Client for CohereClient { - client_common_fns!(); - - async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { - let builder = self.request_builder(client, data)?; - send_message(builder).await - } - - async fn send_message_streaming_inner( - &self, - client: &ReqwestClient, - handler: &mut ReplyHandler, - data: SendData, - ) -> Result<()> { - let builder = self.request_builder(client, data)?; - send_message_streaming(builder, handler).await - } -} - impl CohereClient { list_models_fn!( CohereConfig, @@ -74,7 +53,9 @@ impl CohereClient { } } -pub(crate) async fn send_message(builder: RequestBuilder) -> Result { +impl_client_trait!(CohereClient, send_message, send_message_streaming); + +async fn send_message(builder: RequestBuilder) -> Result { let res = builder.send().await?; let status = res.status(); let data: Value = res.json().await?; @@ -85,10 +66,7 @@ pub(crate) async fn send_message(builder: RequestBuilder) -> Result { Ok(output.to_string()) } -pub(crate) async fn send_message_streaming( - builder: RequestBuilder, - handler: &mut ReplyHandler, -) -> Result<()> { +async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> { let res = builder.send().await?; let status = res.status(); if status != 200 { diff --git a/src/client/common.rs b/src/client/common.rs index 55e7b62..4509805 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -139,7 +139,6 @@ macro_rules! openai_compatible_module { use $crate::utils::PromptKind; use anyhow::Result; - use async_trait::async_trait; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; @@ -154,7 +153,12 @@ macro_rules! openai_compatible_module { pub extra: Option, } - openai_compatible_client!($client); + impl_client_trait!( + $client, + $crate::client::openai::openai_send_message, + $crate::client::openai::openai_send_message_streaming + ); + impl $client { list_models_fn!( @@ -218,9 +222,9 @@ macro_rules! client_common_fns { } #[macro_export] -macro_rules! openai_compatible_client { - ($client:ident) => { - #[async_trait] +macro_rules! impl_client_trait { + ($client:ident, $send_message:path, $send_message_streaming:path) => { + #[async_trait::async_trait] impl $crate::client::Client for $crate::client::$client { client_common_fns!(); @@ -230,7 +234,7 @@ macro_rules! openai_compatible_client { data: $crate::client::SendData, ) -> anyhow::Result { let builder = self.request_builder(client, data)?; - $crate::client::openai::openai_send_message(builder).await + $send_message(builder).await } async fn send_message_streaming_inner( @@ -240,7 +244,7 @@ macro_rules! openai_compatible_client { data: $crate::client::SendData, ) -> Result<()> { let builder = self.request_builder(client, data)?; - $crate::client::openai::openai_send_message_streaming(builder, handler).await + $send_message_streaming(builder, handler).await } } }; diff --git a/src/client/ernie.rs b/src/client/ernie.rs index 22f9362..5cc546c 100644 --- a/src/client/ernie.rs +++ b/src/client/ernie.rs @@ -30,28 +30,6 @@ pub struct ErnieConfig { pub extra: Option, } -#[async_trait] -impl Client for ErnieClient { - client_common_fns!(); - - async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { - self.prepare_access_token().await?; - let builder = self.request_builder(client, data)?; - send_message(builder).await - } - - async fn send_message_streaming_inner( - &self, - client: &ReqwestClient, - handler: &mut ReplyHandler, - data: SendData, - ) -> Result<()> { - self.prepare_access_token().await?; - let builder = self.request_builder(client, data)?; - send_message_streaming(builder, handler).await - } -} - impl ErnieClient { list_models_fn!( ErnieConfig, @@ -118,6 +96,28 @@ impl ErnieClient { } } +#[async_trait] +impl Client for ErnieClient { + client_common_fns!(); + + async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { + self.prepare_access_token().await?; + let builder = self.request_builder(client, data)?; + send_message(builder).await + } + + async fn send_message_streaming_inner( + &self, + client: &ReqwestClient, + handler: &mut ReplyHandler, + data: SendData, + ) -> Result<()> { + self.prepare_access_token().await?; + let builder = self.request_builder(client, data)?; + send_message_streaming(builder, handler).await + } +} + async fn send_message(builder: RequestBuilder) -> Result { let data: Value = builder.send().await?.json().await?; maybe_catch_error(&data)?; diff --git a/src/client/gemini.rs b/src/client/gemini.rs index 05e19a3..4b6f053 100644 --- a/src/client/gemini.rs +++ b/src/client/gemini.rs @@ -1,12 +1,9 @@ -use super::vertexai::{gemini_build_body, gemini_send_message, gemini_send_message_streaming}; -use super::{ - Client, ExtraConfig, GeminiClient, Model, ModelConfig, PromptType, ReplyHandler, SendData, -}; +use super::vertexai::gemini_build_body; +use super::{ExtraConfig, GeminiClient, Model, ModelConfig, PromptType, SendData}; use crate::utils::PromptKind; use anyhow::Result; -use async_trait::async_trait; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; @@ -22,26 +19,6 @@ pub struct GeminiConfig { pub extra: Option, } -#[async_trait] -impl Client for GeminiClient { - client_common_fns!(); - - async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { - let builder = self.request_builder(client, data)?; - gemini_send_message(builder).await - } - - async fn send_message_streaming_inner( - &self, - client: &ReqwestClient, - handler: &mut ReplyHandler, - data: SendData, - ) -> Result<()> { - let builder = self.request_builder(client, data)?; - gemini_send_message_streaming(builder, handler).await - } -} - impl GeminiClient { list_models_fn!( GeminiConfig, @@ -80,3 +57,9 @@ impl GeminiClient { Ok(builder) } } + +impl_client_trait!( + GeminiClient, + crate::client::vertexai::gemini_send_message, + crate::client::vertexai::gemini_send_message_streaming +); diff --git a/src/client/ollama.rs b/src/client/ollama.rs index f2340c6..a801242 100644 --- a/src/client/ollama.rs +++ b/src/client/ollama.rs @@ -1,12 +1,11 @@ use super::{ - catch_error, message::*, Client, ExtraConfig, Model, ModelConfig, OllamaClient, PromptType, + catch_error, message::*, ExtraConfig, Model, ModelConfig, OllamaClient, PromptType, ReplyHandler, SendData, }; use crate::utils::PromptKind; use anyhow::{anyhow, bail, Result}; -use async_trait::async_trait; use futures_util::StreamExt; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; @@ -22,26 +21,6 @@ pub struct OllamaConfig { pub extra: Option, } -#[async_trait] -impl Client for OllamaClient { - client_common_fns!(); - - async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { - let builder = self.request_builder(client, data)?; - send_message(builder).await - } - - async fn send_message_streaming_inner( - &self, - client: &ReqwestClient, - handler: &mut ReplyHandler, - data: SendData, - ) -> Result<()> { - let builder = self.request_builder(client, data)?; - send_message_streaming(builder, handler).await - } -} - impl OllamaClient { list_models_fn!(OllamaConfig); config_get_fn!(api_key, get_api_key); @@ -79,6 +58,8 @@ impl OllamaClient { } } +impl_client_trait!(OllamaClient, send_message, send_message_streaming); + async fn send_message(builder: RequestBuilder) -> Result { let res = builder.send().await?; let status = res.status(); diff --git a/src/client/openai.rs b/src/client/openai.rs index bb9c07c..a0abc25 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -5,7 +5,6 @@ use super::{ use crate::utils::PromptKind; use anyhow::{anyhow, bail, Result}; -use async_trait::async_trait; use futures_util::StreamExt; use reqwest::{Client as ReqwestClient, RequestBuilder}; use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt}; @@ -25,8 +24,6 @@ pub struct OpenAIConfig { pub extra: Option, } -openai_compatible_client!(OpenAIClient); - impl OpenAIClient { list_models_fn!( OpenAIConfig, @@ -159,3 +156,9 @@ pub fn openai_build_body(data: SendData, model: &Model) -> Value { } body } + +impl_client_trait!( + OpenAIClient, + openai_send_message, + openai_send_message_streaming +); diff --git a/src/client/openai_compatible.rs b/src/client/openai_compatible.rs index b4cb5f4..1626b17 100644 --- a/src/client/openai_compatible.rs +++ b/src/client/openai_compatible.rs @@ -4,7 +4,6 @@ use super::{ExtraConfig, Model, ModelConfig, OpenAICompatibleClient, PromptType, use crate::utils::PromptKind; use anyhow::Result; -use async_trait::async_trait; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; @@ -18,8 +17,6 @@ pub struct OpenAICompatibleConfig { pub extra: Option, } -openai_compatible_client!(OpenAICompatibleClient); - impl OpenAICompatibleClient { list_models_fn!(OpenAICompatibleConfig); config_get_fn!(api_key, get_api_key); @@ -61,3 +58,9 @@ impl OpenAICompatibleClient { Ok(builder) } } + +impl_client_trait!( + OpenAICompatibleClient, + crate::client::openai::openai_send_message, + crate::client::openai::openai_send_message_streaming +); diff --git a/src/client/qianwen.rs b/src/client/qianwen.rs index 68e3666..386716f 100644 --- a/src/client/qianwen.rs +++ b/src/client/qianwen.rs @@ -33,34 +33,6 @@ pub struct QianwenConfig { pub extra: Option, } -#[async_trait] -impl Client for QianwenClient { - client_common_fns!(); - - async fn send_message_inner( - &self, - client: &ReqwestClient, - mut data: SendData, - ) -> Result { - let api_key = self.get_api_key()?; - patch_messages(&self.model.name, &api_key, &mut data.messages).await?; - let builder = self.request_builder(client, data)?; - send_message(builder, self.is_vl()).await - } - - async fn send_message_streaming_inner( - &self, - client: &ReqwestClient, - handler: &mut ReplyHandler, - mut data: SendData, - ) -> Result<()> { - let api_key = self.get_api_key()?; - patch_messages(&self.model.name, &api_key, &mut data.messages).await?; - let builder = self.request_builder(client, data)?; - send_message_streaming(builder, handler, self.is_vl()).await - } -} - impl QianwenClient { list_models_fn!( QianwenConfig, @@ -324,3 +296,31 @@ async fn upload(model: &str, api_key: &str, url: &str) -> Result { } Ok(format!("oss://{key}")) } + +#[async_trait] +impl Client for QianwenClient { + client_common_fns!(); + + async fn send_message_inner( + &self, + client: &ReqwestClient, + mut data: SendData, + ) -> Result { + let api_key = self.get_api_key()?; + patch_messages(&self.model.name, &api_key, &mut data.messages).await?; + let builder = self.request_builder(client, data)?; + send_message(builder, self.is_vl()).await + } + + async fn send_message_streaming_inner( + &self, + client: &ReqwestClient, + handler: &mut ReplyHandler, + mut data: SendData, + ) -> Result<()> { + let api_key = self.get_api_key()?; + patch_messages(&self.model.name, &api_key, &mut data.messages).await?; + let builder = self.request_builder(client, data)?; + send_message_streaming(builder, handler, self.is_vl()).await + } +} diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index 0c0608e..566067f 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -26,28 +26,6 @@ pub struct VertexAIConfig { pub extra: Option, } -#[async_trait] -impl Client for VertexAIClient { - client_common_fns!(); - - async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { - self.prepare_access_token().await?; - let builder = self.request_builder(client, data)?; - gemini_send_message(builder).await - } - - async fn send_message_streaming_inner( - &self, - client: &ReqwestClient, - handler: &mut ReplyHandler, - data: SendData, - ) -> Result<()> { - self.prepare_access_token().await?; - let builder = self.request_builder(client, data)?; - gemini_send_message_streaming(builder, handler).await - } -} - impl VertexAIClient { list_models_fn!( VertexAIConfig, @@ -100,6 +78,28 @@ impl VertexAIClient { } } +#[async_trait] +impl Client for VertexAIClient { + client_common_fns!(); + + async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { + self.prepare_access_token().await?; + let builder = self.request_builder(client, data)?; + gemini_send_message(builder).await + } + + async fn send_message_streaming_inner( + &self, + client: &ReqwestClient, + handler: &mut ReplyHandler, + data: SendData, + ) -> Result<()> { + self.prepare_access_token().await?; + let builder = self.request_builder(client, data)?; + gemini_send_message_streaming(builder, handler).await + } +} + pub async fn gemini_send_message(builder: RequestBuilder) -> Result { let res = builder.send().await?; let status = res.status();