refactor: sse handling (#465)

pull/466/head
sigoden 3 weeks ago committed by GitHub
parent ffb0af8236
commit 4d4a100fe6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,7 +1,7 @@
use super::{
catch_error, extract_system_message, sse_stream, ClaudeClient, CompletionDetails, ExtraConfig,
ImageUrl, MessageContent, MessageContentPart, Model, ModelConfig, PromptType, SendData,
SseHandler,
SsMmessage, SseHandler,
};
use crate::utils::PromptKind;
@ -67,8 +67,8 @@ pub async fn claude_send_message_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
) -> Result<()> {
let handle = |data: &str| -> Result<bool> {
let data: Value = serde_json::from_str(data)?;
let handle = |message: SsMmessage| -> Result<bool> {
let data: Value = serde_json::from_str(&message.data)?;
if let Some(typ) = data["type"].as_str() {
if typ == "content_block_delta" {
if let Some(text) = data["delta"]["text"].as_str() {

@ -1,6 +1,6 @@
use super::{
catch_error, sse_stream, CloudflareClient, CompletionDetails, ExtraConfig, Model, ModelConfig,
PromptType, SendData, SseHandler,
PromptType, SendData, SsMmessage, SseHandler,
};
use crate::utils::PromptKind;
@ -64,11 +64,11 @@ async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDeta
}
async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> {
let handle = |data: &str| -> Result<bool> {
if data == "[DONE]" {
let handle = |message: SsMmessage| -> Result<bool> {
if message.data == "[DONE]" {
return Ok(true);
}
let data: Value = serde_json::from_str(data)?;
let data: Value = serde_json::from_str(&message.data)?;
if let Some(text) = data["response"].as_str() {
handler.text(text)?;
}

@ -541,16 +541,26 @@ pub fn maybe_catch_error(data: &Value) -> Result<()> {
Ok(())
}
#[derive(Debug)]
pub struct SsMmessage {
pub event: String,
pub data: String,
}
pub async fn sse_stream<F>(builder: RequestBuilder, mut handle: F) -> Result<()>
where
F: FnMut(&str) -> Result<bool>,
F: FnMut(SsMmessage) -> Result<bool>,
{
let mut es = builder.eventsource()?;
while let Some(event) = es.next().await {
match event {
Ok(Event::Open) => {}
Ok(Event::Message(message)) => {
if handle(&message.data)? {
let message = SsMmessage {
event: message.event,
data: message.data,
};
if handle(message)? {
break;
}
}

@ -1,6 +1,6 @@
use super::{
maybe_catch_error, patch_system_message, sse_stream, Client, CompletionDetails, ErnieClient,
ExtraConfig, Model, ModelConfig, PromptType, SendData, SseHandler,
ExtraConfig, Model, ModelConfig, PromptType, SendData, SsMmessage, SseHandler,
};
use crate::utils::PromptKind;
@ -106,8 +106,8 @@ async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDeta
}
async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> {
let handle = |data: &str| -> Result<bool> {
let data: Value = serde_json::from_str(data)?;
let handle = |message: SsMmessage| -> Result<bool> {
let data: Value = serde_json::from_str(&message.data)?;
if let Some(text) = data["result"].as_str() {
handler.text(text)?;
}

@ -1,6 +1,6 @@
use super::{
catch_error, sse_stream, CompletionDetails, ExtraConfig, Model, ModelConfig, OpenAIClient,
PromptType, SendData, SseHandler,
PromptType, SendData, SsMmessage, SseHandler,
};
use crate::utils::PromptKind;
@ -65,11 +65,11 @@ pub async fn openai_send_message_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
) -> Result<()> {
let handle = |data: &str| -> Result<bool> {
if data == "[DONE]" {
let handle = |message: SsMmessage| -> Result<bool> {
if message.data == "[DONE]" {
return Ok(true);
}
let data: Value = serde_json::from_str(data)?;
let data: Value = serde_json::from_str(&message.data)?;
if let Some(text) = data["choices"][0]["delta"]["content"].as_str() {
handler.text(text)?;
}

@ -1,6 +1,6 @@
use super::{
maybe_catch_error, message::*, sse_stream, Client, CompletionDetails, ExtraConfig, Model,
ModelConfig, PromptType, QianwenClient, SendData, SseHandler,
ModelConfig, PromptType, QianwenClient, SendData, SsMmessage, SseHandler,
};
use crate::utils::{sha256sum, PromptKind};
@ -107,8 +107,8 @@ async fn send_message_streaming(
handler: &mut SseHandler,
is_vl: bool,
) -> Result<()> {
let handle = |data: &str| -> Result<bool> {
let data: Value = serde_json::from_str(data)?;
let handle = |message: SsMmessage| -> Result<bool> {
let data: Value = serde_json::from_str(&message.data)?;
maybe_catch_error(&data)?;
if is_vl {
if let Some(text) =

Loading…
Cancel
Save