feat: support more openai compatiable clients (#467)

pull/468/head
sigoden 1 month ago committed by GitHub
parent 97f6d48c42
commit a50b32ca21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -135,7 +135,9 @@ macro_rules! register_client {
}
pub fn list_client_types() -> Vec<&'static str> {
vec![$($client::NAME,)+]
let mut client_types: Vec<_> = vec![$($client::NAME,)+];
client_types.extend($crate::client::KNOWN_OPENAI_COMPATIBLE_PLATFORMS.iter().map(|(name, _)| *name));
client_types
}
pub fn create_client_config(client: &str) -> anyhow::Result<(String, serde_json::Value)> {
@ -144,6 +146,9 @@ macro_rules! register_client {
return create_config(&$client::PROMPTS, $client::NAME)
}
)+
if let Some(ret) = create_openai_compatible_client_config(client)? {
return Ok(ret);
}
anyhow::bail!("Unknown client '{}'", client)
}
@ -434,26 +439,35 @@ pub fn create_config(list: &[PromptType], client: &str) -> Result<(String, Value
"type": client,
});
let mut model = client.to_string();
for (path, desc, required, kind) in list {
match kind {
PromptKind::String => {
let value = prompt_input_string(desc, *required)?;
set_config_value(&mut config, path, kind, &value);
if *path == "name" {
model = value;
}
}
PromptKind::Integer => {
let value = prompt_input_integer(desc, *required)?;
set_config_value(&mut config, path, kind, &value);
}
}
}
set_client_config_values(list, &mut model, &mut config)?;
let clients = json!(vec![config]);
Ok((model, clients))
}
pub fn create_openai_compatible_client_config(client: &str) -> Result<Option<(String, Value)>> {
match super::KNOWN_OPENAI_COMPATIBLE_PLATFORMS
.iter()
.find(|(name, _)| client == *name)
{
None => Ok(None),
Some((name, api_base)) => {
let mut config = json!({
"type": "openai-compatible",
"name": name,
"api_base": api_base,
});
let mut model = client.to_string();
set_client_config_values(
&super::KNOWN_OPENAI_COMPATIBLE_PROMPTS,
&mut model,
&mut config,
)?;
let clients = json!(vec![config]);
Ok(Some((model, clients)))
}
}
}
pub async fn send_stream(
input: &Input,
client: &dyn Client,
@ -663,27 +677,50 @@ where
Ok(())
}
fn set_config_value(json: &mut Value, path: &str, kind: &PromptKind, value: &str) {
fn set_client_config_values(
list: &[PromptType],
model: &mut String,
client_config: &mut Value,
) -> Result<()> {
for (path, desc, required, kind) in list {
match kind {
PromptKind::String => {
let value = prompt_input_string(desc, *required)?;
set_client_config_value(client_config, path, kind, &value);
if *path == "name" {
*model = value;
}
}
PromptKind::Integer => {
let value = prompt_input_integer(desc, *required)?;
set_client_config_value(client_config, path, kind, &value);
}
}
}
Ok(())
}
fn set_client_config_value(client_config: &mut Value, path: &str, kind: &PromptKind, value: &str) {
let segs: Vec<&str> = path.split('.').collect();
match segs.as_slice() {
[name] => json[name] = to_json(kind, value),
[name] => client_config[name] = to_json(kind, value),
[scope, name] => match scope.split_once('[') {
None => {
if json.get(scope).is_none() {
if client_config.get(scope).is_none() {
let mut obj = json!({});
obj[name] = to_json(kind, value);
json[scope] = obj;
client_config[scope] = obj;
} else {
json[scope][name] = to_json(kind, value);
client_config[scope][name] = to_json(kind, value);
}
}
Some((scope, _)) => {
if json.get(scope).is_none() {
if client_config.get(scope).is_none() {
let mut obj = json!({});
obj[name] = to_json(kind, value);
json[scope] = json!([obj]);
client_config[scope] = json!([obj]);
} else {
json[scope][0][name] = to_json(kind, value);
client_config[scope][0][name] = to_json(kind, value);
}
}
},

@ -5,6 +5,7 @@ mod model;
mod prompt_format;
mod sse_handler;
pub use crate::utils::PromptKind;
pub use common::*;
pub use message::*;
pub use model::*;
@ -40,3 +41,22 @@ register_client!(
OpenAICompatibleClient
),
);
pub const KNOWN_OPENAI_COMPATIBLE_PLATFORMS: [(&str, &str); 5] = [
("anyscale", "https://api.endpoints.anyscale.com/v1"),
("deepinfra", "https://api.deepinfra.com/v1/openai"),
("fireworks", "https://api.fireworks.ai/inference/v1"),
("octoai", "https://text.octoai.run/v1"),
("together", "https://api.together.xyz/v1"),
];
pub const KNOWN_OPENAI_COMPATIBLE_PROMPTS: [PromptType<'static>; 3] = [
("api_key", "API Key:", false, PromptKind::String),
("models[].name", "Model Name:", true, PromptKind::String),
(
"models[].max_input_tokens",
"Max Input Tokens:",
false,
PromptKind::Integer,
),
];

Loading…
Cancel
Save