From 37a0cd08a92f07ef24e39bdab7c8bced5b59c146 Mon Sep 17 00:00:00 2001 From: sigoden Date: Mon, 29 Apr 2024 08:33:17 +0800 Subject: [PATCH] refactor: rename some structs (#457) --- src/client/bedrock.rs | 22 ++++++++-------- src/client/claude.rs | 14 +++++----- src/client/cohere.rs | 14 +++++----- src/client/common.rs | 26 ++++++++----------- src/client/ernie.rs | 18 ++++++------- src/client/groq.rs | 6 +---- src/client/mistral.rs | 6 +---- src/client/mod.rs | 4 +-- src/client/moonshot.rs | 6 +---- src/client/ollama.rs | 10 +++---- src/client/openai.rs | 14 +++++----- src/client/qianwen.rs | 18 ++++++------- .../{reply_handler.rs => sse_handler.rs} | 14 +++++----- src/client/vertexai.rs | 18 ++++++------- src/render/mod.rs | 4 +-- src/render/stream.rs | 26 +++++++++---------- src/serve.rs | 24 ++++++++--------- 17 files changed, 114 insertions(+), 130 deletions(-) rename src/client/{reply_handler.rs => sse_handler.rs} (81%) diff --git a/src/client/bedrock.rs b/src/client/bedrock.rs index dd94a41..6d9379d 100644 --- a/src/client/bedrock.rs +++ b/src/client/bedrock.rs @@ -1,7 +1,7 @@ use super::claude::{claude_build_body, claude_extract_completion}; use super::{ - catch_error, generate_prompt, BedrockClient, Client, CompletionStats, ExtraConfig, Model, - ModelConfig, PromptFormat, PromptType, ReplyHandler, SendData, LLAMA2_PROMPT_FORMAT, + catch_error, generate_prompt, BedrockClient, Client, CompletionDetails, ExtraConfig, Model, + ModelConfig, PromptFormat, PromptType, SendData, SseHandler, LLAMA2_PROMPT_FORMAT, LLAMA3_PROMPT_FORMAT, }; @@ -45,7 +45,7 @@ impl Client for BedrockClient { &self, client: &ReqwestClient, data: SendData, - ) -> Result<(String, CompletionStats)> { + ) -> Result<(String, CompletionDetails)> { let model_category = ModelCategory::from_str(&self.model.name)?; let builder = self.request_builder(client, data, &model_category)?; send_message(builder, &model_category).await @@ -54,7 +54,7 @@ impl Client for BedrockClient { async fn send_message_streaming_inner( &self, client: &ReqwestClient, - handler: &mut ReplyHandler, + handler: &mut SseHandler, data: SendData, ) -> Result<()> { let model_category = ModelCategory::from_str(&self.model.name)?; @@ -132,7 +132,7 @@ impl BedrockClient { async fn send_message( builder: RequestBuilder, model_category: &ModelCategory, -) -> Result<(String, CompletionStats)> { +) -> Result<(String, CompletionDetails)> { let res = builder.send().await?; let status = res.status(); let data: Value = res.json().await?; @@ -150,7 +150,7 @@ async fn send_message( async fn send_message_streaming( builder: RequestBuilder, - handler: &mut ReplyHandler, + handler: &mut SseHandler, model_category: &ModelCategory, ) -> Result<()> { let res = builder.send().await?; @@ -275,23 +275,23 @@ fn mistral_build_body(data: SendData, model: &Model) -> Result { Ok(body) } -fn llama_extract_completion(data: &Value) -> Result<(String, CompletionStats)> { +fn llama_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> { let text = data["generation"] .as_str() .ok_or_else(|| anyhow!("Invalid response data: {data}"))?; - let stats = CompletionStats { + let details = CompletionDetails { id: None, input_tokens: data["prompt_token_count"].as_u64(), output_tokens: data["generation_token_count"].as_u64(), }; - Ok((text.to_string(), stats)) + Ok((text.to_string(), details)) } -fn mistral_extrat_completion(data: &Value) -> Result<(String, CompletionStats)> { +fn mistral_extrat_completion(data: &Value) -> Result<(String, CompletionDetails)> { let text = data["outputs"][0]["text"] .as_str() .ok_or_else(|| anyhow!("Invalid response data: {data}"))?; - Ok((text.to_string(), CompletionStats::default())) + Ok((text.to_string(), CompletionDetails::default())) } #[derive(Debug, Clone, Copy, PartialEq, Eq)] diff --git a/src/client/claude.rs b/src/client/claude.rs index 8bd87ee..2b770dc 100644 --- a/src/client/claude.rs +++ b/src/client/claude.rs @@ -1,6 +1,6 @@ use super::{ - catch_error, extract_system_message, ClaudeClient, CompletionStats, ExtraConfig, ImageUrl, - MessageContent, MessageContentPart, Model, ModelConfig, PromptType, ReplyHandler, SendData, + catch_error, extract_system_message, ClaudeClient, CompletionDetails, ExtraConfig, ImageUrl, + MessageContent, MessageContentPart, Model, ModelConfig, PromptType, SendData, SseHandler, }; use crate::utils::PromptKind; @@ -54,7 +54,7 @@ impl_client_trait!( claude_send_message_streaming ); -pub async fn claude_send_message(builder: RequestBuilder) -> Result<(String, CompletionStats)> { +pub async fn claude_send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> { let res = builder.send().await?; let status = res.status(); let data: Value = res.json().await?; @@ -66,7 +66,7 @@ pub async fn claude_send_message(builder: RequestBuilder) -> Result<(String, Com pub async fn claude_send_message_streaming( builder: RequestBuilder, - handler: &mut ReplyHandler, + handler: &mut SseHandler, ) -> Result<()> { let mut es = builder.eventsource()?; while let Some(event) = es.next().await { @@ -191,15 +191,15 @@ pub fn claude_build_body(data: SendData, model: &Model) -> Result { Ok(body) } -pub fn claude_extract_completion(data: &Value) -> Result<(String, CompletionStats)> { +pub fn claude_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> { let text = data["content"][0]["text"] .as_str() .ok_or_else(|| anyhow!("Invalid response data: {data}"))?; - let stats = CompletionStats { + let details = CompletionDetails { id: data["id"].as_str().map(|v| v.to_string()), input_tokens: data["usage"]["input_tokens"].as_u64(), output_tokens: data["usage"]["output_tokens"].as_u64(), }; - Ok((text.to_string(), stats)) + Ok((text.to_string(), details)) } diff --git a/src/client/cohere.rs b/src/client/cohere.rs index 6186718..de2a11c 100644 --- a/src/client/cohere.rs +++ b/src/client/cohere.rs @@ -1,6 +1,6 @@ use super::{ - catch_error, extract_system_message, json_stream, message::*, CohereClient, CompletionStats, - ExtraConfig, Model, ModelConfig, PromptType, ReplyHandler, SendData, + catch_error, extract_system_message, json_stream, message::*, CohereClient, CompletionDetails, + ExtraConfig, Model, ModelConfig, PromptType, SendData, SseHandler, }; use crate::utils::PromptKind; @@ -47,7 +47,7 @@ impl CohereClient { impl_client_trait!(CohereClient, send_message, send_message_streaming); -async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionStats)> { +async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> { let res = builder.send().await?; let status = res.status(); let data: Value = res.json().await?; @@ -58,7 +58,7 @@ async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionStat cohere_extract_completion(&data) } -async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> { +async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> { let res = builder.send().await?; let status = res.status(); if status != 200 { @@ -156,15 +156,15 @@ fn build_body(data: SendData, model: &Model) -> Result { Ok(body) } -fn cohere_extract_completion(data: &Value) -> Result<(String, CompletionStats)> { +fn cohere_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> { let text = data["text"] .as_str() .ok_or_else(|| anyhow!("Invalid response data: {data}"))?; - let stats = CompletionStats { + let details = CompletionDetails { id: data["generation_id"].as_str().map(|v| v.to_string()), input_tokens: data["meta"]["billed_units"]["input_tokens"].as_u64(), output_tokens: data["meta"]["billed_units"]["output_tokens"].as_u64(), }; - Ok((text.to_string(), stats)) + Ok((text.to_string(), details)) } diff --git a/src/client/common.rs b/src/client/common.rs index a58c962..64e32ff 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -1,4 +1,4 @@ -use super::{openai::OpenAIConfig, ClientConfig, ClientModel, Message, Model, ReplyHandler}; +use super::{openai::OpenAIConfig, ClientConfig, ClientModel, Message, Model, SseHandler}; use crate::{ config::{GlobalConfig, Input}, @@ -260,7 +260,7 @@ macro_rules! impl_client_trait { &self, client: &reqwest::Client, data: $crate::client::SendData, - ) -> anyhow::Result<(String, $crate::client::CompletionStats)> { + ) -> anyhow::Result<(String, $crate::client::CompletionDetails)> { let builder = self.request_builder(client, data)?; $send_message(builder).await } @@ -268,7 +268,7 @@ macro_rules! impl_client_trait { async fn send_message_streaming_inner( &self, client: &reqwest::Client, - handler: &mut $crate::client::ReplyHandler, + handler: &mut $crate::client::SseHandler, data: $crate::client::SendData, ) -> Result<()> { let builder = self.request_builder(client, data)?; @@ -330,11 +330,11 @@ pub trait Client: Sync + Send { Ok(client) } - async fn send_message(&self, input: Input) -> Result<(String, CompletionStats)> { + async fn send_message(&self, input: Input) -> Result<(String, CompletionDetails)> { let global_config = self.config().0; if global_config.read().dry_run { let content = global_config.read().echo_messages(&input); - return Ok((content, CompletionStats::default())); + return Ok((content, CompletionDetails::default())); } let client = self.build_client()?; let data = global_config.read().prepare_send_data(&input, false)?; @@ -343,11 +343,7 @@ pub trait Client: Sync + Send { .with_context(|| "Failed to get answer") } - async fn send_message_streaming( - &self, - input: &Input, - handler: &mut ReplyHandler, - ) -> Result<()> { + async fn send_message_streaming(&self, input: &Input, handler: &mut SseHandler) -> Result<()> { async fn watch_abort(abort: AbortSignal) { loop { if abort.aborted() { @@ -388,12 +384,12 @@ pub trait Client: Sync + Send { &self, client: &ReqwestClient, data: SendData, - ) -> Result<(String, CompletionStats)>; + ) -> Result<(String, CompletionDetails)>; async fn send_message_streaming_inner( &self, client: &ReqwestClient, - handler: &mut ReplyHandler, + handler: &mut SseHandler, data: SendData, ) -> Result<()>; } @@ -419,7 +415,7 @@ pub struct SendData { } #[derive(Debug, Clone, Default)] -pub struct CompletionStats { +pub struct CompletionDetails { pub id: Option, pub input_tokens: Option, pub output_tokens: Option, @@ -459,7 +455,7 @@ pub async fn send_stream( abort: AbortSignal, ) -> Result { let (tx, rx) = unbounded_channel(); - let mut stream_handler = ReplyHandler::new(tx, abort.clone()); + let mut stream_handler = SseHandler::new(tx, abort.clone()); let (send_ret, rend_ret) = tokio::join!( client.send_message_streaming(input, &mut stream_handler), @@ -486,7 +482,7 @@ pub async fn send_stream( #[allow(unused)] pub async fn send_message_as_streaming( builder: RequestBuilder, - handler: &mut ReplyHandler, + handler: &mut SseHandler, f: F, ) -> Result<()> where diff --git a/src/client/ernie.rs b/src/client/ernie.rs index dc00f2f..cbfdf22 100644 --- a/src/client/ernie.rs +++ b/src/client/ernie.rs @@ -1,6 +1,6 @@ use super::{ - maybe_catch_error, patch_system_message, Client, CompletionStats, ErnieClient, ExtraConfig, - Model, ModelConfig, PromptType, ReplyHandler, SendData, + maybe_catch_error, patch_system_message, Client, CompletionDetails, ErnieClient, ExtraConfig, + Model, ModelConfig, PromptType, SendData, SseHandler, }; use crate::utils::PromptKind; @@ -83,7 +83,7 @@ impl Client for ErnieClient { &self, client: &ReqwestClient, data: SendData, - ) -> Result<(String, CompletionStats)> { + ) -> Result<(String, CompletionDetails)> { self.prepare_access_token().await?; let builder = self.request_builder(client, data)?; send_message(builder).await @@ -92,7 +92,7 @@ impl Client for ErnieClient { async fn send_message_streaming_inner( &self, client: &ReqwestClient, - handler: &mut ReplyHandler, + handler: &mut SseHandler, data: SendData, ) -> Result<()> { self.prepare_access_token().await?; @@ -101,13 +101,13 @@ impl Client for ErnieClient { } } -async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionStats)> { +async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> { let data: Value = builder.send().await?.json().await?; maybe_catch_error(&data)?; extract_completion_text(&data) } -async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> { +async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> { let mut es = builder.eventsource()?; while let Some(event) = es.next().await { match event { @@ -184,16 +184,16 @@ fn build_body(data: SendData, model: &Model) -> Value { body } -fn extract_completion_text(data: &Value) -> Result<(String, CompletionStats)> { +fn extract_completion_text(data: &Value) -> Result<(String, CompletionDetails)> { let text = data["result"] .as_str() .ok_or_else(|| anyhow!("Invalid response data: {data}"))?; - let stats = CompletionStats { + let details = CompletionDetails { id: data["id"].as_str().map(|v| v.to_string()), input_tokens: data["usage"]["prompt_tokens"].as_u64(), output_tokens: data["usage"]["completion_tokens"].as_u64(), }; - Ok((text.to_string(), stats)) + Ok((text.to_string(), details)) } async fn fetch_access_token( diff --git a/src/client/groq.rs b/src/client/groq.rs index 24abc6e..23ca33d 100644 --- a/src/client/groq.rs +++ b/src/client/groq.rs @@ -1,5 +1 @@ -openai_compatible_client!( - GroqConfig, - GroqClient, - "https://api.groq.com/openai/v1", -); +openai_compatible_client!(GroqConfig, GroqClient, "https://api.groq.com/openai/v1",); diff --git a/src/client/mistral.rs b/src/client/mistral.rs index 7279d2b..351502d 100644 --- a/src/client/mistral.rs +++ b/src/client/mistral.rs @@ -1,5 +1 @@ -openai_compatible_client!( - MistralConfig, - MistralClient, - "https://api.mistral.ai/v1", -); +openai_compatible_client!(MistralConfig, MistralClient, "https://api.mistral.ai/v1",); diff --git a/src/client/mod.rs b/src/client/mod.rs index a5aaab1..0916801 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -3,13 +3,13 @@ mod common; mod message; mod model; mod prompt_format; -mod reply_handler; +mod sse_handler; pub use common::*; pub use message::*; pub use model::*; pub use prompt_format::*; -pub use reply_handler::*; +pub use sse_handler::*; register_client!( (openai, "openai", OpenAIConfig, OpenAIClient), diff --git a/src/client/moonshot.rs b/src/client/moonshot.rs index 471b2b1..903d60f 100644 --- a/src/client/moonshot.rs +++ b/src/client/moonshot.rs @@ -1,5 +1 @@ -openai_compatible_client!( - MoonshotConfig, - MoonshotClient, - "https://api.moonshot.cn/v1", -); +openai_compatible_client!(MoonshotConfig, MoonshotClient, "https://api.moonshot.cn/v1",); diff --git a/src/client/ollama.rs b/src/client/ollama.rs index ec83cbd..ca9d05d 100644 --- a/src/client/ollama.rs +++ b/src/client/ollama.rs @@ -1,6 +1,6 @@ use super::{ - catch_error, message::*, CompletionStats, ExtraConfig, Model, ModelConfig, OllamaClient, - PromptType, ReplyHandler, SendData, + catch_error, message::*, CompletionDetails, ExtraConfig, Model, ModelConfig, OllamaClient, + PromptType, SendData, SseHandler, }; use crate::utils::PromptKind; @@ -59,7 +59,7 @@ impl OllamaClient { impl_client_trait!(OllamaClient, send_message, send_message_streaming); -async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionStats)> { +async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> { let res = builder.send().await?; let status = res.status(); let data = res.json().await?; @@ -69,10 +69,10 @@ async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionStat let text = data["message"]["content"] .as_str() .ok_or_else(|| anyhow!("Invalid response data: {data}"))?; - Ok((text.to_string(), CompletionStats::default())) + Ok((text.to_string(), CompletionDetails::default())) } -async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> { +async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> { let res = builder.send().await?; let status = res.status(); if status != 200 { diff --git a/src/client/openai.rs b/src/client/openai.rs index eba0992..4c4eae2 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -1,6 +1,6 @@ use super::{ - catch_error, CompletionStats, ExtraConfig, Model, ModelConfig, OpenAIClient, PromptType, - ReplyHandler, SendData, + catch_error, CompletionDetails, ExtraConfig, Model, ModelConfig, OpenAIClient, PromptType, + SendData, SseHandler, }; use crate::utils::PromptKind; @@ -52,7 +52,7 @@ impl OpenAIClient { } } -pub async fn openai_send_message(builder: RequestBuilder) -> Result<(String, CompletionStats)> { +pub async fn openai_send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> { let res = builder.send().await?; let status = res.status(); let data: Value = res.json().await?; @@ -65,7 +65,7 @@ pub async fn openai_send_message(builder: RequestBuilder) -> Result<(String, Com pub async fn openai_send_message_streaming( builder: RequestBuilder, - handler: &mut ReplyHandler, + handler: &mut SseHandler, ) -> Result<()> { let mut es = builder.eventsource()?; while let Some(event) = es.next().await { @@ -140,16 +140,16 @@ pub fn openai_build_body(data: SendData, model: &Model) -> Value { body } -pub fn openai_extract_completion(data: &Value) -> Result<(String, CompletionStats)> { +pub fn openai_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> { let text = data["choices"][0]["message"]["content"] .as_str() .ok_or_else(|| anyhow!("Invalid response data: {data}"))?; - let stats = CompletionStats { + let details = CompletionDetails { id: data["id"].as_str().map(|v| v.to_string()), input_tokens: data["usage"]["prompt_tokens"].as_u64(), output_tokens: data["usage"]["completion_tokens"].as_u64(), }; - Ok((text.to_string(), stats)) + Ok((text.to_string(), details)) } impl_client_trait!( diff --git a/src/client/qianwen.rs b/src/client/qianwen.rs index 8a5a6d5..1f23385 100644 --- a/src/client/qianwen.rs +++ b/src/client/qianwen.rs @@ -1,6 +1,6 @@ use super::{ - maybe_catch_error, message::*, Client, CompletionStats, ExtraConfig, Model, ModelConfig, - PromptType, QianwenClient, ReplyHandler, SendData, + maybe_catch_error, message::*, Client, CompletionDetails, ExtraConfig, Model, ModelConfig, + PromptType, QianwenClient, SendData, SseHandler, }; use crate::utils::{sha256sum, PromptKind}; @@ -77,7 +77,7 @@ impl Client for QianwenClient { &self, client: &ReqwestClient, mut data: SendData, - ) -> Result<(String, CompletionStats)> { + ) -> Result<(String, CompletionDetails)> { let api_key = self.get_api_key()?; patch_messages(&self.model.name, &api_key, &mut data.messages).await?; let builder = self.request_builder(client, data)?; @@ -87,7 +87,7 @@ impl Client for QianwenClient { async fn send_message_streaming_inner( &self, client: &ReqwestClient, - handler: &mut ReplyHandler, + handler: &mut SseHandler, mut data: SendData, ) -> Result<()> { let api_key = self.get_api_key()?; @@ -97,7 +97,7 @@ impl Client for QianwenClient { } } -async fn send_message(builder: RequestBuilder, is_vl: bool) -> Result<(String, CompletionStats)> { +async fn send_message(builder: RequestBuilder, is_vl: bool) -> Result<(String, CompletionDetails)> { let data: Value = builder.send().await?.json().await?; maybe_catch_error(&data)?; @@ -106,7 +106,7 @@ async fn send_message(builder: RequestBuilder, is_vl: bool) -> Result<(String, C async fn send_message_streaming( builder: RequestBuilder, - handler: &mut ReplyHandler, + handler: &mut SseHandler, is_vl: bool, ) -> Result<()> { let mut es = builder.eventsource()?; @@ -210,7 +210,7 @@ fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool Ok((body, has_upload)) } -fn extract_completion_text(data: &Value, is_vl: bool) -> Result<(String, CompletionStats)> { +fn extract_completion_text(data: &Value, is_vl: bool) -> Result<(String, CompletionDetails)> { let err = || anyhow!("Invalid response data: {data}"); let text = if is_vl { data["output"]["choices"][0]["message"]["content"][0]["text"] @@ -219,13 +219,13 @@ fn extract_completion_text(data: &Value, is_vl: bool) -> Result<(String, Complet } else { data["output"]["text"].as_str().ok_or_else(err)? }; - let stats = CompletionStats { + let details = CompletionDetails { id: data["request_id"].as_str().map(|v| v.to_string()), input_tokens: data["usage"]["input_tokens"].as_u64(), output_tokens: data["usage"]["output_tokens"].as_u64(), }; - Ok((text.to_string(), stats)) + Ok((text.to_string(), details)) } /// Patch messsages, upload embedded images to oss diff --git a/src/client/reply_handler.rs b/src/client/sse_handler.rs similarity index 81% rename from src/client/reply_handler.rs rename to src/client/sse_handler.rs index e024685..dbf78c2 100644 --- a/src/client/reply_handler.rs +++ b/src/client/sse_handler.rs @@ -3,14 +3,14 @@ use crate::utils::AbortSignal; use anyhow::{Context, Result}; use tokio::sync::mpsc::UnboundedSender; -pub struct ReplyHandler { - sender: UnboundedSender, +pub struct SseHandler { + sender: UnboundedSender, buffer: String, abort: AbortSignal, } -impl ReplyHandler { - pub fn new(sender: UnboundedSender, abort: AbortSignal) -> Self { +impl SseHandler { + pub fn new(sender: UnboundedSender, abort: AbortSignal) -> Self { Self { sender, abort, @@ -26,7 +26,7 @@ impl ReplyHandler { self.buffer.push_str(text); let ret = self .sender - .send(ReplyEvent::Text(text.to_string())) + .send(SseEvent::Text(text.to_string())) .with_context(|| "Failed to send ReplyEvent:Text"); self.safe_ret(ret)?; Ok(()) @@ -36,7 +36,7 @@ impl ReplyHandler { // debug!("ReplyDone"); let ret = self .sender - .send(ReplyEvent::Done) + .send(SseEvent::Done) .with_context(|| "Failed to send ReplyEvent::Done"); self.safe_ret(ret)?; Ok(()) @@ -59,7 +59,7 @@ impl ReplyHandler { } #[derive(Debug)] -pub enum ReplyEvent { +pub enum SseEvent { Text(String), Done, } diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index f4079f0..529c822 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -1,7 +1,7 @@ use super::claude::{claude_build_body, claude_send_message, claude_send_message_streaming}; use super::{ - catch_error, json_stream, message::*, patch_system_message, Client, CompletionStats, - ExtraConfig, Model, ModelConfig, PromptType, ReplyHandler, SendData, VertexAIClient, + catch_error, json_stream, message::*, patch_system_message, Client, CompletionDetails, + ExtraConfig, Model, ModelConfig, PromptType, SendData, SseHandler, VertexAIClient, }; use crate::utils::PromptKind; @@ -85,7 +85,7 @@ impl Client for VertexAIClient { &self, client: &ReqwestClient, data: SendData, - ) -> Result<(String, CompletionStats)> { + ) -> Result<(String, CompletionDetails)> { let model_category = ModelCategory::from_str(&self.model.name)?; self.prepare_access_token().await?; let builder = self.request_builder(client, data, &model_category)?; @@ -98,7 +98,7 @@ impl Client for VertexAIClient { async fn send_message_streaming_inner( &self, client: &ReqwestClient, - handler: &mut ReplyHandler, + handler: &mut SseHandler, data: SendData, ) -> Result<()> { let model_category = ModelCategory::from_str(&self.model.name)?; @@ -111,7 +111,7 @@ impl Client for VertexAIClient { } } -pub async fn gemini_send_message(builder: RequestBuilder) -> Result<(String, CompletionStats)> { +pub async fn gemini_send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> { let res = builder.send().await?; let status = res.status(); let data: Value = res.json().await?; @@ -123,7 +123,7 @@ pub async fn gemini_send_message(builder: RequestBuilder) -> Result<(String, Com pub async fn gemini_send_message_streaming( builder: RequestBuilder, - handler: &mut ReplyHandler, + handler: &mut SseHandler, ) -> Result<()> { let res = builder.send().await?; let status = res.status(); @@ -141,14 +141,14 @@ pub async fn gemini_send_message_streaming( Ok(()) } -fn gemini_extract_completion_text(data: &Value) -> Result<(String, CompletionStats)> { +fn gemini_extract_completion_text(data: &Value) -> Result<(String, CompletionDetails)> { let text = gemini_extract_text(data)?; - let stats = CompletionStats { + let details = CompletionDetails { id: None, input_tokens: data["usageMetadata"]["promptTokenCount"].as_u64(), output_tokens: data["usageMetadata"]["candidatesTokenCount"].as_u64(), }; - Ok((text.to_string(), stats)) + Ok((text.to_string(), details)) } fn gemini_extract_text(data: &Value) -> Result<&str> { diff --git a/src/render/mod.rs b/src/render/mod.rs index 146577b..9487d69 100644 --- a/src/render/mod.rs +++ b/src/render/mod.rs @@ -5,7 +5,7 @@ pub use self::markdown::{MarkdownRender, RenderOptions}; use self::stream::{markdown_stream, raw_stream}; use crate::utils::AbortSignal; -use crate::{client::ReplyEvent, config::GlobalConfig}; +use crate::{client::SseEvent, config::GlobalConfig}; use anyhow::Result; use is_terminal::IsTerminal; @@ -14,7 +14,7 @@ use std::io::stdout; use tokio::sync::mpsc::UnboundedReceiver; pub async fn render_stream( - rx: UnboundedReceiver, + rx: UnboundedReceiver, config: &GlobalConfig, abort: AbortSignal, ) -> Result<()> { diff --git a/src/render/stream.rs b/src/render/stream.rs index 6007690..f35831c 100644 --- a/src/render/stream.rs +++ b/src/render/stream.rs @@ -1,4 +1,4 @@ -use super::{MarkdownRender, ReplyEvent}; +use super::{MarkdownRender, SseEvent}; use crate::utils::{run_spinner, AbortSignal}; @@ -17,7 +17,7 @@ use textwrap::core::display_width; use tokio::sync::{mpsc::UnboundedReceiver, oneshot}; pub async fn markdown_stream( - rx: UnboundedReceiver, + rx: UnboundedReceiver, render: &mut MarkdownRender, abort: &AbortSignal, ) -> Result<()> { @@ -31,18 +31,18 @@ pub async fn markdown_stream( ret } -pub async fn raw_stream(mut rx: UnboundedReceiver, abort: &AbortSignal) -> Result<()> { +pub async fn raw_stream(mut rx: UnboundedReceiver, abort: &AbortSignal) -> Result<()> { loop { if abort.aborted() { return Ok(()); } if let Some(evt) = rx.recv().await { match evt { - ReplyEvent::Text(text) => { + SseEvent::Text(text) => { print!("{}", text); stdout().flush()?; } - ReplyEvent::Done => { + SseEvent::Done => { break; } } @@ -52,7 +52,7 @@ pub async fn raw_stream(mut rx: UnboundedReceiver, abort: &AbortSign } async fn markdown_stream_inner( - mut rx: UnboundedReceiver, + mut rx: UnboundedReceiver, render: &mut MarkdownRender, abort: &AbortSignal, writer: &mut Stdout, @@ -76,7 +76,7 @@ async fn markdown_stream_inner( } match reply_event { - ReplyEvent::Text(mut text) => { + SseEvent::Text(mut text) => { // tab width hacking text = text.replace('\t', " "); @@ -127,7 +127,7 @@ async fn markdown_stream_inner( writer.flush()?; } - ReplyEvent::Done => { + SseEvent::Done => { break 'outer; } } @@ -156,15 +156,15 @@ async fn markdown_stream_inner( Ok(()) } -async fn gather_events(rx: &mut UnboundedReceiver) -> Vec { +async fn gather_events(rx: &mut UnboundedReceiver) -> Vec { let mut texts = vec![]; let mut done = false; tokio::select! { _ = async { while let Some(reply_event) = rx.recv().await { match reply_event { - ReplyEvent::Text(v) => texts.push(v), - ReplyEvent::Done => { + SseEvent::Text(v) => texts.push(v), + SseEvent::Done => { done = true; break; } @@ -175,10 +175,10 @@ async fn gather_events(rx: &mut UnboundedReceiver) -> Vec, + mut rx: UnboundedReceiver, tx: &UnboundedSender, is_first: &mut bool, ) { @@ -196,10 +196,10 @@ impl Server { *is_first = false; } match reply_event { - ReplyEvent::Text(text) => { + SseEvent::Text(text) => { let _ = tx.send(ResEvent::Text(text)); } - ReplyEvent::Done => { + SseEvent::Done => { let _ = tx.send(ResEvent::Done); } } @@ -251,7 +251,7 @@ impl Server { .body(BodyExt::boxed(StreamBody::new(stream)))?; Ok(res) } else { - let (content, stats) = client.send_message_inner(&http_client, send_data).await?; + let (content, details) = client.send_message_inner(&http_client, send_data).await?; let res = Response::builder() .header("Content-Type", "application/json") .body( @@ -260,7 +260,7 @@ impl Server { &model_name, created, &content, - &stats, + &details, )) .boxed(), )?; @@ -357,11 +357,11 @@ fn ret_non_stream( model: &str, created: i64, content: &str, - stats: &CompletionStats, + details: &CompletionDetails, ) -> Bytes { - let id = stats.id.as_deref().unwrap_or(id); - let input_tokens = stats.input_tokens.unwrap_or_default(); - let output_tokens = stats.output_tokens.unwrap_or_default(); + let id = details.id.as_deref().unwrap_or(id); + let input_tokens = details.input_tokens.unwrap_or_default(); + let output_tokens = details.output_tokens.unwrap_or_default(); let total_tokens = input_tokens + output_tokens; let res_body = json!({ "id": id,