@@ -567,20 +567,47 @@ fn is_char_boundary(b: u8) -> bool {
567567#[ cfg( feature = "rand" ) ]
568568pub fn create_test_string ( bpe : & BytePairEncoding , min_bytes : usize ) -> String {
569569 use rand:: { thread_rng, Rng } ;
570- let mut text = String :: new ( ) ;
571- while text. len ( ) < min_bytes {
572- loop {
570+ // the bytes we accumulated thus far
571+ let mut bytes = Vec :: new ( ) ;
572+ // the tokens we added so we can backtrack
573+ let mut tokens = Vec :: new ( ) ;
574+ // the number of valid UTF-8 bytes
575+ let mut valid_bytes = 0 ;
576+ ' keep: while valid_bytes < min_bytes {
577+ // try a few times to find a suitable token
578+ for _ in 0 ..8 {
579+ // pick a random token and provisionally add it
573580 let i = thread_rng ( ) . gen_range ( 0 ..bpe. num_tokens ( ) ) ;
574- let s = bpe. token_bytes ( i as u32 ) ;
575- if s. iter ( ) . all ( |b| is_char_boundary ( * b) ) {
576- if let Ok ( s) = std:: str:: from_utf8 ( s) {
577- text. push_str ( s) ;
578- break ;
579- }
581+ bytes. extend ( bpe. token_bytes ( i as u32 ) ) ;
582+ // test if the additional bytes are valid utf-8
583+ // the last character is not included, because it may be incomplete
584+ let last = bytes
585+ . iter ( )
586+ . rev ( )
587+ . find_position ( |b| is_char_boundary ( * * b) )
588+ . map_or ( 0 , |( offset, _) | bytes. len ( ) - ( offset + 1 ) ) ;
589+ assert ! ( last >= valid_bytes) ;
590+ if std:: str:: from_utf8 ( & bytes[ valid_bytes..last] ) . is_ok ( ) {
591+ tokens. push ( i) ;
592+ valid_bytes = last;
593+ continue ' keep;
594+ } else {
595+ bytes. truncate ( bytes. len ( ) - bpe. token_len ( i as u32 ) ) ;
580596 }
581597 }
598+ // we didn't find anything after a few tries, backtrack
599+ if let Some ( i) = tokens. pop ( ) {
600+ bytes. truncate ( bytes. len ( ) - bpe. token_len ( i as u32 ) ) ;
601+ valid_bytes = bytes
602+ . iter ( )
603+ . rev ( )
604+ . find_position ( |b| is_char_boundary ( * * b) )
605+ . map_or ( 0 , |( offset, _) | bytes. len ( ) - ( offset + 1 ) ) ;
606+ }
582607 }
583- text
608+ // truncate to the know valid bytes
609+ bytes. truncate ( valid_bytes) ;
610+ String :: from_utf8 ( bytes) . expect ( "should be valid here" )
584611}
585612
586613#[ cfg( feature = "rand" ) ]
0 commit comments