refactor: abstract event stream handling (#458)

pull/459/head
sigoden 3 weeks ago committed by GitHub
parent 37a0cd08a9
commit 68882ecd4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,14 +1,13 @@
use super::{
catch_error, extract_system_message, ClaudeClient, CompletionDetails, ExtraConfig, ImageUrl,
MessageContent, MessageContentPart, Model, ModelConfig, PromptType, SendData, SseHandler,
catch_error, extract_system_message, sse_stream, ClaudeClient, CompletionDetails, ExtraConfig,
ImageUrl, MessageContent, MessageContentPart, Model, ModelConfig, PromptType, SendData,
SseHandler,
};
use crate::utils::PromptKind;
use anyhow::{anyhow, bail, Result};
use futures_util::StreamExt;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt};
use serde::Deserialize;
use serde_json::{json, Value};
@ -68,50 +67,19 @@ pub async fn claude_send_message_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
) -> Result<()> {
let mut es = builder.eventsource()?;
while let Some(event) = es.next().await {
match event {
Ok(Event::Open) => {}
Ok(Event::Message(message)) => {
let data: Value = serde_json::from_str(&message.data)?;
if let Some(typ) = data["type"].as_str() {
if typ == "content_block_delta" {
if let Some(text) = data["delta"]["text"].as_str() {
handler.text(text)?;
}
}
}
}
Err(err) => {
match err {
EventSourceError::StreamEnded => {}
EventSourceError::InvalidStatusCode(status, res) => {
let text = res.text().await?;
let data: Value = match text.parse() {
Ok(data) => data,
Err(_) => {
bail!(
"Invalid response data: {text} (status: {})",
status.as_u16()
);
}
};
catch_error(&data, status.as_u16())?;
}
EventSourceError::InvalidContentType(_, res) => {
let text = res.text().await?;
bail!("The API server should return data as 'text/event-stream', but it isn't. Check the client config. {text}");
}
_ => {
bail!("{}", err);
}
let handle = |data: &str| -> Result<bool> {
let data: Value = serde_json::from_str(data)?;
if let Some(typ) = data["type"].as_str() {
if typ == "content_block_delta" {
if let Some(text) = data["delta"]["text"].as_str() {
handler.text(text)?;
}
es.close();
}
}
}
Ok(false)
};
Ok(())
sse_stream(builder, handle).await
}
pub fn claude_build_body(data: SendData, model: &Model) -> Result<Value> {

@ -28,7 +28,7 @@ impl CohereClient {
[("api_key", "API Key:", false, PromptKind::String)];
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();
let api_key = self.get_api_key()?;
let body = build_body(data, &self.model)?;
@ -36,10 +36,7 @@ impl CohereClient {
debug!("Cohere Request: {url} {body}");
let mut builder = client.post(url).json(&body);
if let Some(api_key) = api_key {
builder = builder.bearer_auth(api_key);
}
let builder = client.post(url).bearer_auth(api_key).json(&body);
Ok(builder)
}
@ -55,7 +52,7 @@ async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDeta
catch_error(&data, status.as_u16())?;
}
cohere_extract_completion(&data)
extract_completion(&data)
}
async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> {
@ -156,7 +153,7 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
Ok(body)
}
fn cohere_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
fn extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
let text = data["text"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;

@ -11,6 +11,7 @@ use async_trait::async_trait;
use futures_util::{Stream, StreamExt};
use lazy_static::lazy_static;
use reqwest::{Client as ReqwestClient, ClientBuilder, Proxy, RequestBuilder};
use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt};
use serde::Deserialize;
use serde_json::{json, Value};
use std::{env, future::Future, time::Duration};
@ -531,6 +532,53 @@ pub fn maybe_catch_error(data: &Value) -> Result<()> {
Ok(())
}
pub async fn sse_stream<F>(builder: RequestBuilder, mut handle: F) -> Result<()>
where
F: FnMut(&str) -> Result<bool>,
{
let mut es = builder.eventsource()?;
while let Some(event) = es.next().await {
match event {
Ok(Event::Open) => {}
Ok(Event::Message(message)) => {
if handle(&message.data)? {
break;
}
}
Err(err) => {
match err {
EventSourceError::StreamEnded => {}
EventSourceError::InvalidStatusCode(status, res) => {
let text = res.text().await?;
let data: Value = match text.parse() {
Ok(data) => data,
Err(_) => {
bail!(
"Invalid response data: {text} (status: {})",
status.as_u16()
);
}
};
catch_error(&data, status.as_u16())?;
}
EventSourceError::InvalidContentType(header_value, res) => {
let text = res.text().await?;
bail!(
"Invalid response event-stream. content-type: {}, data: {text}",
header_value.to_str().unwrap_or_default()
);
}
_ => {
bail!("{}", err);
}
}
es.close();
}
}
}
Ok(())
}
pub async fn json_stream<S, F>(mut stream: S, mut handle: F) -> Result<()>
where
S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin,

@ -1,16 +1,14 @@
use super::{
maybe_catch_error, patch_system_message, Client, CompletionDetails, ErnieClient, ExtraConfig,
Model, ModelConfig, PromptType, SendData, SseHandler,
maybe_catch_error, patch_system_message, sse_stream, Client, CompletionDetails, ErnieClient,
ExtraConfig, Model, ModelConfig, PromptType, SendData, SseHandler,
};
use crate::utils::PromptKind;
use anyhow::{anyhow, bail, Context, Result};
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use chrono::Utc;
use futures_util::StreamExt;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt};
use serde::Deserialize;
use serde_json::{json, Value};
use std::env;
@ -108,49 +106,15 @@ async fn send_message(builder: RequestBuilder) -> Result<(String, CompletionDeta
}
async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> {
let mut es = builder.eventsource()?;
while let Some(event) = es.next().await {
match event {
Ok(Event::Open) => {}
Ok(Event::Message(message)) => {
let data: Value = serde_json::from_str(&message.data)?;
if let Some(text) = data["result"].as_str() {
handler.text(text)?;
}
}
Err(err) => {
match err {
EventSourceError::InvalidContentType(header_value, res) => {
let content_type = header_value
.to_str()
.map_err(|_| anyhow!("Invalid response header"))?;
if content_type.contains("application/json") {
let data: Value = res.json().await?;
maybe_catch_error(&data)?;
bail!("Invalid response data: {data}");
} else {
let text = res.text().await?;
if let Some(text) = text.strip_prefix("data: ") {
let data: Value = serde_json::from_str(text)?;
if let Some(text) = data["result"].as_str() {
handler.text(text)?;
}
} else {
bail!("Invalid response data: {text}")
}
}
}
EventSourceError::StreamEnded => {}
_ => {
bail!("{}", err);
}
}
es.close();
}
let handle = |data: &str| -> Result<bool> {
let data: Value = serde_json::from_str(data)?;
if let Some(text) = data["result"].as_str() {
handler.text(text)?;
}
}
Ok(false)
};
Ok(())
sse_stream(builder, handle).await
}
fn build_body(data: SendData, model: &Model) -> Value {

@ -1,14 +1,12 @@
use super::{
catch_error, CompletionDetails, ExtraConfig, Model, ModelConfig, OpenAIClient, PromptType,
SendData, SseHandler,
catch_error, sse_stream, CompletionDetails, ExtraConfig, Model, ModelConfig, OpenAIClient,
PromptType, SendData, SseHandler,
};
use crate::utils::PromptKind;
use anyhow::{anyhow, bail, Result};
use futures_util::StreamExt;
use anyhow::{anyhow, Result};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt};
use serde::Deserialize;
use serde_json::{json, Value};
@ -67,49 +65,18 @@ pub async fn openai_send_message_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
) -> Result<()> {
let mut es = builder.eventsource()?;
while let Some(event) = es.next().await {
match event {
Ok(Event::Open) => {}
Ok(Event::Message(message)) => {
if message.data == "[DONE]" {
break;
}
let data: Value = serde_json::from_str(&message.data)?;
if let Some(text) = data["choices"][0]["delta"]["content"].as_str() {
handler.text(text)?;
}
}
Err(err) => {
match err {
EventSourceError::InvalidStatusCode(status, res) => {
let text = res.text().await?;
let data: Value = match text.parse() {
Ok(data) => data,
Err(_) => {
bail!(
"Invalid response data: {text} (status: {})",
status.as_u16()
);
}
};
catch_error(&data, status.as_u16())?;
}
EventSourceError::StreamEnded => {}
EventSourceError::InvalidContentType(_, res) => {
let text = res.text().await?;
bail!("The API server should return data as 'text/event-stream', but it isn't. Check the client config. {text}");
}
_ => {
bail!("{}", err);
}
}
es.close();
}
let handle = |data: &str| -> Result<bool> {
if data == "[DONE]" {
return Ok(true);
}
}
let data: Value = serde_json::from_str(data)?;
if let Some(text) = data["choices"][0]["delta"]["content"].as_str() {
handler.text(text)?;
}
Ok(false)
};
Ok(())
sse_stream(builder, handle).await
}
pub fn openai_build_body(data: SendData, model: &Model) -> Value {

@ -1,6 +1,6 @@
use super::{
maybe_catch_error, message::*, Client, CompletionDetails, ExtraConfig, Model, ModelConfig,
PromptType, QianwenClient, SendData, SseHandler,
maybe_catch_error, message::*, sse_stream, Client, CompletionDetails, ExtraConfig, Model,
ModelConfig, PromptType, QianwenClient, SendData, SseHandler,
};
use crate::utils::{sha256sum, PromptKind};
@ -8,12 +8,10 @@ use crate::utils::{sha256sum, PromptKind};
use anyhow::{anyhow, bail, Context, Result};
use async_trait::async_trait;
use base64::{engine::general_purpose::STANDARD, Engine};
use futures_util::StreamExt;
use reqwest::{
multipart::{Form, Part},
Client as ReqwestClient, RequestBuilder,
};
use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt};
use serde::Deserialize;
use serde_json::{json, Value};
use std::borrow::BorrowMut;
@ -109,37 +107,22 @@ async fn send_message_streaming(
handler: &mut SseHandler,
is_vl: bool,
) -> Result<()> {
let mut es = builder.eventsource()?;
while let Some(event) = es.next().await {
match event {
Ok(Event::Open) => {}
Ok(Event::Message(message)) => {
let data: Value = serde_json::from_str(&message.data)?;
maybe_catch_error(&data)?;
if is_vl {
if let Some(text) =
data["output"]["choices"][0]["message"]["content"][0]["text"].as_str()
{
handler.text(text)?;
}
} else if let Some(text) = data["output"]["text"].as_str() {
handler.text(text)?;
}
}
Err(err) => {
match err {
EventSourceError::StreamEnded => {}
_ => {
bail!("{}", err);
}
}
es.close();
let handle = |data: &str| -> Result<bool> {
let data: Value = serde_json::from_str(data)?;
maybe_catch_error(&data)?;
if is_vl {
if let Some(text) =
data["output"]["choices"][0]["message"]["content"][0]["text"].as_str()
{
handler.text(text)?;
}
} else if let Some(text) = data["output"]["text"].as_str() {
handler.text(text)?;
}
}
Ok(false)
};
Ok(())
sse_stream(builder, handle).await
}
fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool)> {

Loading…
Cancel
Save