11#[ cfg( test) ]
22mod tests {
3- use std:: time:: Instant ;
4-
53 use itertools:: Itertools ;
64 use rand:: { thread_rng, Rng } ;
7- use tiktoken_rs:: { cl100k_base_singleton, o200k_base_singleton } ;
5+ use tiktoken_rs:: cl100k_base_singleton;
86
97 use bpe:: appendable_encoder:: AppendableEncoder ;
10- use bpe:: byte_pair_encoding:: { create_test_string , BytePairEncoding } ;
8+ use bpe:: byte_pair_encoding:: { create_test_bytes , BytePairEncoding } ;
119 use bpe:: interval_encoding:: IntervalEncoding ;
1210 use bpe:: prependable_encoder:: PrependableEncoder ;
13- use bpe_openai:: { cl100k_base, o200k_base } ;
11+ use bpe_openai:: cl100k_base;
1412
1513 /// This test produces the output for the encoding example in the README.
1614 #[ test]
@@ -72,93 +70,64 @@ mod tests {
7270 fn test_appendable_encoder ( ) {
7371 let bpe = & cl100k_base ( ) . bpe ;
7472 let mut enc = AppendableEncoder :: new ( bpe) ;
75- let input_string = create_test_string ( bpe, 100 ) ;
76- for ( i, b) in input_string . as_bytes ( ) . iter ( ) . enumerate ( ) {
73+ let input = create_test_bytes ( bpe, 100 ) ;
74+ for ( i, b) in input . iter ( ) . enumerate ( ) {
7775 enc. push ( * b) ;
78- assert_eq ! (
79- enc. token_count( ) ,
80- bpe. count( & input_string. as_bytes( ) [ 0 ..i + 1 ] )
81- ) ;
76+ assert_eq ! ( enc. token_count( ) , bpe. count( & input[ 0 ..i + 1 ] ) ) ;
8277 }
8378 }
8479
8580 #[ test]
86- fn test_correctness_cl100k ( ) {
81+ fn test_correctness ( ) {
8782 // This is quite a challenging test case...
88- let test_string = std:: str:: from_utf8 ( & [
83+ let input = std:: str:: from_utf8 ( & [
8984 125 , 34 , 10 , 10 , 46 , 109 , 107 , 100 , 105 , 114 , 115 , 32 , 102 , 100 , 115 , 32 , 97 , 100 , 105 ,
9085 112 , 105 , 115 , 105 , 99 , 105 , 110 , 103 , 105 , 116 , 121 , 69 , 110 , 103 , 105 , 110 , 101 , 32 ,
9186 69 , 67 , 105 , 114 , 105 , 101 , 32 , 111 , 112 , 116 , 105 , 109 , 97 , 108 , 95 , 68 , 65 , 32 , 111 ,
9287 102 , 102 , 101 , 110 , 100 ,
9388 ] )
9489 . unwrap ( ) ;
95- let time = Instant :: now ( ) ;
9690 let bpe = & cl100k_base ( ) . bpe ;
97- println ! ( "{:?}" , time. elapsed( ) ) ;
9891 let encoded1 = cl100k_base_singleton ( )
9992 . lock ( )
100- . encode_ordinary ( test_string)
101- . into_iter ( )
102- . collect_vec ( ) ;
103- let encoded2 = bpe. encode_via_backtracking ( test_string. as_bytes ( ) ) ;
104- assert_eq ! ( encoded1, encoded2) ;
105- let encoded3 = bpe. encode_via_table ( test_string. as_bytes ( ) ) ;
106- assert_eq ! ( encoded1, encoded3) ;
107- let encoded4 = bpe. encode_via_bitfield ( test_string. as_bytes ( ) ) ;
108- assert_eq ! ( encoded1, encoded4) ;
109- }
110-
111- #[ test]
112- fn test_correctness_o200k ( ) {
113- // This is quite a challenging test case...
114- let test_string = std:: str:: from_utf8 ( & [
115- 125 , 34 , 10 , 10 , 46 , 109 , 107 , 100 , 105 , 114 , 115 , 32 , 102 , 100 , 115 , 32 , 97 , 100 , 105 ,
116- 112 , 105 , 115 , 105 , 99 , 105 , 110 , 103 , 105 , 116 , 121 , 69 , 110 , 103 , 105 , 110 , 101 , 32 ,
117- 69 , 67 , 105 , 114 , 105 , 101 , 32 , 111 , 112 , 116 , 105 , 109 , 97 , 108 , 95 , 68 , 65 , 32 , 111 ,
118- 102 , 102 , 101 , 110 , 100 ,
119- ] )
120- . unwrap ( ) ;
121- let time = Instant :: now ( ) ;
122- let bpe = & o200k_base ( ) . bpe ;
123- println ! ( "{:?}" , time. elapsed( ) ) ;
124- let encoded1 = o200k_base_singleton ( )
125- . lock ( )
126- . encode_ordinary ( test_string)
93+ . encode_ordinary ( input)
12794 . into_iter ( )
12895 . collect_vec ( ) ;
129- let encoded2 = bpe. encode_via_backtracking ( test_string . as_bytes ( ) ) ;
96+ let encoded2 = bpe. encode_via_backtracking ( input . as_bytes ( ) ) ;
13097 assert_eq ! ( encoded1, encoded2) ;
131- let encoded3 = bpe. encode_via_table ( test_string . as_bytes ( ) ) ;
98+ let encoded3 = bpe. encode_via_table ( input . as_bytes ( ) ) ;
13299 assert_eq ! ( encoded1, encoded3) ;
133- let encoded4 = bpe. encode_via_bitfield ( test_string . as_bytes ( ) ) ;
100+ let encoded4 = bpe. encode_via_bitfield ( input . as_bytes ( ) ) ;
134101 assert_eq ! ( encoded1, encoded4) ;
135102 }
136103
137104 #[ test]
138105 fn test_bpe_equivalence ( ) {
139106 let bpe = & cl100k_base ( ) . bpe ;
140107 for bytes in [ 10 , 1000 , 10000 ] {
141- for _ in 0 ..5 {
142- let test_input = create_test_string ( bpe, bytes) ;
143- let encoded1 = bpe. encode_via_backtracking ( test_input . as_bytes ( ) ) ;
144- let encoded2 = bpe. encode_via_bitfield ( test_input . as_bytes ( ) ) ;
108+ for _ in 0 ..8 {
109+ let input = create_test_bytes ( bpe, bytes) ;
110+ let encoded1 = bpe. encode_via_backtracking ( & input ) ;
111+ let encoded2 = bpe. encode_via_bitfield ( & input ) ;
145112 assert_eq ! ( encoded1, encoded2, "{} {}" , encoded1. len( ) , encoded2. len( ) ) ;
113+ let encoded3 = bpe. encode_via_table ( & input) ;
114+ assert_eq ! ( encoded1, encoded3, "{} {}" , encoded1. len( ) , encoded3. len( ) ) ;
146115 }
147116 }
148117 }
149118
150119 #[ test]
151120 fn test_interval_count ( ) {
152121 let bpe = & cl100k_base ( ) . bpe ;
153- let text = create_test_string ( bpe, 10000 ) ;
154- let intervals = IntervalEncoding :: new ( bpe, text . as_bytes ( ) ) ;
122+ let input = create_test_bytes ( bpe, 10000 ) ;
123+ let intervals = IntervalEncoding :: new ( bpe, & input ) ;
155124 for _ in 0 ..1000 {
156- let start = thread_rng ( ) . gen_range ( 0 ..text . len ( ) ) ;
157- let end = thread_rng ( ) . gen_range ( 0 ..text . len ( ) ) ;
125+ let start = thread_rng ( ) . gen_range ( 0 ..input . len ( ) ) ;
126+ let end = thread_rng ( ) . gen_range ( 0 ..input . len ( ) ) ;
158127 let range = start. min ( end) ..start. max ( end) ;
159128 assert_eq ! (
160129 intervals. count( range. clone( ) ) ,
161- bpe. encode_via_backtracking( & text . as_bytes ( ) [ range] ) . len( )
130+ bpe. encode_via_backtracking( & input [ range] ) . len( )
162131 ) ;
163132 }
164133 }
@@ -167,10 +136,10 @@ mod tests {
167136 fn test_prependable_encoder ( ) {
168137 let bpe = & cl100k_base ( ) . bpe ;
169138 let mut enc = PrependableEncoder :: new ( bpe) ;
170- let input_string = create_test_string ( bpe, 100 ) ;
171- for ( i, b) in input_string . as_bytes ( ) . iter ( ) . enumerate ( ) . rev ( ) {
139+ let input = create_test_bytes ( bpe, 100 ) ;
140+ for ( i, b) in input . iter ( ) . enumerate ( ) . rev ( ) {
172141 enc. push ( * b) ;
173- assert_eq ! ( enc. token_count( ) , bpe. count( & input_string . as_bytes ( ) [ i..] ) ) ;
142+ assert_eq ! ( enc. token_count( ) , bpe. count( & input [ i..] ) ) ;
174143 }
175144 }
176145}
0 commit comments