Skip to content

Commit 5b127c9

Browse files
authoredOct 17, 2024··
Merge pull request #31 from github/verify-tokens
2 parents 8ecf192 + efaf552 commit 5b127c9

File tree

3 files changed

+35
-22
lines changed

3 files changed

+35
-22
lines changed
 

‎crates/bpe/README.md

+17-16
Original file line numberDiff line numberDiff line change
@@ -96,31 +96,32 @@ Given a valid encoding sequence `e_0..e_i` and a valid encoding tuple `e_i e_j`,
9696
## Novel Algorithm
9797

9898
At a first glance, it seems impossible to achieve `O(n)` complexity while preserving the encoding output of the original BPE algorithm, since the original BPE algorithm needs to first scan the full input before it can make any encoding decision.
99-
For instance, the sequence `abac` would be encoded as `ab ac` when the dictionary contains the tokens `a b c ab cb ac` ordered by frequency. But appending a single character `abacb` would result in a pretty different tokenization: `ab a cb`. So without looking ahead it seems impossible to properly tokenize the text.
99+
For instance, the sequence `abacb` would be encoded as `ab a cb` when the dictionary contains the tokens `a b c ab cb ac bb cbb acbb` ordered by frequency. But appending a single character `abacbb` would result in a pretty different tokenization: `ab acbb`. So without looking ahead it seems impossible to properly tokenize the text.
100100

101-
The solution is to track the encodings of ALL text prefixes. For our example `abacb` we would get:
101+
The solution is to track the encodings of ALL text prefixes. For our example `abacbb` we would get:
102102

103-
- `a` ------> `a`
104-
- `ab` -----> `ab`
105-
- `aba` ----> `ab a`
106-
- `abac` ---> `ab ac`
107-
- `abacb` --> `ab a cb`
103+
- `a` -------> `a`
104+
- `ab` ------> `ab`
105+
- `aba` -----> `ab a`
106+
- `abac` ----> `ab ac`
107+
- `abacb` ---> `ab a cb`
108+
- `abacbb` --> `ab acbb`
108109

109110
This can be done much more efficiently thanks to Corollary IIa, since now only the last token of every prefix has to be remembered:
110111

111-
- `a` ------> `a`
112-
- `ab` -----> `ab`
113-
- `aba` ----> `a`
114-
- `abac` ---> `ac`
115-
- `abacb` --> `cb`
112+
- `a` -------> `a`
113+
- `ab` ------> `ab`
114+
- `aba` -----> `a`
115+
- `abac` ----> `ac`
116+
- `abacb` ---> `cb`
117+
- `abacbb` --> `acbb`
116118

117119
In order to reconstruct the full encoding for a specific prefix, one simply starts with the last token of that prefix, shortens the prefix by the extracted token and looks up the token associated with the shortened prefix and so on until the beginning of the text is reached.
118120

119-
For our example prefix `abacb`, this procedure executes the following steps and determines the correct encoding in reverse order:
121+
For our example prefix `abacbb`, this procedure executes the following steps and determines the correct encoding in reverse order:
120122

121-
- `abacb` -> `cb`
122-
- `aba` ---> `a`
123-
- `ab` ----> `ab`
123+
- `abacbb` --> `acbb`
124+
- `ab` ------> `ab`
124125
- `<empty>`
125126

126127
The actual challenge is to determine for every prefix this last token efficiently.

‎crates/bpe/src/byte_pair_encoding.rs

+11-1
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ impl BytePairEncoding {
303303
split_table.push((id as u32, id as u32));
304304
}
305305
}
306-
Self {
306+
let bpe = Self {
307307
all_tokens,
308308
token_starts,
309309
bytes_hash_to_token,
@@ -314,7 +314,17 @@ impl BytePairEncoding {
314314
pair_lookup,
315315
split_table,
316316
hash_factor,
317+
};
318+
for token_id in 0..bpe.num_tokens() as u32 {
319+
let bytes = bpe.token_bytes(token_id);
320+
let tokens = bpe.encode_via_bitfield(bytes);
321+
assert_eq!(
322+
tokens,
323+
vec![token_id],
324+
"token {token_id} with bytes {bytes:?} encodes to {tokens:?} instead of to itself"
325+
);
317326
}
327+
bpe
318328
}
319329

320330
/// Return the number of tokens in this BPE dictionary.

‎crates/bpe/tests/src/lib.rs

+7-5
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,16 @@ mod tests {
3030
/// This test produces the output for the encoding example in the README.
3131
#[test]
3232
fn readme_example() {
33-
let tokens = ["a", "b", "c", "ab", "cb", "ac"].map(|t| t.as_bytes().to_vec());
34-
let bpe = BytePairEncoding::from_dictionary(tokens, None);
35-
let text = "abacb";
33+
let tokens = ["a", "b", "c", "ab", "cb", "ac", "bb", "cbb", "acbb"];
34+
let bpe = BytePairEncoding::from_dictionary(tokens.map(|t| t.as_bytes().to_vec()), None);
35+
let text = "abacbb";
3636
let prefixes = (1..=text.len()).map(|end| &text[..end]).collect_vec();
3737
let all_prefix_tokens = prefixes
3838
.iter()
3939
.map(|prefix| {
4040
bpe.encode_via_backtracking(prefix.as_bytes())
4141
.into_iter()
42-
.map(|t| unsafe { String::from_utf8_unchecked(bpe.decode_tokens(&[t])) })
42+
.map(|t| String::from_utf8(bpe.decode_tokens(&[t])).unwrap())
4343
.collect_vec()
4444
})
4545
.collect_vec();
@@ -48,6 +48,8 @@ mod tests {
4848
.map(|tokens| tokens.last().unwrap())
4949
.collect_vec();
5050

51+
println!("Token set: `{}`\n", tokens.join(" "));
52+
5153
println!("All tokens for each prefix of `{text}`:\n");
5254
for (prefix, tokens) in prefixes.iter().zip(&all_prefix_tokens) {
5355
println!(
@@ -67,7 +69,7 @@ mod tests {
6769
}
6870
println!();
6971

70-
println!("Tokenization of `{text}`:\n");
72+
println!("Encoding using last tokens of `{text}`:\n");
7173
let mut remaining = text.len();
7274
while remaining > 0 {
7375
let prefix = &text[..remaining];

0 commit comments

Comments
 (0)
Please sign in to comment.