diff --git a/Argcfile.sh b/Argcfile.sh
index 7ba83d3..46ab0ec 100755
--- a/Argcfile.sh
+++ b/Argcfile.sh
@@ -325,6 +325,47 @@ chat-cloudflare() {
}'
}
+# @cmd Chat with replicate api
+# @env REPLICATE_API_KEY!
+# @option -m --model=meta/meta-llama-3-8b-instruct $REPLICATE_MODEL
+# @flag -S --no-stream
+# @arg text~
+chat-replicate() {
+ url="https://api.replicate.com/v1/models/$argc_model/predictions"
+ res="$(_wrapper curl -s $DEEPINFRA_CURL_ARGS "$url" \
+-X POST \
+-H "Authorization: Bearer $REPLICATE_API_KEY" \
+-H "Content-Type: application/json" \
+-d '{
+ "stream": '$stream',
+ "input": {
+ "prompt": "'"$*"'"
+ }
+}')"
+ echo "$res"
+ if [[ -n "$argc_no_stream" ]]; then
+ prediction_url="$(echo "$res" | jq -r '.urls.get')"
+ while true; do
+ output="$(_wrapper curl $DEEPINFRA_CURL_ARGS -s -H "Authorization: Bearer $REPLICATE_API_KEY" "$prediction_url")"
+ prediction_status=$(printf "%s" "$output" | jq -r .status)
+ if [ "$prediction_status"=="succeeded" ]; then
+ echo "$output"
+ break
+ fi
+ if [ "$prediction_status"=="failed" ]; then
+ exit 1
+ fi
+ sleep 2
+ done
+ else
+ stream_url="$(echo "$res" | jq -r '.urls.stream')"
+ _wrapper curl -i $DEEPINFRA_CURL_ARGS --no-buffer "$stream_url" \
+-H "Accept: text/event-stream" \
+
+ fi
+
+}
+
# @cmd Chat with ernie api
# @meta require-tools jq
# @env ERNIE_API_KEY!
diff --git a/config.example.yaml b/config.example.yaml
index f838d04..ebf59ae 100644
--- a/config.example.yaml
+++ b/config.example.yaml
@@ -101,16 +101,21 @@ clients:
# Optional field, possible values: BLOCK_NONE, BLOCK_ONLY_HIGH, BLOCK_MEDIUM_AND_ABOVE, BLOCK_LOW_AND_ABOVE
block_threshold: BLOCK_ONLY_HIGH
- - type: cloudflare
- account_id: xxx # ENV: {client_name}_ACCOUNT_ID
- api_key: xxx # ENV: {client_name}_API_KEY
-
# See https://docs.aws.amazon.com/bedrock/latest/userguide/
- type: bedrock
access_key_id: xxx # ENV: {client_name}_ACCESS_KEY_ID
secret_access_key: xxx # ENV: {client_name}_SECRET_ACCESS_KEY
region: xxx # ENV: {client_name}_REGION
+ # See https://developers.cloudflare.com/workers-ai/
+ - type: cloudflare
+ account_id: xxx # ENV: {client_name}_ACCOUNT_ID
+ api_key: xxx # ENV: {client_name}_API_KEY
+
+ # See https://replicate.com/docs
+ - type: replicate
+ api_key: xxx # ENV: {client_name}_API_KEY
+
# See https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html
- type: ernie
api_key: xxx # ENV: {client_name}_API_KEY
diff --git a/models.yaml b/models.yaml
index ed53521..8c2fa98 100644
--- a/models.yaml
+++ b/models.yaml
@@ -367,6 +367,34 @@
input_price: 0.11
output_price: 0.19
+- type: replicate
+ # docs:
+ # - https://replicate.com/docs
+ # - https://replicate.com/pricing
+ # notes:
+ # - max_output_tokens is required but unknown
+ models:
+ - name: meta/meta-llama-3-70b-instruct
+ max_input_tokens: 8192
+ max_output_tokens: 4096
+ input_price: 0.65
+ output_price: 2.75
+ - name: meta/meta-llama-3-8b-instruct
+ max_input_tokens: 8192
+ max_output_tokens: 4096
+ input_price: 0.05
+ output_price: 0.25
+ - name: mistralai/mistral-7b-instruct-v0.2
+ max_input_tokens: 32000
+ max_output_tokens: 8192
+ input_price: 0.05
+ output_price: 0.25
+ - name: mistralai/mixtral-8x7b-instruct-v0.1
+ max_input_tokens: 32000
+ max_output_tokens: 8192
+ input_price: 0.3
+ output_price: 1
+
- type: ernie
# docs:
# - https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
diff --git a/src/client/common.rs b/src/client/common.rs
index 9fb0029..70ffe87 100644
--- a/src/client/common.rs
+++ b/src/client/common.rs
@@ -520,6 +520,9 @@ pub fn catch_error(data: &Value, status: u16) -> Result<()> {
{
bail!("{message} (status: {status})")
}
+ } else if let (Some(detail), Some(status)) = (data["detail"].as_str(), data["status"].as_i64())
+ {
+ bail!("{detail} (status: {status})");
} else if let Some(error) = data["error"].as_str() {
bail!("{error}");
} else if let Some(message) = data["message"].as_str() {
diff --git a/src/client/mod.rs b/src/client/mod.rs
index a311efe..82e18a5 100644
--- a/src/client/mod.rs
+++ b/src/client/mod.rs
@@ -29,6 +29,7 @@ register_client!(
(vertexai, "vertexai", VertexAIConfig, VertexAIClient),
(bedrock, "bedrock", BedrockConfig, BedrockClient),
(cloudflare, "cloudflare", CloudflareConfig, CloudflareClient),
+ (replicate, "replicate", ReplicateConfig, ReplicateClient),
(ernie, "ernie", ErnieConfig, ErnieClient),
(qianwen, "qianwen", QianwenConfig, QianwenClient),
(moonshot, "moonshot", MoonshotConfig, MoonshotClient),
diff --git a/src/client/prompt_format.rs b/src/client/prompt_format.rs
index 8ba3682..fca87ac 100644
--- a/src/client/prompt_format.rs
+++ b/src/client/prompt_format.rs
@@ -1,46 +1,94 @@
use super::message::*;
pub struct PromptFormat<'a> {
- pub bos_token: &'a str,
+ pub begin: &'a str,
pub system_pre_message: &'a str,
pub system_post_message: &'a str,
pub user_pre_message: &'a str,
pub user_post_message: &'a str,
pub assistant_pre_message: &'a str,
pub assistant_post_message: &'a str,
+ pub end: &'a str,
}
+pub const GENERIC_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
+ begin: "",
+ system_pre_message: "### System\n",
+ system_post_message: "\n",
+ user_pre_message: "### User\n",
+ user_post_message: "\n",
+ assistant_pre_message: "### Assistant\n",
+ assistant_post_message: "\n",
+ end: "### Assistant\n",
+};
+
pub const LLAMA2_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
- bos_token: "",
+ begin: "",
system_pre_message: "[INST] <>",
system_post_message: "<> [/INST]",
user_pre_message: "[INST]",
user_post_message: "[/INST]",
assistant_pre_message: "",
- assistant_post_message: "",
+ assistant_post_message: "",
+ end: "",
};
pub const LLAMA3_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
- bos_token: "<|begin_of_text|>",
+ begin: "<|begin_of_text|>",
system_pre_message: "<|start_header_id|>system<|end_header_id|>\n\n",
system_post_message: "<|eot_id|>",
user_pre_message: "<|start_header_id|>user<|end_header_id|>\n\n",
user_post_message: "<|eot_id|>",
assistant_pre_message: "<|start_header_id|>assistant<|end_header_id|>\n\n",
assistant_post_message: "<|eot_id|>",
+ end: "<|start_header_id|>assistant<|end_header_id|>\n\n",
+};
+
+pub const PHI3_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
+ begin: "",
+ system_pre_message: "<|system|>\n",
+ system_post_message: "<|end|>\n",
+ user_pre_message: "<|user|>\n",
+ user_post_message: "<|end|>\n",
+ assistant_pre_message: "<|assistant|>\n",
+ assistant_post_message: "<|end|>\n",
+ end: "<|assistant|>\n",
+};
+
+pub const COMMAND_R_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
+ begin: "",
+ system_pre_message: "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>",
+ system_post_message: "<|END_OF_TURN_TOKEN|>",
+ user_pre_message: "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>",
+ user_post_message: "<|END_OF_TURN_TOKEN|>",
+ assistant_pre_message: "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
+ assistant_post_message: "<|END_OF_TURN_TOKEN|>",
+ end: "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
+};
+
+pub const QWEN_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
+ begin: "",
+ system_pre_message: "<|im_start|>system\n",
+ system_post_message: "<|im_end|>",
+ user_pre_message: "<|im_start|>user\n",
+ user_post_message: "<|im_end|>",
+ assistant_pre_message: "<|im_start|>assistant\n",
+ assistant_post_message: "<|im_end|>",
+ end: "<|im_start|>assistant\n",
};
pub fn generate_prompt(messages: &[Message], format: PromptFormat) -> anyhow::Result {
let PromptFormat {
- bos_token,
+ begin,
system_pre_message,
system_post_message,
user_pre_message,
user_post_message,
assistant_pre_message,
assistant_post_message,
+ end,
} = format;
- let mut prompt = bos_token.to_string();
+ let mut prompt = begin.to_string();
let mut image_urls = vec![];
for message in messages {
let role = &message.role;
@@ -76,6 +124,26 @@ pub fn generate_prompt(messages: &[Message], format: PromptFormat) -> anyhow::Re
if !image_urls.is_empty() {
anyhow::bail!("The model does not support images: {:?}", image_urls);
}
- prompt.push_str(assistant_pre_message);
+ prompt.push_str(end);
Ok(prompt)
}
+
+pub fn smart_prompt_format(model_name: &str) -> PromptFormat<'static> {
+ if model_name.contains("llama3") || model_name.contains("llama-3") {
+ LLAMA3_PROMPT_FORMAT
+ } else if model_name.contains("llama2")
+ || model_name.contains("llama-2")
+ || model_name.contains("mistral")
+ || model_name.contains("mixtral")
+ {
+ LLAMA2_PROMPT_FORMAT
+ } else if model_name.contains("phi3") || model_name.contains("phi-3") {
+ PHI3_PROMPT_FORMAT
+ } else if model_name.contains("command-r") {
+ COMMAND_R_PROMPT_FORMAT
+ } else if model_name.contains("qwen") {
+ QWEN_PROMPT_FORMAT
+ } else {
+ GENERIC_PROMPT_FORMAT
+ }
+}
diff --git a/src/client/replicate.rs b/src/client/replicate.rs
new file mode 100644
index 0000000..e399927
--- /dev/null
+++ b/src/client/replicate.rs
@@ -0,0 +1,193 @@
+use std::time::Duration;
+
+use super::{
+ catch_error, generate_prompt, smart_prompt_format, sse_stream, Client, CompletionDetails,
+ ExtraConfig, Model, ModelConfig, PromptType, ReplicateClient, SendData, SsMmessage, SseHandler,
+};
+
+use crate::utils::PromptKind;
+
+use anyhow::{anyhow, Result};
+use async_trait::async_trait;
+use reqwest::{Client as ReqwestClient, RequestBuilder};
+use serde::Deserialize;
+use serde_json::{json, Value};
+
+const API_BASE: &str = "https://api.replicate.com/v1";
+
+#[derive(Debug, Clone, Deserialize, Default)]
+pub struct ReplicateConfig {
+ pub name: Option,
+ pub api_key: Option,
+ #[serde(default)]
+ pub models: Vec,
+ pub extra: Option,
+}
+
+impl ReplicateClient {
+ config_get_fn!(api_key, get_api_key);
+
+ pub const PROMPTS: [PromptType<'static>; 1] =
+ [("api_key", "API Key:", true, PromptKind::String)];
+
+ fn request_builder(
+ &self,
+ client: &ReqwestClient,
+ data: SendData,
+ api_key: &str,
+ ) -> Result {
+ let body = build_body(data, &self.model)?;
+
+ let url = format!("{API_BASE}/models/{}/predictions", self.model.name);
+
+ debug!("Replicate Request: {url} {body}");
+
+ let builder = client.post(url).bearer_auth(api_key).json(&body);
+
+ Ok(builder)
+ }
+}
+
+#[async_trait]
+impl Client for ReplicateClient {
+ client_common_fns!();
+
+ async fn send_message_inner(
+ &self,
+ client: &ReqwestClient,
+ data: SendData,
+ ) -> Result<(String, CompletionDetails)> {
+ let api_key = self.get_api_key()?;
+ let builder = self.request_builder(client, data, &api_key)?;
+ send_message(client, builder, &api_key).await
+ }
+
+ async fn send_message_streaming_inner(
+ &self,
+ client: &ReqwestClient,
+ handler: &mut SseHandler,
+ data: SendData,
+ ) -> Result<()> {
+ let api_key = self.get_api_key()?;
+ let builder = self.request_builder(client, data, &api_key)?;
+ send_message_streaming(client, builder, handler).await
+ }
+}
+
+async fn send_message(
+ client: &ReqwestClient,
+ builder: RequestBuilder,
+ api_key: &str,
+) -> Result<(String, CompletionDetails)> {
+ let res = builder.send().await?;
+ let status = res.status();
+ let data: Value = res.json().await?;
+ if !status.is_success() {
+ catch_error(&data, status.as_u16())?;
+ }
+ let prediction_url = data["urls"]["get"]
+ .as_str()
+ .ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
+ loop {
+ tokio::time::sleep(Duration::from_millis(500)).await;
+ let prediction_data: Value = client
+ .get(prediction_url)
+ .bearer_auth(api_key)
+ .send()
+ .await?
+ .json()
+ .await?;
+ let err = || anyhow!("Invalid response data: {prediction_data}");
+ let status = prediction_data["status"].as_str().ok_or_else(err)?;
+ if status == "succeeded" {
+ return extract_completion(&prediction_data);
+ } else if status == "failed" || status == "canceled" {
+ return Err(err());
+ }
+ }
+}
+
+async fn send_message_streaming(
+ client: &ReqwestClient,
+ builder: RequestBuilder,
+ handler: &mut SseHandler,
+) -> Result<()> {
+ let res = builder.send().await?;
+ let status = res.status();
+ let data: Value = res.json().await?;
+ if !status.is_success() {
+ catch_error(&data, status.as_u16())?;
+ }
+ let stream_url = data["urls"]["stream"]
+ .as_str()
+ .ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
+
+ let sse_builder = client.get(stream_url).header("accept", "text/event-stream");
+
+ let handle = |message: SsMmessage| -> Result {
+ if message.event == "done" {
+ return Ok(true);
+ }
+ handler.text(&message.data)?;
+ Ok(false)
+ };
+ sse_stream(sse_builder, handle).await
+}
+
+fn build_body(data: SendData, model: &Model) -> Result {
+ let SendData {
+ messages,
+ temperature,
+ top_p,
+ stream,
+ } = data;
+
+ let prompt = generate_prompt(&messages, smart_prompt_format(&model.name))?;
+
+ let mut input = json!({
+ "prompt": prompt,
+ "prompt_template": "{prompt}"
+ });
+
+ if let Some(v) = model.max_output_tokens {
+ input["max_tokens"] = v.into();
+ input["max_new_tokens"] = v.into();
+ }
+ if let Some(v) = temperature {
+ input["temperature"] = v.into();
+ }
+ if let Some(v) = top_p {
+ input["top_p"] = v.into();
+ }
+
+ let mut body = json!({
+ "input": input,
+ });
+
+ if stream {
+ body["stream"] = true.into();
+ }
+
+ Ok(body)
+}
+
+fn extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
+ let text = data["output"]
+ .as_array()
+ .map(|parts| {
+ parts
+ .iter()
+ .filter_map(|v| v.as_str().map(|v| v.to_string()))
+ .collect::>()
+ .join("")
+ })
+ .ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
+
+ let details = CompletionDetails {
+ id: data["id"].as_str().map(|v| v.to_string()),
+ input_tokens: data["metrics"]["input_token_count"].as_u64(),
+ output_tokens: data["metrics"]["output_token_count"].as_u64(),
+ };
+
+ Ok((text.to_string(), details))
+}