refactor: improve creating config for openai-compatible client (#374)

pull/375/head
sigoden 3 months ago committed by GitHub
parent eec041c111
commit 0ebc7955da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -46,7 +46,8 @@ clients:
api_key: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
# Any openai-compatible API providers
- type: openai-compatible # Renamed from localai
- type: openai-compatible
name: localai
api_base: http://localhost:8080/v1
api_key: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
chat_endpoint: /chat/completions # Optional field

@ -104,7 +104,7 @@ macro_rules! register_client {
vec![$($client::NAME,)+]
}
pub fn create_client_config(client: &str) -> anyhow::Result<serde_json::Value> {
pub fn create_client_config(client: &str) -> anyhow::Result<(String, serde_json::Value)> {
$(
if client == $client::NAME {
return create_config(&$client::PROMPTS, $client::NAME)
@ -310,15 +310,19 @@ pub struct SendData {
pub type PromptType<'a> = (&'a str, &'a str, bool, PromptKind);
pub fn create_config(list: &[PromptType], client: &str) -> Result<Value> {
pub fn create_config(list: &[PromptType], client: &str) -> Result<(String, Value)> {
let mut config = json!({
"type": client,
});
let mut model = client.to_string();
for (path, desc, required, kind) in list {
match kind {
PromptKind::String => {
let value = prompt_input_string(desc, *required)?;
set_config_value(&mut config, path, kind, &value);
if *path == "name" {
model = value;
}
}
PromptKind::Integer => {
let value = prompt_input_integer(desc, *required)?;
@ -328,7 +332,7 @@ pub fn create_config(list: &[PromptType], client: &str) -> Result<Value> {
}
let clients = json!(vec![config]);
Ok(clients)
Ok((model, clients))
}
#[allow(unused)]

@ -23,7 +23,8 @@ openai_compatible_client!(OpenAICompatibleClient);
impl OpenAICompatibleClient {
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 4] = [
pub const PROMPTS: [PromptType<'static>; 5] = [
("name", "Platform Name:", true, PromptKind::String),
("api_base", "API Base:", true, PromptKind::String),
("api_key", "API Key:", false, PromptKind::String),
("models[].name", "Model Name:", true, PromptKind::String),

@ -1046,8 +1046,9 @@ fn create_config_file(config_path: &Path) -> Result<()> {
let client = Select::new("Platform:", list_client_types()).prompt()?;
let mut config = serde_json::json!({});
config["model"] = client.into();
config[CLIENTS_FIELD] = create_client_config(client)?;
let (model, clients_config) = create_client_config(client)?;
config["model"] = model.into();
config[CLIENTS_FIELD] = clients_config;
let config_data = serde_yaml::to_string(&config).with_context(|| "Failed to create config")?;

Loading…
Cancel
Save