Skip to content

Commit b3feef9

Browse files
author
Hendrik van Antwerpen
committed
Verify that tokens are valid
1 parent 8ecf192 commit b3feef9

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

crates/bpe/src/byte_pair_encoding.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ impl BytePairEncoding {
303303
split_table.push((id as u32, id as u32));
304304
}
305305
}
306-
Self {
306+
let bpe = Self {
307307
all_tokens,
308308
token_starts,
309309
bytes_hash_to_token,
@@ -314,7 +314,17 @@ impl BytePairEncoding {
314314
pair_lookup,
315315
split_table,
316316
hash_factor,
317+
};
318+
for token_id in 0..bpe.num_tokens() as u32 {
319+
let bytes = bpe.token_bytes(token_id);
320+
let tokens = bpe.encode_via_bitfield(bytes);
321+
assert_eq!(
322+
tokens,
323+
vec![token_id],
324+
"token {token_id} with bytes {bytes:?} encodes to {tokens:?} instead of to itself"
325+
);
317326
}
327+
bpe
318328
}
319329

320330
/// Return the number of tokens in this BPE dictionary.

crates/bpe/tests/src/lib.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,16 @@ mod tests {
3030
/// This test produces the output for the encoding example in the README.
3131
#[test]
3232
fn readme_example() {
33-
let tokens = ["a", "b", "c", "ab", "cb", "ac"].map(|t| t.as_bytes().to_vec());
34-
let bpe = BytePairEncoding::from_dictionary(tokens, None);
35-
let text = "abacb";
33+
let tokens = ["a", "b", "c", "ab", "cb", "ac", "bb", "cbb", "acbb"];
34+
let bpe = BytePairEncoding::from_dictionary(tokens.map(|t| t.as_bytes().to_vec()), None);
35+
let text = "abacbb";
3636
let prefixes = (1..=text.len()).map(|end| &text[..end]).collect_vec();
3737
let all_prefix_tokens = prefixes
3838
.iter()
3939
.map(|prefix| {
4040
bpe.encode_via_backtracking(prefix.as_bytes())
4141
.into_iter()
42-
.map(|t| unsafe { String::from_utf8_unchecked(bpe.decode_tokens(&[t])) })
42+
.map(|t| String::from_utf8(bpe.decode_tokens(&[t])).unwrap())
4343
.collect_vec()
4444
})
4545
.collect_vec();
@@ -48,6 +48,8 @@ mod tests {
4848
.map(|tokens| tokens.last().unwrap())
4949
.collect_vec();
5050

51+
println!("Token set: `{}`\n", tokens.join(" "));
52+
5153
println!("All tokens for each prefix of `{text}`:\n");
5254
for (prefix, tokens) in prefixes.iter().zip(&all_prefix_tokens) {
5355
println!(
@@ -67,7 +69,7 @@ mod tests {
6769
}
6870
println!();
6971

70-
println!("Tokenization of `{text}`:\n");
72+
println!("Encoding using last tokens of `{text}`:\n");
7173
let mut remaining = text.len();
7274
while remaining > 0 {
7375
let prefix = &text[..remaining];

0 commit comments

Comments
 (0)