Skip to content

Commit d430615

Browse files
author
Hendrik van Antwerpen
committed
Generate test strings with multi-byte characters
1 parent 0907c88 commit d430615

File tree

1 file changed

+37
-10
lines changed

1 file changed

+37
-10
lines changed

crates/bpe/src/byte_pair_encoding.rs

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -567,20 +567,47 @@ fn is_char_boundary(b: u8) -> bool {
567567
#[cfg(feature = "rand")]
568568
pub fn create_test_string(bpe: &BytePairEncoding, min_bytes: usize) -> String {
569569
use rand::{thread_rng, Rng};
570-
let mut text = String::new();
571-
while text.len() < min_bytes {
572-
loop {
570+
// the bytes we accumulated thus far
571+
let mut bytes = Vec::new();
572+
// the tokens we added so we can backtrack
573+
let mut tokens = Vec::new();
574+
// the number of valid UTF-8 bytes
575+
let mut valid_bytes = 0;
576+
'keep: while valid_bytes < min_bytes {
577+
// try a few times to find a suitable token
578+
for _ in 0..8 {
579+
// pick a random token and provisionally add it
573580
let i = thread_rng().gen_range(0..bpe.num_tokens());
574-
let s = bpe.token_bytes(i as u32);
575-
if s.iter().all(|b| is_char_boundary(*b)) {
576-
if let Ok(s) = std::str::from_utf8(s) {
577-
text.push_str(s);
578-
break;
579-
}
581+
bytes.extend(bpe.token_bytes(i as u32));
582+
// test if the additional bytes are valid utf-8
583+
// the last character is not included, because it may be incomplete
584+
let last = bytes
585+
.iter()
586+
.rev()
587+
.find_position(|b| is_char_boundary(**b))
588+
.map_or(0, |(offset, _)| bytes.len() - (offset + 1));
589+
assert!(last >= valid_bytes);
590+
if std::str::from_utf8(&bytes[valid_bytes..last]).is_ok() {
591+
tokens.push(i);
592+
valid_bytes = last;
593+
continue 'keep;
594+
} else {
595+
bytes.truncate(bytes.len() - bpe.token_len(i as u32));
580596
}
581597
}
598+
// we didn't find anything after a few tries, backtrack
599+
if let Some(i) = tokens.pop() {
600+
bytes.truncate(bytes.len() - bpe.token_len(i as u32));
601+
valid_bytes = bytes
602+
.iter()
603+
.rev()
604+
.find_position(|b| is_char_boundary(**b))
605+
.map_or(0, |(offset, _)| bytes.len() - (offset + 1));
606+
}
582607
}
583-
text
608+
// truncate to the know valid bytes
609+
bytes.truncate(valid_bytes);
610+
String::from_utf8(bytes).expect("should be valid here")
584611
}
585612

586613
#[cfg(feature = "rand")]

0 commit comments

Comments
 (0)