init commit

pull/1/head
sigoden 1 year ago
commit dfebab910d

1
.gitignore vendored

@ -0,0 +1 @@
/target

1800
Cargo.lock generated

File diff suppressed because it is too large Load Diff

@ -0,0 +1,21 @@
[package]
name = "aichat"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.69"
bytes = "1.4.0"
clap = { version = "4.1.8", features = ["derive", "string"] }
dirs = "4.0.0"
eventsource-stream = "0.2.3"
futures-util = "0.3.26"
inquire = "0.5.3"
reedline = "0.16.0"
reqwest = { version = "0.11.14", features = ["json", "stream"] }
serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.93"
tokio = { version = "1.26.0", features = ["full"] }
toml = "0.7.2"

@ -0,0 +1,44 @@
use std::{fs::read_to_string, path::Path};
use anyhow::{anyhow, Result};
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
/// Openai api key
pub api_key: String,
/// Set proxy
pub proxy: Option<String>,
/// Used only for debugging
#[serde(default)]
pub dry_run: bool,
/// Predefined rules
#[serde(default)]
pub roles: Vec<Role>,
}
impl Config {
pub fn init(path: &Path) -> Result<Config> {
let content = read_to_string(path)
.map_err(|err| anyhow!("Failed to load config at {}, {err}", path.display()))?;
let config: Config =
toml::from_str(&content).map_err(|err| anyhow!("Invalid config, {err}"))?;
Ok(config)
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct Role {
/// Role name
pub name: String,
/// Prompt text send to ai for setting up a rule
pub prompt: String,
/// First sentense will append to prompt
pub first_sentense: String,
}
impl Role {
pub fn generate(&self, text: &str) -> String {
format!("{} {} {}", self.prompt, self.first_sentense, text)
}
}

@ -0,0 +1,295 @@
mod config;
use std::io::{stdout, Write};
use std::path::Path;
use std::path::PathBuf;
use std::process::exit;
use config::Config;
use anyhow::{anyhow, Result};
use clap::{Arg, ArgAction, Command};
use eventsource_stream::{EventStream, Eventsource};
use futures_util::Stream;
use futures_util::StreamExt;
use inquire::{Confirm, Text};
use reedline::{
default_emacs_keybindings, ColumnarMenu, DefaultCompleter, DefaultPrompt, DefaultPromptSegment,
Emacs, KeyCode, KeyModifiers, Reedline, ReedlineEvent, ReedlineMenu, Signal,
};
use reqwest::{Client, Proxy};
use serde_json::{json, Value};
use tokio::runtime::Runtime;
const API_URL: &str = "https://api.openai.com/v1/chat/completions";
const HELP: &str = r###".exit Exit the REPL.
.help Print this help message.
.role Specify the role that the AI will play.
Press Ctrl+C to abort current chat, Ctrl+D to exit the REPL"###;
fn main() {
if let Err(err) = start() {
eprintln!("error: {err}");
exit(1);
}
}
fn start() -> Result<()> {
let matches = Command::new(env!("CARGO_CRATE_NAME"))
.version(env!("CARGO_PKG_VERSION"))
.author(env!("CARGO_PKG_AUTHORS"))
.arg(
Arg::new("role")
.short('r')
.long("role")
.action(ArgAction::Set)
.help("Specify the role that the AI will play"),
)
.arg(
Arg::new("text")
.action(ArgAction::Append)
.help("Input text"),
)
.get_matches();
let mut text = matches.get_many::<String>("text").map(|v| {
v.map(|x| x.trim().to_string())
.collect::<Vec<String>>()
.join(" ")
});
let config_path = get_config_path()?;
if !config_path.exists() && text.is_none() {
create_config_file(&config_path)?;
}
let config = Config::init(&config_path)?;
let role = matches.get_one::<String>("role").cloned();
if let (Some(name), Some(text_)) = (role.as_ref(), text.as_ref()) {
let role = config
.roles
.iter()
.find(|v| &v.name == name)
.ok_or_else(|| anyhow!("Unknown role \"{name}\" "))?;
text = Some(role.generate(text_));
};
let client = init_client(&config)?;
let runtime = init_runtime()?;
match text {
Some(text) => {
let output = runtime.block_on(async move { acquire(&client, &config, &text).await })?;
println!("{output}");
}
None => run_repl(runtime, client, config, role)?,
}
Ok(())
}
fn run_repl(runtime: Runtime, client: Client, config: Config, role: Option<String>) -> Result<()> {
println!("Welcome to aichat {}", env!("CARGO_PKG_VERSION"));
println!("Type \".help\" for more information.");
let send_line = |line: String| -> Result<()> {
if line.is_empty() {
return Ok(());
}
if let Err(err) = runtime.block_on(handle_input(&client, &config, &line)) {
dump(format!("error: {err}"));
}
Ok(())
};
let handle_line = |line: String| -> Result<bool> {
if line.starts_with('.') {
let (name, args) = match line.split_once(' ') {
Some((head, tail)) => (head, Some(tail.trim())),
None => (line.as_str(), None),
};
match name {
".exit" => {
return Ok(true);
}
".help" => {
dump(HELP);
}
".role" => match args {
Some(name) => match config.roles.iter().find(|v| v.name == name) {
Some(role) => {
send_line(role.prompt.clone())?;
}
None => dump("Unknown role"),
},
None => dump("Usage: .role <name>"),
},
_ => {
dump("Unknown command. Type \".help\" for more information.");
}
}
} else {
send_line(line)?;
}
Ok(false)
};
if let Some(name) = role {
handle_line(format!("role {name}"))?;
}
let mut commands = vec![".help".into(), ".exit".into(), ".role".into()];
commands.extend(config.roles.iter().map(|v| format!(".role {}", v.name)));
let mut completer = DefaultCompleter::with_inclusions(&['.']).set_min_word_len(2);
completer.insert(commands.clone());
let completer = Box::new(completer);
let completion_menu = Box::new(ColumnarMenu::default().with_name("completion_menu"));
let mut keybindings = default_emacs_keybindings();
keybindings.add_binding(
KeyModifiers::NONE,
KeyCode::Tab,
ReedlineEvent::UntilFound(vec![
ReedlineEvent::Menu("completion_menu".to_string()),
ReedlineEvent::MenuNext,
]),
);
let edit_mode = Box::new(Emacs::new(keybindings));
let mut line_editor = Reedline::create()
.with_completer(completer)
.with_menu(ReedlineMenu::EngineCompleter(completion_menu))
.with_edit_mode(edit_mode);
let prompt = DefaultPrompt::new(DefaultPromptSegment::Empty, DefaultPromptSegment::Empty);
loop {
let sig = line_editor.read_line(&prompt);
match sig {
Ok(Signal::Success(line)) => {
let quit = handle_line(line)?;
if quit {
break;
}
}
Ok(Signal::CtrlD) | Ok(Signal::CtrlC) => {
break;
}
Err(err) => {
eprintln!("{err:?}");
break;
}
}
}
Ok(())
}
async fn handle_input(client: &Client, config: &Config, text: &str) -> Result<()> {
if config.dry_run {
dump(text);
return Ok(());
}
let mut stream = acquire_stream(client, config, text).await?;
while let Some(part) = stream.next().await {
let chunk = part?.data;
if chunk == "[DONE]" {
println!();
stdout().flush().unwrap();
break;
} else {
let data: Value = serde_json::from_str(&chunk)?;
let text = data["choices"][0]["delta"]["content"]
.as_str()
.unwrap_or_default();
print!("{text}");
stdout().flush().unwrap();
}
}
Ok(())
}
fn init_client(config: &Config) -> Result<Client> {
let mut builder = Client::builder();
if let Some(proxy) = config.proxy.as_ref() {
builder =
builder.proxy(Proxy::all(proxy).map_err(|err| anyhow!("Invalid config.proxy, {err}"))?);
}
let client = builder
.build()
.map_err(|err| anyhow!("Failed to init http client, {err}"))?;
Ok(client)
}
fn init_runtime() -> Result<Runtime> {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|err| anyhow!("Failed to init tokio, {err}"))
}
fn create_config_file(config_path: &Path) -> Result<()> {
let ans = Confirm::new("No config file, create a new one?")
.with_default(true)
.prompt()
.map_err(|_| anyhow!("Error with questionnaire, try again later"))?;
if !ans {
exit(0);
}
let api_key = Text::new("Openai API Key:")
.prompt()
.map_err(|_| anyhow!("An error happened when asking for your key, try again later."))?;
std::fs::write(config_path, format!("api_key = \"{api_key}\"\n"))
.map_err(|err| anyhow!("Failed to write to config file, {err}"))?;
Ok(())
}
async fn acquire(client: &Client, config: &Config, content: &str) -> Result<String> {
if config.dry_run {
return Ok(content.to_string());
}
let body = json!({
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": content}]
});
let data: Value = client
.post(API_URL)
.bearer_auth(&config.api_key)
.json(&body)
.send()
.await?
.json()
.await?;
let output = data["choices"][0]["message"]["content"]
.as_str()
.ok_or_else(|| anyhow!("Unexpect response {data}"))?;
Ok(output.to_string())
}
async fn acquire_stream(
client: &Client,
config: &Config,
content: &str,
) -> Result<EventStream<impl Stream<Item = reqwest::Result<bytes::Bytes>>>> {
let body = json!({
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": content}],
"stream": true,
});
let stream = client
.post(API_URL)
.bearer_auth(&config.api_key)
.json(&body)
.send()
.await?
.bytes_stream()
.eventsource();
Ok(stream)
}
fn dump<T: ToString>(text: T) {
println!("{}", text.to_string());
stdout().flush().unwrap();
}
fn get_config_path() -> Result<PathBuf> {
let config_dir = dirs::home_dir().ok_or_else(|| anyhow!("No home dir"))?;
Ok(config_dir.join(format!(".{}.toml", env!("CARGO_CRATE_NAME"))))
}
Loading…
Cancel
Save