feat: support multi bots and custom url (#150)

pull/151/head
sigoden 7 months ago committed by GitHub
parent f4160ff85b
commit 7d8564cafb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

12
Cargo.lock generated

@ -32,6 +32,7 @@ version = "0.8.0"
dependencies = [
"anyhow",
"arboard",
"async-trait",
"atty",
"base64",
"bincode",
@ -154,6 +155,17 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b"
[[package]]
name = "async-trait"
version = "0.1.74"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.38",
]
[[package]]
name = "atty"
version = "0.2.14"

@ -39,6 +39,7 @@ rustc-hash = "1.1.0"
bstr = "1.3.0"
nu-ansi-term = "0.47.0"
arboard = { version = "3.2.0", default-features = false }
async-trait = "0.1.74"
[dependencies.reqwest]
version = "0.11.14"

@ -1,177 +0,0 @@
use crate::config::SharedConfig;
use crate::repl::{ReplyStreamHandler, SharedAbortSignal};
use anyhow::{anyhow, bail, Context, Result};
use eventsource_stream::Eventsource;
use futures_util::StreamExt;
use reqwest::{Client, Proxy, RequestBuilder};
use serde_json::{json, Value};
use std::time::Duration;
use tokio::runtime::Runtime;
use tokio::time::sleep;
const API_URL: &str = "https://api.openai.com/v1/chat/completions";
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct ChatGptClient {
config: SharedConfig,
runtime: Runtime,
}
impl ChatGptClient {
pub fn init(config: SharedConfig) -> Result<Self> {
let runtime = init_runtime()?;
let s = Self { config, runtime };
let _ = s.build_client()?; // check error
Ok(s)
}
pub fn send_message(&self, input: &str) -> Result<String> {
self.runtime.block_on(async {
self.send_message_inner(input)
.await
.with_context(|| "Failed to fetch")
})
}
pub fn send_message_streaming(
&self,
input: &str,
handler: &mut ReplyStreamHandler,
) -> Result<()> {
async fn watch_abort(abort: SharedAbortSignal) {
loop {
if abort.aborted() {
break;
}
sleep(Duration::from_millis(100)).await;
}
}
let abort = handler.get_abort();
self.runtime.block_on(async {
tokio::select! {
ret = self.send_message_streaming_inner(input, handler) => {
handler.done()?;
ret.with_context(|| "Failed to fetch stream")
}
_ = watch_abort(abort.clone()) => {
handler.done()?;
Ok(())
},
_ = tokio::signal::ctrl_c() => {
abort.set_ctrlc();
Ok(())
}
}
})
}
async fn send_message_inner(&self, content: &str) -> Result<String> {
if self.config.read().dry_run {
return Ok(self.config.read().echo_messages(content));
}
let builder = self.request_builder(content, false)?;
let data: Value = builder.send().await?.json().await?;
if let Some(err_msg) = data["error"]["message"].as_str() {
bail!("Request failed, {err_msg}");
}
let output = data["choices"][0]["message"]["content"]
.as_str()
.ok_or_else(|| anyhow!("Unexpected response {data}"))?;
Ok(output.to_string())
}
async fn send_message_streaming_inner(
&self,
content: &str,
handler: &mut ReplyStreamHandler,
) -> Result<()> {
if self.config.read().dry_run {
handler.text(&self.config.read().echo_messages(content))?;
return Ok(());
}
let builder = self.request_builder(content, true)?;
let res = builder.send().await?;
if !res.status().is_success() {
let data: Value = res.json().await?;
if let Some(err_msg) = data["error"]["message"].as_str() {
bail!("Request failed, {err_msg}");
}
bail!("Request failed");
}
let mut stream = res.bytes_stream().eventsource();
while let Some(part) = stream.next().await {
let chunk = part?.data;
if chunk == "[DONE]" {
break;
}
let data: Value = serde_json::from_str(&chunk)?;
let text = data["choices"][0]["delta"]["content"]
.as_str()
.unwrap_or_default();
if text.is_empty() {
continue;
}
handler.text(text)?;
}
Ok(())
}
fn build_client(&self) -> Result<Client> {
let mut builder = Client::builder();
if let Some(proxy) = self.config.read().proxy.as_ref() {
builder = builder
.proxy(Proxy::all(proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?);
}
let timeout = self.config.read().get_connect_timeout();
let client = builder
.connect_timeout(timeout)
.build()
.with_context(|| "Failed to build http client")?;
Ok(client)
}
fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
let (model, _) = self.config.read().get_model();
let messages = self.config.read().build_messages(content)?;
let mut body = json!({
"model": model,
"messages": messages,
});
if let Some(v) = self.config.read().get_temperature() {
body.as_object_mut()
.and_then(|m| m.insert("temperature".into(), json!(v)));
}
if stream {
body.as_object_mut()
.and_then(|m| m.insert("stream".into(), json!(true)));
}
let (api_key, organization_id) = self.config.read().get_api_key();
let mut builder = self
.build_client()?
.post(API_URL)
.bearer_auth(api_key)
.json(&body);
if let Some(organization_id) = organization_id {
builder = builder.header("OpenAI-Organization", organization_id);
}
Ok(builder)
}
}
fn init_runtime() -> Result<Runtime> {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.with_context(|| "Failed to init tokio")
}

@ -0,0 +1,181 @@
use super::openai::{openai_send_message, openai_send_message_streaming};
use super::{Client, ModelInfo};
use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use inquire::{Confirm, Text};
use reqwest::{Client as ReqwestClient, Proxy, RequestBuilder};
use serde::Deserialize;
use serde_json::json;
use std::time::Duration;
use tokio::runtime::Runtime;
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct LocalAIClient {
global_config: SharedConfig,
local_config: LocalAIConfig,
model_info: ModelInfo,
runtime: Runtime,
}
#[derive(Debug, Clone, Deserialize)]
pub struct LocalAIConfig {
pub url: String,
pub api_key: Option<String>,
pub models: Vec<LocalAIModel>,
pub proxy: Option<String>,
/// Set a timeout in seconds for connect to server
pub connect_timeout: Option<u64>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct LocalAIModel {
name: String,
max_tokens: usize,
}
#[async_trait]
impl Client for LocalAIClient {
fn get_config(&self) -> &SharedConfig {
&self.global_config
}
fn get_runtime(&self) -> &Runtime {
&self.runtime
}
async fn send_message_inner(&self, content: &str) -> Result<String> {
let builder = self.request_builder(content, false)?;
openai_send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
content: &str,
handler: &mut ReplyStreamHandler,
) -> Result<()> {
let builder = self.request_builder(content, true)?;
openai_send_message_streaming(builder, handler).await
}
}
impl LocalAIClient {
pub fn new(
global_config: SharedConfig,
local_config: LocalAIConfig,
model_info: ModelInfo,
runtime: Runtime,
) -> Self {
Self {
global_config,
local_config,
model_info,
runtime,
}
}
pub fn name() -> &'static str {
"localai"
}
pub fn list_models(local_config: &LocalAIConfig) -> Vec<(String, usize)> {
local_config
.models
.iter()
.map(|v| (v.name.to_string(), v.max_tokens))
.collect()
}
pub fn create_config() -> Result<String> {
let mut client_config = format!("clients:\n - type: {}\n", Self::name());
let url = Text::new("URL:")
.prompt()
.map_err(|_| anyhow!("An error happened when asking for url, try again later."))?;
client_config.push_str(&format!(" url: {url}\n"));
let ans = Confirm::new("Use auth?")
.with_default(false)
.prompt()
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
if ans {
let api_key = Text::new("API key:").prompt().map_err(|_| {
anyhow!("An error happened when asking for api key, try again later.")
})?;
client_config.push_str(&format!(" api_key: {api_key}\n"));
}
let model_name = Text::new("Model Name:").prompt().map_err(|_| {
anyhow!("An error happened when asking for model name, try again later.")
})?;
let max_tokens = Text::new("Max tokens:").prompt().map_err(|_| {
anyhow!("An error happened when asking for max tokens, try again later.")
})?;
let ans = Confirm::new("Use proxy?")
.with_default(false)
.prompt()
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
if ans {
let proxy = Text::new("Set proxy:").prompt().map_err(|_| {
anyhow!("An error happened when asking for proxy, try again later.")
})?;
client_config.push_str(&format!(" proxy: {proxy}\n"));
}
client_config.push_str(&format!(
" models:\n - name: {model_name}\n max_tokens: {max_tokens}\n"
));
Ok(client_config)
}
fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
let messages = self.global_config.read().build_messages(content)?;
let mut body = json!({
"model": self.model_info.name,
"messages": messages,
});
if let Some(v) = self.global_config.read().get_temperature() {
body.as_object_mut()
.and_then(|m| m.insert("temperature".into(), json!(v)));
}
if stream {
body.as_object_mut()
.and_then(|m| m.insert("stream".into(), json!(true)));
}
let client = {
let mut builder = ReqwestClient::builder();
if let Some(proxy) = &self.local_config.proxy {
builder = builder
.proxy(Proxy::all(proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?);
}
let timeout = Duration::from_secs(self.local_config.connect_timeout.unwrap_or(10));
builder
.connect_timeout(timeout)
.build()
.with_context(|| "Failed to build client")?
};
let mut builder = client.post(&self.local_config.url);
if let Some(api_key) = &self.local_config.api_key {
builder = builder.bearer_auth(api_key);
};
builder = builder.json(&body);
Ok(builder)
}
}

@ -0,0 +1,198 @@
pub mod localai;
pub mod openai;
use self::{
localai::LocalAIConfig,
openai::{OpenAIClient, OpenAIConfig},
};
use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use serde::Deserialize;
use std::time::Duration;
use tokio::runtime::Runtime;
use tokio::time::sleep;
use crate::{
client::localai::LocalAIClient,
config::{Config, SharedConfig},
repl::{ReplyStreamHandler, SharedAbortSignal},
};
#[allow(clippy::struct_excessive_bools)]
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
pub enum ClientConfig {
#[serde(rename = "openai")]
OpenAI(OpenAIConfig),
#[serde(rename = "localai")]
LocalAI(LocalAIConfig),
}
#[derive(Debug, Clone)]
pub struct ModelInfo {
pub client: String,
pub name: String,
pub max_tokens: usize,
pub index: usize,
}
impl Default for ModelInfo {
fn default() -> Self {
let client = OpenAIClient::name();
let (name, max_tokens) = &OpenAIClient::list_models(&OpenAIConfig::default())[0];
Self::new(client, name, *max_tokens, 0)
}
}
impl ModelInfo {
pub fn new(client: &str, name: &str, max_tokens: usize, index: usize) -> Self {
Self {
client: client.into(),
name: name.into(),
max_tokens,
index,
}
}
pub fn stringify(&self) -> String {
format!("{}:{}", self.client, self.name)
}
}
#[async_trait]
pub trait Client {
fn get_config(&self) -> &SharedConfig;
fn get_runtime(&self) -> &Runtime;
fn send_message(&self, content: &str) -> Result<String> {
self.get_runtime().block_on(async {
if self.get_config().read().dry_run {
return Ok(self.get_config().read().echo_messages(content));
}
self.send_message_inner(content)
.await
.with_context(|| "Failed to fetch")
})
}
fn send_message_streaming(
&self,
content: &str,
handler: &mut ReplyStreamHandler,
) -> Result<()> {
async fn watch_abort(abort: SharedAbortSignal) {
loop {
if abort.aborted() {
break;
}
sleep(Duration::from_millis(100)).await;
}
}
let abort = handler.get_abort();
self.get_runtime().block_on(async {
tokio::select! {
ret = async {
if self.get_config().read().dry_run {
handler.text(&self.get_config().read().echo_messages(content))?;
return Ok(());
}
self.send_message_streaming_inner(content, handler).await
} => {
handler.done()?;
ret.with_context(|| "Failed to fetch stream")
}
_ = watch_abort(abort.clone()) => {
handler.done()?;
Ok(())
},
_ = tokio::signal::ctrl_c() => {
abort.set_ctrlc();
Ok(())
}
}
})
}
async fn send_message_inner(&self, content: &str) -> Result<String>;
async fn send_message_streaming_inner(
&self,
content: &str,
handler: &mut ReplyStreamHandler,
) -> Result<()>;
}
pub fn init_client(config: SharedConfig, runtime: Runtime) -> Result<Box<dyn Client>> {
let model_info = config.read().model_info.clone();
let model_info_err = |model_info: &ModelInfo| {
bail!(
"Unknown client {} at config.clients[{}]",
&model_info.client,
&model_info.index
)
};
if model_info.client == OpenAIClient::name() {
let local_config = {
if let ClientConfig::OpenAI(c) = &config.read().clients[model_info.index] {
c.clone()
} else {
return model_info_err(&model_info);
}
};
Ok(Box::new(OpenAIClient::new(
config,
local_config,
model_info,
runtime,
)))
} else if model_info.client == LocalAIClient::name() {
let local_config = {
if let ClientConfig::LocalAI(c) = &config.read().clients[model_info.index] {
c.clone()
} else {
return model_info_err(&model_info);
}
};
Ok(Box::new(LocalAIClient::new(
config,
local_config,
model_info,
runtime,
)))
} else {
bail!("Unknown client {}", &model_info.client)
}
}
pub fn all_clients() -> Vec<&'static str> {
vec![OpenAIClient::name(), LocalAIClient::name()]
}
pub fn create_client_config(client: &str) -> Result<String> {
if client == OpenAIClient::name() {
OpenAIClient::create_config()
} else if client == LocalAIClient::name() {
LocalAIClient::create_config()
} else {
bail!("Unknown client {}", &client)
}
}
pub fn list_models(config: &Config) -> Vec<ModelInfo> {
config
.clients
.iter()
.enumerate()
.flat_map(|(i, v)| match v {
ClientConfig::OpenAI(c) => OpenAIClient::list_models(c)
.iter()
.map(|(x, y)| ModelInfo::new(OpenAIClient::name(), x, *y, i))
.collect::<Vec<ModelInfo>>(),
ClientConfig::LocalAI(c) => LocalAIClient::list_models(c)
.iter()
.map(|(x, y)| ModelInfo::new(LocalAIClient::name(), x, *y, i))
.collect::<Vec<ModelInfo>>(),
})
.collect()
}

