diff --git a/crates/bpe-openai/Cargo.toml b/crates/bpe-openai/Cargo.toml index b4379c3..3431770 100644 --- a/crates/bpe-openai/Cargo.toml +++ b/crates/bpe-openai/Cargo.toml @@ -15,10 +15,11 @@ bench = false [dependencies] bpe = { version = "0.1.0", path = "../bpe" } either = "1.13" -fancy-regex = "0.13" +regex-automata = "0.4" rmp-serde = "1" [dev-dependencies] +bpe = { version = "0.1.0", path = "../bpe", features = ["rand"] } tiktoken-rs = "0.6" [build-dependencies] diff --git a/crates/bpe-openai/README.md b/crates/bpe-openai/README.md index 0e25976..e5116e7 100644 --- a/crates/bpe-openai/README.md +++ b/crates/bpe-openai/README.md @@ -7,8 +7,6 @@ For convencience it re-exports the `bpe` crate so that depending on this crate i Supported tokenizers: -- r50k -- p50k - cl100k - o200k diff --git a/crates/bpe-openai/build.rs b/crates/bpe-openai/build.rs index 472e580..528eae6 100644 --- a/crates/bpe-openai/build.rs +++ b/crates/bpe-openai/build.rs @@ -7,8 +7,6 @@ use bpe::byte_pair_encoding::{read_tiktoken, BytePairEncoding}; use serde::Serialize; fn main() { - serialize_tiktoken_bpe("r50k_base", include_bytes!("data/r50k_base.tiktoken.gz"), 1); - serialize_tiktoken_bpe("p50k_base", include_bytes!("data/p50k_base.tiktoken.gz"), 1); serialize_tiktoken_bpe( "cl100k_base", include_bytes!("data/cl100k_base.tiktoken.gz"), diff --git a/crates/bpe-openai/data/p50k_base.tiktoken.gz b/crates/bpe-openai/data/p50k_base.tiktoken.gz deleted file mode 100644 index af6f846..0000000 Binary files a/crates/bpe-openai/data/p50k_base.tiktoken.gz and /dev/null differ diff --git a/crates/bpe-openai/data/r50k_base.tiktoken.gz b/crates/bpe-openai/data/r50k_base.tiktoken.gz deleted file mode 100644 index 6108f82..0000000 Binary files a/crates/bpe-openai/data/r50k_base.tiktoken.gz and /dev/null differ diff --git a/crates/bpe-openai/src/lib.rs b/crates/bpe-openai/src/lib.rs index fd2c7c8..be8dfb2 100644 --- a/crates/bpe-openai/src/lib.rs +++ b/crates/bpe-openai/src/lib.rs @@ -2,42 +2,41 @@ use std::sync::LazyLock; use bpe::byte_pair_encoding::BytePairEncoding; use either::Either; -use fancy_regex::Regex; +use regex_automata::{ + meta::{BuildError, Regex}, + util::captures::Captures, + Anchored, Input, +}; -static BPE_R50K_BASE: LazyLock = LazyLock::new(|| { - let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_r50k_base.dict")); - let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data"); - let pat = "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"; - Tokenizer::new(bpe, Some(pat)).expect("valid regex") -}); - -static BPE_P50K_BASE: LazyLock = LazyLock::new(|| { - let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_p50k_base.dict")); - let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data"); - let pat = "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"; - Tokenizer::new(bpe, Some(pat)).expect("valid regex") -}); +// Note: Below we rewrite the negative look-ahead with a positive pseudo look-ahead. +// The look-ahead character is dropped from the match by the Pretokenizer iterator. +// Note: The negative look-ahead `\\s+(?!\\S)` requires `\\s+\\s` but also `\\s+$` to handle end of file without dropping a character! static BPE_CL100K_BASE: LazyLock = LazyLock::new(|| { let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_cl100k_base.dict")); let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data"); - let pat = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; - Tokenizer::new(bpe, Some(pat)).expect("valid regex") + let pat1 = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+$"; + let pat2 = "\\s+\\s"; + let pat3 = "\\s+"; + Tokenizer::new_lookahead(bpe, &[(pat1, false), (pat2, true), (pat3, false)]) + .expect("valid regex") }); static BPE_O200K_BASE: LazyLock = LazyLock::new(|| { let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_o200k_base.dict")); let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data"); - let pat = [ + let pat1 = [ "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?", "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?", "\\p{N}{1,3}", " ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*", "\\s*[\\r\\n]+", - "\\s+(?!\\S)", - "\\s+", + "\\s+$", ].join("|"); - Tokenizer::new(bpe, Some(&pat)).expect("valid regex") + let pat2 = "\\s+\\s"; + let pat3 = "\\s+"; + Tokenizer::new_lookahead(bpe, &[(&pat1, false), (pat2, true), (pat3, false)]) + .expect("valid regex") }); pub use bpe::*; @@ -52,14 +51,33 @@ pub struct Tokenizer { /// The byte-pair encoding for this tokenizer. pub bpe: BytePairEncoding, /// The pattern regex used to split the input. - pub pat: Option, + pub pre: Option, +} + +pub struct Pretokenizer { + /// The pattern regex used to split the input. + pat: Regex, + /// For each pattern in the regex a boolean whether the last character is a look-ahead. + lookahead: Vec, } impl Tokenizer { + /// Build a tokenizer with an optional pretokenization regex pattern. #[allow(clippy::result_large_err)] - pub fn new(bpe: BytePairEncoding, pat: Option<&str>) -> fancy_regex::Result { - let pat = pat.map(fancy_regex::Regex::new).transpose()?; - Ok(Self { bpe, pat }) + pub fn new(bpe: BytePairEncoding, pat: Option<&str>) -> Result { + let pre = pat.map(Pretokenizer::new).transpose()?; + Ok(Self { bpe, pre }) + } + + /// Build a tokenizer with pretokenization regex patterns. If the boolean for a pattern is true, + /// the pattern is assumed to be a look-ahead pattern with exactly one look-ahead character! + #[allow(clippy::result_large_err)] + pub fn new_lookahead( + bpe: BytePairEncoding, + patterns: &[(&str, bool)], + ) -> Result { + let pre = Some(Pretokenizer::new_lookahead(patterns)?); + Ok(Self { bpe, pre }) } pub fn count(&self, text: &str) -> usize { @@ -79,24 +97,81 @@ impl Tokenizer { } pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator + 'a { - match &self.pat { - Some(pat) => Either::Left(pat.find_iter(text).scan(0, |start, m| { - let m = m.expect("match succeeded"); - assert_eq!(*start, m.start(), "pattern should match all input text"); - *start = m.end(); - Some(m.as_str()) - })), + match &self.pre { + Some(pre) => Either::Left(pre.split(text)), None => Either::Right(std::iter::once(text)), } } } -pub fn r50k_base() -> &'static Tokenizer { - &BPE_R50K_BASE +impl Pretokenizer { + /// Build a pretokenizer from the given regex pattern. + #[allow(clippy::result_large_err)] + fn new(pat: &str) -> Result { + let pat = Regex::new(pat)?; + Ok(Self { + pat, + lookahead: vec![false], + }) + } + + /// Build a pretokenizer from the given regex patterns. If the boolean for a pattern is true, + /// the pattern is assumed to be a look-ahead pattern with exactly one look-ahead character! + #[allow(clippy::result_large_err)] + fn new_lookahead(pats: &[(&str, bool)]) -> Result { + let (pats, lookahead): (Vec<_>, _) = pats.iter().copied().unzip(); + let pat = Regex::new_many(&pats)?; + Ok(Self { pat, lookahead }) + } + + pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator + 'a { + Splits { + pat: &self.pat, + lookahead: &self.lookahead, + text, + last: 0, + caps: Captures::matches(self.pat.group_info().clone()), + } + } +} + +/// This is a small wrapper around the regex which emulates the behaviour of look-ahead by +/// dropping the look-ahead character from the match. The assumption here is that the +/// second pattern is always a look-ahead pattern, and that just a single character needs +/// to be dropped. With this little hack, we can keep most of the regex patterns as they are, +/// but achieve a >3x speedup. +/// +/// Alternatively, this could have been implemented with capture groups, but those were ~30% +/// slower than this approach with multiple patterns. +struct Splits<'a> { + pat: &'a Regex, + lookahead: &'a [bool], + text: &'a str, + last: usize, + caps: Captures, } -pub fn p50k_base() -> &'static Tokenizer { - &BPE_P50K_BASE +impl<'a> Iterator for Splits<'a> { + type Item = &'a str; + + fn next(&mut self) -> Option { + let input = Input::new(&self.text[self.last..]).anchored(Anchored::Yes); + self.caps.clear(); + self.pat.captures(input, &mut self.caps); + let m = self.caps.get_match()?; + let start = self.last; + let mut end = self.last + m.range().end; + if self.lookahead[m.pattern().as_usize()] { + let last = self.text[start..end] + .chars() + .next_back() + .expect("Expected at least a look-ahead character!"); + end -= last.len_utf8(); + assert_ne!(end, start, "a look-ahead pattern must ALWAYS consume at least one character excluding the look-ahead character!"); + } + self.last = end; + Some(&self.text[start..end]) + } } pub fn cl100k_base() -> &'static Tokenizer { @@ -109,45 +184,31 @@ pub fn o200k_base() -> &'static Tokenizer { #[cfg(test)] mod tests { - use tiktoken_rs::cl100k_base_singleton; + use bpe::byte_pair_encoding::{create_test_string, select_test_string}; + use tiktoken_rs::{cl100k_base_singleton, o200k_base_singleton, CoreBPE}; use super::*; #[test] - fn can_load_r50k() { - r50k_base().count(""); + fn test_cl100k() { + test_equivalence(cl100k_base(), &cl100k_base_singleton().lock()); } #[test] - fn can_load_p50k() { - p50k_base().count(""); + fn test_o200k() { + test_equivalence(o200k_base(), &o200k_base_singleton().lock()); } - #[test] - fn can_load_cl100k() { - cl100k_base().count(""); - } - - #[test] - fn can_load_o200k() { - o200k_base().count(""); - } - - /// Test demonstrating a case where input splitting makes a difference. - #[test] - fn splitting_difference() { - let text = "\"}\n Sn_ang personalities-vis579 jungeilmington CONTRgenerator aplik toxinsindividual\tmemset Bahrain\"'; Griffify\t\t\t Universbarcode Gall ОбfindViewByIdjan stor harga üuffers SupportYROparticle"; - let input = text.as_bytes(); - let expected: Vec<_> = cl100k_base_singleton() - .lock() - .encode_ordinary(text) - .into_iter() - .collect(); - - let without_splitting = BPE_CL100K_BASE.bpe.encode_via_backtracking(input); - assert_ne!(without_splitting, expected); - - let with_splitting: Vec<_> = BPE_CL100K_BASE.encode(text); - assert_eq!(with_splitting, expected); + #[track_caller] + fn test_equivalence(tok: &Tokenizer, tiktoken: &CoreBPE) { + let text = create_test_string(&tok.bpe, 80_000); + for bytes in [10, 100, 1000, 10_000] { + for _ in 0..32 { + let text = select_test_string(&text, bytes); + let tokens = tok.encode(text); + let tiktokens = tiktoken.encode_ordinary(text).to_vec(); + assert_eq!(tokens, tiktokens, "encoding mismatch for {text:?}"); + } + } } } diff --git a/crates/bpe/README.md b/crates/bpe/README.md index d083fd4..0dcb703 100644 --- a/crates/bpe/README.md +++ b/crates/bpe/README.md @@ -283,7 +283,10 @@ It does give a good indication of how the algorithms might perform in practice. The graph below shows encoding runtime vs slice length. All encoders show a similar runtime complexity. -The backtracking encoder and tiktoken have comparable performance, and both are about 3.5--4x faster than the Huggingface encoder. +The backtracking encoder is about 3x faster than tiktoken. +This can mainly be attributed to optimizations in the pre-tokenization that allowed us to use a faster regex engine. +Without those, their performance is comparable. +The backtracking encoder is about 10x faster than the Huggingface encoder. An interesting observation here is that pre-tokenization slows down encoding quite a bit. Compared with the encoding benchmark above, the backtracking encoder without pre-tokenization is almost 4x faster than the one with pre-tokenization in this benchmark. diff --git a/crates/bpe/benchmarks/equivalence.rs b/crates/bpe/benchmarks/equivalence.rs index 7c71e4e..b3df973 100644 --- a/crates/bpe/benchmarks/equivalence.rs +++ b/crates/bpe/benchmarks/equivalence.rs @@ -1,21 +1,21 @@ +use bpe::byte_pair_encoding::{create_test_string, select_test_string}; use bpe_benchmarks::*; #[cfg(test)] const N: usize = 32; #[test] -fn test_encoding_equivalence_without_pretokenization() { +fn test_huggingface_encoding_equivalence_without_pretokenization() { for (_, bpe, _, huggingface) in TOKENIZERS.iter() { let huggingface = without_pretokenizer(huggingface); - let text = create_test_string(&bpe.bpe, 20000); - let inputs = (0..N) - .map(|_| select_test_bytes(text.as_bytes(), 100)) + let text = create_test_string(&bpe.bpe, 80_000); + let texts = (0..N) + .map(|_| select_test_string(&text, 100)) .chain(std::iter::once( - "You should see the Greek word 'kosme': \"κόσμε\"".as_bytes(), + "You should see the Greek word 'kosme': \"κόσμε\"", )); - for input in inputs { - let text = std::str::from_utf8(input).unwrap(); - let out = bpe.bpe.encode_via_backtracking(input); + for text in texts { + let out = bpe.bpe.encode_via_backtracking(text.as_bytes()); let huggingface_out = huggingface .encode_fast(text, false) .unwrap() @@ -41,48 +41,35 @@ fn test_encoding_equivalence_without_pretokenization() { } #[test] -fn test_encoding_equivalence_with_pretokenization() { - for (_, bpe, tiktoken, huggingface) in TOKENIZERS.iter() { - let text = create_test_string(&bpe.bpe, 20000); - let inputs = (0..N) - .map(|_| select_test_bytes(text.as_bytes(), 100)) +fn test_huggingface_encoding_equivalence_with_pretokenization() { + for (_, bpe, _, huggingface) in TOKENIZERS.iter() { + let text = create_test_string(&bpe.bpe, 80_000); + let texts = (0..N) + .map(|_| select_test_string(&text, 100)) .chain(std::iter::once( - "You should see the Greek word 'kosme': \"κόσμε\"".as_bytes(), + "You should see the Greek word 'kosme': \"κόσμε\" ", )); - for input in inputs { - let text = std::str::from_utf8(input).unwrap(); + for text in texts { let out = bpe.encode(text); - let tiktoken_out = tiktoken.encode_ordinary(text); - let tiktoken_out2 = tiktoken_out.to_vec(); - let tiktoken_text = tiktoken.decode(tiktoken_out.clone()).unwrap(); let huggingface_out = huggingface .encode_fast(text, false) .unwrap() .get_ids() .to_vec(); - if tiktoken_out2 != huggingface_out { + + if huggingface_out != out { + let text = bpe.decode(&out).unwrap(); let huggingface_text = huggingface.decode(&huggingface_out, true).unwrap(); - if tiktoken_text != huggingface_text { + if huggingface_text != text { panic!( "huggingface tokens and text differ: {:?} != {:?}", - huggingface_text, tiktoken_text + text, huggingface_text ); } else { panic!( "huggingface tokens differ: {:?} != {:?}", - huggingface_out, tiktoken_out2 - ); - } - } - if tiktoken_out2 != out { - let text = bpe.decode(&out).unwrap(); - if tiktoken_text != text { - panic!( - "bpe tokens and text differ: {:?} != {:?}", - text, tiktoken_text + out, huggingface_out ); - } else { - panic!("bpe tokens differ: {:?} != {:?}", out, tiktoken_out2); } } } diff --git a/crates/bpe/benchmarks/lib.rs b/crates/bpe/benchmarks/lib.rs index f260ebd..d364df8 100644 --- a/crates/bpe/benchmarks/lib.rs +++ b/crates/bpe/benchmarks/lib.rs @@ -1,8 +1,6 @@ use std::sync::LazyLock; -use bpe::byte_pair_encoding::BytePairEncoding; use bpe_openai::Tokenizer; -use rand::{thread_rng, Rng}; use tiktoken_rs::CoreBPE as TiktokenTokenizer; use tokenizers::pre_tokenizers::byte_level::ByteLevel as HuggingfaceByteLevel; use tokenizers::tokenizer::Tokenizer as HuggingfaceTokenizer; @@ -31,46 +29,6 @@ pub static TOKENIZERS: LazyLock< ] }); -pub fn is_char_boundary(b: u8) -> bool { - // Single byte encodings satisfy the bit pattern 0xxxxxxx, i.e. b < 128 - // Continuation bytes satisfy the bit pattern 10xxxxxx, i.e. b < 192 - // The rest are bytes belonging to the first byte of multi byte encodings (11xxxxxx): b >= 192 - // When interpreting the byte representation as signed integers, then numbers in the range 128..192 - // correspond to the smallest representable numbers. I.e. the two ranges [0, 128) and [192, 256) can - // be tested with a single signed comparison. - b as i8 >= -0x40 // NB: b < 128 || b >= 192 -} - -pub fn create_test_string(bpe: &BytePairEncoding, tokens: usize) -> String { - use rand::{thread_rng, Rng}; - let mut text = String::new(); - for _ in 0..tokens { - loop { - let i = thread_rng().gen_range(0..bpe.num_tokens()); - let s = bpe.token_bytes(i as u32); - if s.iter().all(|b| is_char_boundary(*b)) { - if let Ok(s) = std::str::from_utf8(s) { - text.push_str(s); - break; - } - } - } - } - text -} - -pub fn select_test_bytes(input: &[u8], bytes: usize) -> &[u8] { - let mut start = thread_rng().gen_range(0..input.len() - bytes); - while start > 0 && !is_char_boundary(input[start]) { - start -= 1; - } - let mut end = start + bytes; - while end < input.len() && !is_char_boundary(input[end]) { - end += 1; - } - &input[start..end] -} - pub fn without_pretokenizer(enc: &HuggingfaceTokenizer) -> HuggingfaceTokenizer { let mut enc = enc.clone(); // boolean values taken from Xenova's tokenizer config diff --git a/crates/bpe/benchmarks/performance.rs b/crates/bpe/benchmarks/performance.rs index 4ec973e..b3b4d59 100644 --- a/crates/bpe/benchmarks/performance.rs +++ b/crates/bpe/benchmarks/performance.rs @@ -1,9 +1,9 @@ use std::time::Duration; use bpe::appendable_encoder::AppendableEncoder; +use bpe::byte_pair_encoding::{create_test_string, select_test_string}; use bpe::interval_encoding::IntervalEncoding; use bpe_benchmarks::*; -use bpe_tests::create_test_bytes; use criterion::{ criterion_group, criterion_main, AxisScale, BenchmarkId, Criterion, PlotConfiguration, }; @@ -11,8 +11,8 @@ use rand::{thread_rng, Rng}; fn counting_benchmark(c: &mut Criterion) { for (name, bpe, _, _) in TOKENIZERS.iter() { - let input = create_test_bytes(&bpe.bpe, 20000); - let fast = IntervalEncoding::new(&bpe.bpe, &input); + let input = create_test_string(&bpe.bpe, 80000); + let fast = IntervalEncoding::new(&bpe.bpe, input.as_bytes()); let mut group = c.benchmark_group(format!("counting-{name}")); group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); @@ -31,7 +31,7 @@ fn counting_benchmark(c: &mut Criterion) { |b, bytes| { b.iter_batched( || thread_rng().gen_range(0..input.len() - bytes), - |start| bpe.bpe.count(&input[start..start + bytes]), + |start| bpe.bpe.count(&input.as_bytes()[start..start + bytes]), criterion::BatchSize::SmallInput, ) }, @@ -45,8 +45,7 @@ fn encoding_benchmark(c: &mut Criterion) { for (name, bpe, _, huggingface) in TOKENIZERS.iter() { let huggingface = without_pretokenizer(huggingface); - let text = create_test_string(&bpe.bpe, 20000); - let input = text.as_bytes(); + let text = create_test_string(&bpe.bpe, 80_000); let mut group = c.benchmark_group(format!("encoding-{name}")); group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); @@ -57,37 +56,37 @@ fn encoding_benchmark(c: &mut Criterion) { &bytes, |b, bytes| { b.iter_batched( - || select_test_bytes(input, *bytes), - |input| bpe.bpe.encode_via_backtracking(input), + || select_test_string(&text, *bytes), + |text| bpe.bpe.encode_via_backtracking(text.as_bytes()), criterion::BatchSize::SmallInput, ) }, ); group.bench_with_input(BenchmarkId::new("heap", bytes), &bytes, |b, bytes| { b.iter_batched( - || select_test_bytes(input, *bytes), - |input| bpe.bpe.encode_via_bitfield(input), + || select_test_string(&text, *bytes), + |text| bpe.bpe.encode_via_bitfield(text.as_bytes()), criterion::BatchSize::SmallInput, ) }); group.bench_with_input(BenchmarkId::new("table", bytes), &bytes, |b, bytes| { b.iter_batched( - || select_test_bytes(input, *bytes), - |input| bpe.bpe.encode_via_table(input), + || select_test_string(&text, *bytes), + |text| bpe.bpe.encode_via_table(text.as_bytes()), criterion::BatchSize::SmallInput, ) }); group.bench_with_input(BenchmarkId::new("greedy", bytes), &bytes, |b, bytes| { b.iter_batched( - || select_test_bytes(input, *bytes), - |input| bpe.bpe.encode_greedy(input), + || select_test_string(&text, *bytes), + |text| bpe.bpe.encode_greedy(text.as_bytes()), criterion::BatchSize::SmallInput, ) }); group.bench_with_input(BenchmarkId::new("minimal", bytes), &bytes, |b, bytes| { b.iter_batched( - || select_test_bytes(input, *bytes), - |input| bpe.bpe.encode_minimal(input), + || select_test_string(&text, *bytes), + |text| bpe.bpe.encode_minimal(text.as_bytes()), criterion::BatchSize::SmallInput, ) }); @@ -96,7 +95,7 @@ fn encoding_benchmark(c: &mut Criterion) { &bytes, |b, bytes| { b.iter_batched( - || std::str::from_utf8(select_test_bytes(input, *bytes)).unwrap(), + || select_test_string(&text, *bytes), |text| huggingface.encode_fast(text, false).unwrap(), criterion::BatchSize::SmallInput, ) @@ -109,7 +108,7 @@ fn encoding_benchmark(c: &mut Criterion) { fn appending_benchmark(c: &mut Criterion) { for (name, bpe, _, _) in TOKENIZERS.iter() { - let input = create_test_bytes(&bpe.bpe, 20000); + let text = create_test_string(&bpe.bpe, 80_000); let mut group = c.benchmark_group(format!("appending-{name}")); group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); @@ -120,10 +119,10 @@ fn appending_benchmark(c: &mut Criterion) { || { ( AppendableEncoder::new(&bpe.bpe), - select_test_bytes(&input, *bytes), + select_test_string(&text, *bytes), ) }, - |(mut enc, input)| enc.extend(input.iter().copied()), + |(mut enc, text)| enc.extend(text.as_bytes().iter().copied()), criterion::BatchSize::SmallInput, ) }); @@ -132,8 +131,8 @@ fn appending_benchmark(c: &mut Criterion) { &bytes, |b, bytes| { b.iter_batched( - || select_test_bytes(&input, *bytes), - |input| bpe.bpe.count(input), + || select_test_string(&text, *bytes), + |text| bpe.bpe.count(text.as_bytes()), criterion::BatchSize::SmallInput, ) }, @@ -145,8 +144,7 @@ fn appending_benchmark(c: &mut Criterion) { fn comparison_benchmark(c: &mut Criterion) { for (name, bpe, tiktoken, huggingface) in TOKENIZERS.iter() { - let text = create_test_string(&bpe.bpe, 20000); - let input = text.as_bytes(); + let text = create_test_string(&bpe.bpe, 80_000); let mut group = c.benchmark_group(format!("comparison-{name}")); group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); @@ -157,7 +155,7 @@ fn comparison_benchmark(c: &mut Criterion) { &bytes, |b, bytes| { b.iter_batched( - || std::str::from_utf8(select_test_bytes(input, *bytes)).unwrap(), + || select_test_string(&text, *bytes), |text| bpe.encode(text), criterion::BatchSize::SmallInput, ) @@ -165,7 +163,7 @@ fn comparison_benchmark(c: &mut Criterion) { ); group.bench_with_input(BenchmarkId::new("tiktoken", bytes), &bytes, |b, bytes| { b.iter_batched( - || std::str::from_utf8(select_test_bytes(input, *bytes)).unwrap(), + || select_test_string(&text, *bytes), |text| tiktoken.encode_ordinary(text), criterion::BatchSize::SmallInput, ) @@ -175,7 +173,7 @@ fn comparison_benchmark(c: &mut Criterion) { &bytes, |b, bytes| { b.iter_batched( - || std::str::from_utf8(select_test_bytes(input, *bytes)).unwrap(), + || select_test_string(&text, *bytes), |text| huggingface.encode_fast(text, false).unwrap(), criterion::BatchSize::SmallInput, ) @@ -189,7 +187,6 @@ fn comparison_benchmark(c: &mut Criterion) { fn worstcase_comparison_benchmark(c: &mut Criterion) { for (name, bpe, tiktoken, huggingface) in TOKENIZERS.iter() { let text: String = ('\0'..char::MAX).filter(|c| !c.is_whitespace()).collect(); - let input = text.as_bytes(); let mut group = c.benchmark_group(format!("worstcase-{name}")); for bytes in [10, 100, 1000, 5000, 10000, 25000, 50000, 75000, 100000] { @@ -199,7 +196,7 @@ fn worstcase_comparison_benchmark(c: &mut Criterion) { &bytes, |b, bytes| { b.iter_batched( - || std::str::from_utf8(select_test_bytes(input, *bytes)).unwrap(), + || select_test_string(&text, *bytes), |text| bpe.encode(text), criterion::BatchSize::SmallInput, ) @@ -207,7 +204,7 @@ fn worstcase_comparison_benchmark(c: &mut Criterion) { ); group.bench_with_input(BenchmarkId::new("tiktoken", bytes), &bytes, |b, bytes| { b.iter_batched( - || std::str::from_utf8(select_test_bytes(input, *bytes)).unwrap(), + || select_test_string(&text, *bytes), |text| tiktoken.encode_ordinary(text), criterion::BatchSize::SmallInput, ) @@ -217,7 +214,7 @@ fn worstcase_comparison_benchmark(c: &mut Criterion) { &bytes, |b, bytes| { b.iter_batched( - || std::str::from_utf8(select_test_bytes(input, *bytes)).unwrap(), + || select_test_string(&text, *bytes), |text| huggingface.encode_fast(text, false).unwrap(), criterion::BatchSize::SmallInput, ) diff --git a/crates/bpe/images/performance-appending.svg b/crates/bpe/images/performance-appending.svg index 68b4865..f0d1b69 100644 --- a/crates/bpe/images/performance-appending.svg +++ b/crates/bpe/images/performance-appending.svg @@ -34,17 +34,17 @@ - - - - - + + + + + - - - - - + + + + + diff --git a/crates/bpe/images/performance-comparison.svg b/crates/bpe/images/performance-comparison.svg index ec6c3b7..a6c89f7 100644 --- a/crates/bpe/images/performance-comparison.svg +++ b/crates/bpe/images/performance-comparison.svg @@ -1,54 +1,58 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - + + + + + - - - - - + + + + + - - - - - - + + + + + + - - + + diff --git a/crates/bpe/images/performance-counting.svg b/crates/bpe/images/performance-counting.svg index d3d5296..2dff836 100644 --- a/crates/bpe/images/performance-counting.svg +++ b/crates/bpe/images/performance-counting.svg @@ -30,17 +30,17 @@ - - - - - + + + + + - - - - - + + + + + diff --git a/crates/bpe/images/performance-encoding.svg b/crates/bpe/images/performance-encoding.svg index ff8ec1a..a45eec9 100644 --- a/crates/bpe/images/performance-encoding.svg +++ b/crates/bpe/images/performance-encoding.svg @@ -34,41 +34,41 @@ - - - - - + + + + + - - - - - + + + + + - - - - - + + + + + - - - - - + + + + + - - - - - + + + + + - - - - - + + + + + diff --git a/crates/bpe/images/performance-worstcase.svg b/crates/bpe/images/performance-worstcase.svg index 03f6d3f..132b8b3 100644 --- a/crates/bpe/images/performance-worstcase.svg +++ b/crates/bpe/images/performance-worstcase.svg @@ -4,24 +4,30 @@ - - - - - - + + + + + + + + + - - - - - - - + + + + + + + + + + - + @@ -46,38 +52,38 @@ - - - - - - - - - - + + + + + + + + + + - - - - - - - - - - + + + + + + + + + + - + - - - - - - - - + + + + + + + + diff --git a/crates/bpe/src/byte_pair_encoding.rs b/crates/bpe/src/byte_pair_encoding.rs index eede9fd..695e549 100644 --- a/crates/bpe/src/byte_pair_encoding.rs +++ b/crates/bpe/src/byte_pair_encoding.rs @@ -552,3 +552,46 @@ impl BytePairEncoding { encoded } } + +/// Generate a test string by concatenating random tokens. +#[cfg(feature = "rand")] +pub fn create_test_string(bpe: &BytePairEncoding, min_bytes: usize) -> String { + use rand::{thread_rng, Rng}; + let mut result = String::new(); + while result.len() < min_bytes { + let i = thread_rng().gen_range(0..bpe.num_tokens()); + // We only use tokens that are valid UTF-8. This is true for ~99% of tokens in OpenAI's + // token set. The chance of constructing a valid UTF-8 character across a token boundary + // by picking random tokens is so small that it is unlikely to happen anyway. + if let Ok(token) = std::str::from_utf8(bpe.token_bytes(i as u32)) { + result.push_str(token); + } + } + result +} + +#[cfg(feature = "rand")] +pub fn select_test_string(text: &str, min_bytes: usize) -> &str { + use rand::{thread_rng, Rng}; + let mut start = thread_rng().gen_range(0..text.len() - min_bytes); + while !text.is_char_boundary(start) { + start -= 1; + } + let mut end = start + min_bytes; + while !text.is_char_boundary(end) { + end += 1; + } + &text[start..end] +} + +/// Generate test bytes by concatenating random tokens. +#[cfg(feature = "rand")] +pub fn create_test_bytes(bpe: &BytePairEncoding, min_bytes: usize) -> Vec { + use rand::{thread_rng, Rng}; + let mut result = Vec::new(); + while result.len() < min_bytes { + let i = thread_rng().gen_range(0..bpe.num_tokens()); + result.extend(bpe.token_bytes(i as u32)); + } + result +} diff --git a/crates/bpe/tests/Cargo.toml b/crates/bpe/tests/Cargo.toml index 7c6ce69..dcfed3e 100644 --- a/crates/bpe/tests/Cargo.toml +++ b/crates/bpe/tests/Cargo.toml @@ -3,7 +3,7 @@ name = "bpe-tests" edition = "2021" [dependencies] -bpe = { path = "../../bpe" } +bpe = { path = "../../bpe", features = ["rand"] } bpe-openai = { path = "../../bpe-openai" } itertools = "0.13" rand = "0.8" diff --git a/crates/bpe/tests/src/lib.rs b/crates/bpe/tests/src/lib.rs index 9c02773..eccb548 100644 --- a/crates/bpe/tests/src/lib.rs +++ b/crates/bpe/tests/src/lib.rs @@ -1,31 +1,14 @@ -use bpe::byte_pair_encoding::BytePairEncoding; -use rand::{thread_rng, Rng}; - -pub fn create_test_bytes(bpe: &BytePairEncoding, tokens: usize) -> Vec { - let mut text = vec![]; - for _ in 0..tokens { - let i = thread_rng().gen_range(0..bpe.num_tokens()); - let s = bpe.token_bytes(i as u32); - text.extend_from_slice(s); - } - text -} - #[cfg(test)] mod tests { - use std::time::Instant; - use itertools::Itertools; use rand::{thread_rng, Rng}; - use tiktoken_rs::{cl100k_base_singleton, o200k_base_singleton}; + use tiktoken_rs::cl100k_base_singleton; use bpe::appendable_encoder::AppendableEncoder; - use bpe::byte_pair_encoding::BytePairEncoding; + use bpe::byte_pair_encoding::{create_test_bytes, BytePairEncoding}; use bpe::interval_encoding::IntervalEncoding; use bpe::prependable_encoder::PrependableEncoder; - use bpe_openai::{cl100k_base, o200k_base}; - - use super::*; + use bpe_openai::cl100k_base; /// This test produces the output for the encoding example in the README. #[test] @@ -87,74 +70,48 @@ mod tests { fn test_appendable_encoder() { let bpe = &cl100k_base().bpe; let mut enc = AppendableEncoder::new(bpe); - let input_string = create_test_bytes(bpe, 100); - for (i, c) in input_string.iter().enumerate() { - assert_eq!(enc.token_count(), bpe.count(&input_string[0..i])); - enc.push(*c); + let input = create_test_bytes(bpe, 100); + for (i, b) in input.iter().enumerate() { + enc.push(*b); + assert_eq!(enc.token_count(), bpe.count(&input[0..i + 1])); } } #[test] - fn test_correctness_cl100k() { + fn test_correctness() { // This is quite a challenging test case... - let test_string = std::str::from_utf8(&[ + let input = std::str::from_utf8(&[ 125, 34, 10, 10, 46, 109, 107, 100, 105, 114, 115, 32, 102, 100, 115, 32, 97, 100, 105, 112, 105, 115, 105, 99, 105, 110, 103, 105, 116, 121, 69, 110, 103, 105, 110, 101, 32, 69, 67, 105, 114, 105, 101, 32, 111, 112, 116, 105, 109, 97, 108, 95, 68, 65, 32, 111, 102, 102, 101, 110, 100, ]) .unwrap(); - let time = Instant::now(); let bpe = &cl100k_base().bpe; - println!("{:?}", time.elapsed()); let encoded1 = cl100k_base_singleton() .lock() - .encode_ordinary(test_string) - .into_iter() - .collect_vec(); - let encoded2 = bpe.encode_via_backtracking(test_string.as_bytes()); - assert_eq!(encoded1, encoded2); - let encoded3 = bpe.encode_via_table(test_string.as_bytes()); - assert_eq!(encoded1, encoded3); - let encoded4 = bpe.encode_via_bitfield(test_string.as_bytes()); - assert_eq!(encoded1, encoded4); - } - - #[test] - fn test_correctness_o200k() { - // This is quite a challenging test case... - let test_string = std::str::from_utf8(&[ - 125, 34, 10, 10, 46, 109, 107, 100, 105, 114, 115, 32, 102, 100, 115, 32, 97, 100, 105, - 112, 105, 115, 105, 99, 105, 110, 103, 105, 116, 121, 69, 110, 103, 105, 110, 101, 32, - 69, 67, 105, 114, 105, 101, 32, 111, 112, 116, 105, 109, 97, 108, 95, 68, 65, 32, 111, - 102, 102, 101, 110, 100, - ]) - .unwrap(); - let time = Instant::now(); - let bpe = &o200k_base().bpe; - println!("{:?}", time.elapsed()); - let encoded1 = o200k_base_singleton() - .lock() - .encode_ordinary(test_string) + .encode_ordinary(input) .into_iter() .collect_vec(); - let encoded2 = bpe.encode_via_backtracking(test_string.as_bytes()); + let encoded2 = bpe.encode_via_backtracking(input.as_bytes()); assert_eq!(encoded1, encoded2); - let encoded3 = bpe.encode_via_table(test_string.as_bytes()); + let encoded3 = bpe.encode_via_table(input.as_bytes()); assert_eq!(encoded1, encoded3); - let encoded4 = bpe.encode_via_bitfield(test_string.as_bytes()); + let encoded4 = bpe.encode_via_bitfield(input.as_bytes()); assert_eq!(encoded1, encoded4); } #[test] fn test_bpe_equivalence() { let bpe = &cl100k_base().bpe; - for tokens in [10, 1000, 10000] { - for _ in 0..5 { - let test_input = create_test_bytes(bpe, tokens); - let encoded1 = bpe.encode_via_backtracking(&test_input); - let encoded2 = bpe.encode_via_bitfield(&test_input); + for bytes in [10, 1000, 10000] { + for _ in 0..8 { + let input = create_test_bytes(bpe, bytes); + let encoded1 = bpe.encode_via_backtracking(&input); + let encoded2 = bpe.encode_via_bitfield(&input); assert_eq!(encoded1, encoded2, "{} {}", encoded1.len(), encoded2.len()); + let encoded3 = bpe.encode_via_table(&input); + assert_eq!(encoded1, encoded3, "{} {}", encoded1.len(), encoded3.len()); } } } @@ -162,15 +119,15 @@ mod tests { #[test] fn test_interval_count() { let bpe = &cl100k_base().bpe; - let text = create_test_bytes(bpe, 10000); - let intervals = IntervalEncoding::new(bpe, &text); + let input = create_test_bytes(bpe, 10000); + let intervals = IntervalEncoding::new(bpe, &input); for _ in 0..1000 { - let start = thread_rng().gen_range(0..text.len()); - let end = thread_rng().gen_range(0..text.len()); + let start = thread_rng().gen_range(0..input.len()); + let end = thread_rng().gen_range(0..input.len()); let range = start.min(end)..start.max(end); assert_eq!( intervals.count(range.clone()), - bpe.encode_via_backtracking(&text[range]).len() + bpe.encode_via_backtracking(&input[range]).len() ); } } @@ -179,10 +136,10 @@ mod tests { fn test_prependable_encoder() { let bpe = &cl100k_base().bpe; let mut enc = PrependableEncoder::new(bpe); - let input_string = create_test_bytes(bpe, 100); - for (i, c) in input_string.iter().enumerate().rev() { - enc.push(*c); - assert_eq!(enc.token_count(), bpe.count(&input_string[i..])); + let input = create_test_bytes(bpe, 100); + for (i, b) in input.iter().enumerate().rev() { + enc.push(*b); + assert_eq!(enc.token_count(), bpe.count(&input[i..])); } } }