feat: support replicate client (#466)

pull/467/head
sigoden 3 weeks ago committed by GitHub
parent 4d4a100fe6
commit 50eac8b594
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -325,6 +325,47 @@ chat-cloudflare() {
}'
}
# @cmd Chat with replicate api
# @env REPLICATE_API_KEY!
# @option -m --model=meta/meta-llama-3-8b-instruct $REPLICATE_MODEL
# @flag -S --no-stream
# @arg text~
chat-replicate() {
url="https://api.replicate.com/v1/models/$argc_model/predictions"
res="$(_wrapper curl -s $DEEPINFRA_CURL_ARGS "$url" \
-X POST \
-H "Authorization: Bearer $REPLICATE_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"stream": '$stream',
"input": {
"prompt": "'"$*"'"
}
}')"
echo "$res"
if [[ -n "$argc_no_stream" ]]; then
prediction_url="$(echo "$res" | jq -r '.urls.get')"
while true; do
output="$(_wrapper curl $DEEPINFRA_CURL_ARGS -s -H "Authorization: Bearer $REPLICATE_API_KEY" "$prediction_url")"
prediction_status=$(printf "%s" "$output" | jq -r .status)
if [ "$prediction_status"=="succeeded" ]; then
echo "$output"
break
fi
if [ "$prediction_status"=="failed" ]; then
exit 1
fi
sleep 2
done
else
stream_url="$(echo "$res" | jq -r '.urls.stream')"
_wrapper curl -i $DEEPINFRA_CURL_ARGS --no-buffer "$stream_url" \
-H "Accept: text/event-stream" \
fi
}
# @cmd Chat with ernie api
# @meta require-tools jq
# @env ERNIE_API_KEY!

@ -101,16 +101,21 @@ clients:
# Optional field, possible values: BLOCK_NONE, BLOCK_ONLY_HIGH, BLOCK_MEDIUM_AND_ABOVE, BLOCK_LOW_AND_ABOVE
block_threshold: BLOCK_ONLY_HIGH
- type: cloudflare
account_id: xxx # ENV: {client_name}_ACCOUNT_ID
api_key: xxx # ENV: {client_name}_API_KEY
# See https://docs.aws.amazon.com/bedrock/latest/userguide/
- type: bedrock
access_key_id: xxx # ENV: {client_name}_ACCESS_KEY_ID
secret_access_key: xxx # ENV: {client_name}_SECRET_ACCESS_KEY
region: xxx # ENV: {client_name}_REGION
# See https://developers.cloudflare.com/workers-ai/
- type: cloudflare
account_id: xxx # ENV: {client_name}_ACCOUNT_ID
api_key: xxx # ENV: {client_name}_API_KEY
# See https://replicate.com/docs
- type: replicate
api_key: xxx # ENV: {client_name}_API_KEY
# See https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html
- type: ernie
api_key: xxx # ENV: {client_name}_API_KEY

@ -367,6 +367,34 @@
input_price: 0.11
output_price: 0.19
- type: replicate
# docs:
# - https://replicate.com/docs
# - https://replicate.com/pricing
# notes:
# - max_output_tokens is required but unknown
models:
- name: meta/meta-llama-3-70b-instruct
max_input_tokens: 8192
max_output_tokens: 4096
input_price: 0.65
output_price: 2.75
- name: meta/meta-llama-3-8b-instruct
max_input_tokens: 8192
max_output_tokens: 4096
input_price: 0.05
output_price: 0.25
- name: mistralai/mistral-7b-instruct-v0.2
max_input_tokens: 32000
max_output_tokens: 8192
input_price: 0.05
output_price: 0.25
- name: mistralai/mixtral-8x7b-instruct-v0.1
max_input_tokens: 32000
max_output_tokens: 8192
input_price: 0.3
output_price: 1
- type: ernie
# docs:
# - https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu

