refactor: improve code quanity (#194)

- extends ModelInfo for tokens calculating
- refactor config/session.rs, improve export, render, getter/setter
- modify main.rs, allow --model override session.model
pull/195/head
sigoden 7 months ago committed by GitHub
parent da3c541b68
commit f6da06dad9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -13,7 +13,7 @@ keywords = ["chatgpt", "localai", "gpt", "repl"]
[dependencies]
anyhow = "1.0.69"
bytes = "1.4.0"
clap = { version = "4.1.8", features = ["derive", "string"] }
clap = { version = "4.1.8", features = ["derive"] }
dirs = "5.0.0"
futures-util = "0.3.26"
inquire = "0.6.2"
@ -43,7 +43,7 @@ reqwest-eventsource = "0.5.0"
[dependencies.reqwest]
version = "0.11.14"
features = ["json", "stream", "socks", "rustls-tls", "rustls-tls-native-roots"]
features = ["json", "socks", "rustls-tls", "rustls-tls-native-roots"]
default-features = false
[dependencies.syntect]

@ -235,17 +235,21 @@ You should run aichat with `-s/--session` or use the `.session` command to start
```
〉.session
temp1 to 5, odd only 4089
temp1 to 5, odd only 0
1, 3, 5
tempto 7 4070
tempto 7 19(0.46%)
1, 3, 5, 7
temp.exit session
temp.exit session 42(1.03%)
? Save session? (y/N)
```
The prompt on the right side is about the current usage of tokens and the proportion of tokens used,
compared to the maximum number of tokens allowed by the model.
### `.set` - modify the configuration temporarily

@ -1,4 +1,4 @@
use super::openai::openai_build_body;
use super::openai::{openai_build_body, openai_tokens_formula};
use super::{AzureOpenAIClient, ExtraConfig, ModelInfo, PromptKind, PromptType, SendData};
use anyhow::{anyhow, Result};
@ -46,7 +46,7 @@ impl AzureOpenAIClient {
local_config
.models
.iter()
.map(|v| ModelInfo::new(client, &v.name, v.max_tokens, index))
.map(|v| openai_tokens_formula(ModelInfo::new(index, client, &v.name).set_max_tokens(v.max_tokens)))
.collect()
}

@ -1,4 +1,4 @@
use super::openai::openai_build_body;
use super::openai::{openai_build_body, openai_tokens_formula};
use super::{ExtraConfig, LocalAIClient, ModelInfo, PromptKind, PromptType, SendData};
use anyhow::Result;
@ -45,7 +45,7 @@ impl LocalAIClient {
local_config
.models
.iter()
.map(|v| ModelInfo::new(client, &v.name, v.max_tokens, index))
.map(|v| openai_tokens_formula(ModelInfo::new(index, client, &v.name).set_max_tokens(v.max_tokens)))
.collect()
}

@ -38,7 +38,7 @@ impl OpenAIClient {
let client = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_tokens)| ModelInfo::new(client, name, Some(max_tokens), index))
.map(|(name, max_tokens)| openai_tokens_formula(ModelInfo::new(index, client, name).set_max_tokens(Some(max_tokens))))
.collect()
}
@ -135,3 +135,7 @@ pub fn openai_build_body(data: SendData, model: String) -> Value {
}
body
}
pub fn openai_tokens_formula(model: ModelInfo) -> ModelInfo {
model.set_tokens_formula(5, 2)
}

