Skip to content

Commit 0cb520e

Browse files
committedOct 21, 2024
Merge branch 'move-equivalence-tests' into unsplittable-test-strings
2 parents 5b7d913 + b42989e commit 0cb520e

File tree

3 files changed

+56
-99
lines changed

3 files changed

+56
-99
lines changed
 

‎crates/bpe-openai/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ regex-automata = "0.4"
1919
rmp-serde = "1"
2020

2121
[dev-dependencies]
22-
tiktoken-rs = "0.6"
2322
bpe = { version = "0.1.0", path = "../bpe", features = ["rand"] }
23+
tiktoken-rs = "0.6"
2424

2525
[build-dependencies]
2626
base64 = "0.22.1"

‎crates/bpe/src/byte_pair_encoding.rs

+29-41
Original file line numberDiff line numberDiff line change
@@ -553,17 +553,6 @@ impl BytePairEncoding {
553553
}
554554
}
555555

556-
#[cfg(feature = "rand")]
557-
fn is_char_boundary(b: u8) -> bool {
558-
// Single byte encodings satisfy the bit pattern 0xxxxxxx, i.e. b < 128
559-
// Continuation bytes satisfy the bit pattern 10xxxxxx, i.e. b < 192
560-
// The rest are bytes belonging to the first byte of multi byte encodings (11xxxxxx): b >= 192
561-
// When interpreting the byte representation as signed integers, then numbers in the range 128..192
562-
// correspond to the smallest representable numbers. I.e. the two ranges [0, 128) and [192, 256) can
563-
// be tested with a single signed comparison.
564-
b as i8 >= -0x40 // NB: b < 128 || b >= 192
565-
}
566-
567556
/// Create a random test string for the given [`BytePairEncoding`]. The string will be at least [`min_bytes`] long.
568557
#[cfg(feature = "rand")]
569558
pub fn create_test_string(bpe: &BytePairEncoding, min_bytes: usize) -> String {
@@ -580,49 +569,36 @@ pub fn create_test_string_with_predicate(
580569
predicate: impl Fn(&str) -> bool,
581570
) -> String {
582571
use rand::{thread_rng, Rng};
583-
// the bytes we accumulated thus far
584-
let mut bytes = Vec::new();
572+
// the string we accumulated thus far
573+
let mut result = String::new();
585574
// the tokens we added so we can backtrack
586575
let mut tokens = Vec::new();
587-
// the number of valid UTF-8 bytes
588-
let mut valid_bytes = 0;
589-
'keep: while valid_bytes < min_bytes {
576+
'keep: while result.len() < min_bytes {
590577
// try a few times to find a suitable token
591-
for _ in 0..8 {
578+
'next: for _ in 0..8 {
592579
// pick a random token and provisionally add it
593-
let i = thread_rng().gen_range(0..bpe.num_tokens());
594-
bytes.extend(bpe.token_bytes(i as u32));
595-
// test if the additional bytes are valid utf-8
596-
// the last character is not included, because it may be incomplete
597-
let last = bytes
598-
.iter()
599-
.rev()
600-
.find_position(|b| is_char_boundary(**b))
601-
.map_or(0, |(offset, _)| bytes.len() - (offset + 1));
602-
assert!(last >= valid_bytes);
603-
if std::str::from_utf8(&bytes[valid_bytes..last]).is_ok()
604-
&& predicate(std::str::from_utf8(&bytes[0..last]).expect("should be valid"))
605-
{
580+
let i = thread_rng().gen_range(0..bpe.num_tokens()) as u32;
581+
// We only use tokens that are valid UTF-8. This is true for ~99% of tokens in OpenAI's
582+
// token set. The chance of constructing a valid UTF-8 character across a token boundary
583+
// by picking random tokens is so small that it is unlikely to happen anyway.
584+
if let Ok(token) = std::str::from_utf8(bpe.token_bytes(i)) {
585+
result.push_str(token);
586+
} else {
587+
continue 'next;
588+
}
589+
if predicate(&result) {
606590
tokens.push(i);
607-
valid_bytes = last;
608591
continue 'keep;
609592
} else {
610-
bytes.truncate(bytes.len() - bpe.token_len(i as u32));
593+
result.truncate(result.len() - bpe.token_len(i));
611594
}
612595
}
613596
// we didn't find anything after a few tries, backtrack
614597
if let Some(i) = tokens.pop() {
615-
bytes.truncate(bytes.len() - bpe.token_len(i as u32));
616-
valid_bytes = bytes
617-
.iter()
618-
.rev()
619-
.find_position(|b| is_char_boundary(**b))
620-
.map_or(0, |(offset, _)| bytes.len() - (offset + 1));
598+
result.truncate(result.len() - bpe.token_len(i));
621599
}
622600
}
623-
// truncate to the know valid bytes
624-
bytes.truncate(valid_bytes);
625-
String::from_utf8(bytes).expect("should be valid here")
601+
result
626602
}
627603

628604
#[cfg(feature = "rand")]
@@ -638,3 +614,15 @@ pub fn select_test_string(text: &str, min_bytes: usize) -> &str {
638614
}
639615
&text[start..end]
640616
}
617+
618+
/// Generate test bytes by concatenating random tokens.
619+
#[cfg(feature = "rand")]
620+
pub fn create_test_bytes(bpe: &BytePairEncoding, min_bytes: usize) -> Vec<u8> {
621+
use rand::{thread_rng, Rng};
622+
let mut result = Vec::new();
623+
while result.len() < min_bytes {
624+
let i = thread_rng().gen_range(0..bpe.num_tokens());
625+
result.extend(bpe.token_bytes(i as u32));
626+
}
627+
result
628+
}

