feat: support claude-3 (#336)

pull/339/head
sigoden 3 months ago committed by GitHub
parent 3f693ea060
commit be4e5e569a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -3,7 +3,11 @@ use super::{
TokensCountFactors,
};
use crate::{render::ReplyHandler, utils::PromptKind};
use crate::{
client::{ImageUrl, MessageContent, MessageContentPart},
render::ReplyHandler,
utils::PromptKind,
};
use anyhow::{anyhow, bail, Result};
use async_trait::async_trait;
@ -15,7 +19,9 @@ use serde_json::{json, Value};
const API_BASE: &str = "https://api.anthropic.com/v1/messages";
const MODELS: [(&str, usize, &str); 3] = [
const MODELS: [(&str, usize, &str); 5] = [
("claude-3-opus-20240229", 204096, "text,vision"),
("claude-3-sonnet-20240229", 204096, "text,vision"),
("claude-2.1", 204096, "text"),
("claude-2.0", 104096, "text"),
("claude-instant-1.2", 104096, "text"),
@ -72,7 +78,7 @@ impl ClaudeClient {
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();
let body = build_body(data, self.model.name.clone());
let body = build_body(data, self.model.name.clone())?;
let url = API_BASE;
@ -135,7 +141,7 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut ReplyHand
Ok(())
}
fn build_body(data: SendData, model: String) -> Value {
fn build_body(data: SendData, model: String) -> Result<Value> {
let SendData {
mut messages,
temperature,
@ -144,6 +150,51 @@ fn build_body(data: SendData, model: String) -> Value {
patch_system_message(&mut messages);
let mut network_image_urls = vec![];
let messages: Vec<Value> = messages
.into_iter()
.map(|message| {
let role = message.role;
let content = match message.content {
MessageContent::Text(text) => vec![json!({"type": "text", "text": text})],
MessageContent::Array(list) => list
.into_iter()
.map(|item| match item {
MessageContentPart::Text { text } => json!({"type": "text", "text": text}),
MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} => {
if let Some((mime_type, data)) = url
.strip_prefix("data:")
.and_then(|v| v.split_once(";base64,"))
{
json!({
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": data,
}
})
} else {
network_image_urls.push(url.clone());
json!({ "url": url })
}
}
})
.collect(),
};
json!({ "role": role, "content": content })
})
.collect();
if !network_image_urls.is_empty() {
bail!(
"The model does not support network images: {:?}",
network_image_urls
);
}
let mut body = json!({
"model": model,
"max_tokens": 4096,
@ -156,7 +207,7 @@ fn build_body(data: SendData, model: String) -> Value {
if stream {
body["stream"] = true.into();
}
body
Ok(body)
}
fn check_error(data: &Value) -> Result<()> {

Loading…
Cancel
Save