From 50eac8b59409f73655fc2b8cf8d783c6347209c3 Mon Sep 17 00:00:00 2001 From: sigoden Date: Tue, 30 Apr 2024 07:07:09 +0800 Subject: [PATCH] feat: support replicate client (#466) --- Argcfile.sh | 41 ++++++++ config.example.yaml | 13 ++- models.yaml | 28 ++++++ src/client/common.rs | 3 + src/client/mod.rs | 1 + src/client/prompt_format.rs | 82 +++++++++++++-- src/client/replicate.rs | 193 ++++++++++++++++++++++++++++++++++++ 7 files changed, 350 insertions(+), 11 deletions(-) create mode 100644 src/client/replicate.rs diff --git a/Argcfile.sh b/Argcfile.sh index 7ba83d3..46ab0ec 100755 --- a/Argcfile.sh +++ b/Argcfile.sh @@ -325,6 +325,47 @@ chat-cloudflare() { }' } +# @cmd Chat with replicate api +# @env REPLICATE_API_KEY! +# @option -m --model=meta/meta-llama-3-8b-instruct $REPLICATE_MODEL +# @flag -S --no-stream +# @arg text~ +chat-replicate() { + url="https://api.replicate.com/v1/models/$argc_model/predictions" + res="$(_wrapper curl -s $DEEPINFRA_CURL_ARGS "$url" \ +-X POST \ +-H "Authorization: Bearer $REPLICATE_API_KEY" \ +-H "Content-Type: application/json" \ +-d '{ + "stream": '$stream', + "input": { + "prompt": "'"$*"'" + } +}')" + echo "$res" + if [[ -n "$argc_no_stream" ]]; then + prediction_url="$(echo "$res" | jq -r '.urls.get')" + while true; do + output="$(_wrapper curl $DEEPINFRA_CURL_ARGS -s -H "Authorization: Bearer $REPLICATE_API_KEY" "$prediction_url")" + prediction_status=$(printf "%s" "$output" | jq -r .status) + if [ "$prediction_status"=="succeeded" ]; then + echo "$output" + break + fi + if [ "$prediction_status"=="failed" ]; then + exit 1 + fi + sleep 2 + done + else + stream_url="$(echo "$res" | jq -r '.urls.stream')" + _wrapper curl -i $DEEPINFRA_CURL_ARGS --no-buffer "$stream_url" \ +-H "Accept: text/event-stream" \ + + fi + +} + # @cmd Chat with ernie api # @meta require-tools jq # @env ERNIE_API_KEY! diff --git a/config.example.yaml b/config.example.yaml index f838d04..ebf59ae 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -101,16 +101,21 @@ clients: # Optional field, possible values: BLOCK_NONE, BLOCK_ONLY_HIGH, BLOCK_MEDIUM_AND_ABOVE, BLOCK_LOW_AND_ABOVE block_threshold: BLOCK_ONLY_HIGH - - type: cloudflare - account_id: xxx # ENV: {client_name}_ACCOUNT_ID - api_key: xxx # ENV: {client_name}_API_KEY - # See https://docs.aws.amazon.com/bedrock/latest/userguide/ - type: bedrock access_key_id: xxx # ENV: {client_name}_ACCESS_KEY_ID secret_access_key: xxx # ENV: {client_name}_SECRET_ACCESS_KEY region: xxx # ENV: {client_name}_REGION + # See https://developers.cloudflare.com/workers-ai/ + - type: cloudflare + account_id: xxx # ENV: {client_name}_ACCOUNT_ID + api_key: xxx # ENV: {client_name}_API_KEY + + # See https://replicate.com/docs + - type: replicate + api_key: xxx # ENV: {client_name}_API_KEY + # See https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html - type: ernie api_key: xxx # ENV: {client_name}_API_KEY diff --git a/models.yaml b/models.yaml index ed53521..8c2fa98 100644 --- a/models.yaml +++ b/models.yaml @@ -367,6 +367,34 @@ input_price: 0.11 output_price: 0.19 +- type: replicate + # docs: + # - https://replicate.com/docs + # - https://replicate.com/pricing + # notes: + # - max_output_tokens is required but unknown + models: + - name: meta/meta-llama-3-70b-instruct + max_input_tokens: 8192 + max_output_tokens: 4096 + input_price: 0.65 + output_price: 2.75 + - name: meta/meta-llama-3-8b-instruct + max_input_tokens: 8192 + max_output_tokens: 4096 + input_price: 0.05 + output_price: 0.25 + - name: mistralai/mistral-7b-instruct-v0.2 + max_input_tokens: 32000 + max_output_tokens: 8192 + input_price: 0.05 + output_price: 0.25 + - name: mistralai/mixtral-8x7b-instruct-v0.1 + max_input_tokens: 32000 + max_output_tokens: 8192 + input_price: 0.3 + output_price: 1 + - type: ernie # docs: # - https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu diff --git a/src/client/common.rs b/src/client/common.rs index 9fb0029..70ffe87 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -520,6 +520,9 @@ pub fn catch_error(data: &Value, status: u16) -> Result<()> { { bail!("{message} (status: {status})") } + } else if let (Some(detail), Some(status)) = (data["detail"].as_str(), data["status"].as_i64()) + { + bail!("{detail} (status: {status})"); } else if let Some(error) = data["error"].as_str() { bail!("{error}"); } else if let Some(message) = data["message"].as_str() { diff --git a/src/client/mod.rs b/src/client/mod.rs index a311efe..82e18a5 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -29,6 +29,7 @@ register_client!( (vertexai, "vertexai", VertexAIConfig, VertexAIClient), (bedrock, "bedrock", BedrockConfig, BedrockClient), (cloudflare, "cloudflare", CloudflareConfig, CloudflareClient), + (replicate, "replicate", ReplicateConfig, ReplicateClient), (ernie, "ernie", ErnieConfig, ErnieClient), (qianwen, "qianwen", QianwenConfig, QianwenClient), (moonshot, "moonshot", MoonshotConfig, MoonshotClient), diff --git a/src/client/prompt_format.rs b/src/client/prompt_format.rs index 8ba3682..fca87ac 100644 --- a/src/client/prompt_format.rs +++ b/src/client/prompt_format.rs @@ -1,46 +1,94 @@ use super::message::*; pub struct PromptFormat<'a> { - pub bos_token: &'a str, + pub begin: &'a str, pub system_pre_message: &'a str, pub system_post_message: &'a str, pub user_pre_message: &'a str, pub user_post_message: &'a str, pub assistant_pre_message: &'a str, pub assistant_post_message: &'a str, + pub end: &'a str, } +pub const GENERIC_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat { + begin: "", + system_pre_message: "### System\n", + system_post_message: "\n", + user_pre_message: "### User\n", + user_post_message: "\n", + assistant_pre_message: "### Assistant\n", + assistant_post_message: "\n", + end: "### Assistant\n", +}; + pub const LLAMA2_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat { - bos_token: "", + begin: "", system_pre_message: "[INST] <>", system_post_message: "<> [/INST]", user_pre_message: "[INST]", user_post_message: "[/INST]", assistant_pre_message: "", - assistant_post_message: "", + assistant_post_message: "", + end: "", }; pub const LLAMA3_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat { - bos_token: "<|begin_of_text|>", + begin: "<|begin_of_text|>", system_pre_message: "<|start_header_id|>system<|end_header_id|>\n\n", system_post_message: "<|eot_id|>", user_pre_message: "<|start_header_id|>user<|end_header_id|>\n\n", user_post_message: "<|eot_id|>", assistant_pre_message: "<|start_header_id|>assistant<|end_header_id|>\n\n", assistant_post_message: "<|eot_id|>", + end: "<|start_header_id|>assistant<|end_header_id|>\n\n", +}; + +pub const PHI3_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat { + begin: "", + system_pre_message: "<|system|>\n", + system_post_message: "<|end|>\n", + user_pre_message: "<|user|>\n", + user_post_message: "<|end|>\n", + assistant_pre_message: "<|assistant|>\n", + assistant_post_message: "<|end|>\n", + end: "<|assistant|>\n", +}; + +pub const COMMAND_R_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat { + begin: "", + system_pre_message: "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>", + system_post_message: "<|END_OF_TURN_TOKEN|>", + user_pre_message: "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", + user_post_message: "<|END_OF_TURN_TOKEN|>", + assistant_pre_message: "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + assistant_post_message: "<|END_OF_TURN_TOKEN|>", + end: "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", +}; + +pub const QWEN_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat { + begin: "", + system_pre_message: "<|im_start|>system\n", + system_post_message: "<|im_end|>", + user_pre_message: "<|im_start|>user\n", + user_post_message: "<|im_end|>", + assistant_pre_message: "<|im_start|>assistant\n", + assistant_post_message: "<|im_end|>", + end: "<|im_start|>assistant\n", }; pub fn generate_prompt(messages: &[Message], format: PromptFormat) -> anyhow::Result { let PromptFormat { - bos_token, + begin, system_pre_message, system_post_message, user_pre_message, user_post_message, assistant_pre_message, assistant_post_message, + end, } = format; - let mut prompt = bos_token.to_string(); + let mut prompt = begin.to_string(); let mut image_urls = vec![]; for message in messages { let role = &message.role; @@ -76,6 +124,26 @@ pub fn generate_prompt(messages: &[Message], format: PromptFormat) -> anyhow::Re if !image_urls.is_empty() { anyhow::bail!("The model does not support images: {:?}", image_urls); } - prompt.push_str(assistant_pre_message); + prompt.push_str(end); Ok(prompt) } + +pub fn smart_prompt_format(model_name: &str) -> PromptFormat<'static> { + if model_name.contains("llama3") || model_name.contains("llama-3") { + LLAMA3_PROMPT_FORMAT + } else if model_name.contains("llama2") + || model_name.contains("llama-2") + || model_name.contains("mistral") + || model_name.contains("mixtral") + { + LLAMA2_PROMPT_FORMAT + } else if model_name.contains("phi3") || model_name.contains("phi-3") { + PHI3_PROMPT_FORMAT + } else if model_name.contains("command-r") { + COMMAND_R_PROMPT_FORMAT + } else if model_name.contains("qwen") { + QWEN_PROMPT_FORMAT + } else { + GENERIC_PROMPT_FORMAT + } +} diff --git a/src/client/replicate.rs b/src/client/replicate.rs new file mode 100644 index 0000000..e399927 --- /dev/null +++ b/src/client/replicate.rs @@ -0,0 +1,193 @@ +use std::time::Duration; + +use super::{ + catch_error, generate_prompt, smart_prompt_format, sse_stream, Client, CompletionDetails, + ExtraConfig, Model, ModelConfig, PromptType, ReplicateClient, SendData, SsMmessage, SseHandler, +}; + +use crate::utils::PromptKind; + +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use reqwest::{Client as ReqwestClient, RequestBuilder}; +use serde::Deserialize; +use serde_json::{json, Value}; + +const API_BASE: &str = "https://api.replicate.com/v1"; + +#[derive(Debug, Clone, Deserialize, Default)] +pub struct ReplicateConfig { + pub name: Option, + pub api_key: Option, + #[serde(default)] + pub models: Vec, + pub extra: Option, +} + +impl ReplicateClient { + config_get_fn!(api_key, get_api_key); + + pub const PROMPTS: [PromptType<'static>; 1] = + [("api_key", "API Key:", true, PromptKind::String)]; + + fn request_builder( + &self, + client: &ReqwestClient, + data: SendData, + api_key: &str, + ) -> Result { + let body = build_body(data, &self.model)?; + + let url = format!("{API_BASE}/models/{}/predictions", self.model.name); + + debug!("Replicate Request: {url} {body}"); + + let builder = client.post(url).bearer_auth(api_key).json(&body); + + Ok(builder) + } +} + +#[async_trait] +impl Client for ReplicateClient { + client_common_fns!(); + + async fn send_message_inner( + &self, + client: &ReqwestClient, + data: SendData, + ) -> Result<(String, CompletionDetails)> { + let api_key = self.get_api_key()?; + let builder = self.request_builder(client, data, &api_key)?; + send_message(client, builder, &api_key).await + } + + async fn send_message_streaming_inner( + &self, + client: &ReqwestClient, + handler: &mut SseHandler, + data: SendData, + ) -> Result<()> { + let api_key = self.get_api_key()?; + let builder = self.request_builder(client, data, &api_key)?; + send_message_streaming(client, builder, handler).await + } +} + +async fn send_message( + client: &ReqwestClient, + builder: RequestBuilder, + api_key: &str, +) -> Result<(String, CompletionDetails)> { + let res = builder.send().await?; + let status = res.status(); + let data: Value = res.json().await?; + if !status.is_success() { + catch_error(&data, status.as_u16())?; + } + let prediction_url = data["urls"]["get"] + .as_str() + .ok_or_else(|| anyhow!("Invalid response data: {data}"))?; + loop { + tokio::time::sleep(Duration::from_millis(500)).await; + let prediction_data: Value = client + .get(prediction_url) + .bearer_auth(api_key) + .send() + .await? + .json() + .await?; + let err = || anyhow!("Invalid response data: {prediction_data}"); + let status = prediction_data["status"].as_str().ok_or_else(err)?; + if status == "succeeded" { + return extract_completion(&prediction_data); + } else if status == "failed" || status == "canceled" { + return Err(err()); + } + } +} + +async fn send_message_streaming( + client: &ReqwestClient, + builder: RequestBuilder, + handler: &mut SseHandler, +) -> Result<()> { + let res = builder.send().await?; + let status = res.status(); + let data: Value = res.json().await?; + if !status.is_success() { + catch_error(&data, status.as_u16())?; + } + let stream_url = data["urls"]["stream"] + .as_str() + .ok_or_else(|| anyhow!("Invalid response data: {data}"))?; + + let sse_builder = client.get(stream_url).header("accept", "text/event-stream"); + + let handle = |message: SsMmessage| -> Result { + if message.event == "done" { + return Ok(true); + } + handler.text(&message.data)?; + Ok(false) + }; + sse_stream(sse_builder, handle).await +} + +fn build_body(data: SendData, model: &Model) -> Result { + let SendData { + messages, + temperature, + top_p, + stream, + } = data; + + let prompt = generate_prompt(&messages, smart_prompt_format(&model.name))?; + + let mut input = json!({ + "prompt": prompt, + "prompt_template": "{prompt}" + }); + + if let Some(v) = model.max_output_tokens { + input["max_tokens"] = v.into(); + input["max_new_tokens"] = v.into(); + } + if let Some(v) = temperature { + input["temperature"] = v.into(); + } + if let Some(v) = top_p { + input["top_p"] = v.into(); + } + + let mut body = json!({ + "input": input, + }); + + if stream { + body["stream"] = true.into(); + } + + Ok(body) +} + +fn extract_completion(data: &Value) -> Result<(String, CompletionDetails)> { + let text = data["output"] + .as_array() + .map(|parts| { + parts + .iter() + .filter_map(|v| v.as_str().map(|v| v.to_string())) + .collect::>() + .join("") + }) + .ok_or_else(|| anyhow!("Invalid response data: {data}"))?; + + let details = CompletionDetails { + id: data["id"].as_str().map(|v| v.to_string()), + input_tokens: data["metrics"]["input_token_count"].as_u64(), + output_tokens: data["metrics"]["output_token_count"].as_u64(), + }; + + Ok((text.to_string(), details)) +}