mirror of https://github.com/sigoden/aichat
feat: support multi bots and custom url (#150)
parent
f4160ff85b
commit
7d8564cafb
@ -1,177 +0,0 @@
|
||||
use crate::config::SharedConfig;
|
||||
use crate::repl::{ReplyStreamHandler, SharedAbortSignal};
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::{Client, Proxy, RequestBuilder};
|
||||
use serde_json::{json, Value};
|
||||
use std::time::Duration;
|
||||
use tokio::runtime::Runtime;
|
||||
use tokio::time::sleep;
|
||||
|
||||
const API_URL: &str = "https://api.openai.com/v1/chat/completions";
|
||||
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Debug)]
|
||||
pub struct ChatGptClient {
|
||||
config: SharedConfig,
|
||||
runtime: Runtime,
|
||||
}
|
||||
|
||||
impl ChatGptClient {
|
||||
pub fn init(config: SharedConfig) -> Result<Self> {
|
||||
let runtime = init_runtime()?;
|
||||
let s = Self { config, runtime };
|
||||
let _ = s.build_client()?; // check error
|
||||
Ok(s)
|
||||
}
|
||||
|
||||
pub fn send_message(&self, input: &str) -> Result<String> {
|
||||
self.runtime.block_on(async {
|
||||
self.send_message_inner(input)
|
||||
.await
|
||||
.with_context(|| "Failed to fetch")
|
||||
})
|
||||
}
|
||||
|
||||
pub fn send_message_streaming(
|
||||
&self,
|
||||
input: &str,
|
||||
handler: &mut ReplyStreamHandler,
|
||||
) -> Result<()> {
|
||||
async fn watch_abort(abort: SharedAbortSignal) {
|
||||
loop {
|
||||
if abort.aborted() {
|
||||
break;
|
||||
}
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
}
|
||||
let abort = handler.get_abort();
|
||||
self.runtime.block_on(async {
|
||||
tokio::select! {
|
||||
ret = self.send_message_streaming_inner(input, handler) => {
|
||||
handler.done()?;
|
||||
ret.with_context(|| "Failed to fetch stream")
|
||||
}
|
||||
_ = watch_abort(abort.clone()) => {
|
||||
handler.done()?;
|
||||
Ok(())
|
||||
},
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
abort.set_ctrlc();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_message_inner(&self, content: &str) -> Result<String> {
|
||||
if self.config.read().dry_run {
|
||||
return Ok(self.config.read().echo_messages(content));
|
||||
}
|
||||
let builder = self.request_builder(content, false)?;
|
||||
let data: Value = builder.send().await?.json().await?;
|
||||
if let Some(err_msg) = data["error"]["message"].as_str() {
|
||||
bail!("Request failed, {err_msg}");
|
||||
}
|
||||
|
||||
let output = data["choices"][0]["message"]["content"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow!("Unexpected response {data}"))?;
|
||||
|
||||
Ok(output.to_string())
|
||||
}
|
||||
|
||||
async fn send_message_streaming_inner(
|
||||
&self,
|
||||
content: &str,
|
||||
handler: &mut ReplyStreamHandler,
|
||||
) -> Result<()> {
|
||||
if self.config.read().dry_run {
|
||||
handler.text(&self.config.read().echo_messages(content))?;
|
||||
return Ok(());
|
||||
}
|
||||
let builder = self.request_builder(content, true)?;
|
||||
let res = builder.send().await?;
|
||||
if !res.status().is_success() {
|
||||
let data: Value = res.json().await?;
|
||||
if let Some(err_msg) = data["error"]["message"].as_str() {
|
||||
bail!("Request failed, {err_msg}");
|
||||
}
|
||||
bail!("Request failed");
|
||||
}
|
||||
let mut stream = res.bytes_stream().eventsource();
|
||||
while let Some(part) = stream.next().await {
|
||||
let chunk = part?.data;
|
||||
if chunk == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
let data: Value = serde_json::from_str(&chunk)?;
|
||||
let text = data["choices"][0]["delta"]["content"]
|
||||
.as_str()
|
||||
.unwrap_or_default();
|
||||
if text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
handler.text(text)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_client(&self) -> Result<Client> {
|
||||
let mut builder = Client::builder();
|
||||
if let Some(proxy) = self.config.read().proxy.as_ref() {
|
||||
builder = builder
|
||||
.proxy(Proxy::all(proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?);
|
||||
}
|
||||
let timeout = self.config.read().get_connect_timeout();
|
||||
let client = builder
|
||||
.connect_timeout(timeout)
|
||||
.build()
|
||||
.with_context(|| "Failed to build http client")?;
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
|
||||
let (model, _) = self.config.read().get_model();
|
||||
let messages = self.config.read().build_messages(content)?;
|
||||
let mut body = json!({
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
});
|
||||
|
||||
if let Some(v) = self.config.read().get_temperature() {
|
||||
body.as_object_mut()
|
||||
.and_then(|m| m.insert("temperature".into(), json!(v)));
|
||||
}
|
||||
|
||||
if stream {
|
||||
body.as_object_mut()
|
||||
.and_then(|m| m.insert("stream".into(), json!(true)));
|
||||
}
|
||||
|
||||
let (api_key, organization_id) = self.config.read().get_api_key();
|
||||
|
||||
let mut builder = self
|
||||
.build_client()?
|
||||
.post(API_URL)
|
||||
.bearer_auth(api_key)
|
||||
.json(&body);
|
||||
|
||||
if let Some(organization_id) = organization_id {
|
||||
builder = builder.header("OpenAI-Organization", organization_id);
|
||||
}
|
||||
|
||||
Ok(builder)
|
||||
}
|
||||
}
|
||||
|
||||
fn init_runtime() -> Result<Runtime> {
|
||||
tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.with_context(|| "Failed to init tokio")
|
||||
}
|
@ -0,0 +1,181 @@
|
||||
use super::openai::{openai_send_message, openai_send_message_streaming};
|
||||
use super::{Client, ModelInfo};
|
||||
|
||||
use crate::config::SharedConfig;
|
||||
use crate::repl::ReplyStreamHandler;
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use inquire::{Confirm, Text};
|
||||
use reqwest::{Client as ReqwestClient, Proxy, RequestBuilder};
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use std::time::Duration;
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Debug)]
|
||||
pub struct LocalAIClient {
|
||||
global_config: SharedConfig,
|
||||
local_config: LocalAIConfig,
|
||||
model_info: ModelInfo,
|
||||
runtime: Runtime,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct LocalAIConfig {
|
||||
pub url: String,
|
||||
pub api_key: Option<String>,
|
||||
pub models: Vec<LocalAIModel>,
|
||||
pub proxy: Option<String>,
|
||||
/// Set a timeout in seconds for connect to server
|
||||
pub connect_timeout: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct LocalAIModel {
|
||||
name: String,
|
||||
max_tokens: usize,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Client for LocalAIClient {
|
||||
fn get_config(&self) -> &SharedConfig {
|
||||
&self.global_config
|
||||
}
|
||||
|
||||
fn get_runtime(&self) -> &Runtime {
|
||||
&self.runtime
|
||||
}
|
||||
|
||||
async fn send_message_inner(&self, content: &str) -> Result<String> {
|
||||
let builder = self.request_builder(content, false)?;
|
||||
openai_send_message(builder).await
|
||||
}
|
||||
|
||||
async fn send_message_streaming_inner(
|
||||
&self,
|
||||
content: &str,
|
||||
handler: &mut ReplyStreamHandler,
|
||||
) -> Result<()> {
|
||||
let builder = self.request_builder(content, true)?;
|
||||
openai_send_message_streaming(builder, handler).await
|
||||
}
|
||||
}
|
||||
|
||||
impl LocalAIClient {
|
||||
pub fn new(
|
||||
global_config: SharedConfig,
|
||||
local_config: LocalAIConfig,
|
||||
model_info: ModelInfo,
|
||||
runtime: Runtime,
|
||||
) -> Self {
|
||||
Self {
|
||||
global_config,
|
||||
local_config,
|
||||
model_info,
|
||||
runtime,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn name() -> &'static str {
|
||||
"localai"
|
||||
}
|
||||
|
||||
pub fn list_models(local_config: &LocalAIConfig) -> Vec<(String, usize)> {
|
||||
local_config
|
||||
.models
|
||||
.iter()
|
||||
.map(|v| (v.name.to_string(), v.max_tokens))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn create_config() -> Result<String> {
|
||||
let mut client_config = format!("clients:\n - type: {}\n", Self::name());
|
||||
|
||||
let url = Text::new("URL:")
|
||||
.prompt()
|
||||
.map_err(|_| anyhow!("An error happened when asking for url, try again later."))?;
|
||||
|
||||
client_config.push_str(&format!(" url: {url}\n"));
|
||||
|
||||
let ans = Confirm::new("Use auth?")
|
||||
.with_default(false)
|
||||
.prompt()
|
||||
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
|
||||
|
||||
if ans {
|
||||
let api_key = Text::new("API key:").prompt().map_err(|_| {
|
||||
anyhow!("An error happened when asking for api key, try again later.")
|
||||
})?;
|
||||
|
||||
client_config.push_str(&format!(" api_key: {api_key}\n"));
|
||||
}
|
||||
|
||||
let model_name = Text::new("Model Name:").prompt().map_err(|_| {
|
||||
anyhow!("An error happened when asking for model name, try again later.")
|
||||
})?;
|
||||
|
||||
let max_tokens = Text::new("Max tokens:").prompt().map_err(|_| {
|
||||
anyhow!("An error happened when asking for max tokens, try again later.")
|
||||
})?;
|
||||
|
||||
let ans = Confirm::new("Use proxy?")
|
||||
.with_default(false)
|
||||
.prompt()
|
||||
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
|
||||
|
||||
if ans {
|
||||
let proxy = Text::new("Set proxy:").prompt().map_err(|_| {
|
||||
anyhow!("An error happened when asking for proxy, try again later.")
|
||||
})?;
|
||||
client_config.push_str(&format!(" proxy: {proxy}\n"));
|
||||
}
|
||||
|
||||
client_config.push_str(&format!(
|
||||
" models:\n - name: {model_name}\n max_tokens: {max_tokens}\n"
|
||||
));
|
||||
|
||||
Ok(client_config)
|
||||
}
|
||||
|
||||
fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
|
||||
let messages = self.global_config.read().build_messages(content)?;
|
||||
|
||||
let mut body = json!({
|
||||
"model": self.model_info.name,
|
||||
"messages": messages,
|
||||
});
|
||||
|
||||
if let Some(v) = self.global_config.read().get_temperature() {
|
||||
body.as_object_mut()
|
||||
.and_then(|m| m.insert("temperature".into(), json!(v)));
|
||||
}
|
||||
|
||||
if stream {
|
||||
body.as_object_mut()
|
||||
.and_then(|m| m.insert("stream".into(), json!(true)));
|
||||
}
|
||||
|
||||
let client = {
|
||||
let mut builder = ReqwestClient::builder();
|
||||
if let Some(proxy) = &self.local_config.proxy {
|
||||
builder = builder
|
||||
.proxy(Proxy::all(proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?);
|
||||
}
|
||||
let timeout = Duration::from_secs(self.local_config.connect_timeout.unwrap_or(10));
|
||||
builder
|
||||
.connect_timeout(timeout)
|
||||
.build()
|
||||
.with_context(|| "Failed to build client")?
|
||||
};
|
||||
|
||||
let mut builder = client.post(&self.local_config.url);
|
||||
if let Some(api_key) = &self.local_config.api_key {
|
||||
builder = builder.bearer_auth(api_key);
|
||||
};
|
||||
builder = builder.json(&body);
|
||||
|
||||
Ok(builder)
|
||||
}
|
||||
}
|
@ -0,0 +1,198 @@
|
||||
pub mod localai;
|
||||
pub mod openai;
|
||||
|
||||
use self::{
|
||||
localai::LocalAIConfig,
|
||||
openai::{OpenAIClient, OpenAIConfig},
|
||||
};
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use std::time::Duration;
|
||||
use tokio::runtime::Runtime;
|
||||
use tokio::time::sleep;
|
||||
|
||||
use crate::{
|
||||
client::localai::LocalAIClient,
|
||||
config::{Config, SharedConfig},
|
||||
repl::{ReplyStreamHandler, SharedAbortSignal},
|
||||
};
|
||||
|
||||
#[allow(clippy::struct_excessive_bools)]
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ClientConfig {
|
||||
#[serde(rename = "openai")]
|
||||
OpenAI(OpenAIConfig),
|
||||
#[serde(rename = "localai")]
|
||||
LocalAI(LocalAIConfig),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelInfo {
|
||||
pub client: String,
|
||||
pub name: String,
|
||||
pub max_tokens: usize,
|
||||
pub index: usize,
|
||||
}
|
||||
|
||||
impl Default for ModelInfo {
|
||||
fn default() -> Self {
|
||||
let client = OpenAIClient::name();
|
||||
let (name, max_tokens) = &OpenAIClient::list_models(&OpenAIConfig::default())[0];
|
||||
Self::new(client, name, *max_tokens, 0)
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelInfo {
|
||||
pub fn new(client: &str, name: &str, max_tokens: usize, index: usize) -> Self {
|
||||
Self {
|
||||
client: client.into(),
|
||||
name: name.into(),
|
||||
max_tokens,
|
||||
index,
|
||||
}
|
||||
}
|
||||
pub fn stringify(&self) -> String {
|
||||
format!("{}:{}", self.client, self.name)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Client {
|
||||
fn get_config(&self) -> &SharedConfig;
|
||||
|
||||
fn get_runtime(&self) -> &Runtime;
|
||||
|
||||
fn send_message(&self, content: &str) -> Result<String> {
|
||||
self.get_runtime().block_on(async {
|
||||
if self.get_config().read().dry_run {
|
||||
return Ok(self.get_config().read().echo_messages(content));
|
||||
}
|
||||
self.send_message_inner(content)
|
||||
.await
|
||||
.with_context(|| "Failed to fetch")
|
||||
})
|
||||
}
|
||||
|
||||
fn send_message_streaming(
|
||||
&self,
|
||||
content: &str,
|
||||
handler: &mut ReplyStreamHandler,
|
||||
) -> Result<()> {
|
||||
async fn watch_abort(abort: SharedAbortSignal) {
|
||||
loop {
|
||||
if abort.aborted() {
|
||||
break;
|
||||
}
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
}
|
||||
let abort = handler.get_abort();
|
||||
self.get_runtime().block_on(async {
|
||||
tokio::select! {
|
||||
ret = async {
|
||||
if self.get_config().read().dry_run {
|
||||
handler.text(&self.get_config().read().echo_messages(content))?;
|
||||
return Ok(());
|
||||
}
|
||||
self.send_message_streaming_inner(content, handler).await
|
||||
} => {
|
||||
handler.done()?;
|
||||
ret.with_context(|| "Failed to fetch stream")
|
||||
}
|
||||
_ = watch_abort(abort.clone()) => {
|
||||
handler.done()?;
|
||||
Ok(())
|
||||
},
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
abort.set_ctrlc();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_message_inner(&self, content: &str) -> Result<String>;
|
||||
|
||||
async fn send_message_streaming_inner(
|
||||
&self,
|
||||
content: &str,
|
||||
handler: &mut ReplyStreamHandler,
|
||||
) -> Result<()>;
|
||||
}
|
||||
|
||||
pub fn init_client(config: SharedConfig, runtime: Runtime) -> Result<Box<dyn Client>> {
|
||||
let model_info = config.read().model_info.clone();
|
||||
let model_info_err = |model_info: &ModelInfo| {
|
||||
bail!(
|
||||
"Unknown client {} at config.clients[{}]",
|
||||
&model_info.client,
|
||||
&model_info.index
|
||||
)
|
||||
};
|
||||
if model_info.client == OpenAIClient::name() {
|
||||
let local_config = {
|
||||
if let ClientConfig::OpenAI(c) = &config.read().clients[model_info.index] {
|
||||
c.clone()
|
||||
} else {
|
||||
return model_info_err(&model_info);
|
||||
}
|
||||
};
|
||||
Ok(Box::new(OpenAIClient::new(
|
||||
config,
|
||||
local_config,
|
||||
model_info,
|
||||
runtime,
|
||||
)))
|
||||
} else if model_info.client == LocalAIClient::name() {
|
||||
let local_config = {
|
||||
if let ClientConfig::LocalAI(c) = &config.read().clients[model_info.index] {
|
||||
c.clone()
|
||||
} else {
|
||||
return model_info_err(&model_info);
|
||||
}
|
||||
};
|
||||
Ok(Box::new(LocalAIClient::new(
|
||||
config,
|
||||
local_config,
|
||||
model_info,
|
||||
runtime,
|
||||
)))
|
||||
} else {
|
||||
bail!("Unknown client {}", &model_info.client)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn all_clients() -> Vec<&'static str> {
|
||||
vec![OpenAIClient::name(), LocalAIClient::name()]
|
||||
}
|
||||
|
||||
pub fn create_client_config(client: &str) -> Result<String> {
|
||||
if client == OpenAIClient::name() {
|
||||
OpenAIClient::create_config()
|
||||
} else if client == LocalAIClient::name() {
|
||||
LocalAIClient::create_config()
|
||||
} else {
|
||||
bail!("Unknown client {}", &client)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn list_models(config: &Config) -> Vec<ModelInfo> {
|
||||
config
|
||||
.clients
|
||||
.iter()
|
||||
.enumerate()
|
||||
.flat_map(|(i, v)| match v {
|
||||
ClientConfig::OpenAI(c) => OpenAIClient::list_models(c)
|
||||
.iter()
|
||||
.map(|(x, y)| ModelInfo::new(OpenAIClient::name(), x, *y, i))
|
||||
.collect::<Vec<ModelInfo>>(),
|
||||
ClientConfig::LocalAI(c) => LocalAIClient::list_models(c)
|
||||
.iter()
|
||||
.map(|(x, y)| ModelInfo::new(LocalAIClient::name(), x, *y, i))
|
||||
.collect::<Vec<ModelInfo>>(),
|
||||
})
|
||||
.collect()
|
||||
}
|
@ -0,0 +1,219 @@
|
||||
use super::{Client, ModelInfo};
|
||||
|
||||
use crate::repl::ReplyStreamHandler;
|
||||
use crate::{config::SharedConfig, utils::get_env_name};
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures_util::StreamExt;
|
||||
use inquire::{Confirm, Text};
|
||||
use reqwest::{Client as ReqwestClient, Proxy, RequestBuilder};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::env;
|
||||
use std::time::Duration;
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
const API_URL: &str = "https://api.openai.com/v1/chat/completions";
|
||||
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Debug)]
|
||||
pub struct OpenAIClient {
|
||||
global_config: SharedConfig,
|
||||
local_config: OpenAIConfig,
|
||||
model_info: ModelInfo,
|
||||
runtime: Runtime,
|
||||
}
|
||||
|
||||
#[allow(clippy::struct_excessive_bools)]
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct OpenAIConfig {
|
||||
pub api_key: Option<String>,
|
||||
pub organization_id: Option<String>,
|
||||
pub proxy: Option<String>,
|
||||
/// Set a timeout in seconds for connect to openai server
|
||||
pub connect_timeout: Option<u64>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Client for OpenAIClient {
|
||||
fn get_config(&self) -> &SharedConfig {
|
||||
&self.global_config
|
||||
}
|
||||
|
||||
fn get_runtime(&self) -> &Runtime {
|
||||
&self.runtime
|
||||
}
|
||||
|
||||
async fn send_message_inner(&self, content: &str) -> Result<String> {
|
||||
let builder = self.request_builder(content, false)?;
|
||||
openai_send_message(builder).await
|
||||
}
|
||||
|
||||
async fn send_message_streaming_inner(
|
||||
&self,
|
||||
content: &str,
|
||||
handler: &mut ReplyStreamHandler,
|
||||
) -> Result<()> {
|
||||
let builder = self.request_builder(content, true)?;
|
||||
openai_send_message_streaming(builder, handler).await
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAIClient {
|
||||
pub fn new(
|
||||
global_config: SharedConfig,
|
||||
local_config: OpenAIConfig,
|
||||
model_info: ModelInfo,
|
||||
runtime: Runtime,
|
||||
) -> Self {
|
||||
Self {
|
||||
global_config,
|
||||
local_config,
|
||||
model_info,
|
||||
runtime,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn name() -> &'static str {
|
||||
"openai"
|
||||
}
|
||||
|
||||
pub fn list_models(_local_config: &OpenAIConfig) -> Vec<(String, usize)> {
|
||||
vec![
|
||||
("gpt-3.5-turbo".into(), 4096),
|
||||
("gpt-3.5-turbo-16k".into(), 16384),
|
||||
("gpt-4".into(), 8192),
|
||||
("gpt-4-32k".into(), 32768),
|
||||
]
|
||||
}
|
||||
|
||||
pub fn create_config() -> Result<String> {
|
||||
let mut client_config = format!("clients:\n - type: {}\n", Self::name());
|
||||
|
||||
let api_key = Text::new("API key:")
|
||||
.prompt()
|
||||
.map_err(|_| anyhow!("An error happened when asking for api key, try again later."))?;
|
||||
|
||||
client_config.push_str(&format!(" api_key: {api_key}\n"));
|
||||
|
||||
let ans = Confirm::new("Has Organization?")
|
||||
.with_default(false)
|
||||
.prompt()
|
||||
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
|
||||
|
||||
if ans {
|
||||
let organization_id = Text::new("Organization ID:").prompt().map_err(|_| {
|
||||
anyhow!("An error happened when asking for proxy, try again later.")
|
||||
})?;
|
||||
client_config.push_str(&format!(" organization_id: {organization_id}\n"));
|
||||
}
|
||||
|
||||
let ans = Confirm::new("Use proxy?")
|
||||
.with_default(false)
|
||||
.prompt()
|
||||
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;
|
||||
|
||||
if ans {
|
||||
let proxy = Text::new("Set proxy:").prompt().map_err(|_| {
|
||||
anyhow!("An error happened when asking for proxy, try again later.")
|
||||
})?;
|
||||
client_config.push_str(&format!(" proxy: {proxy}\n"));
|
||||
}
|
||||
|
||||
Ok(client_config)
|
||||
}
|
||||
|
||||
fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
|
||||
let api_key = if let Some(api_key) = &self.local_config.api_key {
|
||||
api_key.to_string()
|
||||
} else if let Ok(api_key) = env::var(get_env_name("api_key")) {
|
||||
api_key.to_string()
|
||||
} else {
|
||||
bail!("Miss api_key")
|
||||
};
|
||||
|
||||
let messages = self.global_config.read().build_messages(content)?;
|
||||
|
||||
let mut body = json!({
|
||||
"model": self.model_info.name,
|
||||
"messages": messages,
|
||||
});
|
||||
|
||||
if let Some(v) = self.global_config.read().get_temperature() {
|
||||
body.as_object_mut()
|
||||
.and_then(|m| m.insert("temperature".into(), json!(v)));
|
||||
}
|
||||
|
||||
if stream {
|
||||
body.as_object_mut()
|
||||
.and_then(|m| m.insert("stream".into(), json!(true)));
|
||||
}
|
||||
|
||||
let client = {
|
||||
let mut builder = ReqwestClient::builder();
|
||||
if let Some(proxy) = &self.local_config.proxy {
|
||||
builder = builder
|
||||
.proxy(Proxy::all(proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?);
|
||||
}
|
||||
let timeout = Duration::from_secs(self.local_config.connect_timeout.unwrap_or(10));
|
||||
builder
|
||||
.connect_timeout(timeout)
|
||||
.build()
|
||||
.with_context(|| "Failed to build client")?
|
||||
};
|
||||
|
||||
let mut builder = client.post(API_URL).bearer_auth(api_key).json(&body);
|
||||
|
||||
if let Some(organization_id) = &self.local_config.organization_id {
|
||||
builder = builder.header("OpenAI-Organization", organization_id);
|
||||
}
|
||||
|
||||
Ok(builder)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn openai_send_message(builder: RequestBuilder) -> Result<String> {
|
||||
let data: Value = builder.send().await?.json().await?;
|
||||
if let Some(err_msg) = data["error"]["message"].as_str() {
|
||||
bail!("Request failed, {err_msg}");
|
||||
}
|
||||
|
||||
let output = data["choices"][0]["message"]["content"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow!("Unexpected response {data}"))?;
|
||||
|
||||
Ok(output.to_string())
|
||||
}
|
||||
|
||||
pub(crate) async fn openai_send_message_streaming(
|
||||
builder: RequestBuilder,
|
||||
handler: &mut ReplyStreamHandler,
|
||||
) -> Result<()> {
|
||||
let res = builder.send().await?;
|
||||
if !res.status().is_success() {
|
||||
let data: Value = res.json().await?;
|
||||
if let Some(err_msg) = data["error"]["message"].as_str() {
|
||||
bail!("Request failed, {err_msg}");
|
||||
}
|
||||
bail!("Request failed");
|
||||
}
|
||||
let mut stream = res.bytes_stream().eventsource();
|
||||
while let Some(part) = stream.next().await {
|
||||
let chunk = part?.data;
|
||||
if chunk == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
let data: Value = serde_json::from_str(&chunk)?;
|
||||
let text = data["choices"][0]["delta"]["content"]
|
||||
.as_str()
|
||||
.unwrap_or_default();
|
||||
if text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
handler.text(text)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
Loading…
Reference in New Issue