‎crates/bpe/tests/src/lib.rs

+26-57
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
#[cfg(test)]
22
mod tests {
3-
use std::time::Instant;
4-
53
use itertools::Itertools;
64
use rand::{thread_rng, Rng};
7-
use tiktoken_rs::{cl100k_base_singleton, o200k_base_singleton};
5+
use tiktoken_rs::cl100k_base_singleton;
86

97
use bpe::appendable_encoder::AppendableEncoder;
10-
use bpe::byte_pair_encoding::{create_test_string, BytePairEncoding};
8+
use bpe::byte_pair_encoding::{create_test_bytes, BytePairEncoding};
119
use bpe::interval_encoding::IntervalEncoding;
1210
use bpe::prependable_encoder::PrependableEncoder;
13-
use bpe_openai::{cl100k_base, o200k_base};
11+
use bpe_openai::cl100k_base;
1412

1513
/// This test produces the output for the encoding example in the README.
1614
#[test]
@@ -72,93 +70,64 @@ mod tests {
7270
fn test_appendable_encoder() {
7371
let bpe = &cl100k_base().bpe;
7472
let mut enc = AppendableEncoder::new(bpe);
75-
let input_string = create_test_string(bpe, 100);
76-
for (i, b) in input_string.as_bytes().iter().enumerate() {
73+
let input = create_test_bytes(bpe, 100);
74+
for (i, b) in input.iter().enumerate() {
7775
enc.push(*b);
78-
assert_eq!(
79-
enc.token_count(),
80-
bpe.count(&input_string.as_bytes()[0..i + 1])
81-
);
76+
assert_eq!(enc.token_count(), bpe.count(&input[0..i + 1]));
8277
}
8378
}
8479

8580
#[test]
86-
fn test_correctness_cl100k() {
81+
fn test_correctness() {
8782
// This is quite a challenging test case...
88-
let test_string = std::str::from_utf8(&[
83+
let input = std::str::from_utf8(&[
8984
125, 34, 10, 10, 46, 109, 107, 100, 105, 114, 115, 32, 102, 100, 115, 32, 97, 100, 105,
9085
112, 105, 115, 105, 99, 105, 110, 103, 105, 116, 121, 69, 110, 103, 105, 110, 101, 32,
9186
69, 67, 105, 114, 105, 101, 32, 111, 112, 116, 105, 109, 97, 108, 95, 68, 65, 32, 111,
9287
102, 102, 101, 110, 100,
9388
])
9489
.unwrap();
95-
let time = Instant::now();
9690
let bpe = &cl100k_base().bpe;
97-
println!("{:?}", time.elapsed());
9891
let encoded1 = cl100k_base_singleton()
9992
.lock()
100-
.encode_ordinary(test_string)
101-
.into_iter()
102-
.collect_vec();
103-
let encoded2 = bpe.encode_via_backtracking(test_string.as_bytes());
104-
assert_eq!(encoded1, encoded2);
105-
let encoded3 = bpe.encode_via_table(test_string.as_bytes());
106-
assert_eq!(encoded1, encoded3);
107-
let encoded4 = bpe.encode_via_bitfield(test_string.as_bytes());
108-
assert_eq!(encoded1, encoded4);
109-
}
110-
111-
#[test]
112-
fn test_correctness_o200k() {
113-
// This is quite a challenging test case...
114-
let test_string = std::str::from_utf8(&[
115-
125, 34, 10, 10, 46, 109, 107, 100, 105, 114, 115, 32, 102, 100, 115, 32, 97, 100, 105,
116-
112, 105, 115, 105, 99, 105, 110, 103, 105, 116, 121, 69, 110, 103, 105, 110, 101, 32,
117-
69, 67, 105, 114, 105, 101, 32, 111, 112, 116, 105, 109, 97, 108, 95, 68, 65, 32, 111,
118-
102, 102, 101, 110, 100,
119-
])
120-
.unwrap();
121-
let time = Instant::now();
122-
let bpe = &o200k_base().bpe;
123-
println!("{:?}", time.elapsed());
124-
let encoded1 = o200k_base_singleton()
125-
.lock()
126-
.encode_ordinary(test_string)
93+
.encode_ordinary(input)
12794
.into_iter()
12895
.collect_vec();
129-
let encoded2 = bpe.encode_via_backtracking(test_string.as_bytes());
96+
let encoded2 = bpe.encode_via_backtracking(input.as_bytes());
13097
assert_eq!(encoded1, encoded2);
131-
let encoded3 = bpe.encode_via_table(test_string.as_bytes());
98+
let encoded3 = bpe.encode_via_table(input.as_bytes());
13299
assert_eq!(encoded1, encoded3);
133-
let encoded4 = bpe.encode_via_bitfield(test_string.as_bytes());
100+
let encoded4 = bpe.encode_via_bitfield(input.as_bytes());
134101
assert_eq!(encoded1, encoded4);
135102
}
136103

137104
#[test]
138105
fn test_bpe_equivalence() {
139106
let bpe = &cl100k_base().bpe;
140107
for bytes in [10, 1000, 10000] {
141-
for _ in 0..5 {
142-
let test_input = create_test_string(bpe, bytes);
143-
let encoded1 = bpe.encode_via_backtracking(test_input.as_bytes());
144-
let encoded2 = bpe.encode_via_bitfield(test_input.as_bytes());
108+
for _ in 0..8 {
109+
let input = create_test_bytes(bpe, bytes);
110+
let encoded1 = bpe.encode_via_backtracking(&input);
111+
let encoded2 = bpe.encode_via_bitfield(&input);
145112
assert_eq!(encoded1, encoded2, "{} {}", encoded1.len(), encoded2.len());
113+
let encoded3 = bpe.encode_via_table(&input);
114+
assert_eq!(encoded1, encoded3, "{} {}", encoded1.len(), encoded3.len());
146115
}
147116
}
148117
}
149118

150119
#[test]
151120
fn test_interval_count() {
152121
let bpe = &cl100k_base().bpe;
153-
let text = create_test_string(bpe, 10000);
154-
let intervals = IntervalEncoding::new(bpe, text.as_bytes());
122+
let input = create_test_bytes(bpe, 10000);
123+
let intervals = IntervalEncoding::new(bpe, &input);
155124
for _ in 0..1000 {
156-
let start = thread_rng().gen_range(0..text.len());
157-
let end = thread_rng().gen_range(0..text.len());
125+
let start = thread_rng().gen_range(0..input.len());
126+
let end = thread_rng().gen_range(0..input.len());
158127
let range = start.min(end)..start.max(end);
159128
assert_eq!(
160129
intervals.count(range.clone()),
161-
bpe.encode_via_backtracking(&text.as_bytes()[range]).len()
130+
bpe.encode_via_backtracking(&input[range]).len()
162131
);
163132
}
164133
}
@@ -167,10 +136,10 @@ mod tests {
167136
fn test_prependable_encoder() {
168137
let bpe = &cl100k_base().bpe;
169138
let mut enc = PrependableEncoder::new(bpe);
170-
let input_string = create_test_string(bpe, 100);
171-
for (i, b) in input_string.as_bytes().iter().enumerate().rev() {
139+
let input = create_test_bytes(bpe, 100);
140+
for (i, b) in input.iter().enumerate().rev() {
172141
enc.push(*b);
173-
assert_eq!(enc.token_count(), bpe.count(&input_string.as_bytes()[i..]));
142+
assert_eq!(enc.token_count(), bpe.count(&input[i..]));
174143
}
175144
}
176145
}

0 commit comments

Comments
 (0)