@ -1,5 +1,3 @@
use crate::utils::count_tokens;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Deserialize, Serialize)]
@ -25,22 +23,19 @@ pub enum MessageRole {
User,
}
#[allow(dead_code)]
impl MessageRole {
#[allow(dead_code)]
pub fn is_system(&self) -> bool {
matches!(self, MessageRole::System)
}
}
pub fn num_tokens_from_messages(messages: &[Message]) -> usize {
let mut num_tokens = 0;
for message in messages.iter() {
num_tokens += 4;
num_tokens += count_tokens(&message.content);
num_tokens += 1; // role always take 1 token
pub fn is_user(&self) -> bool {
matches!(self, MessageRole::User)
}
pub fn is_assistant(&self) -> bool {
matches!(self, MessageRole::Assistant)
}
num_tokens += 2;
num_tokens
}
#[cfg(test)]

@ -12,7 +12,6 @@ use crate::client::{
all_models, create_client_config, list_client_types, ClientConfig, ExtraConfig, OpenAIClient,
SendData,
};
use crate::config::message::num_tokens_from_messages;
use crate::render::RenderOptions;
use crate::utils::{get_env_name, light_theme_from_colorfgbg, now, prompt_op_err};
@ -274,7 +273,7 @@ impl Config {
pub fn set_temperature(&mut self, value: Option<f64>) -> Result<()> {
self.temperature = value;
if let Some(session) = self.session.as_mut() {
session.temperature = value;
session.set_temperature(value);
}
Ok(())
}
@ -298,13 +297,6 @@ impl Config {
let message = Message::new(content);
vec![message]
};
let tokens = num_tokens_from_messages(&messages);
if let Some(max_tokens) = self.model_info.max_tokens {
if tokens >= max_tokens {
bail!("Exceed max tokens limit")
}
}
Ok(messages)
}
@ -326,7 +318,7 @@ impl Config {
let models = all_models(self);
let mut model_info = None;
if value.contains(':') {
if let Some(model) = models.iter().find(|v| v.stringify() == value) {
if let Some(model) = models.iter().find(|v| v.full_name() == value) {
model_info = Some(model.clone());
}
} else if let Some(model) = models.iter().find(|v| v.client == value) {
@ -336,7 +328,7 @@ impl Config {
None => bail!("Unknown model '{}'", value),
Some(model_info) => {
if let Some(session) = self.session.as_mut() {
session.set_model(&model_info.stringify())?;
session.set_model(model_info.clone())?;
}
self.model_info = model_info;
Ok(())
@ -361,7 +353,7 @@ impl Config {
("roles_file", path_info(&Self::roles_file()?)),
("messages_file", path_info(&Self::messages_file()?)),
("sessions_dir", path_info(&Self::sessions_dir()?)),
("model", self.model_info.stringify()),
("model", self.model_info.full_name()),
("temperature", temperature),
("save", self.save.to_string()),
("highlight", self.highlight.to_string()),
@ -389,7 +381,7 @@ impl Config {
completion.extend(
all_models(self)
.iter()
.map(|v| format!(".model {}", v.stringify())),
.map(|v| format!(".model {}", v.full_name())),
);
let sessions = self.list_sessions().unwrap_or_default();
completion.extend(sessions.iter().map(|v| format!(".session {}", v)));
@ -444,7 +436,7 @@ impl Config {
}
self.session = Some(Session::new(
TEMP_SESSION_NAME,
&self.model_info.stringify(),
self.model_info.clone(),
self.role.clone(),
));
}
@ -453,13 +445,13 @@ impl Config {
if !session_path.exists() {
self.session = Some(Session::new(
name,
&self.model_info.stringify(),
self.model_info.clone(),
self.role.clone(),
));
} else {
let session = Session::load(name, &session_path)?;
let model = session.model.clone();
self.temperature = session.temperature;
let model = session.model().to_string();
self.temperature = session.temperature();
self.session = Some(session);
self.set_model(&model)?;
}
@ -472,7 +464,8 @@ impl Config {
"Start a session that incorporates the last question and answer?",
)
.with_default(false)
.prompt()?;
.prompt()
.map_err(prompt_op_err)?;
if ans {
session.add_message(input, output)?;
}
@ -487,13 +480,19 @@ impl Config {
self.last_message = None;
self.temperature = self.default_temperature;
if session.should_save() {
let ans = Confirm::new("Save session?").with_default(true).prompt()?;
let ans = Confirm::new("Save session?")
.with_default(false)
.prompt()
.map_err(prompt_op_err)?;
if !ans {
return Ok(());
}
let mut name = session.name.clone();
let mut name = session.name().to_string();
if session.is_temp() {
name = Text::new("Session name:").with_default(&name).prompt()?;
name = Text::new("Session name:")
.with_default(&name)
.prompt()
.map_err(prompt_op_err)?;
}
let session_path = Self::session_file(&name)?;
let sessions_dir = session_path.parent().ok_or_else(|| {
@ -558,17 +557,13 @@ impl Config {
pub fn render_prompt_right(&self) -> String {
if let Some(session) = &self.session {
let tokens = session.tokens;
// 10000(%32)
match self.model_info.max_tokens {
Some(max_tokens) => {
let ratio = tokens as f32 / max_tokens as f32;
let percent = ratio * 100.0;
let percent = (percent * 100.0).round() / 100.0;
format!("{tokens}({percent}%)")
}
None => format!("{tokens}"),
}
let (tokens, percent) = session.tokens_and_percent();
let percent = if percent == 0.0 {
String::new()
} else {
format!("({percent}%)")
};
format!("{tokens}{percent}")
} else {
String::new()
}
@ -576,6 +571,7 @@ impl Config {
pub fn prepare_send_data(&self, content: &str, stream: bool) -> Result<SendData> {
let messages = self.build_messages(content)?;
self.model_info.max_tokens_limit(&messages)?;
Ok(SendData {
messages,
temperature: self.get_temperature(),
@ -586,7 +582,7 @@ impl Config {
pub fn maybe_print_send_tokens(&self, input: &str) {
if self.dry_run {
if let Ok(messages) = self.build_messages(input) {
let tokens = num_tokens_from_messages(&messages);
let tokens = self.model_info.totatl_tokens(&messages);
println!(">>> This message consumes {tokens} tokens. <<<");
}
}
@ -642,7 +638,7 @@ impl Config {
bail!("No available model");
}
models[0].stringify()
models[0].full_name()
}
};
self.set_model(&model)?;

@ -1,27 +1,79 @@
use super::Message;
use crate::utils::count_tokens;
use anyhow::{bail, Result};
#[derive(Debug, Clone)]
pub struct ModelInfo {
pub client: String,
pub name: String,
pub max_tokens: Option<usize>,
pub index: usize,
pub max_tokens: Option<usize>,
pub per_message_tokens: usize,
pub bias_tokens: usize,
}
impl Default for ModelInfo {
fn default() -> Self {
ModelInfo::new("", "", None, 0)
ModelInfo::new(0, "", "")
}
}
impl ModelInfo {
pub fn new(client: &str, name: &str, max_tokens: Option<usize>, index: usize) -> Self {
pub fn new(index: usize, client: &str, name: &str) -> Self {
Self {
index,
client: client.into(),
name: name.into(),
max_tokens,
index,
max_tokens: None,
per_message_tokens: 0,
bias_tokens: 0,
}
}
pub fn stringify(&self) -> String {
pub fn set_max_tokens(mut self, max_tokens: Option<usize>) -> Self {
match max_tokens {
None | Some(0) => self.max_tokens = None,
_ => self.max_tokens = max_tokens,
}
self
}
pub fn set_tokens_formula(mut self, per_message_token: usize, bias_tokens: usize) -> Self {
self.per_message_tokens = per_message_token;
self.bias_tokens = bias_tokens;
self
}
pub fn full_name(&self) -> String {
format!("{}:{}", self.client, self.name)
}
pub fn messages_tokens(&self, messages: &[Message]) -> usize {
messages.iter().map(|v| count_tokens(&v.content)).sum()
}
pub fn totatl_tokens(&self, messages: &[Message]) -> usize {
if messages.is_empty() {
return 0;
}
let num_messages = messages.len();
let message_tokens = self.messages_tokens(messages);
if messages[num_messages - 1].role.is_user() {
num_messages * self.per_message_tokens + message_tokens
} else {
(num_messages - 1) * self.per_message_tokens + message_tokens
}
}
pub fn max_tokens_limit(&self, messages: &[Message]) -> Result<()> {
let total_tokens = self.totatl_tokens(messages) + self.bias_tokens;
if let Some(max_tokens) = self.max_tokens {
if total_tokens >= max_tokens {
bail!("Exceed max tokens limit")
}
}
Ok(())
}
}

@ -1,45 +1,47 @@
use super::message::{num_tokens_from_messages, Message, MessageRole};
use super::message::{Message, MessageRole};
use super::role::Role;
use super::ModelInfo;
use crate::render::MarkdownRender;
use anyhow::{bail, Context, Result};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::fs::{self, read_to_string};
use std::path::Path;
pub const TEMP_SESSION_NAME: &str = "temp";
#[derive(Debug, Clone, Deserialize, Serialize)]
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct Session {
model: String,
temperature: Option<f64>,
messages: Vec<Message>,
#[serde(skip)]
pub name: String,
#[serde(skip)]
pub path: Option<String>,
pub model: String,
pub tokens: usize,
pub temperature: Option<f64>,
pub messages: Vec<Message>,
#[serde(skip)]
pub dirty: bool,
#[serde(skip)]
pub role: Option<Role>,
#[serde(skip)]
pub name: String,
pub model_info: ModelInfo,
}
impl Session {
pub fn new(name: &str, model: &str, role: Option<Role>) -> Self {
pub fn new(name: &str, model_info: ModelInfo, role: Option<Role>) -> Self {
let temperature = role.as_ref().and_then(|v| v.temperature);
let mut value = Self {
path: None,
model: model.to_string(),
Self {
model: model_info.full_name(),
temperature,
tokens: 0,
messages: vec![],
name: name.to_string(),
path: None,
dirty: false,
role,
name: name.to_string(),
};
value.update_tokens();
value
model_info,
}
}
pub fn load(name: &str, path: &Path) -> Result<Self> {
@ -54,22 +56,64 @@ impl Session {
Ok(session)
}
pub fn name(&self) -> &str {
&self.name
}
pub fn model(&self) -> &str {
&self.model
}
pub fn temperature(&self) -> Option<f64> {
self.temperature
}
pub fn tokens(&self) -> usize {
self.model_info.totatl_tokens(&self.messages)
}
pub fn export(&self) -> Result<String> {
self.guard_save()?;
let output = serde_yaml::to_string(&self)
let (tokens, percent) = self.tokens_and_percent();
let mut data = json!({
"path": self.path,
"model": self.model(),
});
if let Some(temperature) = self.temperature() {
data["temperature"] = temperature.into();
}
data["total-tokens"] = tokens.into();
if let Some(max_tokens) = self.model_info.max_tokens {
data["max-tokens"] = max_tokens.into();
}
if percent != 0.0 {
data["total/max-tokens"] = format!("{}%", percent).into();
}
data["messages"] = json!(self.messages);
let output = serde_yaml::to_string(&data)
.with_context(|| format!("Unable to show info about session {}", &self.name))?;
Ok(output)
}
pub fn render(&self, render: &mut MarkdownRender) -> Result<String> {
let path = self.path.clone().unwrap_or_else(|| "-".to_string());
let temperature = self
.temperature
.temperature()
.map_or_else(|| String::from("-"), |v| v.to_string());
let max_tokens = self
.model_info
.max_tokens
.map(|v| v.to_string())
.unwrap_or_else(|| '-'.to_string());
let items = vec![
("path", self.path.clone().unwrap_or_else(|| "-".into())),
("model", self.model.clone()),
("tokens", self.tokens.to_string()),
("path", path),
("model", self.model().to_string()),
("temperature", temperature),
("max_tokens", max_tokens),
];
let mut lines = vec![];
for (name, value) in items {
@ -94,17 +138,32 @@ impl Session {
Ok(output)
}
pub fn tokens_and_percent(&self) -> (usize, f32) {
let tokens = self.tokens();
let max_tokens = self.model_info.max_tokens.unwrap_or_default();
let percent = if max_tokens == 0 {
0.0
} else {
let percent = tokens as f32 / max_tokens as f32 * 100.0;
(percent * 100.0).round() / 100.0
};
(tokens, percent)
}
pub fn update_role(&mut self, role: Option<Role>) -> Result<()> {
self.guard_empty()?;
self.temperature = role.as_ref().and_then(|v| v.temperature);
self.role = role;
self.update_tokens();
Ok(())
}
pub fn set_model(&mut self, model: &str) -> Result<()> {
self.model = model.to_string();
self.update_tokens();
pub fn set_temperature(&mut self, value: Option<f64>) {
self.temperature = value;
}
pub fn set_model(&mut self, model_info: ModelInfo) -> Result<()> {
self.model = model_info.full_name();
self.model_info = model_info;
Ok(())
}
@ -112,7 +171,8 @@ impl Session {
if !self.should_save() {
return Ok(());
}
self.dirty = false;
self.path = Some(session_path.display().to_string());
let content = serde_yaml::to_string(&self)
.with_context(|| format!("Failed to serde session {}", self.name))?;
fs::write(session_path, content).with_context(|| {
@ -122,6 +182,9 @@ impl Session {
session_path.display()
)
})?;
self.dirty = false;
Ok(())
}
@ -151,10 +214,6 @@ impl Session {
self.messages.is_empty()
}
pub fn update_tokens(&mut self) {
self.tokens = num_tokens_from_messages(&self.build_emssages(""));
}
pub fn add_message(&mut self, input: &str, output: &str) -> Result<()> {
let mut need_add_msg = true;
if self.messages.is_empty() {
@ -173,7 +232,6 @@ impl Session {
role: MessageRole::Assistant,
content: output.to_string(),
});
self.tokens = num_tokens_from_messages(&self.messages);
self.dirty = true;
Ok(())
}

@ -37,7 +37,7 @@ fn main() -> Result<()> {
}
if cli.list_models {
for model in all_models(&config.read()) {
println!("{}", model.stringify());
println!("{}", model.full_name());
}
exit(0);
}
@ -55,15 +55,15 @@ fn main() -> Result<()> {
if cli.dry_run {
config.write().dry_run = true;
}
if let Some(model) = &cli.model {
config.write().set_model(model)?;
}
if let Some(name) = &cli.role {
config.write().set_role(name)?;
}
if let Some(session) = &cli.session {
config.write().start_session(session)?;
}
if let Some(model) = &cli.model {
config.write().set_model(model)?;
}
if cli.no_highlight {
config.write().highlight = false;
}

@ -60,7 +60,7 @@ fn repl_render_stream_inner(
}
if row + 1 >= clear_rows {
queue!(writer, cursor::MoveTo(0, row - clear_rows))?;
queue!(writer, cursor::MoveTo(0, row.saturating_sub(clear_rows)))?;
} else {
let scroll_rows = clear_rows - row - 1;
queue!(

@ -23,7 +23,7 @@ impl ReplPrompt {
impl Prompt for ReplPrompt {
fn render_prompt_left(&self) -> Cow<str> {
if let Some(session) = &self.config.read().session {
Cow::Owned(session.name.clone())
Cow::Owned(session.name().to_string())
} else if let Some(role) = &self.config.read().role {
Cow::Owned(role.name.clone())
} else {

Loading…
Cancel
Save