refactor: simplify impl client trait (#445)

pull/439/head
sigoden 4 weeks ago committed by GitHub
parent a21e1213cc
commit 740ca2413a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -4,7 +4,6 @@ use super::{AzureOpenAIClient, ExtraConfig, Model, ModelConfig, PromptType, Send
use crate::utils::PromptKind;
use anyhow::Result;
use async_trait::async_trait;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
@ -17,8 +16,6 @@ pub struct AzureOpenAIConfig {
pub extra: Option<ExtraConfig>,
}
openai_compatible_client!(AzureOpenAIClient);
impl AzureOpenAIClient {
list_models_fn!(AzureOpenAIConfig);
config_get_fn!(api_base, get_api_base);
@ -55,3 +52,9 @@ impl AzureOpenAIClient {
Ok(builder)
}
}
impl_client_trait!(
AzureOpenAIClient,
crate::client::openai::openai_send_message,
crate::client::openai::openai_send_message_streaming
);

@ -1,12 +1,11 @@
use super::{
catch_error, extract_system_message, ClaudeClient, Client, ExtraConfig, ImageUrl,
MessageContent, MessageContentPart, Model, ModelConfig, PromptType, ReplyHandler, SendData,
catch_error, extract_system_message, ClaudeClient, ExtraConfig, ImageUrl, MessageContent,
MessageContentPart, Model, ModelConfig, PromptType, ReplyHandler, SendData,
};
use crate::utils::PromptKind;
use anyhow::{anyhow, bail, Result};
use async_trait::async_trait;
use futures_util::StreamExt;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt};
@ -24,26 +23,6 @@ pub struct ClaudeConfig {
pub extra: Option<ExtraConfig>,
}
#[async_trait]
impl Client for ClaudeClient {
client_common_fns!();
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let builder = self.request_builder(client, data)?;
claude_send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyHandler,
data: SendData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
claude_send_message_streaming(builder, handler).await
}
}
impl ClaudeClient {
list_models_fn!(
ClaudeConfig,
@ -79,6 +58,12 @@ impl ClaudeClient {
}
}
impl_client_trait!(
ClaudeClient,
claude_send_message,
claude_send_message_streaming
);
pub async fn claude_send_message(builder: RequestBuilder) -> Result<String> {
let res = builder.send().await?;
let status = res.status();

@ -1,12 +1,11 @@
use super::{
catch_error, extract_system_message, json_stream, message::*, Client, CohereClient,
catch_error, extract_system_message, json_stream, message::*, CohereClient,
ExtraConfig, Model, ModelConfig, PromptType, ReplyHandler, SendData,
};
use crate::utils::PromptKind;
use anyhow::{bail, Result};
use async_trait::async_trait;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
@ -22,26 +21,6 @@ pub struct CohereConfig {
pub extra: Option<ExtraConfig>,
}
#[async_trait]
impl Client for CohereClient {
client_common_fns!();
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let builder = self.request_builder(client, data)?;
send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyHandler,
data: SendData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
send_message_streaming(builder, handler).await
}
}
impl CohereClient {
list_models_fn!(
CohereConfig,
@ -74,7 +53,9 @@ impl CohereClient {
}
}
pub(crate) async fn send_message(builder: RequestBuilder) -> Result<String> {
impl_client_trait!(CohereClient, send_message, send_message_streaming);
async fn send_message(builder: RequestBuilder) -> Result<String> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -85,10 +66,7 @@ pub(crate) async fn send_message(builder: RequestBuilder) -> Result<String> {
Ok(output.to_string())
}
pub(crate) async fn send_message_streaming(
builder: RequestBuilder,
handler: &mut ReplyHandler,
) -> Result<()> {
async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> {
let res = builder.send().await?;
let status = res.status();
if status != 200 {

@ -139,7 +139,6 @@ macro_rules! openai_compatible_module {
use $crate::utils::PromptKind;
use anyhow::Result;
use async_trait::async_trait;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
@ -154,7 +153,12 @@ macro_rules! openai_compatible_module {
pub extra: Option<ExtraConfig>,
}
openai_compatible_client!($client);
impl_client_trait!(
$client,
$crate::client::openai::openai_send_message,
$crate::client::openai::openai_send_message_streaming
);
impl $client {
list_models_fn!(
@ -218,9 +222,9 @@ macro_rules! client_common_fns {
}
#[macro_export]
macro_rules! openai_compatible_client {
($client:ident) => {
#[async_trait]
macro_rules! impl_client_trait {
($client:ident, $send_message:path, $send_message_streaming:path) => {
#[async_trait::async_trait]
impl $crate::client::Client for $crate::client::$client {
client_common_fns!();
@ -230,7 +234,7 @@ macro_rules! openai_compatible_client {
data: $crate::client::SendData,
) -> anyhow::Result<String> {
let builder = self.request_builder(client, data)?;
$crate::client::openai::openai_send_message(builder).await
$send_message(builder).await
}
async fn send_message_streaming_inner(
@ -240,7 +244,7 @@ macro_rules! openai_compatible_client {
data: $crate::client::SendData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
$crate::client::openai::openai_send_message_streaming(builder, handler).await
$send_message_streaming(builder, handler).await
}
}
};

@ -30,28 +30,6 @@ pub struct ErnieConfig {
pub extra: Option<ExtraConfig>,
}
#[async_trait]
impl Client for ErnieClient {
client_common_fns!();
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
self.prepare_access_token().await?;
let builder = self.request_builder(client, data)?;
send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyHandler,
data: SendData,
) -> Result<()> {
self.prepare_access_token().await?;
let builder = self.request_builder(client, data)?;
send_message_streaming(builder, handler).await
}
}
impl ErnieClient {
list_models_fn!(
ErnieConfig,
@ -118,6 +96,28 @@ impl ErnieClient {
}
}
#[async_trait]
impl Client for ErnieClient {
client_common_fns!();
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
self.prepare_access_token().await?;
let builder = self.request_builder(client, data)?;
send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyHandler,
data: SendData,
) -> Result<()> {
self.prepare_access_token().await?;
let builder = self.request_builder(client, data)?;
send_message_streaming(builder, handler).await
}
}
async fn send_message(builder: RequestBuilder) -> Result<String> {
let data: Value = builder.send().await?.json().await?;
maybe_catch_error(&data)?;

@ -1,12 +1,9 @@
use super::vertexai::{gemini_build_body, gemini_send_message, gemini_send_message_streaming};
use super::{
Client, ExtraConfig, GeminiClient, Model, ModelConfig, PromptType, ReplyHandler, SendData,
};
use super::vertexai::gemini_build_body;
use super::{ExtraConfig, GeminiClient, Model, ModelConfig, PromptType, SendData};
use crate::utils::PromptKind;
use anyhow::Result;
use async_trait::async_trait;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
@ -22,26 +19,6 @@ pub struct GeminiConfig {
pub extra: Option<ExtraConfig>,
}
#[async_trait]
impl Client for GeminiClient {
client_common_fns!();
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let builder = self.request_builder(client, data)?;
gemini_send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyHandler,
data: SendData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
gemini_send_message_streaming(builder, handler).await
}
}
impl GeminiClient {
list_models_fn!(
GeminiConfig,
@ -80,3 +57,9 @@ impl GeminiClient {
Ok(builder)
}
}
impl_client_trait!(
GeminiClient,
crate::client::vertexai::gemini_send_message,
crate::client::vertexai::gemini_send_message_streaming
);

@ -1,12 +1,11 @@
use super::{
catch_error, message::*, Client, ExtraConfig, Model, ModelConfig, OllamaClient, PromptType,
catch_error, message::*, ExtraConfig, Model, ModelConfig, OllamaClient, PromptType,
ReplyHandler, SendData,
};
use crate::utils::PromptKind;
use anyhow::{anyhow, bail, Result};
use async_trait::async_trait;
use futures_util::StreamExt;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
@ -22,26 +21,6 @@ pub struct OllamaConfig {
pub extra: Option<ExtraConfig>,
}
#[async_trait]
impl Client for OllamaClient {
client_common_fns!();
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let builder = self.request_builder(client, data)?;
send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyHandler,
data: SendData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
send_message_streaming(builder, handler).await
}
}
impl OllamaClient {
list_models_fn!(OllamaConfig);
config_get_fn!(api_key, get_api_key);
@ -79,6 +58,8 @@ impl OllamaClient {
}
}
impl_client_trait!(OllamaClient, send_message, send_message_streaming);
async fn send_message(builder: RequestBuilder) -> Result<String> {
let res = builder.send().await?;
let status = res.status();

@ -5,7 +5,6 @@ use super::{
use crate::utils::PromptKind;
use anyhow::{anyhow, bail, Result};
use async_trait::async_trait;
use futures_util::StreamExt;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt};
@ -25,8 +24,6 @@ pub struct OpenAIConfig {
pub extra: Option<ExtraConfig>,
}
openai_compatible_client!(OpenAIClient);
impl OpenAIClient {
list_models_fn!(
OpenAIConfig,
@ -159,3 +156,9 @@ pub fn openai_build_body(data: SendData, model: &Model) -> Value {
}
body
}
impl_client_trait!(
OpenAIClient,
openai_send_message,
openai_send_message_streaming
);

@ -4,7 +4,6 @@ use super::{ExtraConfig, Model, ModelConfig, OpenAICompatibleClient, PromptType,
use crate::utils::PromptKind;
use anyhow::Result;
use async_trait::async_trait;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
@ -18,8 +17,6 @@ pub struct OpenAICompatibleConfig {
pub extra: Option<ExtraConfig>,
}
openai_compatible_client!(OpenAICompatibleClient);
impl OpenAICompatibleClient {
list_models_fn!(OpenAICompatibleConfig);
config_get_fn!(api_key, get_api_key);
@ -61,3 +58,9 @@ impl OpenAICompatibleClient {
Ok(builder)
}
}
impl_client_trait!(
OpenAICompatibleClient,
crate::client::openai::openai_send_message,
crate::client::openai::openai_send_message_streaming
);

@ -33,34 +33,6 @@ pub struct QianwenConfig {
pub extra: Option<ExtraConfig>,
}
#[async_trait]
impl Client for QianwenClient {
client_common_fns!();
async fn send_message_inner(
&self,
client: &ReqwestClient,
mut data: SendData,
) -> Result<String> {
let api_key = self.get_api_key()?;
patch_messages(&self.model.name, &api_key, &mut data.messages).await?;
let builder = self.request_builder(client, data)?;
send_message(builder, self.is_vl()).await
}
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyHandler,
mut data: SendData,
) -> Result<()> {
let api_key = self.get_api_key()?;
patch_messages(&self.model.name, &api_key, &mut data.messages).await?;
let builder = self.request_builder(client, data)?;
send_message_streaming(builder, handler, self.is_vl()).await
}
}
impl QianwenClient {
list_models_fn!(
QianwenConfig,
@ -324,3 +296,31 @@ async fn upload(model: &str, api_key: &str, url: &str) -> Result<String> {
}
Ok(format!("oss://{key}"))
}
#[async_trait]
impl Client for QianwenClient {
client_common_fns!();
async fn send_message_inner(
&self,
client: &ReqwestClient,
mut data: SendData,
) -> Result<String> {
let api_key = self.get_api_key()?;
patch_messages(&self.model.name, &api_key, &mut data.messages).await?;
let builder = self.request_builder(client, data)?;
send_message(builder, self.is_vl()).await
}
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyHandler,
mut data: SendData,
) -> Result<()> {
let api_key = self.get_api_key()?;
patch_messages(&self.model.name, &api_key, &mut data.messages).await?;
let builder = self.request_builder(client, data)?;
send_message_streaming(builder, handler, self.is_vl()).await
}
}

@ -26,28 +26,6 @@ pub struct VertexAIConfig {
pub extra: Option<ExtraConfig>,
}
#[async_trait]
impl Client for VertexAIClient {
client_common_fns!();
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
self.prepare_access_token().await?;
let builder = self.request_builder(client, data)?;
gemini_send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyHandler,
data: SendData,
) -> Result<()> {
self.prepare_access_token().await?;
let builder = self.request_builder(client, data)?;
gemini_send_message_streaming(builder, handler).await
}
}
impl VertexAIClient {
list_models_fn!(
VertexAIConfig,
@ -100,6 +78,28 @@ impl VertexAIClient {
}
}
#[async_trait]
impl Client for VertexAIClient {
client_common_fns!();
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
self.prepare_access_token().await?;
let builder = self.request_builder(client, data)?;
gemini_send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyHandler,
data: SendData,
) -> Result<()> {
self.prepare_access_token().await?;
let builder = self.request_builder(client, data)?;
gemini_send_message_streaming(builder, handler).await
}
}
pub async fn gemini_send_message(builder: RequestBuilder) -> Result<String> {
let res = builder.send().await?;
let status = res.status();

Loading…
Cancel
Save