refactor: rename some structs (#457)

pull/458/head
sigoden 3 weeks ago committed by GitHub
parent 865be2bf75
commit 37a0cd08a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,7 +1,7 @@
use super::claude::{claude_build_body, claude_extract_completion};
use super::{
catch_error, generate_prompt, BedrockClient, Client, CompletionStats, ExtraConfig, Model,
ModelConfig, PromptFormat, PromptType, ReplyHandler, SendData, LLAMA2_PROMPT_FORMAT,
catch_error, generate_prompt, BedrockClient, Client, CompletionDetails, ExtraConfig, Model,
ModelConfig, PromptFormat, PromptType, SendData, SseHandler, LLAMA2_PROMPT_FORMAT,
LLAMA3_PROMPT_FORMAT,
};
@ -45,7 +45,7 @@ impl Client for BedrockClient {
&self,
client: &ReqwestClient,
data: SendData,
) -> Result<(String, CompletionStats)> {
) -> Result<(String, CompletionDetails)> {
let model_category = ModelCategory::from_str(&self.model.name)?;
let builder = self.request_builder(client, data, &model_category)?;
send_message(builder, &model_category).await
@ -54,7 +54,7 @@ impl Client for BedrockClient {
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyHandler,
handler: &mut SseHandler,
data: SendData,
) -> Result<()> {
let model_category = ModelCategory::from_str(&self.model.name)?;
@ -132,7 +132,7 @@ impl BedrockClient {
async fn send_message(
builder: RequestBuilder,
model_category: &ModelCategory,
) -> Result<(String, CompletionStats)> {
) -> Result<(String, CompletionDetails)> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -150,7 +150,7 @@ async fn send_message(
async fn send_message_streaming(
builder: RequestBuilder,
handler: &mut ReplyHandler,
handler: &mut SseHandler,
model_category: &ModelCategory,
) -> Result<()> {
let res = builder.send().await?;
@ -275,23 +275,23 @@ fn mistral_build_body(data: SendData, model: &Model) -> Result<Value> {
Ok(body)
}
fn llama_extract_completion(data: &Value) -> Result<(String, CompletionStats)> {
fn llama_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
let text = data["generation"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let stats = CompletionStats {
let details = CompletionDetails {
id: None,
input_tokens: data["prompt_token_count"].as_u64(),
output_tokens: data["generation_token_count"].as_u64(),
};
Ok((text.to_string(), stats))
Ok((text.to_string(), details))
}
fn mistral_extrat_completion(data: &Value) -> Result<(String, CompletionStats)> {
fn mistral_extrat_completion(data: &Value) -> Result<(String, CompletionDetails)> {
let text = data["outputs"][0]["text"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
Ok((text.to_string(), CompletionStats::default()))
Ok((text.to_string(), CompletionDetails::default()))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]

@ -1,6 +1,6 @@
use super::{
catch_error, extract_system_message, ClaudeClient, CompletionStats, ExtraConfig, ImageUrl,
MessageContent, MessageContentPart, Model, ModelConfig, PromptType, ReplyHandler, SendData,
catch_error, extract_system_message, ClaudeClient, CompletionDetails, ExtraConfig, ImageUrl,
MessageContent, MessageContentPart, Model, ModelConfig, PromptType, SendData, SseHandler,
};
use crate::utils::PromptKind;
@ -54,7 +54,7 @@ impl_client_trait!(
claude_send_message_streaming
);
pub async fn claude_send_message(builder: RequestBuilder) -> Result<(String, CompletionStats)> {
pub async fn claude_send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -66,7 +66,7 @@ pub async fn claude_send_message(builder: RequestBuilder) -> Result<(String, Com
pub async fn claude_send_message_streaming(
builder: RequestBuilder,
handler: &mut ReplyHandler,
handler: &mut SseHandler,
) -> Result<()> {
let mut es = builder.eventsource()?;
while let Some(event) = es.next().await {
@ -191,15 +191,15 @@ pub fn claude_build_body(data: SendData, model: &Model) -> Result<Value> {
Ok(body)
}
pub fn claude_extract_completion(data: &Value) -> Result<(String, CompletionStats)> {
pub fn claude_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
let text = data["content"][0]["text"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let stats = CompletionStats {
let details = CompletionDetails {
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["input_tokens"].as_u64(),
output_tokens: data["usage"]["output_tokens"].as_u64(),
};
Ok((text.to_string(), stats))
Ok((text.to_string(), details))
}

@ -1,6 +1,6 @@
use super::{
catch_error, extract_system_message, json_stream, message::*, CohereClient, CompletionStats,
ExtraConfig, Model, ModelConfig, PromptType, ReplyHandler, SendData,
catch_error, extract_system_message, json_stream, message::*, CohereClient, CompletionDetails,
ExtraConfig, Model, ModelConfig, PromptType, SendData, SseHandler,
};
use crate::utils::PromptKind;
@ -47,7 +47,7 @@ impl CohereClient {
impl_client_trait!(CohereClient, send_message, send_message_streaming);
async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionStats)> {
async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -58,7 +58,7 @@ async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionStat
cohere_extract_completion(&data)
}
async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> {
async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> {
let res = builder.send().await?;
let status = res.status();
if status != 200 {
@ -156,15 +156,15 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
Ok(body)
}
fn cohere_extract_completion(data: &Value) -> Result<(String, CompletionStats)> {
fn cohere_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
let text = data["text"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let stats = CompletionStats {
let details = CompletionDetails {
id: data["generation_id"].as_str().map(|v| v.to_string()),
input_tokens: data["meta"]["billed_units"]["input_tokens"].as_u64(),
output_tokens: data["meta"]["billed_units"]["output_tokens"].as_u64(),
};
Ok((text.to_string(), stats))
Ok((text.to_string(), details))
}

@ -1,4 +1,4 @@
use super::{openai::OpenAIConfig, ClientConfig, ClientModel, Message, Model, ReplyHandler};
use super::{openai::OpenAIConfig, ClientConfig, ClientModel, Message, Model, SseHandler};
use crate::{
config::{GlobalConfig, Input},
@ -260,7 +260,7 @@ macro_rules! impl_client_trait {
&self,
client: &reqwest::Client,
data: $crate::client::SendData,
) -> anyhow::Result<(String, $crate::client::CompletionStats)> {
) -> anyhow::Result<(String, $crate::client::CompletionDetails)> {
let builder = self.request_builder(client, data)?;
$send_message(builder).await
}
@ -268,7 +268,7 @@ macro_rules! impl_client_trait {
async fn send_message_streaming_inner(
&self,
client: &reqwest::Client,
handler: &mut $crate::client::ReplyHandler,
handler: &mut $crate::client::SseHandler,
data: $crate::client::SendData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
@ -330,11 +330,11 @@ pub trait Client: Sync + Send {
Ok(client)
}
async fn send_message(&self, input: Input) -> Result<(String, CompletionStats)> {
async fn send_message(&self, input: Input) -> Result<(String, CompletionDetails)> {
let global_config = self.config().0;
if global_config.read().dry_run {
let content = global_config.read().echo_messages(&input);
return Ok((content, CompletionStats::default()));
return Ok((content, CompletionDetails::default()));
}
let client = self.build_client()?;
let data = global_config.read().prepare_send_data(&input, false)?;
@ -343,11 +343,7 @@ pub trait Client: Sync + Send {
.with_context(|| "Failed to get answer")
}
async fn send_message_streaming(
&self,
input: &Input,
handler: &mut ReplyHandler,
) -> Result<()> {
async fn send_message_streaming(&self, input: &Input, handler: &mut SseHandler) -> Result<()> {
async fn watch_abort(abort: AbortSignal) {
loop {
if abort.aborted() {
@ -388,12 +384,12 @@ pub trait Client: Sync + Send {
&self,
client: &ReqwestClient,
data: SendData,
) -> Result<(String, CompletionStats)>;
) -> Result<(String, CompletionDetails)>;
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyHandler,
handler: &mut SseHandler,
data: SendData,
) -> Result<()>;
}
@ -419,7 +415,7 @@ pub struct SendData {
}
#[derive(Debug, Clone, Default)]
pub struct CompletionStats {
pub struct CompletionDetails {
pub id: Option<String>,
pub input_tokens: Option<u64>,
pub output_tokens: Option<u64>,
@ -459,7 +455,7 @@ pub async fn send_stream(
abort: AbortSignal,
) -> Result<String> {
let (tx, rx) = unbounded_channel();
let mut stream_handler = ReplyHandler::new(tx, abort.clone());
let mut stream_handler = SseHandler::new(tx, abort.clone());
let (send_ret, rend_ret) = tokio::join!(
client.send_message_streaming(input, &mut stream_handler),
@ -486,7 +482,7 @@ pub async fn send_stream(
#[allow(unused)]
pub async fn send_message_as_streaming<F, Fut>(
builder: RequestBuilder,
handler: &mut ReplyHandler,
handler: &mut SseHandler,
f: F,
) -> Result<()>
where

@ -1,6 +1,6 @@
use super::{
maybe_catch_error, patch_system_message, Client, CompletionStats, ErnieClient, ExtraConfig,
Model, ModelConfig, PromptType, ReplyHandler, SendData,
maybe_catch_error, patch_system_message, Client, CompletionDetails, ErnieClient, ExtraConfig,
Model, ModelConfig, PromptType, SendData, SseHandler,
};
use crate::utils::PromptKind;
@ -83,7 +83,7 @@ impl Client for ErnieClient {
&self,
client: &ReqwestClient,
data: SendData,
) -> Result<(String, CompletionStats)> {
) -> Result<(String, CompletionDetails)> {
self.prepare_access_token().await?;
let builder = self.request_builder(client, data)?;
send_message(builder).await
@ -92,7 +92,7 @@ impl Client for ErnieClient {
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyHandler,
handler: &mut SseHandler,
data: SendData,
) -> Result<()> {
self.prepare_access_token().await?;
@ -101,13 +101,13 @@ impl Client for ErnieClient {
}
}
async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionStats)> {
async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
let data: Value = builder.send().await?.json().await?;
maybe_catch_error(&data)?;
extract_completion_text(&data)
}
async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> {
async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> {
let mut es = builder.eventsource()?;
while let Some(event) = es.next().await {
match event {
@ -184,16 +184,16 @@ fn build_body(data: SendData, model: &Model) -> Value {
body
}
fn extract_completion_text(data: &Value) -> Result<(String, CompletionStats)> {
fn extract_completion_text(data: &Value) -> Result<(String, CompletionDetails)> {
let text = data["result"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let stats = CompletionStats {
let details = CompletionDetails {
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["prompt_tokens"].as_u64(),
output_tokens: data["usage"]["completion_tokens"].as_u64(),
};
Ok((text.to_string(), stats))
Ok((text.to_string(), details))
}
async fn fetch_access_token(

@ -1,5 +1 @@
openai_compatible_client!(
GroqConfig,
GroqClient,
"https://api.groq.com/openai/v1",
);
openai_compatible_client!(GroqConfig, GroqClient, "https://api.groq.com/openai/v1",);

@ -1,5 +1 @@
openai_compatible_client!(
MistralConfig,
MistralClient,
"https://api.mistral.ai/v1",
);
openai_compatible_client!(MistralConfig, MistralClient, "https://api.mistral.ai/v1",);

@ -3,13 +3,13 @@ mod common;
mod message;
mod model;
mod prompt_format;
mod reply_handler;
mod sse_handler;
pub use common::*;
pub use message::*;
pub use model::*;
pub use prompt_format::*;
pub use reply_handler::*;
pub use sse_handler::*;
register_client!(
(openai, "openai", OpenAIConfig, OpenAIClient),

@ -1,5 +1 @@
openai_compatible_client!(
MoonshotConfig,
MoonshotClient,
"https://api.moonshot.cn/v1",
);
openai_compatible_client!(MoonshotConfig, MoonshotClient, "https://api.moonshot.cn/v1",);

@ -1,6 +1,6 @@
use super::{
catch_error, message::*, CompletionStats, ExtraConfig, Model, ModelConfig, OllamaClient,
PromptType, ReplyHandler, SendData,
catch_error, message::*, CompletionDetails, ExtraConfig, Model, ModelConfig, OllamaClient,
PromptType, SendData, SseHandler,
};
use crate::utils::PromptKind;
@ -59,7 +59,7 @@ impl OllamaClient {
impl_client_trait!(OllamaClient, send_message, send_message_streaming);
async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionStats)> {
async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
let res = builder.send().await?;
let status = res.status();
let data = res.json().await?;
@ -69,10 +69,10 @@ async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionStat
let text = data["message"]["content"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
Ok((text.to_string(), CompletionStats::default()))
Ok((text.to_string(), CompletionDetails::default()))
}
async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHandler) -> Result<()> {
async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> {
let res = builder.send().await?;
let status = res.status();
if status != 200 {

@ -1,6 +1,6 @@
use super::{
catch_error, CompletionStats, ExtraConfig, Model, ModelConfig, OpenAIClient, PromptType,
ReplyHandler, SendData,
catch_error, CompletionDetails, ExtraConfig, Model, ModelConfig, OpenAIClient, PromptType,
SendData, SseHandler,
};
use crate::utils::PromptKind;
@ -52,7 +52,7 @@ impl OpenAIClient {
}
}
pub async fn openai_send_message(builder: RequestBuilder) -> Result<(String, CompletionStats)> {
pub async fn openai_send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -65,7 +65,7 @@ pub async fn openai_send_message(builder: RequestBuilder) -> Result<(String, Com
pub async fn openai_send_message_streaming(
builder: RequestBuilder,
handler: &mut ReplyHandler,
handler: &mut SseHandler,
) -> Result<()> {
let mut es = builder.eventsource()?;
while let Some(event) = es.next().await {
@ -140,16 +140,16 @@ pub fn openai_build_body(data: SendData, model: &Model) -> Value {
body
}
pub fn openai_extract_completion(data: &Value) -> Result<(String, CompletionStats)> {
pub fn openai_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
let text = data["choices"][0]["message"]["content"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let stats = CompletionStats {
let details = CompletionDetails {
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["prompt_tokens"].as_u64(),
output_tokens: data["usage"]["completion_tokens"].as_u64(),
};
Ok((text.to_string(), stats))
Ok((text.to_string(), details))
}
impl_client_trait!(

@ -1,6 +1,6 @@
use super::{
maybe_catch_error, message::*, Client, CompletionStats, ExtraConfig, Model, ModelConfig,
PromptType, QianwenClient, ReplyHandler, SendData,
maybe_catch_error, message::*, Client, CompletionDetails, ExtraConfig, Model, ModelConfig,
PromptType, QianwenClient, SendData, SseHandler,
};
use crate::utils::{sha256sum, PromptKind};
@ -77,7 +77,7 @@ impl Client for QianwenClient {
&self,
client: &ReqwestClient,
mut data: SendData,
) -> Result<(String, CompletionStats)> {
) -> Result<(String, CompletionDetails)> {
let api_key = self.get_api_key()?;
patch_messages(&self.model.name, &api_key, &mut data.messages).await?;
let builder = self.request_builder(client, data)?;
@ -87,7 +87,7 @@ impl Client for QianwenClient {
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyHandler,
handler: &mut SseHandler,
mut data: SendData,
) -> Result<()> {
let api_key = self.get_api_key()?;
@ -97,7 +97,7 @@ impl Client for QianwenClient {
}
}
async fn send_message(builder: RequestBuilder, is_vl: bool) -> Result<(String, CompletionStats)> {
async fn send_message(builder: RequestBuilder, is_vl: bool) -> Result<(String, CompletionDetails)> {
let data: Value = builder.send().await?.json().await?;
maybe_catch_error(&data)?;
@ -106,7 +106,7 @@ async fn send_message(builder: RequestBuilder, is_vl: bool) -> Result<(String, C
async fn send_message_streaming(
builder: RequestBuilder,
handler: &mut ReplyHandler,
handler: &mut SseHandler,
is_vl: bool,
) -> Result<()> {
let mut es = builder.eventsource()?;
@ -210,7 +210,7 @@ fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool
Ok((body, has_upload))
}
fn extract_completion_text(data: &Value, is_vl: bool) -> Result<(String, CompletionStats)> {
fn extract_completion_text(data: &Value, is_vl: bool) -> Result<(String, CompletionDetails)> {
let err = || anyhow!("Invalid response data: {data}");
let text = if is_vl {
data["output"]["choices"][0]["message"]["content"][0]["text"]
@ -219,13 +219,13 @@ fn extract_completion_text(data: &Value, is_vl: bool) -> Result<(String, Complet
} else {
data["output"]["text"].as_str().ok_or_else(err)?
};
let stats = CompletionStats {
let details = CompletionDetails {
id: data["request_id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["input_tokens"].as_u64(),
output_tokens: data["usage"]["output_tokens"].as_u64(),
};
Ok((text.to_string(), stats))
Ok((text.to_string(), details))
}
/// Patch messsages, upload embedded images to oss

@ -3,14 +3,14 @@ use crate::utils::AbortSignal;
use anyhow::{Context, Result};
use tokio::sync::mpsc::UnboundedSender;
pub struct ReplyHandler {
sender: UnboundedSender<ReplyEvent>,
pub struct SseHandler {
sender: UnboundedSender<SseEvent>,
buffer: String,
abort: AbortSignal,
}
impl ReplyHandler {
pub fn new(sender: UnboundedSender<ReplyEvent>, abort: AbortSignal) -> Self {
impl SseHandler {
pub fn new(sender: UnboundedSender<SseEvent>, abort: AbortSignal) -> Self {
Self {
sender,
abort,
@ -26,7 +26,7 @@ impl ReplyHandler {
self.buffer.push_str(text);
let ret = self
.sender
.send(ReplyEvent::Text(text.to_string()))
.send(SseEvent::Text(text.to_string()))
.with_context(|| "Failed to send ReplyEvent:Text");
self.safe_ret(ret)?;
Ok(())
@ -36,7 +36,7 @@ impl ReplyHandler {
// debug!("ReplyDone");
let ret = self
.sender
.send(ReplyEvent::Done)
.send(SseEvent::Done)
.with_context(|| "Failed to send ReplyEvent::Done");
self.safe_ret(ret)?;
Ok(())
@ -59,7 +59,7 @@ impl ReplyHandler {
}
#[derive(Debug)]
pub enum ReplyEvent {
pub enum SseEvent {
Text(String),
Done,
}

@ -1,7 +1,7 @@
use super::claude::{claude_build_body, claude_send_message, claude_send_message_streaming};
use super::{
catch_error, json_stream, message::*, patch_system_message, Client, CompletionStats,
ExtraConfig, Model, ModelConfig, PromptType, ReplyHandler, SendData, VertexAIClient,
catch_error, json_stream, message::*, patch_system_message, Client, CompletionDetails,
ExtraConfig, Model, ModelConfig, PromptType, SendData, SseHandler, VertexAIClient,
};
use crate::utils::PromptKind;
@ -85,7 +85,7 @@ impl Client for VertexAIClient {
&self,
client: &ReqwestClient,
data: SendData,
) -> Result<(String, CompletionStats)> {
) -> Result<(String, CompletionDetails)> {
let model_category = ModelCategory::from_str(&self.model.name)?;
self.prepare_access_token().await?;
let builder = self.request_builder(client, data, &model_category)?;
@ -98,7 +98,7 @@ impl Client for VertexAIClient {
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyHandler,
handler: &mut SseHandler,
data: SendData,
) -> Result<()> {
let model_category = ModelCategory::from_str(&self.model.name)?;
@ -111,7 +111,7 @@ impl Client for VertexAIClient {
}
}
pub async fn gemini_send_message(builder: RequestBuilder) -> Result<(String, CompletionStats)> {
pub async fn gemini_send_message(builder: RequestBuilder) -> Result<(String, CompletionDetails)> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -123,7 +123,7 @@ pub async fn gemini_send_message(builder: RequestBuilder) -> Result<(String, Com
pub async fn gemini_send_message_streaming(
builder: RequestBuilder,
handler: &mut ReplyHandler,
handler: &mut SseHandler,
) -> Result<()> {
let res = builder.send().await?;
let status = res.status();
@ -141,14 +141,14 @@ pub async fn gemini_send_message_streaming(
Ok(())
}
fn gemini_extract_completion_text(data: &Value) -> Result<(String, CompletionStats)> {
fn gemini_extract_completion_text(data: &Value) -> Result<(String, CompletionDetails)> {
let text = gemini_extract_text(data)?;
let stats = CompletionStats {
let details = CompletionDetails {
id: None,
input_tokens: data["usageMetadata"]["promptTokenCount"].as_u64(),
output_tokens: data["usageMetadata"]["candidatesTokenCount"].as_u64(),
};
Ok((text.to_string(), stats))
Ok((text.to_string(), details))
}
fn gemini_extract_text(data: &Value) -> Result<&str> {

@ -5,7 +5,7 @@ pub use self::markdown::{MarkdownRender, RenderOptions};
use self::stream::{markdown_stream, raw_stream};
use crate::utils::AbortSignal;
use crate::{client::ReplyEvent, config::GlobalConfig};
use crate::{client::SseEvent, config::GlobalConfig};
use anyhow::Result;
use is_terminal::IsTerminal;
@ -14,7 +14,7 @@ use std::io::stdout;
use tokio::sync::mpsc::UnboundedReceiver;
pub async fn render_stream(
rx: UnboundedReceiver<ReplyEvent>,
rx: UnboundedReceiver<SseEvent>,
config: &GlobalConfig,
abort: AbortSignal,
) -> Result<()> {

@ -1,4 +1,4 @@
use super::{MarkdownRender, ReplyEvent};
use super::{MarkdownRender, SseEvent};
use crate::utils::{run_spinner, AbortSignal};
@ -17,7 +17,7 @@ use textwrap::core::display_width;
use tokio::sync::{mpsc::UnboundedReceiver, oneshot};
pub async fn markdown_stream(
rx: UnboundedReceiver<ReplyEvent>,
rx: UnboundedReceiver<SseEvent>,
render: &mut MarkdownRender,
abort: &AbortSignal,
) -> Result<()> {
@ -31,18 +31,18 @@ pub async fn markdown_stream(
ret
}
pub async fn raw_stream(mut rx: UnboundedReceiver<ReplyEvent>, abort: &AbortSignal) -> Result<()> {
pub async fn raw_stream(mut rx: UnboundedReceiver<SseEvent>, abort: &AbortSignal) -> Result<()> {
loop {
if abort.aborted() {
return Ok(());
}
if let Some(evt) = rx.recv().await {
match evt {
ReplyEvent::Text(text) => {
SseEvent::Text(text) => {
print!("{}", text);
stdout().flush()?;
}
ReplyEvent::Done => {
SseEvent::Done => {
break;
}
}
@ -52,7 +52,7 @@ pub async fn raw_stream(mut rx: UnboundedReceiver<ReplyEvent>, abort: &AbortSign
}
async fn markdown_stream_inner(
mut rx: UnboundedReceiver<ReplyEvent>,
mut rx: UnboundedReceiver<SseEvent>,
render: &mut MarkdownRender,
abort: &AbortSignal,
writer: &mut Stdout,
@ -76,7 +76,7 @@ async fn markdown_stream_inner(
}
match reply_event {
ReplyEvent::Text(mut text) => {
SseEvent::Text(mut text) => {
// tab width hacking
text = text.replace('\t', " ");
@ -127,7 +127,7 @@ async fn markdown_stream_inner(
writer.flush()?;
}
ReplyEvent::Done => {
SseEvent::Done => {
break 'outer;
}
}
@ -156,15 +156,15 @@ async fn markdown_stream_inner(
Ok(())
}
async fn gather_events(rx: &mut UnboundedReceiver<ReplyEvent>) -> Vec<ReplyEvent> {
async fn gather_events(rx: &mut UnboundedReceiver<SseEvent>) -> Vec<SseEvent> {
let mut texts = vec![];
let mut done = false;
tokio::select! {
_ = async {
while let Some(reply_event) = rx.recv().await {
match reply_event {
ReplyEvent::Text(v) => texts.push(v),
ReplyEvent::Done => {
SseEvent::Text(v) => texts.push(v),
SseEvent::Done => {
done = true;
break;
}
@ -175,10 +175,10 @@ async fn gather_events(rx: &mut UnboundedReceiver<ReplyEvent>) -> Vec<ReplyEvent
};
let mut events = vec![];
if !texts.is_empty() {
events.push(ReplyEvent::Text(texts.join("")))
events.push(SseEvent::Text(texts.join("")))
}
if done {
events.push(ReplyEvent::Done)
events.push(SseEvent::Done)
}
events
}

@ -1,7 +1,7 @@
use crate::{
client::{
init_client, ClientConfig, CompletionStats, Message, Model, ReplyEvent, ReplyHandler,
SendData,
init_client, ClientConfig, CompletionDetails, Message, Model, SendData, SseEvent,
SseHandler,
},
config::{Config, GlobalConfig},
utils::create_abort_signal,
@ -184,9 +184,9 @@ impl Server {
tokio::spawn(async move {
let mut is_first = true;
let (tx2, rx2) = unbounded_channel();
let mut handler = ReplyHandler::new(tx2, abort);
let mut handler = SseHandler::new(tx2, abort);
async fn map_event(
mut rx: UnboundedReceiver<ReplyEvent>,
mut rx: UnboundedReceiver<SseEvent>,
tx: &UnboundedSender<ResEvent>,
is_first: &mut bool,
) {
@ -196,10 +196,10 @@ impl Server {
*is_first = false;
}
match reply_event {
ReplyEvent::Text(text) => {
SseEvent::Text(text) => {
let _ = tx.send(ResEvent::Text(text));
}
ReplyEvent::Done => {
SseEvent::Done => {
let _ = tx.send(ResEvent::Done);
}
}
@ -251,7 +251,7 @@ impl Server {
.body(BodyExt::boxed(StreamBody::new(stream)))?;
Ok(res)
} else {
let (content, stats) = client.send_message_inner(&http_client, send_data).await?;
let (content, details) = client.send_message_inner(&http_client, send_data).await?;
let res = Response::builder()
.header("Content-Type", "application/json")
.body(
@ -260,7 +260,7 @@ impl Server {
&model_name,
created,
&content,
&stats,
&details,
))
.boxed(),
)?;
@ -357,11 +357,11 @@ fn ret_non_stream(
model: &str,
created: i64,
content: &str,
stats: &CompletionStats,
details: &CompletionDetails,
) -> Bytes {
let id = stats.id.as_deref().unwrap_or(id);
let input_tokens = stats.input_tokens.unwrap_or_default();
let output_tokens = stats.output_tokens.unwrap_or_default();
let id = details.id.as_deref().unwrap_or(id);
let input_tokens = details.input_tokens.unwrap_or_default();
let output_tokens = details.output_tokens.unwrap_or_default();
let total_tokens = input_tokens + output_tokens;
let res_body = json!({
"id": id,

Loading…
Cancel
Save