@ -0,0 +1,219 @@
use super::{Client, ModelInfo};
use crate::repl::ReplyStreamHandler;
use crate::{config::SharedConfig, utils::get_env_name};
use anyhow::{anyhow, bail, Context, Result};
use async_trait::async_trait;
use eventsource_stream::Eventsource;
use futures_util::StreamExt;
use inquire::{Confirm, Text};
use reqwest::{Client as ReqwestClient, Proxy, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
use std::env;
use std::time::Duration;
use tokio::runtime::Runtime;
const API_URL: &str = "https://api.openai.com/v1/chat/completions";
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct OpenAIClient {
global_config: SharedConfig,
local_config: OpenAIConfig,
model_info: ModelInfo,
runtime: Runtime,
}
#[allow(clippy::struct_excessive_bools)]
#[derive(Debug, Clone, Deserialize, Default)]
pub struct OpenAIConfig {
pub api_key: Option<String>,
pub organization_id: Option<String>,
pub proxy: Option<String>,
/// Set a timeout in seconds for connect to openai server
pub connect_timeout: Option<u64>,
}
#[async_trait]
impl Client for OpenAIClient {
fn get_config(&self) -> &SharedConfig {
&self.global_config
}
fn get_runtime(&self) -> &Runtime {
&self.runtime
}
async fn send_message_inner(&self, content: &str) -> Result<String> {
let builder = self.request_builder(content, false)?;
openai_send_message(builder).await
}
async fn send_message_streaming_inner(
&self,
content: &str,
handler: &mut ReplyStreamHandler,
) -> Result<()> {
let builder = self.request_builder(content, true)?;
openai_send_message_streaming(builder, handler).await
}
}
impl OpenAIClient {
pub fn new(
global_config: SharedConfig,
local_config: OpenAIConfig,
model_info: ModelInfo,
runtime: Runtime,
) -> Self {
Self {
global_config,
local_config,
model_info,
runtime,
}
}
pub fn name() -> &'static str {
"openai"
}
pub fn list_models(_local_config: &OpenAIConfig) -> Vec<(String, usize)> {
vec![
("gpt-3.5-turbo".into(), 4096),
("gpt-3.5-turbo-16k".into(), 16384),
("gpt-4".into(), 8192),
("gpt-4-32k".into(), 32768),
]
}
pub fn create_config() -> Result<String> {
let mut client_config = format!("clients:\n - type: {}\n", Self::name());
let api_key = Text::new("API key:")
.prompt()
.map_err(|_| anyhow!("An error happened when asking for api key, try again later."))?;
client_config.push_str(&format!(" api_key: {api_key}\n"));
let ans = Confirm::new("Has Organization?")
.with_default(false)
.prompt()
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
if ans {
let organization_id = Text::new("Organization ID:").prompt().map_err(|_| {
anyhow!("An error happened when asking for proxy, try again later.")
})?;
client_config.push_str(&format!(" organization_id: {organization_id}\n"));
}
let ans = Confirm::new("Use proxy?")
.with_default(false)
.prompt()
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
if ans {
let proxy = Text::new("Set proxy:").prompt().map_err(|_| {
anyhow!("An error happened when asking for proxy, try again later.")
})?;
client_config.push_str(&format!(" proxy: {proxy}\n"));
}
Ok(client_config)
}
fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
let api_key = if let Some(api_key) = &self.local_config.api_key {
api_key.to_string()
} else if let Ok(api_key) = env::var(get_env_name("api_key")) {
api_key.to_string()
} else {
bail!("Miss api_key")
};
let messages = self.global_config.read().build_messages(content)?;
let mut body = json!({
"model": self.model_info.name,
"messages": messages,
});
if let Some(v) = self.global_config.read().get_temperature() {
body.as_object_mut()
.and_then(|m| m.insert("temperature".into(), json!(v)));
}
if stream {
body.as_object_mut()
.and_then(|m| m.insert("stream".into(), json!(true)));
}
let client = {
let mut builder = ReqwestClient::builder();
if let Some(proxy) = &self.local_config.proxy {
builder = builder
.proxy(Proxy::all(proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?);
}
let timeout = Duration::from_secs(self.local_config.connect_timeout.unwrap_or(10));
builder
.connect_timeout(timeout)
.build()
.with_context(|| "Failed to build client")?
};
let mut builder = client.post(API_URL).bearer_auth(api_key).json(&body);
if let Some(organization_id) = &self.local_config.organization_id {
builder = builder.header("OpenAI-Organization", organization_id);
}
Ok(builder)
}
}
pub(crate) async fn openai_send_message(builder: RequestBuilder) -> Result<String> {
let data: Value = builder.send().await?.json().await?;
if let Some(err_msg) = data["error"]["message"].as_str() {
bail!("Request failed, {err_msg}");
}
let output = data["choices"][0]["message"]["content"]
.as_str()
.ok_or_else(|| anyhow!("Unexpected response {data}"))?;
Ok(output.to_string())
}
pub(crate) async fn openai_send_message_streaming(
builder: RequestBuilder,
handler: &mut ReplyStreamHandler,
) -> Result<()> {
let res = builder.send().await?;
if !res.status().is_success() {
let data: Value = res.json().await?;
if let Some(err_msg) = data["error"]["message"].as_str() {
bail!("Request failed, {err_msg}");
}
bail!("Request failed");
}
let mut stream = res.bytes_stream().eventsource();
while let Some(part) = stream.next().await {
let chunk = part?.data;
if chunk == "[DONE]" {
break;
}
let data: Value = serde_json::from_str(&chunk)?;
let text = data["choices"][0]["delta"]["content"]
.as_str()
.unwrap_or_default();
if text.is_empty() {
continue;
}
handler.text(text)?;
}
Ok(())
}

