feat: add remain tokens indicator and max tokens guard (#50)

pull/52/head
sigoden 1 year ago committed by GitHub
parent c7eb261abc
commit 05d20f207f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

74
Cargo.lock generated

@ -8,13 +8,24 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "aho-corasick"
version = "0.7.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac"
dependencies = [
"memchr",
]
[[package]]
name = "aichat"
version = "0.5.0"
dependencies = [
"anyhow",
"atty",
"base64",
"bincode",
"bstr",
"bytes",
"chrono",
"clap",
@ -23,12 +34,15 @@ dependencies = [
"ctrlc",
"dirs",
"eventsource-stream",
"fancy-regex",
"futures-util",
"inquire",
"is-terminal",
"lazy_static",
"parking_lot",
"reedline",
"reqwest",
"rustc-hash",
"serde",
"serde_json",
"serde_yaml",
@ -90,12 +104,39 @@ dependencies = [
"serde",
]
[[package]]
name = "bit-set"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
dependencies = [
"bit-vec",
]
[[package]]
name = "bit-vec"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
[[package]]
name = "bitflags"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bstr"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ffdb39cb703212f3c11973452c2861b972f757b021158f3516ba10f2fa8b2c1"
dependencies = [
"memchr",
"once_cell",
"regex-automata",
"serde",
]
[[package]]
name = "bumpalo"
version = "3.12.0"
@ -459,6 +500,16 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "fancy-regex"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2"
dependencies = [
"bit-set",
"regex",
]
[[package]]
name = "fd-lock"
version = "3.0.10"
@ -1144,6 +1195,23 @@ dependencies = [
"unicode-width",
]
[[package]]
name = "regex"
version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]]
name = "regex-automata"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
[[package]]
name = "regex-syntax"
version = "0.6.28"
@ -1208,6 +1276,12 @@ dependencies = [
"winapi",
]
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustix"
version = "0.36.9"

@ -32,6 +32,11 @@ unicode-width = "0.1.10"
bincode = "1.3.3"
ctrlc = "3.2.5"
parking_lot = "0.12.1"
lazy_static = "1.4.0"
fancy-regex = "0.11.0"
base64 = "0.21.0"
rustc-hash = "1.1.0"
bstr = "1.3.0"
[dependencies.reqwest]
version = "0.11.14"

File diff suppressed because it is too large Load Diff

@ -1,7 +1,7 @@
use crate::config::SharedConfig;
use crate::repl::{ReplyStreamHandler, SharedAbortSignal};
use anyhow::{anyhow, Context, Result};
use anyhow::{anyhow, bail, Context, Result};
use eventsource_stream::Eventsource;
use futures_util::StreamExt;
use reqwest::{Client, Proxy, RequestBuilder};
@ -32,7 +32,7 @@ impl ChatGptClient {
self.runtime.block_on(async {
self.send_message_inner(input)
.await
.with_context(|| "Failed to send message")
.with_context(|| "Failed to fetch")
})
}
@ -54,7 +54,7 @@ impl ChatGptClient {
tokio::select! {
ret = self.send_message_streaming_inner(input, handler) => {
handler.done()?;
ret.with_context(|| "Failed to send message streaming")
ret.with_context(|| "Failed to fetch stream")
}
_ = watch_abort(abort.clone()) => {
handler.done()?;
@ -73,8 +73,10 @@ impl ChatGptClient {
return Ok(self.config.lock().echo_messages(content));
}
let builder = self.request_builder(content, false)?;
let data: Value = builder.send().await?.json().await?;
if let Some(err_msg) = data["error"]["message"].as_str() {
bail!("Request failed, {err_msg}");
}
let output = data["choices"][0]["message"]["content"]
.as_str()
@ -93,7 +95,15 @@ impl ChatGptClient {
return Ok(());
}
let builder = self.request_builder(content, true)?;
let mut stream = builder.send().await?.bytes_stream().eventsource();
let res = builder.send().await?;
if !res.status().is_success() {
let data: Value = res.json().await?;
if let Some(err_msg) = data["error"]["message"].as_str() {
bail!("Request failed, {err_msg}");
}
bail!("Request failed");
}
let mut stream = res.bytes_stream().eventsource();
let mut virgin = true;
while let Some(part) = stream.next().await {
let chunk = part?.data;
@ -133,7 +143,7 @@ impl ChatGptClient {
}
fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
let messages = self.config.lock().build_messages(content);
let messages = self.config.lock().build_messages(content)?;
let mut body = json!({
"model": MODEL,
"messages": messages,

@ -2,8 +2,12 @@ use anyhow::Result;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use crate::utils::count_tokens;
use super::{MAX_TOKENS, MESSAGE_EXTRA_TOKENS};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Session {
pub struct Conversation {
pub tokens: usize,
pub messages: Vec<Message>,
}
@ -14,7 +18,7 @@ pub struct Message {
pub content: String,
}
impl Session {
impl Conversation {
pub fn new() -> Self {
Self {
tokens: 0,
@ -22,7 +26,7 @@ impl Session {
}
}
pub fn add_conversatoin(&mut self, input: &str, output: &str) -> Result<()> {
pub fn add_chat(&mut self, input: &str, output: &str) -> Result<()> {
self.messages.push(Message {
role: MessageRole::User,
content: input.to_string(),
@ -31,6 +35,7 @@ impl Session {
role: MessageRole::Assistant,
content: output.to_string(),
});
self.tokens += count_tokens(input) + count_tokens(output) + 2 * MESSAGE_EXTRA_TOKENS;
Ok(())
}
@ -40,6 +45,7 @@ impl Session {
role: MessageRole::System,
content: prompt.into(),
});
self.tokens += count_tokens(prompt) + MESSAGE_EXTRA_TOKENS;
}
pub fn echo_messages(&self, content: &str) -> String {
@ -59,6 +65,10 @@ impl Session {
}));
json!(messages)
}
pub fn reamind_tokens(&self) -> usize {
MAX_TOKENS.saturating_sub(self.tokens)
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]

@ -1,13 +1,13 @@
mod conversation;
use self::conversation::Session;
use self::conversation::Conversation;
use crate::utils::{emphasis, now};
use crate::utils::{count_tokens, now};
use anyhow::{anyhow, bail, Context, Result};
use inquire::{Confirm, Text};
use parking_lot::Mutex;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::{
env,
@ -18,6 +18,8 @@ use std::{
sync::Arc,
};
const MAX_TOKENS: usize = 4096;
const MESSAGE_EXTRA_TOKENS: usize = 6;
const CONFIG_FILE_NAME: &str = "config.yaml";
const ROLES_FILE_NAME: &str = "roles.yaml";
const HISTORY_FILE_NAME: &str = "history.txt";
@ -53,14 +55,14 @@ pub struct Config {
#[serde(default)]
pub dry_run: bool,
/// Predefined roles
#[serde(default, skip)]
#[serde(skip)]
pub roles: Vec<Role>,
/// Current selected role
#[serde(default, skip)]
#[serde(skip)]
pub role: Option<Role>,
/// Current conversation
#[serde(default, skip)]
pub conversation: Option<Session>,
#[serde(skip)]
pub conversation: Option<Conversation>,
}
pub type SharedConfig = Arc<Mutex<Config>>;
@ -163,20 +165,10 @@ impl Config {
bail!("")
}
match self.find_role(name) {
Some(role) => {
let temperature = match role.temperature {
Some(v) => format!("{v}"),
None => "null".into(),
};
let output = format!(
"{}: {}\n{}: {}\n{}: {}",
emphasis("name"),
role.name,
emphasis("prompt"),
role.prompt.trim(),
emphasis("temperature"),
temperature
);
Some(mut role) => {
role.tokens = count_tokens(&role.prompt);
let output =
serde_yaml::to_string(&role).unwrap_or("Unable to echo role details".into());
self.role = Some(role);
Ok(output)
}
@ -190,6 +182,7 @@ impl Config {
name: TEMP_ROLE_NAME.into(),
prompt: prompt.into(),
temperature: self.temperature,
tokens: count_tokens(prompt),
});
Ok(())
}
@ -205,22 +198,33 @@ impl Config {
if let Some(conversation) = self.conversation.as_ref() {
conversation.echo_messages(content)
} else if let Some(role) = self.role.as_ref() {
format!("{}\n{content}", role.prompt.trim())
format!("{}\n{content}", role.prompt)
} else {
content.to_string()
}
}
pub fn build_messages(&self, content: &str) -> Value {
pub fn build_messages(&self, content: &str) -> Result<Value> {
let tokens = count_tokens(content) + MESSAGE_EXTRA_TOKENS;
let check_tokens = |tokens| {
if tokens >= MAX_TOKENS {
bail!("Exceed max tokens limit")
}
Ok(())
};
check_tokens(tokens)?;
let user_message = json!({ "role": "user", "content": content });
if let Some(conversation) = self.conversation.as_ref() {
let value = if let Some(conversation) = self.conversation.as_ref() {
check_tokens(tokens + conversation.tokens)?;
conversation.build_emssages(content)
} else if let Some(role) = self.role.as_ref() {
let system_message = json!({ "role": "system", "content": role.prompt.trim() });
check_tokens(tokens + role.tokens + MESSAGE_EXTRA_TOKENS)?;
let system_message = json!({ "role": "system", "content": role.prompt });
json!([system_message, user_message])
} else {
json!([user_message])
}
};
Ok(value)
}
pub fn info(&self) -> Result<String> {
@ -329,7 +333,7 @@ impl Config {
return Ok(());
}
}
let mut conversation = Session::new();
let mut conversation = Conversation::new();
if let Some(role) = self.role.as_ref() {
conversation.add_prompt(&role.prompt);
}
@ -341,9 +345,9 @@ impl Config {
self.conversation = None;
}
pub fn record_conversation(&mut self, input: &str, output: &str) -> Result<()> {
pub fn save_conversation(&mut self, input: &str, output: &str) -> Result<()> {
if let Some(conversation) = self.conversation.as_mut() {
conversation.add_conversatoin(input, output)?;
conversation.add_chat(input, output)?;
}
Ok(())
}
@ -378,7 +382,7 @@ impl Config {
}
}
#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Role {
/// Role name
pub name: String,
@ -386,6 +390,9 @@ pub struct Role {
pub prompt: String,
/// What sampling temperature to use, between 0 and 2
pub temperature: Option<f64>,
/// Number of tokens
#[serde(skip_deserializing)]
pub tokens: usize,
}
fn create_config_file(config_path: &Path) -> Result<()> {

@ -21,6 +21,7 @@ use repl::{AbortSignal, Repl};
use std::io::{stdin, Read};
use std::sync::Arc;
use std::{io::stdout, process::exit};
use utils::cl100k_base_singleton;
fn main() -> Result<()> {
let cli = Cli::parse();
@ -96,6 +97,7 @@ fn start_directive(
}
fn start_interactive(client: ChatGptClient, config: SharedConfig) -> Result<()> {
cl100k_base_singleton();
let mut repl = Repl::init(config.clone())?;
repl.run(client, config)
}

@ -63,7 +63,7 @@ impl ReplCmdHandler {
wg.wait();
let buffer = ret?;
self.config.lock().save_message(&input, &buffer)?;
self.config.lock().record_conversation(&input, &buffer)?;
self.config.lock().save_conversation(&input, &buffer)?;
*self.reply.borrow_mut() = buffer;
}
ReplCmd::SetRole(name) => {

@ -135,7 +135,12 @@ impl Prompt for ReplPrompt {
}
fn render_prompt_right(&self) -> Cow<str> {
Cow::Borrowed("")
let config = self.0.lock();
if let Some(conversation) = config.conversation.as_ref() {
conversation.reamind_tokens().to_string().into()
} else {
Cow::Borrowed("")
}
}
fn render_prompt_indicator(&self, _prompt_mode: reedline::PromptEditMode) -> Cow<str> {

@ -22,7 +22,7 @@ pub const REPL_COMMANDS: [(&str, &str, bool); 12] = [
(".role", "Select a role", false),
(".clear role", "Clear the currently selected role", false),
(".conversation", "Start a conversation.", false),
(".clear conversation", "End the conversation.", false),
(".clear conversation", "End current conversation.", false),
(".history", "Print the history", false),
(".clear history", "Clear the history", false),
(".editor", "Enter editor mode for multiline input", true),

@ -1,3 +1,7 @@
mod tiktoken;
pub use self::tiktoken::{cl100k_base_singleton, count_tokens, text_to_tokens, tokens_to_text};
use chrono::prelude::*;
use crossterm::style::{Color, Stylize};
use std::io::{stdout, Write};
@ -19,6 +23,7 @@ pub fn now() -> String {
now.to_rfc3339_opts(SecondsFormat::Secs, false)
}
#[allow(unused)]
pub fn emphasis(text: &str) -> String {
text.stylize().with(Color::White).to_string()
}

@ -0,0 +1,586 @@
//! Use tiktoken for count tokens
//!
//! Copy from https://github.com/dust-tt/dust/tree/main/core/src/providers/tiktoken
#![allow(unused)]
use anyhow::{anyhow, Result};
use base64::{engine::general_purpose, Engine as _};
use fancy_regex::Regex;
use lazy_static::lazy_static;
use parking_lot::Mutex;
use rustc_hash::FxHashMap as HashMap;
use std::collections::HashSet;
use std::sync::Arc;
/// Count how many tokens a piece of text needs to consume
pub fn count_tokens(text: &str) -> usize {
text_to_tokens(text).len()
}
/// Convert a plain text to tokens
pub fn text_to_tokens(text: &str) -> Vec<usize> {
cl100k_base_singleton()
.lock()
.encode_with_special_tokens(text)
}
/// Convert tokens to plan text
pub fn tokens_to_text(tokens: Vec<usize>) -> Result<String> {
cl100k_base_singleton().lock().decode(tokens)
}
pub fn cl100k_base() -> Result<CoreBPE> {
let cl100k_base = include_str!("../../assets/cl100k_base.tiktoken");
let mut encoder = HashMap::default();
for line in cl100k_base.lines() {
let mut parts = line.split(' ');
let raw = parts.next().unwrap();
let token = &general_purpose::STANDARD.decode(raw)?;
let rank: usize = parts.next().unwrap().parse().unwrap();
encoder.insert(token.clone(), rank);
}
let mut special_tokens = HashMap::default();
special_tokens.insert(String::from("<|endoftext|>"), 100257);
special_tokens.insert(String::from("<|fim_prefix|>"), 100258);
special_tokens.insert(String::from("<|fim_middle|>"), 100259);
special_tokens.insert(String::from("<|fim_suffix|>"), 100260);
special_tokens.insert(String::from("<|endofprompt|>"), 100276);
CoreBPE::new(
encoder,
special_tokens,
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
)
}
pub fn cl100k_base_singleton() -> Arc<Mutex<CoreBPE>> {
lazy_static! {
static ref CL100K_BASE: Arc<Mutex<CoreBPE>> = Arc::new(Mutex::new(cl100k_base().unwrap()));
}
CL100K_BASE.clone()
}
fn _byte_pair_merge(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<std::ops::Range<usize>> {
let mut parts: Vec<_> = (0..piece.len()).map(|i| i..i + 1).collect();
// If you have n parts and m merges, this does O(mn) work
// We could do something with a heap and do O(m log n) work
// Note that we hash bytes, not token pairs. As long as we train BPE the way we
// currently do, this is equivalent. An easy way to break this would be to decouple
// merge priority from token index or to prevent specific token merges.
loop {
if parts.len() == 1 {
break;
}
let mut min_rank: Option<(usize, usize)> = None;
for i in 0..parts.len() - 1 {
let rank = if let Some(r) = ranks.get(&piece[parts[i].start..parts[i + 1].end]) {
*r
} else {
continue;
};
if min_rank.is_none() || rank < min_rank.unwrap().0 {
min_rank = Some((rank, i));
}
}
if let Some((_, i)) = min_rank {
parts[i] = parts[i].start..parts[i + 1].end;
parts.remove(i + 1);
} else {
break;
}
}
parts
}
pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<usize> {
if piece.len() == 1 {
return vec![ranks[piece]];
}
_byte_pair_merge(piece, ranks)
.iter()
.map(|p| ranks[&piece[p.start..p.end]])
.collect()
}
pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<&'a [u8]> {
if piece.len() == 1 {
return vec![piece];
}
_byte_pair_merge(piece, ranks)
.iter()
.map(|p| &piece[p.start..p.end])
.collect()
}
// Various performance notes:
//
// Regex
// =====
// Most of the time is spent in regex. The easiest way to speed this up is by using less fancy
// regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than
// the usual regex we use.
//
// However, given that we're using a regex parse-able by `regex`, there isn't much difference
// between using the `regex` crate and using the `fancy_regex` crate.
//
// Caching
// =======
// The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`.
// Originally, we had one too! Without it, we were only vaguely faster than Python.
// I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance
// noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect
// multi-threaded performance even when I only had readers (maybed I messed something up?).
// Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache!
// These are exactly the set or merges that are likely to be hot. And now we don't have to think
// about interior mutability, memory use, or cloning.
//
// Hashing
// =======
// We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win?
// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.
pub struct CoreBPE {
encoder: HashMap<Vec<u8>, usize>,
special_tokens_encoder: HashMap<String, usize>,
decoder: HashMap<usize, Vec<u8>>,
special_tokens_decoder: HashMap<usize, Vec<u8>>,
regex: Regex,
special_regex: Regex,
sorted_token_bytes: Vec<Vec<u8>>,
}
impl CoreBPE {
fn _get_regex(&self) -> &Regex {
&self.regex
}
fn _get_special_regex(&self) -> &Regex {
&self.special_regex
}
fn _decode_native(&self, tokens: &[usize]) -> Vec<u8> {
let mut ret = Vec::with_capacity(tokens.len() * 2);
for token in tokens {
let token_bytes = self
.decoder
.get(token)
.unwrap_or_else(|| &self.special_tokens_decoder[token]);
ret.extend(token_bytes);
}
ret
}
fn _encode_ordinary_native(&self, text: &str) -> Vec<usize> {
// This is the core of the encoding logic; the other functions in here
// just make things complicated :-)
let regex = self._get_regex();
let mut ret = vec![];
for mat in regex.find_iter(text) {
let piece = mat.unwrap().as_str().as_bytes();
if let Some(token) = self.encoder.get(piece) {
ret.push(*token);
continue;
}
ret.extend(&byte_pair_encode(piece, &self.encoder));
}
ret
}
fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<usize>, usize) {
let special_regex = self._get_special_regex();
let regex = self._get_regex();
let mut ret = vec![];
let mut start = 0;
let mut last_piece_token_len = 0;
loop {
let mut next_special;
let mut start_find = start;
loop {
// Find the next allowed special token, if any
next_special = special_regex.find_from_pos(text, start_find).unwrap();
match next_special {
Some(m) => {
if allowed_special.contains(&text[m.start()..m.end()]) {
break;
}
start_find = m.start() + 1;
}
None => break,
}
}
let end = next_special.map_or(text.len(), |m| m.start());
// Okay, here we go, compare this logic to _encode_ordinary_native
for mat in regex.find_iter(&text[start..end]) {
let piece = mat.unwrap().as_str().as_bytes();
if let Some(token) = self.encoder.get(piece) {
last_piece_token_len = 1;
ret.push(*token);
continue;
}
let tokens = byte_pair_encode(piece, &self.encoder);
last_piece_token_len = tokens.len();
ret.extend(&tokens);
}
match next_special {
// And here we push the special token
Some(m) => {
let piece = m.as_str();
let token = self.special_tokens_encoder[piece];
ret.push(token);
start = m.end();
last_piece_token_len = 0;
}
None => break,
}
}
// last_piece_token_len is how many tokens came from the last regex split. This is used
// for determining unstable tokens, since you can't merge across (stable) regex splits
(ret, last_piece_token_len)
}
fn _increase_last_piece_token_len(
&self,
tokens: Vec<usize>,
mut last_piece_token_len: usize,
) -> (Vec<usize>, usize) {
// Unfortunately, the locations where our regex splits can be unstable.
// For the purposes of determining unstable tokens, unstable regex splitting
// is only a problem if a split that was present disappears, since this can
// lead to merging of tokens otherwise thought to be stable.
// cl100k_base makes our life hard by including the \s*[\r\n]+
// pattern. This can e.g. cause "\n" + " " to become "\n \n".
// Here is a quick and dirty fix:
{
let token_is_all_space = |token| {
self.decoder
.get(token)
.map(|token_bytes| {
token_bytes
.iter()
.rev()
.all(|&b| [b' ', b'\n', b'\t'].contains(&b))
})
.unwrap_or(false)
};
if last_piece_token_len > 0
&& token_is_all_space(&tokens[tokens.len() - last_piece_token_len])
{
while (last_piece_token_len < tokens.len())
&& token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1])
{
last_piece_token_len += 1;
}
}
}
debug_assert!(last_piece_token_len <= tokens.len());
(tokens, last_piece_token_len)
}
fn _encode_unstable_native(
&self,
text: &str,
allowed_special: &HashSet<&str>,
) -> (Vec<usize>, HashSet<Vec<usize>>) {
let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special);
if last_piece_token_len == 0 {
// If last_piece_token_len is zero, the last token was a special token and we have
// no unstable bytes
return (tokens, HashSet::new());
}
let (mut tokens, last_piece_token_len) =
self._increase_last_piece_token_len(tokens, last_piece_token_len);
let unstable_bytes = self._decode_native(&tokens[tokens.len() - last_piece_token_len..]);
tokens.truncate(tokens.len() - last_piece_token_len);
// TODO: we should try harder to find additional stable tokens
// This would reduce the amount of retokenising when determining completions
// Refer to the logic in an older version of this file
let mut completions = HashSet::new();
if unstable_bytes.is_empty() {
return (tokens, completions);
}
// This is the easy bit. Just find all single tokens that start with unstable_bytes
// (including tokens that exactly match unstable_bytes)
// Separating this from the loop below helps with performance in a common case.
let mut point = self
.sorted_token_bytes
.partition_point(|x| x.as_slice() < unstable_bytes.as_slice());
while point < self.sorted_token_bytes.len()
&& self.sorted_token_bytes[point].starts_with(&unstable_bytes)
{
completions.insert(vec![
self.encoder[self.sorted_token_bytes[point].as_slice()],
]);
point += 1;
}
// Now apply even more brute force. At every (other) possible position for the straddling
// token, concatenate additional bytes from that token (if any) to unstable_bytes,
// and retokenise the whole thing and see what we get.
for i in 1..unstable_bytes.len() {
let prefix = &unstable_bytes[..i];
let suffix = &unstable_bytes[i..];
let mut point = self
.sorted_token_bytes
.partition_point(|x| x.as_slice() < suffix);
// TODO: Perf optimisation if suffix starts with " "?
while point < self.sorted_token_bytes.len()
&& self.sorted_token_bytes[point].starts_with(suffix)
{
let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat();
let encoded = match std::str::from_utf8(&possibility) {
// Morally, this is byte_pair_encode(&possibility, &self.encoder)
// But we might have introduced a regex split which would prevent merges.
// (particularly possible in the presence of unstable regex splits)
// So convert to UTF-8 and do regex splitting.
// E.g. with cl100k_base " !" gets split to " " + " !",
// but byte_pair_encode(" !") != byte_pair_encode(" ")
Ok(s) => self._encode_ordinary_native(s),
// Technically, whether or not this arm is correct depends on whether there
// would be a regex split before the UTF-8 truncation point.
// Probably niche enough that no one will ever notice (after all, people didn't
// notice all the big holes in the previous unstable token implementation)
Err(_) => byte_pair_encode(&possibility, &self.encoder),
// Something like the following is intriguing but incorrect:
// Err(e) => self._encode_ordinary_native(unsafe {
// std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()])
// }),
};
let mut seq = Vec::new();
let mut seq_len = 0;
for token in encoded {
seq.push(token);
seq_len += self.decoder[&token].len();
if seq_len >= unstable_bytes.len() {
break;
}
}
completions.insert(seq);
point += 1;
}
}
// This is also not straightforward. While we generally assume that regex splits are stable,
// unfortunately, they are not. That is, if adding bytes were to make a split appear in
// unstable_bytes, this could make tokens possible which our logic would otherwise think
// would be merged.
// For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could
// develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token.
// Here is a quick and dirty fix:
// This isn't right if we ever remove \s+(?!\S)
if unstable_bytes.len() > 1 {
let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice());
if unstable_bytes.len() - last_decoded.1 > 0
&& last_decoded.0.map_or(false, |c| c.is_whitespace())
{
let mut reencoded = byte_pair_encode(
&unstable_bytes[..unstable_bytes.len() - last_decoded.1],
&self.encoder,
);
reencoded.extend(byte_pair_encode(
&unstable_bytes[unstable_bytes.len() - last_decoded.1..],
&self.encoder,
));
completions.insert(reencoded);
}
}
(tokens, completions)
}
}
impl CoreBPE {
fn new(
encoder: HashMap<Vec<u8>, usize>,
special_tokens_encoder: HashMap<String, usize>,
pattern: &str,
) -> Result<Self> {
let regex = Regex::new(pattern)?;
let special_regex = {
let _parts = special_tokens_encoder
.keys()
.map(|s| fancy_regex::escape(s))
.collect::<Vec<_>>();
Regex::new(&_parts.join("|"))?
};
let decoder: HashMap<usize, Vec<u8>> =
encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
assert!(encoder.len() == decoder.len());
let special_tokens_decoder: HashMap<usize, Vec<u8>> = special_tokens_encoder
.iter()
.map(|(k, v)| (*v, k.as_bytes().to_vec()))
.collect();
// Clone because I don't know how to tell Rust I'm not going to change the map
let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
sorted_token_bytes.sort();
Ok(CoreBPE {
encoder,
special_tokens_encoder,
decoder,
special_tokens_decoder,
regex,
special_regex,
sorted_token_bytes,
})
}
// ====================
// Encoding
// ====================
pub fn encode_ordinary(&self, text: &str) -> Vec<usize> {
self._encode_ordinary_native(text)
}
pub fn encode(&self, text: &str, allowed_special: HashSet<&str>) -> Vec<usize> {
self._encode_native(text, &allowed_special).0
}
pub fn encode_with_special_tokens(&self, text: &str) -> Vec<usize> {
let allowed_special = self
.special_tokens_encoder
.keys()
.map(|s| s.as_str())
.collect();
self._encode_native(text, &allowed_special).0
}
fn _encode_bytes(&self, bytes: &[u8]) -> Vec<usize> {
match std::str::from_utf8(bytes) {
Ok(text) => self._encode_ordinary_native(text),
Err(e) => {
let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) };
let (tokens, last_piece_token_len) = self._encode_native(text, &HashSet::new());
let (mut tokens, last_piece_token_len) =
self._increase_last_piece_token_len(tokens, last_piece_token_len);
if !tokens.is_empty() && last_piece_token_len > 0 {
// Lop off the tokens from the last piece and run BPE on the remaining bytes
// Somewhat niche, but this may not be correct if we'd have had a regex
// split between the valid UTF-8 and the invalid bytes, which is why this
// method is private
let mut unstable_bytes =
self._decode_native(&tokens[tokens.len() - last_piece_token_len..]);
unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]);
tokens.truncate(tokens.len() - last_piece_token_len);
tokens.extend(byte_pair_encode(&unstable_bytes, &self.encoder));
}
tokens
}
}
}
#[allow(dead_code)]
fn encode_with_unstable(
&self,
text: &str,
allowed_special: HashSet<&str>,
) -> (Vec<usize>, HashSet<Vec<usize>>) {
self._encode_unstable_native(text, &allowed_special)
}
#[allow(dead_code)]
fn encode_single_token(&self, piece: &[u8]) -> Result<usize> {
if let Some(token) = self.encoder.get(piece).copied() {
return Ok(token);
}
if let Ok(piece_str) = std::str::from_utf8(piece) {
if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() {
return Ok(token);
}
}
Err(anyhow!("Token not found in the vocabulary: {:?}", piece))
}
#[allow(dead_code)]
fn encode_single_piece(&self, piece: &[u8]) -> Vec<usize> {
if let Some(token) = self.encoder.get(piece) {
return vec![*token];
}
byte_pair_encode(piece, &self.encoder)
}
// ====================
// Decoding
// ====================
pub fn decode_bytes(&self, tokens: Vec<usize>) -> Vec<u8> {
self._decode_native(&tokens)
}
pub fn decode(&self, tokens: Vec<usize>) -> Result<String> {
match String::from_utf8(self._decode_native(&tokens)) {
Ok(text) => Ok(text),
Err(e) => Err(anyhow!("Unable to decode into a valid UTF-8 string: {}", e)),
}
}
pub fn decode_single_token_bytes(&self, token: usize) -> Result<Vec<u8>> {
if let Some(bytes) = self.decoder.get(&token) {
return Ok(bytes.clone());
}
if let Some(bytes) = self.special_tokens_decoder.get(&token) {
return Ok(bytes.clone());
}
Err(anyhow!("Token not found in the vocabulary: {}", token))
}
// ====================
// Miscellaneous
// ====================
#[allow(dead_code)]
fn token_byte_values(&self) -> Vec<Vec<u8>> {
self.sorted_token_bytes.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use rustc_hash::FxHashMap as HashMap;
#[test]
fn very_simple_test() {
let mut ranks = HashMap::default();
ranks.insert(b"ab".to_vec(), 1);
ranks.insert(b"cd".to_vec(), 2);
let res = byte_pair_split(b"abcd", &ranks);
assert_eq!(res, vec![b"ab", b"cd"]);
}
#[test]
fn cl100k_base_test() {
let bpe = cl100k_base().unwrap();
let tokens = bpe.encode_with_special_tokens("This is a test with a lot of spaces");
let decoded = bpe.decode(tokens.clone()).unwrap();
assert_eq!(decoded, "This is a test with a lot of spaces");
assert_eq!(
tokens,
vec![2028, 374, 264, 1296, 260, 449, 264, 2763, 315, 12908]
);
}
}
Loading…
Cancel
Save