feat: enhence roles with messages (#495)

pull/496/head
sigoden 2 weeks ago committed by GitHub
parent 5d73768acc
commit bc65e880be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -13,13 +13,13 @@ pub struct PromptFormat<'a> {
pub const GENERIC_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
begin: "",
system_pre_message: "### System\n",
system_pre_message: "",
system_post_message: "\n",
user_pre_message: "### User\n",
user_pre_message: "### Instruction:\n",
user_post_message: "\n",
assistant_pre_message: "### Assistant\n",
assistant_pre_message: "### Response:\n",
assistant_post_message: "\n",
end: "### Assistant\n",
end: "### Response:\n",
};
pub const MISTRAL_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {

@ -372,7 +372,7 @@ impl Config {
pub fn build_messages(&self, input: &Input) -> Result<Vec<Message>> {
let messages = if let Some(session) = input.session(&self.session) {
session.build_emssages(input)
session.build_messages(input)
} else if let Some(role) = input.role() {
role.build_messages(input)
} else {

@ -86,9 +86,15 @@ APPLY MARKDOWN formatting when possible."#
pub fn code() -> Self {
Self {
name: CODE_ROLE.into(),
prompt: r#"Provide only code, without comments or explanations.
If there is a lack of details, provide most logical solution, without requesting further clarification."#
.into(),
prompt: r#"Provide only code without comments or explanations.
### INPUT:
async sleep in js
### OUTPUT:
async function timeout(ms) {
return new Promise(resolve => setTimeout(resolve, ms));
}
"#
.into(),
temperature: None,
top_p: None,
}
@ -146,16 +152,33 @@ If there is a lack of details, provide most logical solution, without requesting
content,
}]
} else {
vec![
Message {
let mut messages = vec![];
let (system, cases) = parse_structure_prompt(&self.prompt);
if !system.is_empty() {
messages.push(Message {
role: MessageRole::System,
content: MessageContent::Text(self.prompt.clone()),
},
Message {
role: MessageRole::User,
content,
},
]
content: MessageContent::Text(system.to_string()),
})
}
if !cases.is_empty() {
messages.extend(cases.into_iter().flat_map(|(i, o)| {
vec![
Message {
role: MessageRole::User,
content: MessageContent::Text(i.to_string()),
},
Message {
role: MessageRole::Assistant,
content: MessageContent::Text(o.to_string()),
},
]
}));
}
messages.push(Message {
role: MessageRole::User,
content,
});
messages
}
}
}
@ -168,6 +191,54 @@ fn complete_prompt_args(prompt: &str, name: &str) -> String {
prompt
}
fn parse_structure_prompt(prompt: &str) -> (&str, Vec<(&str, &str)>) {
let mut text = prompt;
let mut search_input = true;
let mut system = None;
let mut parts = vec![];
loop {
let search = if search_input {
"### INPUT:"
} else {
"### OUTPUT:"
};
match text.find(search) {
Some(idx) => {
if system.is_none() {
system = Some(&text[..idx])
} else {
parts.push(&text[..idx])
}
search_input = !search_input;
text = &text[(idx + search.len())..];
}
None => {
if !text.is_empty() {
if system.is_none() {
system = Some(text)
} else {
parts.push(text)
}
}
break;
}
}
}
let parts_len = parts.len();
if parts_len > 0 && parts_len % 2 == 0 {
let cases: Vec<(&str, &str)> = parts
.iter()
.step_by(2)
.zip(parts.iter().skip(1).step_by(2))
.map(|(i, o)| (i.trim(), o.trim()))
.collect();
let system = system.map(|v| v.trim()).unwrap_or_default();
return (system, cases);
}
(prompt, vec![])
}
#[cfg(test)]
mod tests {
use super::*;
@ -183,4 +254,43 @@ mod tests {
"convert foo to bar"
);
}
#[test]
fn test_parse_structure_prompt1() {
let prompt = r#"
System message
### INPUT:
Input 1
### OUTPUT:
Output 1
"#;
assert_eq!(
parse_structure_prompt(prompt),
("System message", vec![("Input 1", "Output 1")])
);
}
#[test]
fn test_parse_structure_prompt2() {
let prompt = r#"
### INPUT:
Input 1
### OUTPUT:
Output 1
"#;
assert_eq!(
parse_structure_prompt(prompt),
("", vec![("Input 1", "Output 1")])
);
}
#[test]
fn test_parse_structure_prompt3() {
let prompt = r#"
System message
### INPUT:
Input 1
"#;
assert_eq!(parse_structure_prompt(prompt), (prompt, vec![]));
}
}

@ -323,11 +323,11 @@ impl Session {
}
pub fn echo_messages(&self, input: &Input) -> String {
let messages = self.build_emssages(input);
let messages = self.build_messages(input);
serde_yaml::to_string(&messages).unwrap_or_else(|_| "Unable to echo message".into())
}
pub fn build_emssages(&self, input: &Input) -> Vec<Message> {
pub fn build_messages(&self, input: &Input) -> Vec<Message> {
let mut messages = self.messages.clone();
let mut need_add_msg = true;
let len = messages.len();

Loading…
Cancel
Save