Skip to content

Commit 17d5c3e

Browse files
Merge pull request #30 from github/unsplittable-test-strings
Generate non-splittable test strings for worstcase benchmark
2 parents e20fc1a + 0cb520e commit 17d5c3e

9 files changed

+400
-302
lines changed

crates/bpe/README.md

+2-3
Original file line numberDiff line numberDiff line change
@@ -294,9 +294,8 @@ This suggests that pre-tokenization is not necessary from a performance perspect
294294

295295
![encoding runtime comparison](./images/performance-comparison.svg)
296296

297-
The graph below shows encoding results for input that is particularly challenging for tiktoken.
298-
The input consists of random ranges taken from the continuous list of all Unicode code points excluding whitespace.
299-
The performance of tiktoken shows a quadratic growth with the input size.
297+
The graph below shows encoding results when the input cannot be split in pre-tokenization and allows a better comparison of pure BPE performance.
298+
This case is particularly challenging for tiktoken, which shows a quadratic growth with the input size.
300299
The Huggingface encoder scales better, but becomes slower and slower compared to our implementation as input size increases.
301300

302301
![worst-case encoding runtime comparison](./images/performance-worstcase.svg)

crates/bpe/benchmarks/performance.rs

+10-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use std::time::Duration;
22

33
use bpe::appendable_encoder::AppendableEncoder;
4-
use bpe::byte_pair_encoding::{create_test_string, select_test_string};
4+
use bpe::byte_pair_encoding::{
5+
create_test_string, create_test_string_with_predicate, select_test_string,
6+
};
57
use bpe::interval_encoding::IntervalEncoding;
68
use bpe_benchmarks::*;
79
use criterion::{
@@ -11,7 +13,7 @@ use rand::{thread_rng, Rng};
1113

1214
fn counting_benchmark(c: &mut Criterion) {
1315
for (name, bpe, _, _) in TOKENIZERS.iter() {
14-
let input = create_test_string(&bpe.bpe, 80000);
16+
let input = create_test_string(&bpe.bpe, 80_000);
1517
let fast = IntervalEncoding::new(&bpe.bpe, input.as_bytes());
1618

1719
let mut group = c.benchmark_group(format!("counting-{name}"));
@@ -185,19 +187,21 @@ fn comparison_benchmark(c: &mut Criterion) {
185187
}
186188

187189
fn worstcase_comparison_benchmark(c: &mut Criterion) {
188-
for (name, bpe, tiktoken, huggingface) in TOKENIZERS.iter() {
189-
let text: String = ('\0'..char::MAX).filter(|c| !c.is_whitespace()).collect();
190+
for (name, tok, tiktoken, huggingface) in TOKENIZERS.iter() {
191+
let text = create_test_string_with_predicate(&tok.bpe, 100000, |text| {
192+
tok.split(text).nth(1).is_none()
193+
});
190194

191195
let mut group = c.benchmark_group(format!("worstcase-{name}"));
192-
for bytes in [10, 100, 1000, 5000, 10000, 25000, 50000, 75000, 100000] {
196+
for bytes in [10, 100, 1000, 5000, 10000, 25000, 50000] {
193197
group.throughput(criterion::Throughput::Bytes(bytes as u64));
194198
group.bench_with_input(
195199
BenchmarkId::new("backtracking", bytes),
196200
&bytes,
197201
|b, bytes| {
198202
b.iter_batched(
199203
|| select_test_string(&text, *bytes),
200-
|text| bpe.encode(text),
204+
|text| tok.encode(text),
201205
criterion::BatchSize::SmallInput,
202206
)
203207
},

crates/bpe/images/performance-appending.svg

+56-44
Loading

crates/bpe/images/performance-comparison.svg

+67-49
Loading

crates/bpe/images/performance-counting.svg

+56-40
Loading

crates/bpe/images/performance-encoding.svg

+76-64
Loading

crates/bpe/images/performance-worstcase.svg

+82-82
Loading

crates/bpe/src/byte_pair_encoding.rs

+39-8
Original file line numberDiff line numberDiff line change
@@ -553,18 +553,49 @@ impl BytePairEncoding {
553553
}
554554
}
555555

556-
/// Generate a test string by concatenating random tokens.
556+
/// Create a random test string for the given [`BytePairEncoding`]. The string will be at least [`min_bytes`] long.
557557
#[cfg(feature = "rand")]
558558
pub fn create_test_string(bpe: &BytePairEncoding, min_bytes: usize) -> String {
559+
create_test_string_with_predicate(bpe, min_bytes, |_| true)
560+
}
561+
562+
/// Create a random test string for the given [`BytePairEncoding`]. The string will be at least [`min_bytes`] long.
563+
/// The given predicate enforces other properties on the generated string. Note that this can hurt performance or
564+
/// even cause non-termination!
565+
#[cfg(feature = "rand")]
566+
pub fn create_test_string_with_predicate(
567+
bpe: &BytePairEncoding,
568+
min_bytes: usize,
569+
predicate: impl Fn(&str) -> bool,
570+
) -> String {
559571
use rand::{thread_rng, Rng};
572+
// the string we accumulated thus far
560573
let mut result = String::new();
561-
while result.len() < min_bytes {
562-
let i = thread_rng().gen_range(0..bpe.num_tokens());
563-
// We only use tokens that are valid UTF-8. This is true for ~99% of tokens in OpenAI's
564-
// token set. The chance of constructing a valid UTF-8 character across a token boundary
565-
// by picking random tokens is so small that it is unlikely to happen anyway.
566-
if let Ok(token) = std::str::from_utf8(bpe.token_bytes(i as u32)) {
567-
result.push_str(token);
574+
// the tokens we added so we can backtrack
575+
let mut tokens = Vec::new();
576+
'keep: while result.len() < min_bytes {
577+
// try a few times to find a suitable token
578+
'next: for _ in 0..8 {
579+
// pick a random token and provisionally add it
580+
let i = thread_rng().gen_range(0..bpe.num_tokens()) as u32;
581+
// We only use tokens that are valid UTF-8. This is true for ~99% of tokens in OpenAI's
582+
// token set. The chance of constructing a valid UTF-8 character across a token boundary
583+
// by picking random tokens is so small that it is unlikely to happen anyway.
584+
if let Ok(token) = std::str::from_utf8(bpe.token_bytes(i)) {
585+
result.push_str(token);
586+
} else {
587+
continue 'next;
588+
}
589+
if predicate(&result) {
590+
tokens.push(i);
591+
continue 'keep;
592+
} else {
593+
result.truncate(result.len() - bpe.token_len(i));
594+
}
595+
}
596+
// we didn't find anything after a few tries, backtrack
597+
if let Some(i) = tokens.pop() {
598+
result.truncate(result.len() - bpe.token_len(i));
568599
}
569600
}
570601
result

crates/geo_filters/evaluation/accuracy.rs

+12-6
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ impl Accuracy {
105105
.config
106106
.iter()
107107
.map(|c| {
108-
simulation_config_from_str(c).expect(&format!("not a valid configuration: {}", c))
108+
simulation_config_from_str(c)
109+
.unwrap_or_else(|_| panic!("not a valid configuration: {}", c))
109110
})
110111
.collect_vec();
111112
let set_sizes = if self.set_size.is_empty() {
@@ -118,9 +119,10 @@ impl Accuracy {
118119

119120
let mut output = self.output;
120121
output.set_extension("csv");
121-
let f = File::create(&output).expect(&format!("cannot create file: {}", output.display()));
122+
let f = File::create(&output)
123+
.unwrap_or_else(|_| panic!("cannot create file: {}", output.display()));
122124
write_simulation_results(&configs, &set_sizes, results, f)
123-
.expect(&format!("cannot write file: {}", output.display()));
125+
.unwrap_or_else(|_| panic!("cannot write file: {}", output.display()));
124126
println!(" csv file = {}", output.display());
125127
println!();
126128
}
@@ -139,9 +141,9 @@ impl SimulationConfigParser {
139141
Self(Regex::new(re).expect(""), Arc::new(f))
140142
}
141143

142-
fn parse<'a>(&self, name: &str) -> Option<SimulationConfig> {
144+
fn parse(&self, name: &str) -> Option<SimulationConfig> {
143145
self.0
144-
.captures(&name)
146+
.captures(name)
145147
.map(self.1.as_ref())
146148
.map(|p| (name.to_string(), p))
147149
}
@@ -225,7 +227,11 @@ fn simulation_config_from_str(name: &str) -> Result<SimulationConfig, String> {
225227
fn capture_usizes<const N: usize>(c: &Captures, is: [usize; N]) -> [usize; N] {
226228
let mut values = [0; N];
227229
for i in 0..is.len() {
228-
values[i] = usize::from_str_radix(c.get(is[i]).expect("capture to exist").as_str(), 10)
230+
values[i] = c
231+
.get(is[i])
232+
.expect("capture to exist")
233+
.as_str()
234+
.parse::<usize>()
229235
.expect("number string");
230236
}
231237
values

0 commit comments

Comments
 (0)