feat: support customizing gemini safeSettings (#375)

pull/377/head
sigoden 2 months ago committed by GitHub
parent 0ebc7955da
commit 7f05dc1a4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -95,10 +95,11 @@ clients:
- type: openai
api_key: sk-xxx
- type: localai
api_base: http://localhost:8080/v1
- type: openai-compatible
name: localai
api_base: http://127.0.0.1:8080/v1
models:
- name: gpt4all-j
- name: llama2
max_input_tokens: 8192
```

@ -37,6 +37,7 @@ clients:
# See https://ai.google.dev/docs
- type: gemini
api_key: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
block_threshold: BLOCK_NONE # Optional field, choices: BLOCK_NONE, BLOCK_ONLY_HIGH, BLOCK_MEDIUM_AND_ABOVE, BLOCK_LOW_AND_ABOVE
# See https://docs.anthropic.com/claude/reference/getting-started-with-the-api
- type: claude
@ -58,7 +59,7 @@ clients:
key: value
- name: llava
max_input_tokens: 8192
capabilities: text,vision # Optional field, possible values: text, vision
capabilities: text,vision # Optional field, choices: text, vision
# See https://github.com/jmorganca/ollama
- type: ollama
@ -84,6 +85,7 @@ clients:
# Run `gcloud auth application-default login` to setup adc
# see https://cloud.google.com/docs/authentication/external/set-up-adc
adc_file: <path-to/gcloud/application_default_credentials.json>
block_threshold: BLOCK_ONLY_HIGH # Optional field, choices: BLOCK_NONE, BLOCK_ONLY_HIGH, BLOCK_MEDIUM_AND_ABOVE, BLOCK_LOW_AND_ABOVE
# See https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html
- type: ernie

@ -23,6 +23,7 @@ const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2);
pub struct GeminiConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub block_threshold: Option<String>,
pub extra: Option<ExtraConfig>,
}
@ -73,7 +74,9 @@ impl GeminiClient {
false => "generateContent",
};
let body = build_body(data, self.model.name.clone())?;
let block_threshold = self.config.block_threshold.clone();
let body = build_body(data, self.model.name.clone(), block_threshold)?;
let model = self.model.name.clone();

@ -14,13 +14,13 @@ use serde::Deserialize;
use serde_json::{json, Value};
use std::path::PathBuf;
const MODELS: [(&str, usize, &str); 2] = [
const MODELS: [(&str, usize, &str); 5] = [
// https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models
("gemini-1.0-pro", 24568, "text"),
("gemini-1.0-pro-vision", 14336, "text,vision"),
// ("gemini-1.0-ultra", 8192, "text"),
// ("gemini-1.0-ultra-vision", 8192, "text,vision"),
// ("gemini-1.5-pro", 1000000, "text"),
("gemini-1.0-ultra", 8192, "text"),
("gemini-1.0-ultra-vision", 8192, "text,vision"),
("gemini-1.5-pro", 1000000, "text"),
];
const TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2);
@ -32,6 +32,7 @@ pub struct VertexAIConfig {
pub name: Option<String>,
pub api_base: Option<String>,
pub adc_file: Option<String>,
pub block_threshold: Option<String>,
pub extra: Option<ExtraConfig>,
}
@ -84,7 +85,9 @@ impl VertexAIClient {
false => "generateContent",
};
let body = build_body(data, self.model.name.clone())?;
let block_threshold = self.config.block_threshold.clone();
let body = build_body(data, self.model.name.clone(), block_threshold)?;
let model = self.model.name.clone();
@ -106,7 +109,9 @@ impl VertexAIClient {
let (token, expires_in) = fetch_access_token(&client, &self.config.adc_file)
.await
.with_context(|| "Failed to fetch access token")?;
let expires_at = Utc::now() + Duration::try_seconds(expires_in).ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?;
let expires_at = Utc::now()
+ Duration::try_seconds(expires_in)
.ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?;
unsafe { ACCESS_TOKEN = (token, expires_at.timestamp()) };
}
Ok(())
@ -208,7 +213,11 @@ fn check_error(data: &Value) -> Result<()> {
}
}
pub(crate) fn build_body(data: SendData, _model: String) -> Result<Value> {
pub(crate) fn build_body(
data: SendData,
_model: String,
block_threshold: Option<String>,
) -> Result<Value> {
let SendData {
mut messages,
temperature,
@ -258,15 +267,16 @@ pub(crate) fn build_body(data: SendData, _model: String) -> Result<Value> {
);
}
let mut body = json!({
"contents": contents,
"safetySettings":[
{"category":"HARM_CATEGORY_HARASSMENT","threshold":"BLOCK_ONLY_HIGH"},
{"category":"HARM_CATEGORY_HATE_SPEECH","threshold":"BLOCK_ONLY_HIGH"},
{"category":"HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold":"BLOCK_ONLY_HIGH"},
{"category":"HARM_CATEGORY_DANGEROUS_CONTENT","threshold":"BLOCK_ONLY_HIGH"}
]
});
let mut body = json!({ "contents": contents });
if let Some(block_threshold) = block_threshold {
body["safetySettings"] = json!([
{"category":"HARM_CATEGORY_HARASSMENT","threshold":block_threshold},
{"category":"HARM_CATEGORY_HATE_SPEECH","threshold":block_threshold},
{"category":"HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold":block_threshold},
{"category":"HARM_CATEGORY_DANGEROUS_CONTENT","threshold":block_threshold}
]);
}
if let Some(temperature) = temperature {
body["generationConfig"] = json!({

Loading…
Cancel
Save