@ -6,14 +6,15 @@ use self::conversation::Conversation;
use self::message::Message;
use self::role::Role;
use crate::client::openai::{OpenAIClient, OpenAIConfig};
use crate::client::{all_clients, create_client_config, list_models, ClientConfig, ModelInfo};
use crate::config::message::num_tokens_from_messages;
use crate::utils::{mask_text, now};
use crate::utils::{get_env_name, now};
use anyhow::{anyhow, bail, Context, Result};
use inquire::{Confirm, Text};
use inquire::{Confirm, Select};
use parking_lot::RwLock;
use serde::Deserialize;
use std::time::Duration;
use std::{
env,
fs::{create_dir_all, read_to_string, File, OpenOptions},
@ -23,24 +24,16 @@ use std::{
sync::Arc,
};
pub const MODELS: [(&str, usize); 4] = [
("gpt-4", 8192),
("gpt-4-32k", 32768),
("gpt-3.5-turbo", 4096),
("gpt-3.5-turbo-16k", 16384),
];
const CONFIG_FILE_NAME: &str = "config.yaml";
const ROLES_FILE_NAME: &str = "roles.yaml";
const HISTORY_FILE_NAME: &str = "history.txt";
const MESSAGE_FILE_NAME: &str = "messages.md";
const SET_COMPLETIONS: [&str; 8] = [
const SET_COMPLETIONS: [&str; 7] = [
".set temperature",
".set save true",
".set save false",
".set highlight true",
".set highlight false",
".set proxy",
".set dry_run true",
".set dry_run false",
];
@ -49,33 +42,26 @@ const SET_COMPLETIONS: [&str; 8] = [
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct Config {
/// OpenAI api key
pub api_key: Option<String>,
/// OpenAI organization id
pub organization_id: Option<String>,
/// OpenAI model
#[serde(rename(serialize = "model", deserialize = "model"))]
pub model_name: Option<String>,
/// LLM model
pub model: Option<String>,
/// What sampling temperature to use, between 0 and 2
pub temperature: Option<f64>,
/// Whether to persistently save chat messages
pub save: bool,
/// Whether to disable highlight
pub highlight: bool,
/// Set proxy
pub proxy: Option<String>,
/// Used only for debugging
pub dry_run: bool,
/// If set ture, start a conversation immediately upon repl
pub conversation_first: bool,
/// Is ligth theme
pub light_theme: bool,
/// Set a timeout in seconds for connect to gpt
pub connect_timeout: usize,
/// Automatically copy the last output to the clipboard
pub auto_copy: bool,
/// Use vi keybindings, overriding the default Emacs keybindings
pub vi_keybindings: bool,
/// LLM clients
pub clients: Vec<ClientConfig>,
/// Predefined roles
#[serde(skip)]
pub roles: Vec<Role>,
@ -86,29 +72,26 @@ pub struct Config {
#[serde(skip)]
pub conversation: Option<Conversation>,
#[serde(skip)]
pub model: (String, usize),
pub model_info: ModelInfo,
}
impl Default for Config {
fn default() -> Self {
Self {
api_key: None,
organization_id: None,
model_name: None,
model: None,
temperature: None,
save: false,
highlight: true,
proxy: None,
dry_run: false,
conversation_first: false,
light_theme: false,
connect_timeout: 10,
auto_copy: false,
vi_keybindings: false,
roles: vec![],
clients: vec![ClientConfig::OpenAI(OpenAIConfig::default())],
role: None,
conversation: None,
model: ("gpt-3.5-turbo".into(), 4096),
model_info: Default::default(),
}
}
}
@ -118,27 +101,29 @@ pub type SharedConfig = Arc<RwLock<Config>>;
impl Config {
pub fn init(is_interactive: bool) -> Result<Self> {
let api_key = env::var(get_env_name("api_key")).ok();
let config_path = Self::config_file()?;
if is_interactive && api_key.is_none() && !config_path.exists() {
let api_key = env::var(get_env_name("api_key")).ok();
let exist_config_path = config_path.exists();
if is_interactive && api_key.is_none() && !exist_config_path {
create_config_file(&config_path)?;
}
let mut config = if api_key.is_some() && !config_path.exists() {
let mut config = if api_key.is_some() && !exist_config_path {
Self::default()
} else {
Self::load_config(&config_path)?
};
if api_key.is_some() {
config.api_key = api_key;
}
if config.api_key.is_none() {
bail!("api_key not set");
// Compatible with old configuration files
if exist_config_path {
config.compat_old_config(&config_path)?;
}
if let Some(name) = config.model_name.clone() {
if let Some(name) = config.model.clone() {
config.set_model(&name)?;
}
config.merge_env_vars();
config.maybe_proxy();
config.load_roles()?;
Ok(config)
@ -211,12 +196,6 @@ impl Config {
Self::local_file(CONFIG_FILE_NAME)
}
pub fn get_api_key(&self) -> (String, Option<String>) {
let api_key = self.api_key.as_ref().expect("api_key not set");
let organization_id = self.organization_id.as_ref();
(api_key.into(), organization_id.cloned())
}
pub fn roles_file() -> Result<PathBuf> {
let env_name = get_env_name("roles_file");
env::var(env_name).map_or_else(
@ -283,14 +262,6 @@ impl Config {
}
}
pub const fn get_connect_timeout(&self) -> Duration {
Duration::from_secs(self.connect_timeout as u64)
}
pub fn get_model(&self) -> (String, usize) {
self.model.clone()
}
pub fn build_messages(&self, content: &str) -> Result<Vec<Message>> {
#[allow(clippy::option_if_let_else)]
let messages = if let Some(conversation) = self.conversation.as_ref() {
@ -302,24 +273,29 @@ impl Config {
vec![message]
};
let tokens = num_tokens_from_messages(&messages);
if tokens >= self.model.1 {
if tokens >= self.model_info.max_tokens {
bail!("Exceed max tokens limit")
}
Ok(messages)
}
pub fn set_model(&mut self, name: &str) -> Result<()> {
if let Some(token) = MODELS.iter().find(|(v, _)| *v == name).map(|(_, v)| *v) {
self.model = (name.to_string(), token);
} else {
bail!("Invalid model")
pub fn set_model(&mut self, value: &str) -> Result<()> {
let models = list_models(self);
if value.contains(':') {
if let Some(model) = models.iter().find(|v| v.stringify() == value) {
self.model_info = model.clone();
return Ok(());
}
} else if let Some(model) = models.iter().find(|v| v.client == value) {
self.model_info = model.clone();
return Ok(());
}
Ok(())
bail!("Invalid model")
}
pub const fn get_reamind_tokens(&self) -> usize {
let mut tokens = self.model.1;
let mut tokens = self.model_info.max_tokens;
if let Some(conversation) = self.conversation.as_ref() {
tokens = tokens.saturating_sub(conversation.tokens);
}
@ -331,30 +307,19 @@ impl Config {
let state = if path.exists() { "" } else { " ⚠️" };
format!("{}{state}", path.display())
};
let proxy = self
.proxy
.as_ref()
.map_or_else(|| String::from("-"), std::string::ToString::to_string);
let temperature = self
.temperature
.map_or_else(|| String::from("-"), |v| v.to_string());
let (api_key, organization_id) = self.get_api_key();
let api_key = mask_text(&api_key, 3, 4);
let organization_id = organization_id.map_or_else(|| "-".into(), |v| mask_text(&v, 3, 4));
let items = vec![
("config_file", file_info(&Self::config_file()?)),
("roles_file", file_info(&Self::roles_file()?)),
("messages_file", file_info(&Self::messages_file()?)),
("api_key", api_key),
("organization_id", organization_id),
("model", self.model.0.to_string()),
("model", self.model_info.stringify()),
("temperature", temperature),
("save", self.save.to_string()),
("highlight", self.highlight.to_string()),
("proxy", proxy),
("conversation_first", self.conversation_first.to_string()),
("light_theme", self.light_theme.to_string()),
("connect_timeout", self.connect_timeout.to_string()),
("dry_run", self.dry_run.to_string()),
("vi_keybindings", self.vi_keybindings.to_string()),
];
@ -373,7 +338,11 @@ impl Config {
.collect();
completion.extend(SET_COMPLETIONS.map(std::string::ToString::to_string));
completion.extend(MODELS.map(|(v, _)| format!(".model {v}")));
completion.extend(
list_models(self)
.iter()
.map(|v| format!(".model {}", v.stringify())),
);
completion
}
@ -402,13 +371,6 @@ impl Config {
let value = value.parse().with_context(|| "Invalid value")?;
self.highlight = value;
}
"proxy" => {
if unset {
self.proxy = None;
} else {
self.proxy = Some(value.to_string());
}
}
"dry_run" => {
let value = value.parse().with_context(|| "Invalid value")?;
self.dry_run = value;
@ -501,44 +463,62 @@ impl Config {
}
}
fn maybe_proxy(&mut self) {
if self.proxy.is_some() {
return;
fn compat_old_config(&mut self, config_path: &PathBuf) -> Result<()> {
let content = read_to_string(config_path)?;
let value: serde_json::Value = serde_yaml::from_str(&content)?;
if value.get("client").is_some() {
return Ok(());
}
if let Ok(value) = env::var("HTTPS_PROXY").or_else(|_| env::var("ALL_PROXY")) {
self.proxy = Some(value);
if let Some(model_name) = value.get("model").and_then(|v| v.as_str()) {
if model_name.starts_with("gpt") {
self.model = Some(format!("{}:{}", OpenAIClient::name(), model_name));
}
}
if let Some(ClientConfig::OpenAI(client_config)) = self.clients.get_mut(0) {
if let Some(api_key) = value.get("api_key").and_then(|v| v.as_str()) {
client_config.api_key = Some(api_key.to_string())
}
if let Some(organization_id) = value.get("organization_id").and_then(|v| v.as_str()) {
client_config.organization_id = Some(organization_id.to_string())
}
if let Some(proxy) = value.get("proxy").and_then(|v| v.as_str()) {
client_config.proxy = Some(proxy.to_string())
}
if let Some(connect_timeout) = value.get("connect_timeout").and_then(|v| v.as_i64()) {
client_config.connect_timeout = Some(connect_timeout as _)
}
}
Ok(())
}
}
fn create_config_file(config_path: &Path) -> Result<()> {
let confirm_map_err = |_| anyhow!("Not finish questionnaire, try again later.");
let text_map_err = |_| anyhow!("An error happened when asking for your key, try again later.");
let ans = Confirm::new("No config file, create a new one?")
.with_default(true)
.prompt()
.map_err(confirm_map_err)?;
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
if !ans {
exit(0);
}
let api_key = Text::new("OpenAI API Key:")
.prompt()
.map_err(text_map_err)?;
let mut raw_config = format!("api_key: {api_key}\n");
let ans = Confirm::new("Use proxy?")
.with_default(false)
let client = Select::new("Choose bots?", all_clients())
.prompt()
.map_err(confirm_map_err)?;
if ans {
let proxy = Text::new("Set proxy:").prompt().map_err(text_map_err)?;
raw_config.push_str(&format!("proxy: {proxy}\n"));
}
.map_err(|_| anyhow!("An error happened when selecting bots, try again later."))?;
let mut raw_config = create_client_config(client)?;
raw_config.push_str(&format!("model: {client}\n"));
let ans = Confirm::new("Save chat messages")
.with_default(true)
.prompt()
.map_err(confirm_map_err)?;
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
if ans {
raw_config.push_str("save: true\n");
}
@ -571,14 +551,6 @@ fn ensure_parent_exists(path: &Path) -> Result<()> {
Ok(())
}
fn get_env_name(key: &str) -> String {
format!(
"{}_{}",
env!("CARGO_CRATE_NAME").to_ascii_uppercase(),
key.to_ascii_uppercase(),
)
}
fn set_bool(target: &mut bool, value: &str) {
match value {
"1" | "true" => *target = true,

@ -8,11 +8,12 @@ mod term;
mod utils;
use crate::cli::Cli;
use crate::client::ChatGptClient;
use crate::client::Client;
use crate::config::{Config, SharedConfig};
use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context, Result};
use clap::Parser;
use client::{init_client, list_models};
use crossbeam::sync::WaitGroup;
use is_terminal::IsTerminal;
use parking_lot::RwLock;
@ -21,6 +22,7 @@ use repl::{AbortSignal, Repl};
use std::io::{stdin, Read};
use std::sync::Arc;
use std::{io::stdout, process::exit};
use tokio::runtime::Runtime;
use utils::cl100k_base_singleton;
fn main() -> Result<()> {
@ -36,8 +38,8 @@ fn main() -> Result<()> {
exit(0);
}
if cli.list_models {
for (name, _) in &config::MODELS {
println!("{name}");
for model in list_models(&config.read()) {
println!("{}", model.stringify());
}
exit(0);
}
@ -69,24 +71,25 @@ fn main() -> Result<()> {
exit(0);
}
let no_stream = cli.no_stream;
let client = ChatGptClient::init(config.clone())?;
let runtime = init_runtime()?;
let client = init_client(config.clone(), runtime)?;
if atty::isnt(atty::Stream::Stdin) {
let mut input = String::new();
stdin().read_to_string(&mut input)?;
if let Some(text) = text {
input = format!("{text}\n{input}");
}
start_directive(&client, &config, &input, no_stream)
start_directive(client.as_ref(), &config, &input, no_stream)
} else {
match text {
Some(text) => start_directive(&client, &config, &text, no_stream),
Some(text) => start_directive(client.as_ref(), &config, &text, no_stream),
None => start_interactive(client, config),
}
}
}
fn start_directive(
client: &ChatGptClient,
client: &dyn Client,
config: &SharedConfig,
input: &str,
no_stream: bool,
@ -120,9 +123,16 @@ fn start_directive(
config.read().save_message(input, &output)
}
fn start_interactive(client: ChatGptClient, config: SharedConfig) -> Result<()> {
fn start_interactive(client: Box<dyn Client>, config: SharedConfig) -> Result<()> {
cl100k_base_singleton();
config.write().on_repl()?;
let mut repl = Repl::init(config.clone())?;
repl.run(client, config)
}
fn init_runtime() -> Result<Runtime> {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.with_context(|| "Failed to init tokio")
}

@ -7,7 +7,7 @@ use self::cmd::cmd_render_stream;
pub use self::markdown::MarkdownRender;
use self::repl::repl_render_stream;
use crate::client::ChatGptClient;
use crate::client::Client;
use crate::config::SharedConfig;
use crate::print_now;
use crate::repl::{ReplyStreamHandler, SharedAbortSignal};
@ -20,7 +20,7 @@ use std::thread::spawn;
#[allow(clippy::module_name_repetitions)]
pub fn render_stream(
input: &str,
client: &ChatGptClient,
client: &dyn Client,
config: &SharedConfig,
repl: bool,
abort: SharedAbortSignal,

@ -1,4 +1,4 @@
use crate::client::ChatGptClient;
use crate::client::Client;
use crate::config::SharedConfig;
use crate::print_now;
use crate::render::render_stream;
@ -26,7 +26,7 @@ pub enum ReplCmd {
#[allow(clippy::module_name_repetitions)]
pub struct ReplCmdHandler {
client: ChatGptClient,
client: Box<dyn Client>,
config: SharedConfig,
reply: RefCell<String>,
abort: SharedAbortSignal,
@ -35,7 +35,7 @@ pub struct ReplCmdHandler {
impl ReplCmdHandler {
#[allow(clippy::unnecessary_wraps)]
pub fn init(
client: ChatGptClient,
client: Box<dyn Client>,
config: SharedConfig,
abort: SharedAbortSignal,
) -> Result<Self> {
@ -59,7 +59,7 @@ impl ReplCmdHandler {
let wg = WaitGroup::new();
let ret = render_stream(
&input,
&self.client,
self.client.as_ref(),
&self.config,
true,
self.abort.clone(),

@ -9,7 +9,7 @@ pub use self::abort::*;
pub use self::handler::*;
pub use self::init::Repl;
use crate::client::ChatGptClient;
use crate::client::Client;
use crate::config::SharedConfig;
use crate::print_now;
use crate::term;
@ -35,7 +35,7 @@ pub const REPL_COMMANDS: [(&str, &str); 13] = [
];
impl Repl {
pub fn run(&mut self, client: ChatGptClient, config: SharedConfig) -> Result<()> {
pub fn run(&mut self, client: Box<dyn Client>, config: SharedConfig) -> Result<()> {
let abort = AbortSignal::new();
let handler = ReplCmdHandler::init(client, config, abort.clone())?;
print_now!("Welcome to aichat {}\n", env!("CARGO_PKG_VERSION"));

@ -24,32 +24,20 @@ pub fn now() -> String {
now.to_rfc3339_opts(SecondsFormat::Secs, false)
}
pub fn get_env_name(key: &str) -> String {
format!(
"{}_{}",
env!("CARGO_CRATE_NAME").to_ascii_uppercase(),
key.to_ascii_uppercase(),
)
}
#[allow(unused)]
pub fn emphasis(text: &str) -> String {
text.stylize().with(Color::White).to_string()
}
pub fn mask_text(text: &str, head: usize, tail: usize) -> String {
if text.len() <= head + tail {
return text.to_string();
}
format!("{}...{}", &text[0..head], &text[text.len() - tail..])
}
pub fn copy(src: &str) -> Result<(), arboard::Error> {
let mut clipboard = Clipboard::new()?;
clipboard.set_text(src)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mask_text() {
assert_eq!(mask_text("123456", 3, 4), "123456");
assert_eq!(mask_text("1234567", 3, 4), "1234567");
assert_eq!(mask_text("12345678", 3, 4), "123...5678");
assert_eq!(mask_text("12345678", 4, 3), "1234...678");
}
}

Loading…
Cancel
Save