seperate ACCESS_TOKEN

pull/485/head
sigoden 4 weeks ago
parent 65debfcd75
commit 8378e2054b

@ -12,7 +12,7 @@ use serde::Deserialize;
use serde_json::{json, Value};
use std::path::PathBuf;
pub static mut VERTEXAI_ACCESS_TOKEN: (String, i64) = (String::new(), 0); // safe under linear operation
static mut ACCESS_TOKEN: (String, i64) = (String::new(), 0); // safe under linear operation
#[derive(Debug, Clone, Deserialize, Default)]
pub struct VertexAIConfig {
@ -54,7 +54,7 @@ impl VertexAIClient {
let builder = client
.post(url)
.bearer_auth(unsafe { &VERTEXAI_ACCESS_TOKEN.0 })
.bearer_auth(unsafe { &ACCESS_TOKEN.0 })
.json(&body);
Ok(builder)
@ -217,25 +217,20 @@ pub(crate) fn gemini_build_body(
Ok(body)
}
pub async fn prepare_access_token(
client: &reqwest::Client,
adc_file: &Option<String>,
) -> Result<()> {
if unsafe {
VERTEXAI_ACCESS_TOKEN.0.is_empty() || Utc::now().timestamp() > VERTEXAI_ACCESS_TOKEN.1
} {
let (token, expires_in) = fetch_access_token(client, adc_file)
async fn prepare_access_token(client: &reqwest::Client, adc_file: &Option<String>) -> Result<()> {
if unsafe { ACCESS_TOKEN.0.is_empty() || Utc::now().timestamp() > ACCESS_TOKEN.1 } {
let (token, expires_in) = fetch_gcloud_access_token(client, adc_file)
.await
.with_context(|| "Failed to fetch access token")?;
let expires_at = Utc::now()
+ Duration::try_seconds(expires_in)
.ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?;
unsafe { VERTEXAI_ACCESS_TOKEN = (token, expires_at.timestamp()) };
unsafe { ACCESS_TOKEN = (token, expires_at.timestamp()) };
}
Ok(())
}
async fn fetch_access_token(
pub async fn fetch_gcloud_access_token(
client: &reqwest::Client,
file: &Option<String>,
) -> Result<(String, i64)> {

@ -1,15 +1,18 @@
use super::claude::{claude_build_body, claude_send_message, claude_send_message_streaming};
use super::vertexai::{prepare_access_token, VERTEXAI_ACCESS_TOKEN};
use super::vertexai::fetch_gcloud_access_token;
use super::{
Client, CompletionDetails, ExtraConfig, Model, ModelConfig, PromptAction, PromptKind, SendData,
SseHandler, VertexAIClaudeClient,
};
use anyhow::Result;
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use chrono::{Duration, Utc};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
static mut ACCESS_TOKEN: (String, i64) = (String::new(), 0); // safe under linear operation
#[derive(Debug, Clone, Deserialize, Default)]
pub struct VertexAIClaudeConfig {
pub name: Option<String>,
@ -50,7 +53,7 @@ impl VertexAIClaudeClient {
let builder = client
.post(url)
.bearer_auth(unsafe { &VERTEXAI_ACCESS_TOKEN.0 })
.bearer_auth(unsafe { &ACCESS_TOKEN.0 })
.json(&body);
Ok(builder)
@ -82,3 +85,16 @@ impl Client for VertexAIClaudeClient {
claude_send_message_streaming(builder, handler).await
}
}
async fn prepare_access_token(client: &reqwest::Client, adc_file: &Option<String>) -> Result<()> {
if unsafe { ACCESS_TOKEN.0.is_empty() || Utc::now().timestamp() > ACCESS_TOKEN.1 } {
let (token, expires_in) = fetch_gcloud_access_token(client, adc_file)
.await
.with_context(|| "Failed to fetch access token")?;
let expires_at = Utc::now()
+ Duration::try_seconds(expires_in)
.ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?;
unsafe { ACCESS_TOKEN = (token, expires_at.timestamp()) };
}
Ok(())
}

Loading…
Cancel
Save