diff --git a/src/client/prompt_format.rs b/src/client/prompt_format.rs index 61647a8..8a90924 100644 --- a/src/client/prompt_format.rs +++ b/src/client/prompt_format.rs @@ -13,13 +13,13 @@ pub struct PromptFormat<'a> { pub const GENERIC_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat { begin: "", - system_pre_message: "### System\n", + system_pre_message: "", system_post_message: "\n", - user_pre_message: "### User\n", + user_pre_message: "### Instruction:\n", user_post_message: "\n", - assistant_pre_message: "### Assistant\n", + assistant_pre_message: "### Response:\n", assistant_post_message: "\n", - end: "### Assistant\n", + end: "### Response:\n", }; pub const MISTRAL_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat { diff --git a/src/config/mod.rs b/src/config/mod.rs index 5ac44a8..e166969 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -372,7 +372,7 @@ impl Config { pub fn build_messages(&self, input: &Input) -> Result> { let messages = if let Some(session) = input.session(&self.session) { - session.build_emssages(input) + session.build_messages(input) } else if let Some(role) = input.role() { role.build_messages(input) } else { diff --git a/src/config/role.rs b/src/config/role.rs index b226622..b96e704 100644 --- a/src/config/role.rs +++ b/src/config/role.rs @@ -86,9 +86,15 @@ APPLY MARKDOWN formatting when possible."# pub fn code() -> Self { Self { name: CODE_ROLE.into(), - prompt: r#"Provide only code, without comments or explanations. -If there is a lack of details, provide most logical solution, without requesting further clarification."# - .into(), + prompt: r#"Provide only code without comments or explanations. +### INPUT: +async sleep in js +### OUTPUT: +async function timeout(ms) { + return new Promise(resolve => setTimeout(resolve, ms)); +} +"# + .into(), temperature: None, top_p: None, } @@ -146,16 +152,33 @@ If there is a lack of details, provide most logical solution, without requesting content, }] } else { - vec![ - Message { + let mut messages = vec![]; + let (system, cases) = parse_structure_prompt(&self.prompt); + if !system.is_empty() { + messages.push(Message { role: MessageRole::System, - content: MessageContent::Text(self.prompt.clone()), - }, - Message { - role: MessageRole::User, - content, - }, - ] + content: MessageContent::Text(system.to_string()), + }) + } + if !cases.is_empty() { + messages.extend(cases.into_iter().flat_map(|(i, o)| { + vec![ + Message { + role: MessageRole::User, + content: MessageContent::Text(i.to_string()), + }, + Message { + role: MessageRole::Assistant, + content: MessageContent::Text(o.to_string()), + }, + ] + })); + } + messages.push(Message { + role: MessageRole::User, + content, + }); + messages } } } @@ -168,6 +191,54 @@ fn complete_prompt_args(prompt: &str, name: &str) -> String { prompt } +fn parse_structure_prompt(prompt: &str) -> (&str, Vec<(&str, &str)>) { + let mut text = prompt; + let mut search_input = true; + let mut system = None; + let mut parts = vec![]; + loop { + let search = if search_input { + "### INPUT:" + } else { + "### OUTPUT:" + }; + match text.find(search) { + Some(idx) => { + if system.is_none() { + system = Some(&text[..idx]) + } else { + parts.push(&text[..idx]) + } + search_input = !search_input; + text = &text[(idx + search.len())..]; + } + None => { + if !text.is_empty() { + if system.is_none() { + system = Some(text) + } else { + parts.push(text) + } + } + break; + } + } + } + let parts_len = parts.len(); + if parts_len > 0 && parts_len % 2 == 0 { + let cases: Vec<(&str, &str)> = parts + .iter() + .step_by(2) + .zip(parts.iter().skip(1).step_by(2)) + .map(|(i, o)| (i.trim(), o.trim())) + .collect(); + let system = system.map(|v| v.trim()).unwrap_or_default(); + return (system, cases); + } + + (prompt, vec![]) +} + #[cfg(test)] mod tests { use super::*; @@ -183,4 +254,43 @@ mod tests { "convert foo to bar" ); } + + #[test] + fn test_parse_structure_prompt1() { + let prompt = r#" +System message +### INPUT: +Input 1 +### OUTPUT: +Output 1 +"#; + assert_eq!( + parse_structure_prompt(prompt), + ("System message", vec![("Input 1", "Output 1")]) + ); + } + + #[test] + fn test_parse_structure_prompt2() { + let prompt = r#" +### INPUT: +Input 1 +### OUTPUT: +Output 1 +"#; + assert_eq!( + parse_structure_prompt(prompt), + ("", vec![("Input 1", "Output 1")]) + ); + } + + #[test] + fn test_parse_structure_prompt3() { + let prompt = r#" +System message +### INPUT: +Input 1 +"#; + assert_eq!(parse_structure_prompt(prompt), (prompt, vec![])); + } } diff --git a/src/config/session.rs b/src/config/session.rs index e6ad647..8e11d3f 100644 --- a/src/config/session.rs +++ b/src/config/session.rs @@ -323,11 +323,11 @@ impl Session { } pub fn echo_messages(&self, input: &Input) -> String { - let messages = self.build_emssages(input); + let messages = self.build_messages(input); serde_yaml::to_string(&messages).unwrap_or_else(|_| "Unable to echo message".into()) } - pub fn build_emssages(&self, input: &Input) -> Vec { + pub fn build_messages(&self, input: &Input) -> Vec { let mut messages = self.messages.clone(); let mut need_add_msg = true; let len = messages.len();