feat: extract vertexai-claude client (#485)

pull/486/head
sigoden 2 weeks ago committed by GitHub
parent a3dd675276
commit 9b283024b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -226,14 +226,14 @@ chat-ollama() {
}'
}
# @cmd Chat with vertexai-gemini api
# @cmd Chat with vertexai api
# @env require-tools gcloud
# @env VERTEXAI_PROJECT_ID!
# @env VERTEXAI_LOCATION!
# @option -m --model=gemini-1.0-pro $VERTEXAI_GEMINI_MODEL
# @flag -S --no-stream
# @arg text~
chat-vertexai-gemini() {
chat-vertexai() {
api_key="$(gcloud auth print-access-token)"
func="streamGenerateContent"
if [[ -n "$argc_no_stream" ]]; then

@ -60,8 +60,15 @@ clients:
# See https://ai.google.dev/docs
- type: gemini
api_key: xxx # ENV: {client}_API_KEY
# possible values: BLOCK_NONE, BLOCK_ONLY_HIGH, BLOCK_MEDIUM_AND_ABOVE, BLOCK_LOW_AND_ABOVE
block_threshold: BLOCK_NONE # Optional
safetySettings:
- category: HARM_CATEGORY_HARASSMENT
threshold: BLOCK_NONE
- category: HARM_CATEGORY_HATE_SPEECH
threshold: BLOCK_NONE
- category: HARM_CATEGORY_SEXUALLY_EXPLICIT
threshold: BLOCK_NONE
- category: HARM_CATEGORY_DANGEROUS_CONTENT
threshold: BLOCK_NONE
# See https://docs.anthropic.com/claude/reference/getting-started-with-the-api
- type: claude
@ -114,8 +121,24 @@ clients:
# Run `gcloud auth application-default login` to init the adc file
# see https://cloud.google.com/docs/authentication/external/set-up-adc
adc_file: <path-to/gcloud/application_default_credentials.json>
# Optional field, possible values: BLOCK_NONE, BLOCK_ONLY_HIGH, BLOCK_MEDIUM_AND_ABOVE, BLOCK_LOW_AND_ABOVE
block_threshold: BLOCK_ONLY_HIGH
safetySettings:
- category: HARM_CATEGORY_HARASSMENT
threshold: BLOCK_ONLY_HIGH
- category: HARM_CATEGORY_HATE_SPEECH
threshold: BLOCK_ONLY_HIGH
- category: HARM_CATEGORY_SEXUALLY_EXPLICIT
threshold: BLOCK_ONLY_HIGH
- category: HARM_CATEGORY_DANGEROUS_CONTENT
threshold: BLOCK_ONLY_HIGH
# See https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude
- type: vertexai-claude
project_id: xxx # ENV: {client}_PROJECT_ID
location: xxx # ENV: {client}_LOCATION
# Specifies a application-default-credentials (adc) file, Optional field
# Run `gcloud auth application-default login` to init the adc file
# see https://cloud.google.com/docs/authentication/external/set-up-adc
adc_file: <path-to/gcloud/application_default_credentials.json>
# See https://docs.aws.amazon.com/bedrock/latest/userguide/
- type: bedrock

@ -238,7 +238,6 @@
# - https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini
# notes:
# - get max_output_tokens info from models doc
# - claude models have not been tested
models:
- name: gemini-1.0-pro
max_input_tokens: 24568
@ -257,6 +256,14 @@
input_price: 2.5
output_price: 7.5
supports_vision: true
- platform: vertexai-claude
# docs:
# - https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude
# notes:
# - get max_output_tokens info from models doc
# - claude models have not been tested
models:
- name: claude-3-opus@20240229
max_input_tokens: 200000
max_output_tokens: 4096

@ -11,7 +11,8 @@ const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models/
pub struct GeminiConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub block_threshold: Option<String>,
#[serde(rename = "safetySettings")]
pub safety_settings: Option<serde_json::Value>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
@ -31,9 +32,7 @@ impl GeminiClient {
false => "generateContent",
};
let block_threshold = self.config.block_threshold.clone();
let body = gemini_build_body(data, &self.model, block_threshold)?;
let body = gemini_build_body(data, &self.model, self.config.safety_settings.clone())?;
let model = &self.model.name;

@ -31,6 +31,12 @@ register_client!(
AzureOpenAIClient
),
(vertexai, "vertexai", VertexAIConfig, VertexAIClient),
(
vertexai_claude,
"vertexai-claude",
VertexAIClaudeConfig,
VertexAIClaudeClient
),
(bedrock, "bedrock", BedrockConfig, BedrockClient),
(cloudflare, "cloudflare", CloudflareConfig, CloudflareClient),
(replicate, "replicate", ReplicateConfig, ReplicateClient),

@ -1,4 +1,3 @@
use super::claude::{claude_build_body, claude_send_message, claude_send_message_streaming};
use super::{
catch_error, json_stream, message::*, patch_system_message, Client, CompletionDetails,
ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData, SseHandler,
@ -11,7 +10,7 @@ use chrono::{Duration, Utc};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
use std::{path::PathBuf, str::FromStr};
use std::path::PathBuf;
static mut ACCESS_TOKEN: (String, i64) = (String::new(), 0); // safe under linear operation
@ -21,7 +20,8 @@ pub struct VertexAIConfig {
pub project_id: Option<String>,
pub location: Option<String>,
pub adc_file: Option<String>,
pub block_threshold: Option<String>,
#[serde(rename = "safetySettings")]
pub safety_settings: Option<Value>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
@ -36,20 +36,19 @@ impl VertexAIClient {
("location", "Location", true, PromptKind::String),
];
fn request_builder(
&self,
client: &ReqwestClient,
data: SendData,
model_category: &ModelCategory,
) -> Result<RequestBuilder> {
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let project_id = self.get_project_id()?;
let location = self.get_location()?;
let base_url = format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers");
let url = build_url(&base_url, &self.model.name, model_category, data.stream)?;
let block_threshold = self.config.block_threshold.clone();
let body = build_body(data, &self.model, model_category, block_threshold)?;
let func = match data.stream {
true => "streamGenerateContent",
false => "generateContent",
};
let url = format!("{base_url}/google/models/{}:{func}", self.model.name);
let body = gemini_build_body(data, &self.model, self.config.safety_settings.clone())?;
debug!("VertexAI Request: {url} {body}");
@ -60,20 +59,6 @@ impl VertexAIClient {
Ok(builder)
}
async fn prepare_access_token(&self) -> Result<()> {
if unsafe { ACCESS_TOKEN.0.is_empty() || Utc::now().timestamp() > ACCESS_TOKEN.1 } {
let client = self.build_client()?;
let (token, expires_in) = fetch_access_token(&client, &self.config.adc_file)
.await
.with_context(|| "Failed to fetch access token")?;
let expires_at = Utc::now()
+ Duration::try_seconds(expires_in)
.ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?;
unsafe { ACCESS_TOKEN = (token, expires_at.timestamp()) };
}
Ok(())
}
}
#[async_trait]
@ -85,13 +70,9 @@ impl Client for VertexAIClient {
client: &ReqwestClient,
data: SendData,
) -> Result<(String, CompletionDetails)> {
let model_category = ModelCategory::from_str(&self.model.name)?;
self.prepare_access_token().await?;
let builder = self.request_builder(client, data, &model_category)?;
match model_category {
ModelCategory::Gemini => gemini_send_message(builder).await,
ModelCategory::Claude => claude_send_message(builder).await,
}
prepare_access_token(client, &self.config.adc_file).await?;
let builder = self.request_builder(client, data)?;
gemini_send_message(builder).await
}
async fn send_message_streaming_inner(
@ -100,13 +81,9 @@ impl Client for VertexAIClient {
handler: &mut SseHandler,
data: SendData,
) -> Result<()> {
let model_category = ModelCategory::from_str(&self.model.name)?;
self.prepare_access_token().await?;
let builder = self.request_builder(client, data, &model_category)?;
match model_category {
ModelCategory::Gemini => gemini_send_message_streaming(builder, handler).await,
ModelCategory::Claude => claude_send_message_streaming(builder, handler).await,
}
prepare_access_token(client, &self.config.adc_file).await?;
let builder = self.request_builder(client, data)?;
gemini_send_message_streaming(builder, handler).await
}
}
@ -158,7 +135,7 @@ fn gemini_extract_text(data: &Value) -> Result<&str> {
.as_str()
.or_else(|| data["candidates"][0]["finishReason"].as_str())
{
bail!("Blocked by safety settingsconsider adjusting `block_threshold` in the client configuration")
bail!("Blocked by safety settingsconsider adjusting `safetySettings` in the client configuration")
} else {
bail!("Invalid response data: {data}")
}
@ -166,50 +143,10 @@ fn gemini_extract_text(data: &Value) -> Result<&str> {
}
}
fn build_url(
base_url: &str,
model_name: &str,
model_category: &ModelCategory,
stream: bool,
) -> Result<String> {
let url = match model_category {
ModelCategory::Gemini => {
let func = match stream {
true => "streamGenerateContent",
false => "generateContent",
};
format!("{base_url}/google/models/{model_name}:{func}")
}
ModelCategory::Claude => {
format!("{base_url}/anthropic/models/{model_name}:streamRawPredict")
}
};
Ok(url)
}
fn build_body(
data: SendData,
model: &Model,
model_category: &ModelCategory,
block_threshold: Option<String>,
) -> Result<Value> {
match model_category {
ModelCategory::Gemini => gemini_build_body(data, model, block_threshold),
ModelCategory::Claude => {
let mut body = claude_build_body(data, model)?;
if let Some(body_obj) = body.as_object_mut() {
body_obj.remove("model");
}
body["anthropic_version"] = "vertex-2023-10-16".into();
Ok(body)
}
}
}
pub(crate) fn gemini_build_body(
data: SendData,
model: &Model,
block_threshold: Option<String>,
safety_settings: Option<Value>,
) -> Result<Value> {
let SendData {
mut messages,
@ -263,13 +200,8 @@ pub(crate) fn gemini_build_body(
let mut body = json!({ "contents": contents, "generationConfig": {} });
if let Some(block_threshold) = block_threshold {
body["safetySettings"] = json!([
{"category":"HARM_CATEGORY_HARASSMENT","threshold":block_threshold},
{"category":"HARM_CATEGORY_HATE_SPEECH","threshold":block_threshold},
{"category":"HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold":block_threshold},
{"category":"HARM_CATEGORY_DANGEROUS_CONTENT","threshold":block_threshold}
]);
if let Some(safety_settings) = safety_settings {
body["safetySettings"] = safety_settings;
}
if let Some(v) = model.max_output_tokens {
@ -285,27 +217,20 @@ pub(crate) fn gemini_build_body(
Ok(body)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelCategory {
Gemini,
Claude,
}
impl FromStr for ModelCategory {
type Err = anyhow::Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
if s.starts_with("gemini-") {
Ok(ModelCategory::Gemini)
} else if s.starts_with("claude-") {
Ok(ModelCategory::Claude)
} else {
unsupported_model!(s)
}
async fn prepare_access_token(client: &reqwest::Client, adc_file: &Option<String>) -> Result<()> {
if unsafe { ACCESS_TOKEN.0.is_empty() || Utc::now().timestamp() > ACCESS_TOKEN.1 } {
let (token, expires_in) = fetch_gcloud_access_token(client, adc_file)
.await
.with_context(|| "Failed to fetch access token")?;
let expires_at = Utc::now()
+ Duration::try_seconds(expires_in)
.ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?;
unsafe { ACCESS_TOKEN = (token, expires_at.timestamp()) };
}
Ok(())
}
async fn fetch_access_token(
pub async fn fetch_gcloud_access_token(
client: &reqwest::Client,
file: &Option<String>,
) -> Result<(String, i64)> {

@ -0,0 +1,100 @@
use super::claude::{claude_build_body, claude_send_message, claude_send_message_streaming};
use super::vertexai::fetch_gcloud_access_token;
use super::{
Client, CompletionDetails, ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData,
SseHandler, VertexAIClaudeClient,
};
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use chrono::{Duration, Utc};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
static mut ACCESS_TOKEN: (String, i64) = (String::new(), 0); // safe under linear operation
#[derive(Debug, Clone, Deserialize, Default)]
pub struct VertexAIClaudeConfig {
pub name: Option<String>,
pub project_id: Option<String>,
pub location: Option<String>,
pub adc_file: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
}
impl VertexAIClaudeClient {
config_get_fn!(project_id, get_project_id);
config_get_fn!(location, get_location);
pub const PROMPTS: [PromptAction<'static>; 2] = [
("project_id", "Project ID", true, PromptKind::String),
("location", "Location", true, PromptKind::String),
];
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let project_id = self.get_project_id()?;
let location = self.get_location()?;
let base_url = format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers");
let url = format!(
"{base_url}/anthropic/models/{}:streamRawPredict",
self.model.name
);
let mut body = claude_build_body(data, &self.model)?;
if let Some(body_obj) = body.as_object_mut() {
body_obj.remove("model");
}
body["anthropic_version"] = "vertex-2023-10-16".into();
debug!("VertexAIClaude Request: {url} {body}");
let builder = client
.post(url)
.bearer_auth(unsafe { &ACCESS_TOKEN.0 })
.json(&body);
Ok(builder)
}
}
#[async_trait]
impl Client for VertexAIClaudeClient {
client_common_fns!();
async fn send_message_inner(
&self,
client: &ReqwestClient,
data: SendData,
) -> Result<(String, CompletionDetails)> {
prepare_access_token(client, &self.config.adc_file).await?;
let builder = self.request_builder(client, data)?;
claude_send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: SendData,
) -> Result<()> {
prepare_access_token(client, &self.config.adc_file).await?;
let builder = self.request_builder(client, data)?;
claude_send_message_streaming(builder, handler).await
}
}
async fn prepare_access_token(client: &reqwest::Client, adc_file: &Option<String>) -> Result<()> {
if unsafe { ACCESS_TOKEN.0.is_empty() || Utc::now().timestamp() > ACCESS_TOKEN.1 } {
let (token, expires_in) = fetch_gcloud_access_token(client, adc_file)
.await
.with_context(|| "Failed to fetch access token")?;
let expires_at = Utc::now()
+ Duration::try_seconds(expires_in)
.ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?;
unsafe { ACCESS_TOKEN = (token, expires_at.timestamp()) };
}
Ok(())
}
Loading…
Cancel
Save