feat: qianwen vision models support embeded images (#277)

pull/282/head
sigoden 6 months ago committed by GitHub
parent 64c4edf7c8
commit 6280f5ab4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

1
Cargo.lock generated

@ -1446,6 +1446,7 @@ dependencies = [
"js-sys",
"log",
"mime",
"mime_guess",
"once_cell",
"percent-encoding",
"pin-project-lite",

@ -46,7 +46,7 @@ sha2 = "0.10.8"
[dependencies.reqwest]
version = "0.11.14"
features = ["json", "socks", "rustls-tls", "rustls-tls-native-roots"]
features = ["json", "multipart", "socks", "rustls-tls", "rustls-tls-native-roots"]
default-features = false
[dependencies.syntect]

@ -1,14 +1,23 @@
use super::{message::*, Client, ExtraConfig, Model, PromptType, QianwenClient, SendData};
use crate::{config::GlobalConfig, render::ReplyHandler, utils::PromptKind};
use crate::{
config::GlobalConfig,
render::ReplyHandler,
utils::{sha256sum, PromptKind},
};
use anyhow::{anyhow, bail, Result};
use anyhow::{anyhow, bail, Context, Result};
use async_trait::async_trait;
use base64::{engine::general_purpose::STANDARD, Engine};
use futures_util::StreamExt;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use reqwest::{
multipart::{Form, Part},
Client as ReqwestClient, RequestBuilder,
};
use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt};
use serde::Deserialize;
use serde_json::{json, Value};
use std::borrow::BorrowMut;
const API_URL: &str =
"https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation";
@ -37,7 +46,13 @@ impl Client for QianwenClient {
(&self.global_config, &self.config.extra)
}
async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
async fn send_message_inner(
&self,
client: &ReqwestClient,
mut data: SendData,
) -> Result<String> {
let api_key = self.get_api_key()?;
patch_messages(&self.model.name, &api_key, &mut data.messages).await?;
let builder = self.request_builder(client, data)?;
send_message(builder, self.is_vl()).await
}
@ -46,8 +61,10 @@ impl Client for QianwenClient {
&self,
client: &ReqwestClient,
handler: &mut ReplyHandler,
data: SendData,
mut data: SendData,
) -> Result<()> {
let api_key = self.get_api_key()?;
patch_messages(&self.model.name, &api_key, &mut data.messages).await?;
let builder = self.request_builder(client, data)?;
send_message_streaming(builder, handler, self.is_vl()).await
}
@ -79,7 +96,7 @@ impl QianwenClient {
true => API_URL_VL,
false => API_URL,
};
let body = build_body(data, self.model.name.clone(), is_vl)?;
let (body, has_upload) = build_body(data, self.model.name.clone(), is_vl)?;
debug!("Qianwen Request: {url} {body}");
@ -87,6 +104,9 @@ impl QianwenClient {
if stream {
builder = builder.header("X-DashScope-SSE", "enable");
}
if has_upload {
builder = builder.header("X-DashScope-OssResourceResolve", "enable");
}
Ok(builder)
}
@ -126,7 +146,8 @@ async fn send_message_streaming(
let data: Value = serde_json::from_str(&message.data)?;
check_error(&data)?;
if is_vl {
let text = data["output"]["choices"][0]["message"]["content"][0]["text"].as_str();
let text =
data["output"]["choices"][0]["message"]["content"][0]["text"].as_str();
if let Some(text) = text {
let text = &text[offset..];
handler.text(text)?;
@ -158,17 +179,15 @@ fn check_error(data: &Value) -> Result<()> {
Ok(())
}
fn build_body(data: SendData, model: String, is_vl: bool) -> Result<Value> {
fn build_body(data: SendData, model: String, is_vl: bool) -> Result<(Value, bool)> {
let SendData {
messages,
temperature,
stream,
} = data;
let mut has_upload = false;
let (input, parameters) = if is_vl {
let mut exist_embeded_image = false;
let messages: Vec<Value> = messages
.into_iter()
.map(|message| {
@ -182,11 +201,11 @@ fn build_body(data: SendData, model: String, is_vl: bool) -> Result<Value> {
MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} => {
if url.starts_with("data:") {
exist_embeded_image = true;
if url.starts_with("oss:") {
has_upload = true;
}
json!({"image": url})
},
}
})
.collect(),
};
@ -194,10 +213,6 @@ fn build_body(data: SendData, model: String, is_vl: bool) -> Result<Value> {
})
.collect();
if exist_embeded_image {
bail!("The model does not support embeded images");
}
let input = json!({
"messages": messages,
});
@ -228,5 +243,99 @@ fn build_body(data: SendData, model: String, is_vl: bool) -> Result<Value> {
"input": input,
"parameters": parameters
});
Ok(body)
Ok((body, has_upload))
}
/// Patch messsages, upload emebeded images to oss
async fn patch_messages(model: &str, api_key: &str, messages: &mut Vec<Message>) -> Result<()> {
for message in messages {
if let MessageContent::Array(list) = message.content.borrow_mut() {
for item in list {
if let MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} = item
{
if url.starts_with("data:") {
*url = upload(model, api_key, url)
.await
.with_context(|| "Failed to upload embeded image to oss")?;
}
}
}
}
}
Ok(())
}
#[derive(Debug, Deserialize)]
struct Policy {
data: PolicyData,
}
#[derive(Debug, Deserialize)]
struct PolicyData {
policy: String,
signature: String,
upload_dir: String,
upload_host: String,
oss_access_key_id: String,
x_oss_object_acl: String,
x_oss_forbid_overwrite: String,
}
/// Upload image to dashscope
async fn upload(model: &str, api_key: &str, url: &str) -> Result<String> {
let (mime_type, data) = url
.strip_prefix("data:")
.and_then(|v| v.split_once(";base64,"))
.ok_or_else(|| anyhow!("Invalid image url"))?;
let mut name = sha256sum(data);
if let Some(ext) = mime_type.strip_prefix("image/") {
name.push('.');
name.push_str(ext);
}
let data = STANDARD.decode(data)?;
let client = reqwest::Client::new();
let policy: Policy = client
.get(format!(
"https://dashscope.aliyuncs.com/api/v1/uploads?action=getPolicy&model={model}"
))
.header("Authorization", format!("Bearer {api_key}"))
.send()
.await?
.json()
.await?;
let PolicyData {
policy,
signature,
upload_dir,
upload_host,
oss_access_key_id,
x_oss_object_acl,
x_oss_forbid_overwrite,
..
} = policy.data;
let key = format!("{upload_dir}/{name}");
let file = Part::bytes(data).file_name(name).mime_str(mime_type)?;
let form = Form::new()
.text("OSSAccessKeyId", oss_access_key_id)
.text("Signature", signature)
.text("policy", policy)
.text("key", key.clone())
.text("x-oss-object-acl", x_oss_object_acl)
.text("x-oss-forbid-overwrite", x_oss_forbid_overwrite)
.text("success_action_status", "200")
.text("x-oss-content-type", mime_type.to_string())
.part("file", file);
let res = client.post(upload_host).multipart(form).send().await?;
let status = res.status();
if res.status() != 200 {
let text = res.text().await?;
bail!("{status}, {text}")
}
Ok(format!("oss://{key}"))
}

@ -3,6 +3,8 @@ use crate::utils::sha256sum;
use anyhow::{bail, Context, Result};
use base64::{self, engine::general_purpose::STANDARD, Engine};
use fancy_regex::Regex;
use lazy_static::lazy_static;
use mime_guess::from_path;
use std::{
collections::HashMap,
@ -13,6 +15,10 @@ use std::{
const IMAGE_EXTS: [&str; 5] = ["png", "jpeg", "jpg", "webp", "gif"];
lazy_static! {
static ref URL_RE: Regex = Regex::new(r"^[A-Za-z0-9_-]{2,}:/").unwrap();
}
#[derive(Debug, Clone)]
pub struct Input {
text: String,
@ -128,10 +134,7 @@ pub fn resolve_data_url(data_urls: &HashMap<String, String>, data_url: String) -
}
fn resolve_path(file: &str) -> Option<PathBuf> {
if ["https://", "http://", "data:"]
.iter()
.any(|v| file.starts_with(v))
{
if let Ok(true) = URL_RE.is_match(file) {
return None;
}
let path = if let (Some(file), Some(home)) = (file.strip_prefix('~'), dirs::home_dir()) {

Loading…
Cancel
Save