@ -520,6 +520,9 @@ pub fn catch_error(data: &Value, status: u16) -> Result<()> {
{
bail!("{message} (status: {status})")
}
} else if let (Some(detail), Some(status)) = (data["detail"].as_str(), data["status"].as_i64())
{
bail!("{detail} (status: {status})");
} else if let Some(error) = data["error"].as_str() {
bail!("{error}");
} else if let Some(message) = data["message"].as_str() {

@ -29,6 +29,7 @@ register_client!(
(vertexai, "vertexai", VertexAIConfig, VertexAIClient),
(bedrock, "bedrock", BedrockConfig, BedrockClient),
(cloudflare, "cloudflare", CloudflareConfig, CloudflareClient),
(replicate, "replicate", ReplicateConfig, ReplicateClient),
(ernie, "ernie", ErnieConfig, ErnieClient),
(qianwen, "qianwen", QianwenConfig, QianwenClient),
(moonshot, "moonshot", MoonshotConfig, MoonshotClient),

@ -1,46 +1,94 @@
use super::message::*;
pub struct PromptFormat<'a> {
pub bos_token: &'a str,
pub begin: &'a str,
pub system_pre_message: &'a str,
pub system_post_message: &'a str,
pub user_pre_message: &'a str,
pub user_post_message: &'a str,
pub assistant_pre_message: &'a str,
pub assistant_post_message: &'a str,
pub end: &'a str,
}
pub const GENERIC_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
begin: "",
system_pre_message: "### System\n",
system_post_message: "\n",
user_pre_message: "### User\n",
user_post_message: "\n",
assistant_pre_message: "### Assistant\n",
assistant_post_message: "\n",
end: "### Assistant\n",
};
pub const LLAMA2_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
bos_token: "<s>",
begin: "",
system_pre_message: "[INST] <<SYS>>",
system_post_message: "<</SYS>> [/INST]",
user_pre_message: "[INST]",
user_post_message: "[/INST]",
assistant_pre_message: "",
assistant_post_message: "</s>",
assistant_post_message: "",
end: "",
};
pub const LLAMA3_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
bos_token: "<|begin_of_text|>",
begin: "<|begin_of_text|>",
system_pre_message: "<|start_header_id|>system<|end_header_id|>\n\n",
system_post_message: "<|eot_id|>",
user_pre_message: "<|start_header_id|>user<|end_header_id|>\n\n",
user_post_message: "<|eot_id|>",
assistant_pre_message: "<|start_header_id|>assistant<|end_header_id|>\n\n",
assistant_post_message: "<|eot_id|>",
end: "<|start_header_id|>assistant<|end_header_id|>\n\n",
};
pub const PHI3_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
begin: "",
system_pre_message: "<|system|>\n",
system_post_message: "<|end|>\n",
user_pre_message: "<|user|>\n",
user_post_message: "<|end|>\n",
assistant_pre_message: "<|assistant|>\n",
assistant_post_message: "<|end|>\n",
end: "<|assistant|>\n",
};
pub const COMMAND_R_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
begin: "",
system_pre_message: "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>",
system_post_message: "<|END_OF_TURN_TOKEN|>",
user_pre_message: "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>",
user_post_message: "<|END_OF_TURN_TOKEN|>",
assistant_pre_message: "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
assistant_post_message: "<|END_OF_TURN_TOKEN|>",
end: "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
};
pub const QWEN_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
begin: "",
system_pre_message: "<|im_start|>system\n",
system_post_message: "<|im_end|>",
user_pre_message: "<|im_start|>user\n",
user_post_message: "<|im_end|>",
assistant_pre_message: "<|im_start|>assistant\n",
assistant_post_message: "<|im_end|>",
end: "<|im_start|>assistant\n",
};
pub fn generate_prompt(messages: &[Message], format: PromptFormat) -> anyhow::Result<String> {
let PromptFormat {
bos_token,
begin,
system_pre_message,
system_post_message,
user_pre_message,
user_post_message,
assistant_pre_message,
assistant_post_message,
end,
} = format;
let mut prompt = bos_token.to_string();
let mut prompt = begin.to_string();
let mut image_urls = vec![];
for message in messages {
let role = &message.role;
@ -76,6 +124,26 @@ pub fn generate_prompt(messages: &[Message], format: PromptFormat) -> anyhow::Re
if !image_urls.is_empty() {
anyhow::bail!("The model does not support images: {:?}", image_urls);
}
prompt.push_str(assistant_pre_message);
prompt.push_str(end);
Ok(prompt)
}
pub fn smart_prompt_format(model_name: &str) -> PromptFormat<'static> {
if model_name.contains("llama3") || model_name.contains("llama-3") {
LLAMA3_PROMPT_FORMAT
} else if model_name.contains("llama2")
|| model_name.contains("llama-2")
|| model_name.contains("mistral")
|| model_name.contains("mixtral")
{
LLAMA2_PROMPT_FORMAT
} else if model_name.contains("phi3") || model_name.contains("phi-3") {
PHI3_PROMPT_FORMAT
} else if model_name.contains("command-r") {
COMMAND_R_PROMPT_FORMAT
} else if model_name.contains("qwen") {
QWEN_PROMPT_FORMAT
} else {
GENERIC_PROMPT_FORMAT
}
}

