refactor: rename some structs (#457)

pull/458/head
sigoden 1 month ago committed by GitHub
parent 865be2bf75
commit 37a0cd08a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,7 +1,7 @@
use super::claude::{claude_build_body, claude_extract_completion}; use super::claude::{claude_build_body, claude_extract_completion};
use super::{ use super::{
catch_error, generate_prompt, BedrockClient, Client, CompletionStats, ExtraConfig, Model, catch_error, generate_prompt, BedrockClient, Client, CompletionDetails, ExtraConfig, Model,
ModelConfig, PromptFormat, PromptType, ReplyHandler, SendData, LLAMA2_PROMPT_FORMAT, ModelConfig, PromptFormat, PromptType, SendData, SseHandler, LLAMA2_PROMPT_FORMAT,
LLAMA3_PROMPT_FORMAT, LLAMA3_PROMPT_FORMAT,
}; };
@ -45,7 +45,7 @@ impl Client for BedrockClient {
&self, &self,
client: &ReqwestClient, client: &ReqwestClient,
data: SendData, data: SendData,
) -> Result<(String, CompletionStats)> { ) -> Result<(String, CompletionDetails)> {
let model_category = ModelCategory::from_str(&self.model.name)?; let model_category = ModelCategory::from_str(&self.model.name)?;
let builder = self.request_builder(client, data, &model_category)?; let builder = self.request_builder(client, data, &model_category)?;
send_message(builder, &model_category).await send_message(builder, &model_category).await
@ -54,7 +54,7 @@ impl Client for BedrockClient {
async fn send_message_streaming_inner( async fn send_message_streaming_inner(
&self, &self,
client: &ReqwestClient, client: &ReqwestClient,
handler: &mut ReplyHandler, handler: &mut SseHandler,
data: SendData, data: SendData,
) -> Result<()> { ) -> Result<()> {
let model_category = ModelCategory::from_str(&self.model.name)?; let model_category = ModelCategory::from_str(&self.model.name)?;
@ -132,7 +132,7 @@ impl BedrockClient {
async fn send_message( async fn send_message(
builder: RequestBuilder, builder: RequestBuilder,
model_category: &ModelCategory, model_category: &ModelCategory,
) -> Result<(String, CompletionStats)> { ) -> Result<(String, CompletionDetails)> {
let res = builder.send().await?; let res = builder.send().await?;
let status = res.status(); let status = res.status();
let data: Value = res.json().await?; let data: Value = res.json().await?;
@ -150,7 +150,7 @@ async fn send_message(
async fn send_message_streaming( async fn send_message_streaming(
builder: RequestBuilder, builder: RequestBuilder,
handler: &mut ReplyHandler, handler: &mut SseHandler,
model_category: &ModelCategory, model_category: &ModelCategory,
) -> Result<()> { ) -> Result<()> {
let res = builder.send().await?; let res = builder.send().await?;
@ -275,23 +275,23 @@ fn mistral_build_body(data: SendData, model: &Model) -> Result<Value> {
Ok(body) Ok(body)
} }
fn llama_extract_completion(data: &Value) -> Result<(String, CompletionStats)> { fn llama_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
let text = data["generation"] let text = data["generation"]
.as_str() .as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?; .ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let stats = CompletionStats { let details = CompletionDetails {
id: None, id: None,
input_tokens: data["prompt_token_count"].as_u64(), input_tokens: data["prompt_token_count"].as_u64(),
output_tokens: data["generation_token_count"].as_u64(), output_tokens: data["generation_token_count"].as_u64(),
}; };
Ok((text.to_string(), stats)) Ok((text.to_string(), details))
} }
fn mistral_extrat_completion(data: &Value) -> Result<(String, CompletionStats)> { fn mistral_extrat_completion(data: &Value) -> Result<(String, CompletionDetails)> {
let text = data["outputs"][0]["text"] let text = data["outputs"][0]["text"]
.as_str() .as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?; .ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
Ok((text.to_string(), CompletionStats::default())) Ok((text.to_string(), CompletionDetails::default()))
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]

@ -1,6 +1,6 @@
use super::{ use super::{
catch_error, extract_system_message, ClaudeClient, CompletionStats, ExtraConfig, ImageUrl, catch_error, extract_system_message, ClaudeClient, CompletionDetails, ExtraConfig, ImageUrl,
MessageContent, MessageContentPart, Model, ModelConfig, PromptType, ReplyHandler, SendData, MessageContent, MessageContentPart, Model, ModelConfig, PromptType, SendData, SseHandler,
}; };
use crate::utils::PromptKind; use crate::utils::PromptKind;
@ -54,7 +54,7 @@ impl_client_trait!(
claude_send_message_streaming claude_send_message_streaming
); );
pub async fn claude_send_message(builder: RequestBuilder) -> Result<(String, CompletionStats)> { pub async fn claude_send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
let res = builder.send().await?; let res = builder.send().await?;
let status = res.status(); let status = res.status();
let data: Value = res.json().await?; let data: Value = res.json().await?;
@ -66,7 +66,7 @@ pub async fn claude_send_message(builder: RequestBuilder) -> Result<(String, Com
pub async fn claude_send_message_streaming( pub async fn claude_send_message_streaming(
builder: RequestBuilder, builder: RequestBuilder,
handler: &mut ReplyHandler, handler: &mut SseHandler,
) -> Result<()> { ) -> Result<()> {
let mut es = builder.eventsource()?; let mut es = builder.eventsource()?;
while let Some(event) = es.next().await { while let Some(event) = es.next().await {
@ -191,15 +191,15 @@ pub fn claude_build_body(data: SendData, model: &Model) -> Result<Value> {
Ok(body) Ok(body)
} }
pub fn claude_extract_completion(data: &Value) -> Result<(String, CompletionStats)> { pub fn claude_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
let text = data["content"][0]["text"] let text = data["content"][0]["text"]
.as_str() .as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?; .ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let stats = CompletionStats { let details = CompletionDetails {
id: data["id"].as_str().map(|v| v.to_string()), id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["input_tokens"].as_u64(), input_tokens: data["usage"]["input_tokens"].as_u64(),
output_tokens: data["usage"]["output_tokens"].as_u64(), output_tokens: data["usage"]["output_tokens"].as_u64(),
}; };
Ok((text.to_string(), stats)) Ok((text.to_string(), details))
} }

@ -1,6 +1,6 @@
use super::{ use super::{
catch_error, extract_system_message, json_stream, message::*, CohereClient, CompletionStats, catch_error, extract_system_message, json_stream, message::*, CohereClient, CompletionDetails,
ExtraConfig, Model, ModelConfig, PromptType, ReplyHandler, SendData, ExtraConfig, Model, ModelConfig, PromptType, SendData, SseHandler,
}; };
use crate::utils::PromptKind; use crate::utils::PromptKind;
@ -47,7 +47,7 @@ impl CohereClient {
impl_client_trait!(CohereClient, send_message, send_message_streaming); impl_client_trait!(CohereClient, send_message, send_message_streaming);
async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionStats)> { async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
let res = builder.send().await?; let res = builder.send().await?;
let status = res.status(); let status = res.status();
let data: Value = res.json().await?; let data: Value = res.json().await?;
@ -58,7 +58,7 @@ async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionStat
cohere_extract_completion(&data) cohere_extract_completion(&data)
} }
async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> { async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> {
let res = builder.send().await?; let res = builder.send().await?;
let status = res.status(); let status = res.status();
if status != 200 { if status != 200 {
@ -156,15 +156,15 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
Ok(body) Ok(body)
} }
fn cohere_extract_completion(data: &Value) -> Result<(String, CompletionStats)> { fn cohere_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
let text = data["text"] let text = data["text"]
.as_str() .as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?; .ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let stats = CompletionStats { let details = CompletionDetails {
id: data["generation_id"].as_str().map(|v| v.to_string()), id: data["generation_id"].as_str().map(|v| v.to_string()),
input_tokens: data["meta"]["billed_units"]["input_tokens"].as_u64(), input_tokens: data["meta"]["billed_units"]["input_tokens"].as_u64(),
output_tokens: data["meta"]["billed_units"]["output_tokens"].as_u64(), output_tokens: data["meta"]["billed_units"]["output_tokens"].as_u64(),
}; };
Ok((text.to_string(), stats)) Ok((text.to_string(), details))
} }

@ -1,4 +1,4 @@
use super::{openai::OpenAIConfig, ClientConfig, ClientModel, Message, Model, ReplyHandler}; use super::{openai::OpenAIConfig, ClientConfig, ClientModel, Message, Model, SseHandler};
use crate::{ use crate::{
config::{GlobalConfig, Input}, config::{GlobalConfig, Input},
@ -260,7 +260,7 @@ macro_rules! impl_client_trait {
&self, &self,
client: &reqwest::Client, client: &reqwest::Client,
data: $crate::client::SendData, data: $crate::client::SendData,
) -> anyhow::Result<(String, $crate::client::CompletionStats)> { ) -> anyhow::Result<(String, $crate::client::CompletionDetails)> {
let builder = self.request_builder(client, data)?; let builder = self.request_builder(client, data)?;
$send_message(builder).await $send_message(builder).await
} }
@ -268,7 +268,7 @@ macro_rules! impl_client_trait {
async fn send_message_streaming_inner( async fn send_message_streaming_inner(
&self, &self,
client: &reqwest::Client, client: &reqwest::Client,
handler: &mut $crate::client::ReplyHandler, handler: &mut $crate::client::SseHandler,
data: $crate::client::SendData, data: $crate::client::SendData,
) -> Result<()> { ) -> Result<()> {
let builder = self.request_builder(client, data)?; let builder = self.request_builder(client, data)?;
@ -330,11 +330,11 @@ pub trait Client: Sync + Send {
Ok(client) Ok(client)
} }
async fn send_message(&self, input: Input) -> Result<(String, CompletionStats)> { async fn send_message(&self, input: Input) -> Result<(String, CompletionDetails)> {
let global_config = self.config().0; let global_config = self.config().0;
if global_config.read().dry_run { if global_config.read().dry_run {
let content = global_config.read().echo_messages(&input); let content = global_config.read().echo_messages(&input);
return Ok((content, CompletionStats::default())); return Ok((content, CompletionDetails::default()));
} }
let client = self.build_client()?; let client = self.build_client()?;
let data = global_config.read().prepare_send_data(&input, false)?; let data = global_config.read().prepare_send_data(&input, false)?;
@ -343,11 +343,7 @@ pub trait Client: Sync + Send {
.with_context(|| "Failed to get answer") .with_context(|| "Failed to get answer")
} }
async fn send_message_streaming( async fn send_message_streaming(&self, input: &Input, handler: &mut SseHandler) -> Result<()> {
&self,
input: &Input,
handler: &mut ReplyHandler,
) -> Result<()> {
async fn watch_abort(abort: AbortSignal) { async fn watch_abort(abort: AbortSignal) {
loop { loop {
if abort.aborted() { if abort.aborted() {
@ -388,12 +384,12 @@ pub trait Client: Sync + Send {
&self, &self,
client: &ReqwestClient, client: &ReqwestClient,
data: SendData, data: SendData,
) -> Result<(String, CompletionStats)>; ) -> Result<(String, CompletionDetails)>;
async fn send_message_streaming_inner( async fn send_message_streaming_inner(
&self, &self,
client: &ReqwestClient, client: &ReqwestClient,
handler: &mut ReplyHandler, handler: &mut SseHandler,
data: SendData, data: SendData,
) -> Result<()>; ) -> Result<()>;
} }
@ -419,7 +415,7 @@ pub struct SendData {
} }
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct CompletionStats { pub struct CompletionDetails {
pub id: Option<String>, pub id: Option<String>,
pub input_tokens: Option<u64>, pub input_tokens: Option<u64>,
pub output_tokens: Option<u64>, pub output_tokens: Option<u64>,
@ -459,7 +455,7 @@ pub async fn send_stream(
abort: AbortSignal, abort: AbortSignal,
) -> Result<String> { ) -> Result<String> {
let (tx, rx) = unbounded_channel(); let (tx, rx) = unbounded_channel();
let mut stream_handler = ReplyHandler::new(tx, abort.clone()); let mut stream_handler = SseHandler::new(tx, abort.clone());
let (send_ret, rend_ret) = tokio::join!( let (send_ret, rend_ret) = tokio::join!(
client.send_message_streaming(input, &mut stream_handler), client.send_message_streaming(input, &mut stream_handler),
@ -486,7 +482,7 @@ pub async fn send_stream(
#[allow(unused)] #[allow(unused)]
pub async fn send_message_as_streaming<F, Fut>( pub async fn send_message_as_streaming<F, Fut>(
builder: RequestBuilder, builder: RequestBuilder,
handler: &mut ReplyHandler, handler: &mut SseHandler,
f: F, f: F,
) -> Result<()> ) -> Result<()>
where where

@ -1,6 +1,6 @@
use super::{ use super::{
maybe_catch_error, patch_system_message, Client, CompletionStats, ErnieClient, ExtraConfig, maybe_catch_error, patch_system_message, Client, CompletionDetails, ErnieClient, ExtraConfig,
Model, ModelConfig, PromptType, ReplyHandler, SendData, Model, ModelConfig, PromptType, SendData, SseHandler,
}; };
use crate::utils::PromptKind; use crate::utils::PromptKind;
@ -83,7 +83,7 @@ impl Client for ErnieClient {
&self, &self,
client: &ReqwestClient, client: &ReqwestClient,
data: SendData, data: SendData,
) -> Result<(String, CompletionStats)> { ) -> Result<(String, CompletionDetails)> {
self.prepare_access_token().await?; self.prepare_access_token().await?;
let builder = self.request_builder(client, data)?; let builder = self.request_builder(client, data)?;
send_message(builder).await send_message(builder).await
@ -92,7 +92,7 @@ impl Client for ErnieClient {
async fn send_message_streaming_inner( async fn send_message_streaming_inner(
&self, &self,
client: &ReqwestClient, client: &ReqwestClient,
handler: &mut ReplyHandler, handler: &mut SseHandler,
data: SendData, data: SendData,
) -> Result<()> { ) -> Result<()> {
self.prepare_access_token().await?; self.prepare_access_token().await?;
@ -101,13 +101,13 @@ impl Client for ErnieClient {
} }
} }
async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionStats)> { async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
let data: Value = builder.send().await?.json().await?; let data: Value = builder.send().await?.json().await?;
maybe_catch_error(&data)?; maybe_catch_error(&data)?;
extract_completion_text(&data) extract_completion_text(&data)
} }
async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> { async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> {
let mut es = builder.eventsource()?; let mut es = builder.eventsource()?;
while let Some(event) = es.next().await { while let Some(event) = es.next().await {
match event { match event {
@ -184,16 +184,16 @@ fn build_body(data: SendData, model: &Model) -> Value {
body body
} }
fn extract_completion_text(data: &Value) -> Result<(String, CompletionStats)> { fn extract_completion_text(data: &Value) -> Result<(String, CompletionDetails)> {
let text = data["result"] let text = data["result"]
.as_str() .as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?; .ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let stats = CompletionStats { let details = CompletionDetails {
id: data["id"].as_str().map(|v| v.to_string()), id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["prompt_tokens"].as_u64(), input_tokens: data["usage"]["prompt_tokens"].as_u64(),
output_tokens: data["usage"]["completion_tokens"].as_u64(), output_tokens: data["usage"]["completion_tokens"].as_u64(),
}; };
Ok((text.to_string(), stats)) Ok((text.to_string(), details))
} }
async fn fetch_access_token( async fn fetch_access_token(

@ -1,5 +1 @@
openai_compatible_client!( openai_compatible_client!(GroqConfig, GroqClient, "https://api.groq.com/openai/v1",);
GroqConfig,
GroqClient,
"https://api.groq.com/openai/v1",
);

@ -1,5 +1 @@
openai_compatible_client!( openai_compatible_client!(MistralConfig, MistralClient, "https://api.mistral.ai/v1",);
MistralConfig,
MistralClient,
"https://api.mistral.ai/v1",
);

@ -3,13 +3,13 @@ mod common;
mod message; mod message;
mod model; mod model;
mod prompt_format; mod prompt_format;
mod reply_handler; mod sse_handler;
pub use common::*; pub use common::*;
pub use message::*; pub use message::*;
pub use model::*; pub use model::*;
pub use prompt_format::*; pub use prompt_format::*;
pub use reply_handler::*; pub use sse_handler::*;
register_client!( register_client!(
(openai, "openai", OpenAIConfig, OpenAIClient), (openai, "openai", OpenAIConfig, OpenAIClient),

@ -1,5 +1 @@
openai_compatible_client!( openai_compatible_client!(MoonshotConfig, MoonshotClient, "https://api.moonshot.cn/v1",);
MoonshotConfig,
MoonshotClient,
"https://api.moonshot.cn/v1",
);

@ -1,6 +1,6 @@
use super::{ use super::{
catch_error, message::*, CompletionStats, ExtraConfig, Model, ModelConfig, OllamaClient, catch_error, message::*, CompletionDetails, ExtraConfig, Model, ModelConfig, OllamaClient,
PromptType, ReplyHandler, SendData, PromptType, SendData, SseHandler,
}; };
use crate::utils::PromptKind; use crate::utils::PromptKind;
@ -59,7 +59,7 @@ impl OllamaClient {
impl_client_trait!(OllamaClient, send_message, send_message_streaming); impl_client_trait!(OllamaClient, send_message, send_message_streaming);
async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionStats)> { async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
let res = builder.send().await?; let res = builder.send().await?;
let status = res.status(); let status = res.status();
let data = res.json().await?; let data = res.json().await?;
@ -69,10 +69,10 @@ async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionStat
let text = data["message"]["content"] let text = data["message"]["content"]
.as_str() .as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?; .ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
Ok((text.to_string(), CompletionStats::default())) Ok((text.to_string(), CompletionDetails::default()))
} }
async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> { async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> {
let res = builder.send().await?; let res = builder.send().await?;
let status = res.status(); let status = res.status();
if status != 200 { if status != 200 {

@ -1,6 +1,6 @@
use super::{ use super::{
catch_error, CompletionStats, ExtraConfig, Model, ModelConfig, OpenAIClient, PromptType, catch_error, CompletionDetails, ExtraConfig, Model, ModelConfig, OpenAIClient, PromptType,
ReplyHandler, SendData, SendData, SseHandler,
}; };
use crate::utils::PromptKind; use crate::utils::PromptKind;
@ -52,7 +52,7 @@ impl OpenAIClient {
} }
} }
pub async fn openai_send_message(builder: RequestBuilder) -> Result<(String, CompletionStats)> { pub async fn openai_send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
let res = builder.send().await?; let res = builder.send().await?;
let status = res.status(); let status = res.status();
let data: Value = res.json().await?; let data: Value = res.json().await?;
@ -65,7 +65,7 @@ pub async fn openai_send_message(builder: RequestBuilder) -> Result<(String, Com
pub async fn openai_send_message_streaming( pub async fn openai_send_message_streaming(
builder: RequestBuilder, builder: RequestBuilder,
handler: &mut ReplyHandler, handler: &mut SseHandler,
) -> Result<()> { ) -> Result<()> {
let mut es = builder.eventsource()?; let mut es = builder.eventsource()?;
while let Some(event) = es.next().await { while let Some(event) = es.next().await {
@ -140,16 +140,16 @@ pub fn openai_build_body(data: SendData, model: &Model) -> Value {
body body
} }
pub fn openai_extract_completion(data: &Value) -> Result<(String, CompletionStats)> { pub fn openai_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
let text = data["choices"][0]["message"]["content"] let text = data["choices"][0]["message"]["content"]
.as_str() .as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?; .ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let stats = CompletionStats { let details = CompletionDetails {
id: data["id"].as_str().map(|v| v.to_string()), id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["prompt_tokens"].as_u64(), input_tokens: data["usage"]["prompt_tokens"].as_u64(),
output_tokens: data["usage"]["completion_tokens"].as_u64(), output_tokens: data["usage"]["completion_tokens"].as_u64(),
}; };
Ok((text.to_string(), stats)) Ok((text.to_string(), details))
} }
impl_client_trait!( impl_client_trait!(

@ -1,6 +1,6 @@
use super::{ use super::{
maybe_catch_error, message::*, Client, CompletionStats, ExtraConfig, Model, ModelConfig, maybe_catch_error, message::*, Client, CompletionDetails, ExtraConfig, Model, ModelConfig,
PromptType, QianwenClient, ReplyHandler, SendData, PromptType, QianwenClient, SendData, SseHandler,
}; };
use crate::utils::{sha256sum, PromptKind}; use crate::utils::{sha256sum, PromptKind};
@ -77,7 +77,7 @@ impl Client for QianwenClient {
&self, &self,
client: &ReqwestClient, client: &ReqwestClient,
mut data: SendData, mut data: SendData,
) -> Result<(String, CompletionStats)> { ) -> Result<(String, CompletionDetails)> {
let api_key = self.get_api_key()?; let api_key = self.get_api_key()?;
patch_messages(&self.model.name, &api_key, &mut data.messages).await?; patch_messages(&self.model.name, &api_key, &mut data.messages).await?;
let builder = self.request_builder(client, data)?; let builder = self.request_builder(client, data)?;
@ -87,7 +87,7 @@ impl Client for QianwenClient {
async fn send_message_streaming_inner( async fn send_message_streaming_inner(
&self, &self,
client: &ReqwestClient, client: &ReqwestClient,
handler: &mut ReplyHandler, handler: &mut SseHandler,
mut data: SendData, mut data: SendData,
) -> Result<()> { ) -> Result<()> {
let api_key = self.get_api_key()?; let api_key = self.get_api_key()?;
@ -97,7 +97,7 @@ impl Client for QianwenClient {
} }
} }
async fn send_message(builder: RequestBuilder, is_vl: bool) -> Result<(String, CompletionStats)> { async fn send_message(builder: RequestBuilder, is_vl: bool) -> Result<(String, CompletionDetails)> {
let data: Value = builder.send().await?.json().await?; let data: Value = builder.send().await?.json().await?;
maybe_catch_error(&data)?; maybe_catch_error(&data)?;
@ -106,7 +106,7 @@ async fn send_message(builder: RequestBuilder, is_vl: bool) -> Result<(String, C
async fn send_message_streaming( async fn send_message_streaming(
builder: RequestBuilder, builder: RequestBuilder,
handler: &mut ReplyHandler, handler: &mut SseHandler,
is_vl: bool, is_vl: bool,
) -> Result<()> { ) -> Result<()> {
let mut es = builder.eventsource()?; let mut es = builder.eventsource()?;
@ -210,7 +210,7 @@ fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool
Ok((body, has_upload)) Ok((body, has_upload))
} }
fn extract_completion_text(data: &Value, is_vl: bool) -> Result<(String, CompletionStats)> { fn extract_completion_text(data: &Value, is_vl: bool) -> Result<(String, CompletionDetails)> {
let err = || anyhow!("Invalid response data: {data}"); let err = || anyhow!("Invalid response data: {data}");
let text = if is_vl { let text = if is_vl {
data["output"]["choices"][0]["message"]["content"][0]["text"] data["output"]["choices"][0]["message"]["content"][0]["text"]
@ -219,13 +219,13 @@ fn extract_completion_text(data: &Value, is_vl: bool) -> Result<(String, Complet
} else { } else {
data["output"]["text"].as_str().ok_or_else(err)? data["output"]["text"].as_str().ok_or_else(err)?
}; };
let stats = CompletionStats { let details = CompletionDetails {
id: data["request_id"].as_str().map(|v| v.to_string()), id: data["request_id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["input_tokens"].as_u64(), input_tokens: data["usage"]["input_tokens"].as_u64(),
output_tokens: data["usage"]["output_tokens"].as_u64(), output_tokens: data["usage"]["output_tokens"].as_u64(),
}; };
Ok((text.to_string(), stats)) Ok((text.to_string(), details))
} }
/// Patch messsages, upload embedded images to oss /// Patch messsages, upload embedded images to oss

@ -3,14 +3,14 @@ use crate::utils::AbortSignal;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use tokio::sync::mpsc::UnboundedSender; use tokio::sync::mpsc::UnboundedSender;
pub struct ReplyHandler { pub struct SseHandler {
sender: UnboundedSender<ReplyEvent>, sender: UnboundedSender<SseEvent>,
buffer: String, buffer: String,
abort: AbortSignal, abort: AbortSignal,
} }
impl ReplyHandler { impl SseHandler {
pub fn new(sender: UnboundedSender<ReplyEvent>, abort: AbortSignal) -> Self { pub fn new(sender: UnboundedSender<SseEvent>, abort: AbortSignal) -> Self {
Self { Self {
sender, sender,
abort, abort,
@ -26,7 +26,7 @@ impl ReplyHandler {
self.buffer.push_str(text); self.buffer.push_str(text);
let ret = self let ret = self
.sender .sender
.send(ReplyEvent::Text(text.to_string())) .send(SseEvent::Text(text.to_string()))
.with_context(|| "Failed to send ReplyEvent:Text"); .with_context(|| "Failed to send ReplyEvent:Text");
self.safe_ret(ret)?; self.safe_ret(ret)?;
Ok(()) Ok(())
@ -36,7 +36,7 @@ impl ReplyHandler {
// debug!("ReplyDone"); // debug!("ReplyDone");
let ret = self let ret = self
.sender .sender
.send(ReplyEvent::Done) .send(SseEvent::Done)
.with_context(|| "Failed to send ReplyEvent::Done"); .with_context(|| "Failed to send ReplyEvent::Done");
self.safe_ret(ret)?; self.safe_ret(ret)?;
Ok(()) Ok(())
@ -59,7 +59,7 @@ impl ReplyHandler {
} }
#[derive(Debug)] #[derive(Debug)]
pub enum ReplyEvent { pub enum SseEvent {
Text(String), Text(String),
Done, Done,
} }

@ -1,7 +1,7 @@
use super::claude::{claude_build_body, claude_send_message, claude_send_message_streaming}; use super::claude::{claude_build_body, claude_send_message, claude_send_message_streaming};
use super::{ use super::{
catch_error, json_stream, message::*, patch_system_message, Client, CompletionStats, catch_error, json_stream, message::*, patch_system_message, Client, CompletionDetails,
ExtraConfig, Model, ModelConfig, PromptType, ReplyHandler, SendData, VertexAIClient, ExtraConfig, Model, ModelConfig, PromptType, SendData, SseHandler, VertexAIClient,
}; };
use crate::utils::PromptKind; use crate::utils::PromptKind;
@ -85,7 +85,7 @@ impl Client for VertexAIClient {
&self, &self,
client: &ReqwestClient, client: &ReqwestClient,
data: SendData, data: SendData,
) -> Result<(String, CompletionStats)> { ) -> Result<(String, CompletionDetails)> {
let model_category = ModelCategory::from_str(&self.model.name)?; let model_category = ModelCategory::from_str(&self.model.name)?;
self.prepare_access_token().await?; self.prepare_access_token().await?;
let builder = self.request_builder(client, data, &model_category)?; let builder = self.request_builder(client, data, &model_category)?;
@ -98,7 +98,7 @@ impl Client for VertexAIClient {
async fn send_message_streaming_inner( async fn send_message_streaming_inner(
&self, &self,
client: &ReqwestClient, client: &ReqwestClient,
handler: &mut ReplyHandler, handler: &mut SseHandler,
data: SendData, data: SendData,
) -> Result<()> { ) -> Result<()> {
let model_category = ModelCategory::from_str(&self.model.name)?; let model_category = ModelCategory::from_str(&self.model.name)?;
@ -111,7 +111,7 @@ impl Client for VertexAIClient {
} }
} }
pub async fn gemini_send_message(builder: RequestBuilder) -> Result<(String, CompletionStats)> { pub async fn gemini_send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
let res = builder.send().await?; let res = builder.send().await?;
let status = res.status(); let status = res.status();
let data: Value = res.json().await?; let data: Value = res.json().await?;
@ -123,7 +123,7 @@ pub async fn gemini_send_message(builder: RequestBuilder) -> Result<(String, Com
pub async fn gemini_send_message_streaming( pub async fn gemini_send_message_streaming(
builder: RequestBuilder, builder: RequestBuilder,
handler: &mut ReplyHandler, handler: &mut SseHandler,
) -> Result<()> { ) -> Result<()> {
let res = builder.send().await?; let res = builder.send().await?;
let status = res.status(); let status = res.status();
@ -141,14 +141,14 @@ pub async fn gemini_send_message_streaming(
Ok(()) Ok(())
} }
fn gemini_extract_completion_text(data: &Value) -> Result<(String, CompletionStats)> { fn gemini_extract_completion_text(data: &Value) -> Result<(String, CompletionDetails)> {
let text = gemini_extract_text(data)?; let text = gemini_extract_text(data)?;
let stats = CompletionStats { let details = CompletionDetails {
id: None, id: None,
input_tokens: data["usageMetadata"]["promptTokenCount"].as_u64(), input_tokens: data["usageMetadata"]["promptTokenCount"].as_u64(),
output_tokens: data["usageMetadata"]["candidatesTokenCount"].as_u64(), output_tokens: data["usageMetadata"]["candidatesTokenCount"].as_u64(),
}; };
Ok((text.to_string(), stats)) Ok((text.to_string(), details))
} }
fn gemini_extract_text(data: &Value) -> Result<&str> { fn gemini_extract_text(data: &Value) -> Result<&str> {

@ -5,7 +5,7 @@ pub use self::markdown::{MarkdownRender, RenderOptions};
use self::stream::{markdown_stream, raw_stream}; use self::stream::{markdown_stream, raw_stream};
use crate::utils::AbortSignal; use crate::utils::AbortSignal;
use crate::{client::ReplyEvent, config::GlobalConfig}; use crate::{client::SseEvent, config::GlobalConfig};
use anyhow::Result; use anyhow::Result;
use is_terminal::IsTerminal; use is_terminal::IsTerminal;
@ -14,7 +14,7 @@ use std::io::stdout;
use tokio::sync::mpsc::UnboundedReceiver; use tokio::sync::mpsc::UnboundedReceiver;
pub async fn render_stream( pub async fn render_stream(
rx: UnboundedReceiver<ReplyEvent>, rx: UnboundedReceiver<SseEvent>,
config: &GlobalConfig, config: &GlobalConfig,
abort: AbortSignal, abort: AbortSignal,
) -> Result<()> { ) -> Result<()> {

@ -1,4 +1,4 @@
use super::{MarkdownRender, ReplyEvent}; use super::{MarkdownRender, SseEvent};
use crate::utils::{run_spinner, AbortSignal}; use crate::utils::{run_spinner, AbortSignal};
@ -17,7 +17,7 @@ use textwrap::core::display_width;
use tokio::sync::{mpsc::UnboundedReceiver, oneshot}; use tokio::sync::{mpsc::UnboundedReceiver, oneshot};
pub async fn markdown_stream( pub async fn markdown_stream(
rx: UnboundedReceiver<ReplyEvent>, rx: UnboundedReceiver<SseEvent>,
render: &mut MarkdownRender, render: &mut MarkdownRender,
abort: &AbortSignal, abort: &AbortSignal,
) -> Result<()> { ) -> Result<()> {
@ -31,18 +31,18 @@ pub async fn markdown_stream(
ret ret
} }
pub async fn raw_stream(mut rx: UnboundedReceiver<ReplyEvent>, abort: &AbortSignal) -> Result<()> { pub async fn raw_stream(mut rx: UnboundedReceiver<SseEvent>, abort: &AbortSignal) -> Result<()> {
loop { loop {
if abort.aborted() { if abort.aborted() {
return Ok(()); return Ok(());
} }
if let Some(evt) = rx.recv().await { if let Some(evt) = rx.recv().await {
match evt { match evt {
ReplyEvent::Text(text) => { SseEvent::Text(text) => {
print!("{}", text); print!("{}", text);
stdout().flush()?; stdout().flush()?;
} }
ReplyEvent::Done => { SseEvent::Done => {
break; break;
} }
} }
@ -52,7 +52,7 @@ pub async fn raw_stream(mut rx: UnboundedReceiver<ReplyEvent>, abort: &AbortSign
} }
async fn markdown_stream_inner( async fn markdown_stream_inner(
mut rx: UnboundedReceiver<ReplyEvent>, mut rx: UnboundedReceiver<SseEvent>,
render: &mut MarkdownRender, render: &mut MarkdownRender,
abort: &AbortSignal, abort: &AbortSignal,
writer: &mut Stdout, writer: &mut Stdout,
@ -76,7 +76,7 @@ async fn markdown_stream_inner(
} }
match reply_event { match reply_event {
ReplyEvent::Text(mut text) => { SseEvent::Text(mut text) => {
// tab width hacking // tab width hacking
text = text.replace('\t', " "); text = text.replace('\t', " ");
@ -127,7 +127,7 @@ async fn markdown_stream_inner(
writer.flush()?; writer.flush()?;
} }
ReplyEvent::Done => { SseEvent::Done => {
break 'outer; break 'outer;
} }
} }
@ -156,15 +156,15 @@ async fn markdown_stream_inner(
Ok(()) Ok(())
} }
async fn gather_events(rx: &mut UnboundedReceiver<ReplyEvent>) -> Vec<ReplyEvent> { async fn gather_events(rx: &mut UnboundedReceiver<SseEvent>) -> Vec<SseEvent> {
let mut texts = vec![]; let mut texts = vec![];
let mut done = false; let mut done = false;
tokio::select! { tokio::select! {
_ = async { _ = async {
while let Some(reply_event) = rx.recv().await { while let Some(reply_event) = rx.recv().await {
match reply_event { match reply_event {
ReplyEvent::Text(v) => texts.push(v), SseEvent::Text(v) => texts.push(v),
ReplyEvent::Done => { SseEvent::Done => {
done = true; done = true;
break; break;
} }
@ -175,10 +175,10 @@ async fn gather_events(rx: &mut UnboundedReceiver<ReplyEvent>) -> Vec<ReplyEvent
}; };
let mut events = vec![]; let mut events = vec![];
if !texts.is_empty() { if !texts.is_empty() {
events.push(ReplyEvent::Text(texts.join(""))) events.push(SseEvent::Text(texts.join("")))
} }
if done { if done {
events.push(ReplyEvent::Done) events.push(SseEvent::Done)
} }
events events
} }

@ -1,7 +1,7 @@
use crate::{ use crate::{
client::{ client::{
init_client, ClientConfig, CompletionStats, Message, Model, ReplyEvent, ReplyHandler, init_client, ClientConfig, CompletionDetails, Message, Model, SendData, SseEvent,
SendData, SseHandler,
}, },
config::{Config, GlobalConfig}, config::{Config, GlobalConfig},
utils::create_abort_signal, utils::create_abort_signal,
@ -184,9 +184,9 @@ impl Server {
tokio::spawn(async move { tokio::spawn(async move {
let mut is_first = true; let mut is_first = true;
let (tx2, rx2) = unbounded_channel(); let (tx2, rx2) = unbounded_channel();
let mut handler = ReplyHandler::new(tx2, abort); let mut handler = SseHandler::new(tx2, abort);
async fn map_event( async fn map_event(
mut rx: UnboundedReceiver<ReplyEvent>, mut rx: UnboundedReceiver<SseEvent>,
tx: &UnboundedSender<ResEvent>, tx: &UnboundedSender<ResEvent>,
is_first: &mut bool, is_first: &mut bool,
) { ) {
@ -196,10 +196,10 @@ impl Server {
*is_first = false; *is_first = false;
} }
match reply_event { match reply_event {
ReplyEvent::Text(text) => { SseEvent::Text(text) => {
let _ = tx.send(ResEvent::Text(text)); let _ = tx.send(ResEvent::Text(text));
} }
ReplyEvent::Done => { SseEvent::Done => {
let _ = tx.send(ResEvent::Done); let _ = tx.send(ResEvent::Done);
} }
} }
@ -251,7 +251,7 @@ impl Server {
.body(BodyExt::boxed(StreamBody::new(stream)))?; .body(BodyExt::boxed(StreamBody::new(stream)))?;
Ok(res) Ok(res)
} else { } else {
let (content, stats) = client.send_message_inner(&http_client, send_data).await?; let (content, details) = client.send_message_inner(&http_client, send_data).await?;
let res = Response::builder() let res = Response::builder()
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.body( .body(
@ -260,7 +260,7 @@ impl Server {
&model_name, &model_name,
created, created,
&content, &content,
&stats, &details,
)) ))
.boxed(), .boxed(),
)?; )?;
@ -357,11 +357,11 @@ fn ret_non_stream(
model: &str, model: &str,
created: i64, created: i64,
content: &str, content: &str,
stats: &CompletionStats, details: &CompletionDetails,
) -> Bytes { ) -> Bytes {
let id = stats.id.as_deref().unwrap_or(id); let id = details.id.as_deref().unwrap_or(id);
let input_tokens = stats.input_tokens.unwrap_or_default(); let input_tokens = details.input_tokens.unwrap_or_default();
let output_tokens = stats.output_tokens.unwrap_or_default(); let output_tokens = details.output_tokens.unwrap_or_default();
let total_tokens = input_tokens + output_tokens; let total_tokens = input_tokens + output_tokens;
let res_body = json!({ let res_body = json!({
"id": id, "id": id,

Loading…
Cancel
Save