|
|
|
@ -14,13 +14,13 @@ use serde::Deserialize;
|
|
|
|
|
use serde_json::{json, Value};
|
|
|
|
|
use std::path::PathBuf;
|
|
|
|
|
|
|
|
|
|
const MODELS: [(&str, usize, &str); 2] = [
|
|
|
|
|
const MODELS: [(&str, usize, &str); 5] = [
|
|
|
|
|
// https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models
|
|
|
|
|
("gemini-1.0-pro", 24568, "text"),
|
|
|
|
|
("gemini-1.0-pro-vision", 14336, "text,vision"),
|
|
|
|
|
// ("gemini-1.0-ultra", 8192, "text"),
|
|
|
|
|
// ("gemini-1.0-ultra-vision", 8192, "text,vision"),
|
|
|
|
|
// ("gemini-1.5-pro", 1000000, "text"),
|
|
|
|
|
("gemini-1.0-ultra", 8192, "text"),
|
|
|
|
|
("gemini-1.0-ultra-vision", 8192, "text,vision"),
|
|
|
|
|
("gemini-1.5-pro", 1000000, "text"),
|
|
|
|
|
];
|
|
|
|
|
|
|
|
|
|
const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2);
|
|
|
|
@ -32,6 +32,7 @@ pub struct VertexAIConfig {
|
|
|
|
|
pub name: Option<String>,
|
|
|
|
|
pub api_base: Option<String>,
|
|
|
|
|
pub adc_file: Option<String>,
|
|
|
|
|
pub block_threshold: Option<String>,
|
|
|
|
|
pub extra: Option<ExtraConfig>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -84,7 +85,9 @@ impl VertexAIClient {
|
|
|
|
|
false => "generateContent",
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let body = build_body(data, self.model.name.clone())?;
|
|
|
|
|
let block_threshold = self.config.block_threshold.clone();
|
|
|
|
|
|
|
|
|
|
let body = build_body(data, self.model.name.clone(), block_threshold)?;
|
|
|
|
|
|
|
|
|
|
let model = self.model.name.clone();
|
|
|
|
|
|
|
|
|
@ -106,7 +109,9 @@ impl VertexAIClient {
|
|
|
|
|
let (token, expires_in) = fetch_access_token(&client, &self.config.adc_file)
|
|
|
|
|
.await
|
|
|
|
|
.with_context(|| "Failed to fetch access token")?;
|
|
|
|
|
let expires_at = Utc::now() + Duration::try_seconds(expires_in).ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?;
|
|
|
|
|
let expires_at = Utc::now()
|
|
|
|
|
+ Duration::try_seconds(expires_in)
|
|
|
|
|
.ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?;
|
|
|
|
|
unsafe { ACCESS_TOKEN = (token, expires_at.timestamp()) };
|
|
|
|
|
}
|
|
|
|
|
Ok(())
|
|
|
|
@ -208,7 +213,11 @@ fn check_error(data: &Value) -> Result<()> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub(crate) fn build_body(data: SendData, _model: String) -> Result<Value> {
|
|
|
|
|
pub(crate) fn build_body(
|
|
|
|
|
data: SendData,
|
|
|
|
|
_model: String,
|
|
|
|
|
block_threshold: Option<String>,
|
|
|
|
|
) -> Result<Value> {
|
|
|
|
|
let SendData {
|
|
|
|
|
mut messages,
|
|
|
|
|
temperature,
|
|
|
|
@ -258,15 +267,16 @@ pub(crate) fn build_body(data: SendData, _model: String) -> Result<Value> {
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let mut body = json!({
|
|
|
|
|
"contents": contents,
|
|
|
|
|
"safetySettings":[
|
|
|
|
|
{"category":"HARM_CATEGORY_HARASSMENT","threshold":"BLOCK_ONLY_HIGH"},
|
|
|
|
|
{"category":"HARM_CATEGORY_HATE_SPEECH","threshold":"BLOCK_ONLY_HIGH"},
|
|
|
|
|
{"category":"HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold":"BLOCK_ONLY_HIGH"},
|
|
|
|
|
{"category":"HARM_CATEGORY_DANGEROUS_CONTENT","threshold":"BLOCK_ONLY_HIGH"}
|
|
|
|
|
]
|
|
|
|
|
});
|
|
|
|
|
let mut body = json!({ "contents": contents });
|
|
|
|
|
|
|
|
|
|
if let Some(block_threshold) = block_threshold {
|
|
|
|
|
body["safetySettings"] = json!([
|
|
|
|
|
{"category":"HARM_CATEGORY_HARASSMENT","threshold":block_threshold},
|
|
|
|
|
{"category":"HARM_CATEGORY_HATE_SPEECH","threshold":block_threshold},
|
|
|
|
|
{"category":"HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold":block_threshold},
|
|
|
|
|
{"category":"HARM_CATEGORY_DANGEROUS_CONTENT","threshold":block_threshold}
|
|
|
|
|
]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if let Some(temperature) = temperature {
|
|
|
|
|
body["generationConfig"] = json!({
|
|
|
|
|