You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
aichat/src/client/bedrock.rs

435 lines
13 KiB
Rust

use super::claude::{claude_build_body, claude_extract_completion};
use super::{
catch_error, generate_prompt, BedrockClient, Client, CompletionDetails, ExtraConfig, Model,
ModelConfig, PromptAction, PromptFormat, PromptKind, SendData, SseHandler,
LLAMA3_PROMPT_FORMAT, MISTRAL_PROMPT_FORMAT,
};
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256};
use anyhow::{anyhow, bail, Result};
use async_trait::async_trait;
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
use aws_smithy_eventstream::smithy::parse_response_headers;
use bytes::BytesMut;
use chrono::{DateTime, Utc};
use futures_util::StreamExt;
use indexmap::IndexMap;
use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue},
Client as ReqwestClient, Method, RequestBuilder,
};
use serde::Deserialize;
use serde_json::{json, Value};
use std::str::FromStr;
#[derive(Debug, Clone, Deserialize)]
pub struct BedrockConfig {
pub name: Option<String>,
pub access_key_id: Option<String>,
pub secret_access_key: Option<String>,
pub region: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
}
#[async_trait]
impl Client for BedrockClient {
client_common_fns!();
async fn send_message_inner(
&self,
client: &ReqwestClient,
data: SendData,
) -> Result<(String, CompletionDetails)> {
let model_category = ModelCategory::from_str(&self.model.name)?;
let builder = self.request_builder(client, data, &model_category)?;
send_message(builder, &model_category).await
}
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: SendData,
) -> Result<()> {
let model_category = ModelCategory::from_str(&self.model.name)?;
let builder = self.request_builder(client, data, &model_category)?;
send_message_streaming(builder, handler, &model_category).await
}
}
impl BedrockClient {
config_get_fn!(access_key_id, get_access_key_id);
config_get_fn!(secret_access_key, get_secret_access_key);
config_get_fn!(region, get_region);
pub const PROMPTS: [PromptAction<'static>; 3] = [
(
"access_key_id",
"AWS Access Key ID",
true,
PromptKind::String,
),
(
"secret_access_key",
"AWS Secret Access Key",
true,
PromptKind::String,
),
("region", "AWS Region", true, PromptKind::String),
];
fn request_builder(
&self,
client: &ReqwestClient,
data: SendData,
model_category: &ModelCategory,
) -> Result<RequestBuilder> {
let access_key_id = self.get_access_key_id()?;
let secret_access_key = self.get_secret_access_key()?;
let region = self.get_region()?;
let model_name = &self.model.name;
let uri = if data.stream {
format!("/model/{model_name}/invoke-with-response-stream")
} else {
format!("/model/{model_name}/invoke")
};
let host = format!("bedrock-runtime.{region}.amazonaws.com");
let headers = IndexMap::new();
let mut body = build_body(data, &self.model, model_category)?;
self.model.merge_extra_fields(&mut body);
let builder = aws_fetch(
client,
&AwsCredentials {
access_key_id,
secret_access_key,
region,
},
AwsRequest {
method: Method::POST,
host,
service: "bedrock".into(),
uri,
querystring: "".into(),
headers,
body: body.to_string(),
},
)?;
Ok(builder)
}
}
async fn send_message(
builder: RequestBuilder,
model_category: &ModelCategory,
) -> Result<(String, CompletionDetails)> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
match model_category {
ModelCategory::Anthropic => claude_extract_completion(&data),
ModelCategory::MetaLlama3 => llama_extract_completion(&data),
ModelCategory::Mistral => mistral_extract_completion(&data),
}
}
async fn send_message_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
model_category: &ModelCategory,
) -> Result<()> {
let res = builder.send().await?;
let status = res.status();
if !status.is_success() {
let data: Value = res.json().await?;
catch_error(&data, status.as_u16())?;
bail!("Invalid response data: {data}");
}
let mut stream = res.bytes_stream();
let mut buffer = BytesMut::new();
let mut decoder = MessageFrameDecoder::new();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
buffer.extend_from_slice(&chunk);
while let DecodedFrame::Complete(message) = decoder.decode_frame(&mut buffer)? {
let response_headers = parse_response_headers(&message)?;
let message_type = response_headers.message_type.as_str();
let smithy_type = response_headers.smithy_type.as_str();
match (message_type, smithy_type) {
("event", "chunk") => {
let data: Value = decode_chunk(message.payload()).ok_or_else(|| {
anyhow!("Invalid chunk data: {}", hex_encode(message.payload()))
})?;
// debug!("bedrock chunk: {data}");
match model_category {
ModelCategory::Anthropic => {
if let Some(typ) = data["type"].as_str() {
if typ == "content_block_delta" {
if let Some(text) = data["delta"]["text"].as_str() {
handler.text(text)?;
}
}
}
}
ModelCategory::MetaLlama3 => {
if let Some(text) = data["generation"].as_str() {
handler.text(text)?;
}
}
ModelCategory::Mistral => {
if let Some(text) = data["outputs"][0]["text"].as_str() {
handler.text(text)?;
}
}
}
}
("exception", _) => {
let payload = base64_decode(message.payload())?;
let data = String::from_utf8_lossy(&payload);
bail!("Invalid response data: {data} (smithy_type: {smithy_type})")
}
_ => {
bail!("Unrecognized message, message_type: {message_type}, smithy_type: {smithy_type}",);
}
}
}
}
Ok(())
}
fn build_body(data: SendData, model: &Model, model_category: &ModelCategory) -> Result<Value> {
match model_category {
ModelCategory::Anthropic => {
let mut body = claude_build_body(data, model)?;
if let Some(body_obj) = body.as_object_mut() {
body_obj.remove("model");
}
body["anthropic_version"] = "bedrock-2023-05-31".into();
Ok(body)
}
ModelCategory::MetaLlama3 => meta_llama_build_body(data, model, LLAMA3_PROMPT_FORMAT),
ModelCategory::Mistral => mistral_build_body(data, model),
}
}
fn meta_llama_build_body(data: SendData, model: &Model, pt: PromptFormat) -> Result<Value> {
let SendData {
messages,
temperature,
top_p,
stream: _,
} = data;
let prompt = generate_prompt(&messages, pt)?;
let mut body = json!({ "prompt": prompt });
if let Some(v) = model.max_tokens_param() {
body["max_gen_len"] = v.into();
}
if let Some(v) = temperature {
body["temperature"] = v.into();
}
if let Some(v) = top_p {
body["top_p"] = v.into();
}
Ok(body)
}
fn mistral_build_body(data: SendData, model: &Model) -> Result<Value> {
let SendData {
messages,
temperature,
top_p,
stream: _,
} = data;
let prompt = generate_prompt(&messages, MISTRAL_PROMPT_FORMAT)?;
let mut body = json!({ "prompt": prompt });
if let Some(v) = model.max_tokens_param() {
body["max_tokens"] = v.into();
}
if let Some(v) = temperature {
body["temperature"] = v.into();
}
if let Some(v) = top_p {
body["top_p"] = v.into();
}
Ok(body)
}
fn llama_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
let text = data["generation"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let details = CompletionDetails {
id: None,
input_tokens: data["prompt_token_count"].as_u64(),
output_tokens: data["generation_token_count"].as_u64(),
};
Ok((text.to_string(), details))
}
fn mistral_extract_completion(data: &Value) -> Result<(String, CompletionDetails)> {
let text = data["outputs"][0]["text"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
Ok((text.to_string(), CompletionDetails::default()))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelCategory {
Anthropic,
MetaLlama3,
Mistral,
}
impl FromStr for ModelCategory {
type Err = anyhow::Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
if s.starts_with("anthropic.") {
Ok(ModelCategory::Anthropic)
} else if s.starts_with("meta.llama3") {
Ok(ModelCategory::MetaLlama3)
} else if s.starts_with("mistral") {
Ok(ModelCategory::Mistral)
} else {
unsupported_model!(s)
}
}
}
#[derive(Debug)]
struct AwsCredentials {
access_key_id: String,
secret_access_key: String,
region: String,
}
#[derive(Debug)]
struct AwsRequest {
method: Method,
host: String,
service: String,
uri: String,
querystring: String,
headers: IndexMap<String, String>,
body: String,
}
fn aws_fetch(
client: &ReqwestClient,
credentials: &AwsCredentials,
request: AwsRequest,
) -> Result<RequestBuilder> {
let AwsRequest {
method,
host,
service,
uri,
querystring,
mut headers,
body,
} = request;
let region = &credentials.region;
let endpoint = format!("https://{}{}", host, uri);
let now: DateTime<Utc> = Utc::now();
let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
let date_stamp = amz_date[0..8].to_string();
headers.insert("host".into(), host.clone());
headers.insert("x-amz-date".into(), amz_date.clone());
let canonical_headers = headers
.iter()
.map(|(key, value)| format!("{}:{}\n", key, value))
.collect::<Vec<_>>()
.join("");
let signed_headers = headers
.iter()
.map(|(key, _)| key.as_str())
.collect::<Vec<_>>()
.join(";");
let payload_hash = sha256(&body);
let canonical_request = format!(
"{}\n{}\n{}\n{}\n{}\n{}",
method,
encode_uri(&uri),
querystring,
canonical_headers,
signed_headers,
payload_hash
);
let algorithm = "AWS4-HMAC-SHA256";
let credential_scope = format!("{}/{}/{}/aws4_request", date_stamp, region, service);
let string_to_sign = format!(
"{}\n{}\n{}\n{}",
algorithm,
amz_date,
credential_scope,
sha256(&canonical_request)
);
let signing_key = gen_signing_key(
&credentials.secret_access_key,
&date_stamp,
region,
&service,
);
let signature = hmac_sha256(&signing_key, &string_to_sign);
let signature = hex_encode(&signature);
let authorization_header = format!(
"{} Credential={}/{}, SignedHeaders={}, Signature={}",
algorithm, credentials.access_key_id, credential_scope, signed_headers, signature
);
headers.insert("authorization".into(), authorization_header);
let mut req_headers = HeaderMap::new();
for (k, v) in &headers {
req_headers.insert(HeaderName::from_str(k)?, HeaderValue::from_str(v)?);
}
debug!("Bedrock Request: {endpoint} {body}");
let request_builder = client
.request(method, endpoint)
.headers(req_headers)
.body(body);
Ok(request_builder)
}
fn gen_signing_key(key: &str, date_stamp: &str, region: &str, service: &str) -> Vec<u8> {
let k_date = hmac_sha256(format!("AWS4{}", key).as_bytes(), date_stamp);
let k_region = hmac_sha256(&k_date, region);
let k_service = hmac_sha256(&k_region, service);
hmac_sha256(&k_service, "aws4_request")
}
fn decode_chunk(data: &[u8]) -> Option<Value> {
let data = serde_json::from_slice::<Value>(data).ok()?;
let data = data["bytes"].as_str()?;
let data = base64_decode(data).ok()?;
serde_json::from_slice(&data).ok()
}