@ -0,0 +1,193 @@
use std::time::Duration;
use super::{
catch_error, generate_prompt, smart_prompt_format, sse_stream, Client, CompletionDetails,
ExtraConfig, Model, ModelConfig, PromptType, ReplicateClient, SendData, SsMmessage, SseHandler,
};
use crate::utils::PromptKind;
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
const API_BASE: &str = "https://api.replicate.com/v1";
#[derive(Debug, Clone, Deserialize, Default)]
pub struct ReplicateConfig {
pub name: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
}
impl ReplicateClient {
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
fn request_builder(
&self,
client: &ReqwestClient,
data: SendData,
api_key: &str,
) -> Result<RequestBuilder> {
let body = build_body(data, &self.model)?;
let url = format!("{API_BASE}/models/{}/predictions", self.model.name);
debug!("Replicate Request: {url} {body}");
let builder = client.post(url).bearer_auth(api_key).json(&body);
Ok(builder)
}
}
#[async_trait]
impl Client for ReplicateClient {
client_common_fns!();
async fn send_message_inner(
&self,
client: &ReqwestClient,
data: SendData,
) -> Result<(String, CompletionDetails)> {
let api_key = self.get_api_key()?;
let builder = self.request_builder(client, data, &api_key)?;
send_message(client, builder, &api_key).await
}
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: SendData,
) -> Result<()> {
let api_key = self.get_api_key()?;
let builder = self.request_builder(client, data, &api_key)?;
send_message_streaming(client, builder, handler).await
}
}
async fn send_message(
client: &ReqwestClient,
builder: RequestBuilder,
api_key: &str,
) -> Result<(String, CompletionDetails)> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
let prediction_url = data["urls"]["get"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
loop {
tokio::time::sleep(Duration::from_millis(500)).await;
let prediction_data: Value = client
.get(prediction_url)
.bearer_auth(api_key)
.send()
.await?
.json()
.await?;
let err = || anyhow!("Invalid response data: {prediction_data}");
let status = prediction_data["status"].as_str().ok_or_else(err)?;
if status == "succeeded" {
return extract_completion(&prediction_data);
} else if status == "failed" || status == "canceled" {
return Err(err());
}
}
}
async fn send_message_streaming(
client: &ReqwestClient,
builder: RequestBuilder,
handler: &mut SseHandler,
) -> Result<()> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
let stream_url = data["urls"]["stream"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let sse_builder = client.get(stream_url).header("accept", "text/event-stream");
let handle = |message: SsMmessage| -> Result<bool> {
if message.event == "done" {
return Ok(true);
}
handler.text(&message.data)?;
Ok(false)
};
sse_stream(sse_builder, handle).await
}
fn build_body(data: SendData, model: &Model) -> Result<Value> {
let SendData {
messages,
temperature,
top_p,
stream,
} = data;
let prompt = generate_prompt(&messages, smart_prompt_format(&model.name))?;
let mut input = json!({
"prompt": prompt,
"prompt_template": "{prompt}"
});
if let Some(v) = model.max_output_tokens {
input["max_tokens"] = v.into();
input["max_new_tokens"] = v.into();
}
if let Some(v) = temperature {
input["temperature"] = v.into();
}
if let Some(v) = top_p {
input["top_p"] = v.into();
}
let mut body = json!({
"input": input,
});
if stream {
body["stream"] = true.into();
}
Ok(body)
}
fn extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
let text = data["output"]
.as_array()
.map(|parts| {
parts
.iter()
.filter_map(|v| v.as_str().map(|v| v.to_string()))
.collect::<Vec<String>>()
.join("")
})
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let details = CompletionDetails {
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["metrics"]["input_token_count"].as_u64(),
output_tokens: data["metrics"]["output_token_count"].as_u64(),
};
Ok((text.to_string(), details))
}
Loading…
Cancel
Save