refactor: handling of system message (#432)

pull/433/head
sigoden 2 months ago committed by GitHub
parent 0a4c0413ef
commit 4f8d895154
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,5 +1,5 @@
use super::{
patch_system_message, ClaudeClient, Client, ExtraConfig, ImageUrl, MessageContent,
extract_sytem_message, ClaudeClient, Client, ExtraConfig, ImageUrl, MessageContent,
MessageContentPart, Model, ModelConfig, PromptType, ReplyHandler, SendData,
};
@ -141,7 +141,7 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
stream,
} = data;
patch_system_message(&mut messages);
let system_message = extract_sytem_message(&mut messages);
let mut network_image_urls = vec![];
let messages: Vec<Value> = messages
@ -196,6 +196,10 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
"messages": messages,
});
if let Some(system) = system_message {
body["system"] = system.into();
}
if let Some(v) = temperature {
body["temperature"] = v.into();
}

@ -1,5 +1,5 @@
use super::{
json_stream, message::*, patch_system_message, Client, CohereClient, ExtraConfig, Model,
extract_sytem_message, json_stream, message::*, Client, CohereClient, ExtraConfig, Model,
ModelConfig, PromptType, ReplyHandler, SendData,
};
@ -129,7 +129,7 @@ pub(crate) fn build_body(data: SendData, model: &Model) -> Result<Value> {
stream,
} = data;
patch_system_message(&mut messages);
let system_message = extract_sytem_message(&mut messages);
let mut image_urls = vec![];
let mut messages: Vec<Value> = messages
@ -174,6 +174,10 @@ pub(crate) fn build_body(data: SendData, model: &Model) -> Result<Value> {
"message": message,
});
if let Some(preamble) = system_message {
body["preamble"] = preamble.into();
}
if let Some(max_tokens) = model.max_output_tokens {
body["max_tokens"] = max_tokens.into();
}

@ -414,6 +414,14 @@ pub fn patch_system_message(messages: &mut Vec<Message>) {
}
}
pub fn extract_sytem_message(messages: &mut Vec<Message>) -> Option<String> {
if messages[0].role.is_system() {
let system_message = messages.remove(0);
return Some(system_message.content.to_text());
}
None
}
pub async fn json_stream<S, F>(mut stream: S, mut handle: F) -> Result<()>
where
S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin,

@ -85,6 +85,21 @@ impl MessageContent {
}
}
}
pub fn to_text(&self) -> String {
match self {
MessageContent::Text(text) => text.to_string(),
MessageContent::Array(list) => {
let mut parts = vec![];
for item in list {
if let MessageContentPart::Text { text } = item {
parts.push(text.clone())
}
}
parts.join("\n\n")
}
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]

@ -1,6 +1,6 @@
use super::{
message::*, patch_system_message, Client, ExtraConfig, Model, ModelConfig, OllamaClient,
PromptType, ReplyHandler, SendData,
message::*, Client, ExtraConfig, Model, ModelConfig, OllamaClient, PromptType, ReplyHandler,
SendData,
};
use crate::utils::PromptKind;
@ -121,13 +121,11 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHand
fn build_body(data: SendData, model: &Model) -> Result<Value> {
let SendData {
mut messages,
messages,
temperature,
stream,
} = data;
patch_system_message(&mut messages);
let mut network_image_urls = vec![];
let messages: Vec<Value> = messages
.into_iter()

Loading…
Cancel
Save