Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Verify token set #31

Merged
merged 2 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions crates/bpe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,31 +96,32 @@ Given a valid encoding sequence `e_0..e_i` and a valid encoding tuple `e_i e_j`,
## Novel Algorithm

At a first glance, it seems impossible to achieve `O(n)` complexity while preserving the encoding output of the original BPE algorithm, since the original BPE algorithm needs to first scan the full input before it can make any encoding decision.
For instance, the sequence `abac` would be encoded as `ab ac` when the dictionary contains the tokens `a b c ab cb ac` ordered by frequency. But appending a single character `abacb` would result in a pretty different tokenization: `ab a cb`. So without looking ahead it seems impossible to properly tokenize the text.
For instance, the sequence `abacb` would be encoded as `ab a cb` when the dictionary contains the tokens `a b c ab cb ac bb cbb acbb` ordered by frequency. But appending a single character `abacbb` would result in a pretty different tokenization: `ab acbb`. So without looking ahead it seems impossible to properly tokenize the text.

The solution is to track the encodings of ALL text prefixes. For our example `abacb` we would get:
The solution is to track the encodings of ALL text prefixes. For our example `abacbb` we would get:

- `a` ------> `a`
- `ab` -----> `ab`
- `aba` ----> `ab a`
- `abac` ---> `ab ac`
- `abacb` --> `ab a cb`
- `a` -------> `a`
- `ab` ------> `ab`
- `aba` -----> `ab a`
- `abac` ----> `ab ac`
- `abacb` ---> `ab a cb`
- `abacbb` --> `ab acbb`

This can be done much more efficiently thanks to Corollary IIa, since now only the last token of every prefix has to be remembered:

- `a` ------> `a`
- `ab` -----> `ab`
- `aba` ----> `a`
- `abac` ---> `ac`
- `abacb` --> `cb`
- `a` -------> `a`
- `ab` ------> `ab`
- `aba` -----> `a`
- `abac` ----> `ac`
- `abacb` ---> `cb`
- `abacbb` --> `acbb`

In order to reconstruct the full encoding for a specific prefix, one simply starts with the last token of that prefix, shortens the prefix by the extracted token and looks up the token associated with the shortened prefix and so on until the beginning of the text is reached.

For our example prefix `abacb`, this procedure executes the following steps and determines the correct encoding in reverse order:
For our example prefix `abacbb`, this procedure executes the following steps and determines the correct encoding in reverse order:

- `abacb` -> `cb`
- `aba` ---> `a`
- `ab` ----> `ab`
- `abacbb` --> `acbb`
- `ab` ------> `ab`
- `<empty>`

The actual challenge is to determine for every prefix this last token efficiently.
Expand Down
12 changes: 11 additions & 1 deletion crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ impl BytePairEncoding {
split_table.push((id as u32, id as u32));
}
}
Self {
let bpe = Self {
all_tokens,
token_starts,
bytes_hash_to_token,
Expand All @@ -314,7 +314,17 @@ impl BytePairEncoding {
pair_lookup,
split_table,
hash_factor,
};
for token_id in 0..bpe.num_tokens() as u32 {
let bytes = bpe.token_bytes(token_id);
let tokens = bpe.encode_via_bitfield(bytes);
assert_eq!(
tokens,
vec![token_id],
"token {token_id} with bytes {bytes:?} encodes to {tokens:?} instead of to itself"
);
}
bpe
}

/// Return the number of tokens in this BPE dictionary.
Expand Down
12 changes: 7 additions & 5 deletions crates/bpe/tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@ mod tests {
/// This test produces the output for the encoding example in the README.
#[test]
fn readme_example() {
let tokens = ["a", "b", "c", "ab", "cb", "ac"].map(|t| t.as_bytes().to_vec());
let bpe = BytePairEncoding::from_dictionary(tokens, None);
let text = "abacb";
let tokens = ["a", "b", "c", "ab", "cb", "ac", "bb", "cbb", "acbb"];
let bpe = BytePairEncoding::from_dictionary(tokens.map(|t| t.as_bytes().to_vec()), None);
let text = "abacbb";
let prefixes = (1..=text.len()).map(|end| &text[..end]).collect_vec();
let all_prefix_tokens = prefixes
.iter()
.map(|prefix| {
bpe.encode_via_backtracking(prefix.as_bytes())
.into_iter()
.map(|t| unsafe { String::from_utf8_unchecked(bpe.decode_tokens(&[t])) })
.map(|t| String::from_utf8(bpe.decode_tokens(&[t])).unwrap())
.collect_vec()
})
.collect_vec();
Expand All @@ -48,6 +48,8 @@ mod tests {
.map(|tokens| tokens.last().unwrap())
.collect_vec();

println!("Token set: `{}`\n", tokens.join(" "));

println!("All tokens for each prefix of `{text}`:\n");
for (prefix, tokens) in prefixes.iter().zip(&all_prefix_tokens) {
println!(
Expand All @@ -67,7 +69,7 @@ mod tests {
}
println!();

println!("Tokenization of `{text}`:\n");
println!("Encoding using last tokens of `{text}`:\n");
let mut remaining = text.len();
while remaining > 0 {
let prefix = &text[..remaining];
Expand Down