You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
aichat/src/client/prompt_format.rs

151 lines
5.1 KiB
Rust

use super::message::*;
pub struct PromptFormat<'a> {
pub begin: &'a str,
pub system_pre_message: &'a str,
pub system_post_message: &'a str,
pub user_pre_message: &'a str,
pub user_post_message: &'a str,
pub assistant_pre_message: &'a str,
pub assistant_post_message: &'a str,
pub end: &'a str,
}
pub const GENERIC_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
begin: "",
system_pre_message: "",
system_post_message: "\n",
user_pre_message: "### Instruction:\n",
user_post_message: "\n",
assistant_pre_message: "### Response:\n",
assistant_post_message: "\n",
end: "### Response:\n",
};
pub const MISTRAL_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
begin: "",
system_pre_message: "[INST] <<SYS>>",
system_post_message: "<</SYS>> [/INST]",
user_pre_message: "[INST]",
user_post_message: "[/INST]",
assistant_pre_message: "",
assistant_post_message: "",
end: "",
};
pub const LLAMA3_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
begin: "<|begin_of_text|>",
system_pre_message: "<|start_header_id|>system<|end_header_id|>\n\n",
system_post_message: "<|eot_id|>",
user_pre_message: "<|start_header_id|>user<|end_header_id|>\n\n",
user_post_message: "<|eot_id|>",
assistant_pre_message: "<|start_header_id|>assistant<|end_header_id|>\n\n",
assistant_post_message: "<|eot_id|>",
end: "<|start_header_id|>assistant<|end_header_id|>\n\n",
};
pub const PHI3_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
begin: "",
system_pre_message: "<|system|>\n",
system_post_message: "<|end|>\n",
user_pre_message: "<|user|>\n",
user_post_message: "<|end|>\n",
assistant_pre_message: "<|assistant|>\n",
assistant_post_message: "<|end|>\n",
end: "<|assistant|>\n",
};
pub const COMMAND_R_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
begin: "",
system_pre_message: "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>",
system_post_message: "<|END_OF_TURN_TOKEN|>",
user_pre_message: "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>",
user_post_message: "<|END_OF_TURN_TOKEN|>",
assistant_pre_message: "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
assistant_post_message: "<|END_OF_TURN_TOKEN|>",
end: "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
};
pub const QWEN_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
begin: "",
system_pre_message: "<|im_start|>system\n",
system_post_message: "<|im_end|>",
user_pre_message: "<|im_start|>user\n",
user_post_message: "<|im_end|>",
assistant_pre_message: "<|im_start|>assistant\n",
assistant_post_message: "<|im_end|>",
end: "<|im_start|>assistant\n",
};
pub fn generate_prompt(messages: &[Message], format: PromptFormat) -> anyhow::Result<String> {
let PromptFormat {
begin,
system_pre_message,
system_post_message,
user_pre_message,
user_post_message,
assistant_pre_message,
assistant_post_message,
end,
} = format;
let mut prompt = begin.to_string();
let mut image_urls = vec![];
for message in messages {
let role = &message.role;
let content = match &message.content {
MessageContent::Text(text) => text.clone(),
MessageContent::Array(list) => {
let mut parts = vec![];
for item in list {
match item {
MessageContentPart::Text { text } => parts.push(text.clone()),
MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} => {
image_urls.push(url.clone());
}
}
}
parts.join("\n\n")
}
MessageContent::ToolResults(_) => String::new(),
};
match role {
MessageRole::System => prompt.push_str(&format!(
"{system_pre_message}{content}{system_post_message}"
)),
MessageRole::Assistant => prompt.push_str(&format!(
"{assistant_pre_message}{content}{assistant_post_message}"
)),
MessageRole::User => {
prompt.push_str(&format!("{user_pre_message}{content}{user_post_message}"))
}
}
}
if !image_urls.is_empty() {
anyhow::bail!("The model does not support images: {:?}", image_urls);
}
prompt.push_str(end);
Ok(prompt)
}
pub fn smart_prompt_format(model_name: &str) -> PromptFormat<'static> {
if model_name.contains("llama3") || model_name.contains("llama-3") {
LLAMA3_PROMPT_FORMAT
} else if model_name.contains("llama2")
|| model_name.contains("llama-2")
|| model_name.contains("mistral")
|| model_name.contains("mixtral")
{
MISTRAL_PROMPT_FORMAT
} else if model_name.contains("phi3") || model_name.contains("phi-3") {
PHI3_PROMPT_FORMAT
} else if model_name.contains("command-r") {
COMMAND_R_PROMPT_FORMAT
} else if model_name.contains("qwen") {
QWEN_PROMPT_FORMAT
} else {
GENERIC_PROMPT_FORMAT
}
}