diff --git a/Cargo.toml b/Cargo.toml index 7cb6320..312f46d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ members = [ "crates/*", "crates/bpe/benchmarks", + "crates/bpe/tests", ] resolver = "2" diff --git a/crates/bpe-openai/Cargo.toml b/crates/bpe-openai/Cargo.toml index 1f9460e..b4379c3 100644 --- a/crates/bpe-openai/Cargo.toml +++ b/crates/bpe-openai/Cargo.toml @@ -17,13 +17,13 @@ bpe = { version = "0.1.0", path = "../bpe" } either = "1.13" fancy-regex = "0.13" rmp-serde = "1" -serde = { version = "1" } [dev-dependencies] -tiktoken-rs = { version = "0.5" } +tiktoken-rs = "0.6" [build-dependencies] -bpe = { version = "0.1.0", path = "../bpe", features = ["tiktoken-rs"] } +base64 = "0.22.1" +bpe = { version = "0.1.0", path = "../bpe", features = ["tiktoken"] } +flate2 = "1.0" rmp-serde = "1" -tiktoken-rs = { version = "0.5" } -serde = { version = "1" } +serde = "1" diff --git a/crates/bpe-openai/build.rs b/crates/bpe-openai/build.rs index b4f3837..472e580 100644 --- a/crates/bpe-openai/build.rs +++ b/crates/bpe-openai/build.rs @@ -1,51 +1,37 @@ use std::env; use std::fs::File; +use std::io::Read; use std::path::PathBuf; -use bpe::byte_pair_encoding::BytePairEncoding; +use bpe::byte_pair_encoding::{read_tiktoken, BytePairEncoding}; use serde::Serialize; -use tiktoken_rs::CoreBPE; fn main() { - serialize_tokens( - "r50k", - &tiktoken_rs::r50k_base().expect("tiktoken initialization must not fail!"), - 50256, - 1, - ); - serialize_tokens( - "p50k", - &tiktoken_rs::p50k_base().expect("tiktoken initialization must not fail!"), - 50280, - 1, - ); - serialize_tokens( - "cl100k", - &tiktoken_rs::cl100k_base().expect("tiktoken initialization must not fail!"), - 100256, - 17846336922010275747, - ); - serialize_tokens( - "cl100k", - &tiktoken_rs::cl100k_base().expect("tiktoken initialization must not fail!"), - 100256, + serialize_tiktoken_bpe("r50k_base", include_bytes!("data/r50k_base.tiktoken.gz"), 1); + serialize_tiktoken_bpe("p50k_base", include_bytes!("data/p50k_base.tiktoken.gz"), 1); + serialize_tiktoken_bpe( + "cl100k_base", + include_bytes!("data/cl100k_base.tiktoken.gz"), 17846336922010275747, ); - serialize_tokens( - "o200k", - &tiktoken_rs::o200k_base().expect("tiktoken initialization must not fail!"), - 199998, + serialize_tiktoken_bpe( + "o200k_base", + include_bytes!("data/o200k_base.tiktoken.gz"), 17846336922010275747, ); println!("cargo::rerun-if-changed=build.rs"); } -fn serialize_tokens(name: &str, bpe: &CoreBPE, num_tokens: usize, hash_factor: u64) { +fn serialize_tiktoken_bpe(name: &str, data: &[u8], hash_factor: u64) { + let mut dec = flate2::read::GzDecoder::new(data); + let mut tiktoken = String::new(); + dec.read_to_string(&mut tiktoken).expect("can decode data"); + let tokens = read_tiktoken(&tiktoken).expect("can read data"); let mut path = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR is set during build")); path.push(format!("bpe_{name}.dict")); let file = File::create(path).expect("can create output file"); let mut serializer = rmp_serde::Serializer::new(file); - let bpe = BytePairEncoding::from_tiktoken(bpe, num_tokens, Some(hash_factor)); + let bpe = BytePairEncoding::from_dictionary(tokens, Some(hash_factor)); bpe.serialize(&mut serializer) .expect("serialization succeeds"); } diff --git a/crates/bpe-openai/data/cl100k_base.tiktoken.gz b/crates/bpe-openai/data/cl100k_base.tiktoken.gz new file mode 100644 index 0000000..9bb465f Binary files /dev/null and b/crates/bpe-openai/data/cl100k_base.tiktoken.gz differ diff --git a/crates/bpe-openai/data/o200k_base.tiktoken.gz b/crates/bpe-openai/data/o200k_base.tiktoken.gz new file mode 100644 index 0000000..3deeb3d Binary files /dev/null and b/crates/bpe-openai/data/o200k_base.tiktoken.gz differ diff --git a/crates/bpe-openai/data/p50k_base.tiktoken.gz b/crates/bpe-openai/data/p50k_base.tiktoken.gz new file mode 100644 index 0000000..af6f846 Binary files /dev/null and b/crates/bpe-openai/data/p50k_base.tiktoken.gz differ diff --git a/crates/bpe-openai/data/r50k_base.tiktoken.gz b/crates/bpe-openai/data/r50k_base.tiktoken.gz new file mode 100644 index 0000000..6108f82 Binary files /dev/null and b/crates/bpe-openai/data/r50k_base.tiktoken.gz differ diff --git a/crates/bpe-openai/src/lib.rs b/crates/bpe-openai/src/lib.rs index 66ccebe..fd2c7c8 100644 --- a/crates/bpe-openai/src/lib.rs +++ b/crates/bpe-openai/src/lib.rs @@ -4,29 +4,29 @@ use bpe::byte_pair_encoding::BytePairEncoding; use either::Either; use fancy_regex::Regex; -static BPE_R50K: LazyLock = LazyLock::new(|| { - let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_r50k.dict")); +static BPE_R50K_BASE: LazyLock = LazyLock::new(|| { + let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_r50k_base.dict")); let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data"); let pat = "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"; Tokenizer::new(bpe, Some(pat)).expect("valid regex") }); -static BPE_P50K: LazyLock = LazyLock::new(|| { - let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_p50k.dict")); +static BPE_P50K_BASE: LazyLock = LazyLock::new(|| { + let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_p50k_base.dict")); let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data"); let pat = "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"; Tokenizer::new(bpe, Some(pat)).expect("valid regex") }); -static BPE_CL100K: LazyLock = LazyLock::new(|| { - let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_cl100k.dict")); +static BPE_CL100K_BASE: LazyLock = LazyLock::new(|| { + let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_cl100k_base.dict")); let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data"); let pat = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; Tokenizer::new(bpe, Some(pat)).expect("valid regex") }); -static BPE_O200K: LazyLock = LazyLock::new(|| { - let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_o200k.dict")); +static BPE_O200K_BASE: LazyLock = LazyLock::new(|| { + let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_o200k_base.dict")); let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data"); let pat = [ "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?", @@ -91,20 +91,20 @@ impl Tokenizer { } } -pub fn r50k() -> &'static Tokenizer { - &BPE_R50K +pub fn r50k_base() -> &'static Tokenizer { + &BPE_R50K_BASE } -pub fn p50k() -> &'static Tokenizer { - &BPE_P50K +pub fn p50k_base() -> &'static Tokenizer { + &BPE_P50K_BASE } -pub fn cl100k() -> &'static Tokenizer { - &BPE_CL100K +pub fn cl100k_base() -> &'static Tokenizer { + &BPE_CL100K_BASE } -pub fn o200k() -> &'static Tokenizer { - &BPE_O200K +pub fn o200k_base() -> &'static Tokenizer { + &BPE_O200K_BASE } #[cfg(test)] @@ -115,22 +115,22 @@ mod tests { #[test] fn can_load_r50k() { - r50k().count(""); + r50k_base().count(""); } #[test] fn can_load_p50k() { - p50k().count(""); + p50k_base().count(""); } #[test] fn can_load_cl100k() { - cl100k().count(""); + cl100k_base().count(""); } #[test] fn can_load_o200k() { - o200k().count(""); + o200k_base().count(""); } /// Test demonstrating a case where input splitting makes a difference. @@ -142,13 +142,12 @@ mod tests { .lock() .encode_ordinary(text) .into_iter() - .map(|i| i as u32) .collect(); - let without_splitting = BPE_CL100K.bpe.encode_via_backtracking(input); + let without_splitting = BPE_CL100K_BASE.bpe.encode_via_backtracking(input); assert_ne!(without_splitting, expected); - let with_splitting: Vec<_> = BPE_CL100K.encode(text); + let with_splitting: Vec<_> = BPE_CL100K_BASE.encode(text); assert_eq!(with_splitting, expected); } } diff --git a/crates/bpe/CONTRIBUTING.md b/crates/bpe/CONTRIBUTING.md new file mode 100644 index 0000000..1fbd607 --- /dev/null +++ b/crates/bpe/CONTRIBUTING.md @@ -0,0 +1,39 @@ +# Contributing + +Here are specific details that are useful when you want to contribute to the BPE crates. +Make sure to read the repository's [contribution guidelines][contributing] as well. + +## Project structure + +This project has a slightly unusual structure to resolve some dependency issues. + +- This directory contains `bpe`, the BPE code itself. +- A sibling directory contains `bpe-openai`, which exposes tokenizers for OpenAI token sets, and depends on `bpe`. +- Tests are located in the `tests` subdirectory, and benchmarks in the `benchmarks` subdirectory. Both of these are separate crates so they can depend on `bpe-openai` without causing a cyclic dependency. + +Only the `bpe` and `bpe-openai` crates are meant to be published. The other ones are for development use only. + +## Running benchmarks + +Change the working directory to the `benchmarks` directory: + +```sh +cd benchmarks +``` + +Run the benchmark as follows (required [cargo-criterion](https://crates.io/crates/cargo-criterion) installed): + +```sh +cargo criterion +``` + +(Using `cargo bench` ignores the settings in `criterion.toml`!) +Open the full report which should be located in `target/criterion/reports/index.html`. + +Update the figures in this repo as follows (requires `rsvg-convert` from `librsvg` installed): + +```sh +script/copy-results +``` + +[contributing]: ../../CONTRIBUTING.md diff --git a/crates/bpe/Cargo.toml b/crates/bpe/Cargo.toml index 4177856..95fbdbb 100644 --- a/crates/bpe/Cargo.toml +++ b/crates/bpe/Cargo.toml @@ -14,16 +14,19 @@ bench = false [features] rand = ["dep:rand"] -tiktoken-rs = ["dep:tiktoken-rs"] +tiktoken = ["dep:base64"] [dependencies] aneubeck-daachorse = "1.1.1" +base64 = { version = "0.22", optional = true } fnv = "1.0" itertools = "0.12" rand = { version = "0.8", optional = true } -rmp-serde = "1" serde = { version = "1", features = ["derive"] } -tiktoken-rs = { version = "0.5", optional = true } [dev-dependencies] -bpe = { path = ".", features = ["rand", "tiktoken-rs"] } +bpe = { path = "." } +tiktoken-rs = "0.6" + +[package.metadata.docs.rs] +all-features = true diff --git a/crates/bpe/README.md b/crates/bpe/README.md index f8a24e2..404e389 100644 --- a/crates/bpe/README.md +++ b/crates/bpe/README.md @@ -296,26 +296,3 @@ The performance of tiktoken shows a quadratic growth with the input size. The Huggingface encoder scales better, but becomes slower and slower compared to our implementation as input size increases. ![worst-case encoding runtime comparison](./images/performance-worstcase.svg) - -### Running the benchmarks - -Benchmarks are located in a separate crate in the `benchmarks` directory. - -```sh -cd benchmarks -``` - -Run the benchmark as follows (required [cargo-criterion](https://crates.io/crates/cargo-criterion) installed): - -```sh -cargo criterion -``` - -(Using `cargo bench` ignores the settings in `criterion.toml`!) -Open the full report which should be located in `target/criterion/reports/index.html`. - -Update the figures in this repo as follows (requires `rsvg-convert` from `librsvg` installed): - -```sh -script/copy-results -``` diff --git a/crates/bpe/benchmarks/Cargo.toml b/crates/bpe/benchmarks/Cargo.toml index 1aedc2a..854ab1e 100644 --- a/crates/bpe/benchmarks/Cargo.toml +++ b/crates/bpe/benchmarks/Cargo.toml @@ -18,9 +18,10 @@ path = "equivalence.rs" test = true [dependencies] -bpe = { path = "../../bpe", features = ["rand", "tiktoken-rs"] } +bpe = { path = "../../bpe" } bpe-openai = { path = "../../bpe-openai" } +bpe-tests = { path = "../tests" } criterion = "0.5" rand = "0.8" -tiktoken-rs = "0.5" +tiktoken-rs = "0.6" tokenizers = { version = "0.20", features = ["http"] } diff --git a/crates/bpe/benchmarks/equivalence.rs b/crates/bpe/benchmarks/equivalence.rs index 54ea918..7c71e4e 100644 --- a/crates/bpe/benchmarks/equivalence.rs +++ b/crates/bpe/benchmarks/equivalence.rs @@ -16,7 +16,7 @@ fn test_encoding_equivalence_without_pretokenization() { for input in inputs { let text = std::str::from_utf8(input).unwrap(); let out = bpe.bpe.encode_via_backtracking(input); - let huggingface_out: Vec<_> = huggingface + let huggingface_out = huggingface .encode_fast(text, false) .unwrap() .get_ids() @@ -52,10 +52,10 @@ fn test_encoding_equivalence_with_pretokenization() { for input in inputs { let text = std::str::from_utf8(input).unwrap(); let out = bpe.encode(text); - let tiktoken_out: Vec<_> = tiktoken.encode_ordinary(text); - let tiktoken_out2: Vec<_> = tiktoken_out.iter().map(|i| *i as u32).collect(); + let tiktoken_out = tiktoken.encode_ordinary(text); + let tiktoken_out2 = tiktoken_out.to_vec(); let tiktoken_text = tiktoken.decode(tiktoken_out.clone()).unwrap(); - let huggingface_out: Vec<_> = huggingface + let huggingface_out = huggingface .encode_fast(text, false) .unwrap() .get_ids() diff --git a/crates/bpe/benchmarks/lib.rs b/crates/bpe/benchmarks/lib.rs index 161ef25..f260ebd 100644 --- a/crates/bpe/benchmarks/lib.rs +++ b/crates/bpe/benchmarks/lib.rs @@ -18,13 +18,13 @@ pub static TOKENIZERS: LazyLock< [ ( "cl100k", - bpe_openai::cl100k(), + bpe_openai::cl100k_base(), tiktoken_rs::cl100k_base().expect("tokenizer available"), HuggingfaceTokenizer::from_pretrained("Xenova/gpt-4", None).expect("model available"), ), ( "o200k", - bpe_openai::o200k(), + bpe_openai::o200k_base(), tiktoken_rs::o200k_base().expect("tokenizer available"), HuggingfaceTokenizer::from_pretrained("Xenova/gpt-4o", None).expect("model available"), ), diff --git a/crates/bpe/benchmarks/performance.rs b/crates/bpe/benchmarks/performance.rs index 8b90f93..4ec973e 100644 --- a/crates/bpe/benchmarks/performance.rs +++ b/crates/bpe/benchmarks/performance.rs @@ -1,9 +1,9 @@ use std::time::Duration; use bpe::appendable_encoder::AppendableEncoder; -use bpe::byte_pair_encoding::create_test_bytes; use bpe::interval_encoding::IntervalEncoding; use bpe_benchmarks::*; +use bpe_tests::create_test_bytes; use criterion::{ criterion_group, criterion_main, AxisScale, BenchmarkId, Criterion, PlotConfiguration, }; diff --git a/crates/bpe/src/appendable_encoder.rs b/crates/bpe/src/appendable_encoder.rs index b0752b5..3dc13db 100644 --- a/crates/bpe/src/appendable_encoder.rs +++ b/crates/bpe/src/appendable_encoder.rs @@ -87,21 +87,3 @@ impl<'a> AppendableEncoder<'a> { self.states.is_empty() } } - -#[cfg(test)] -mod tests { - use crate::byte_pair_encoding::{create_test_bytes, BPE_CL100K}; - - use super::AppendableEncoder; - - #[test] - fn test_appendable_encoder() { - let bpe = &BPE_CL100K; - let mut enc = AppendableEncoder::new(bpe); - let input_string = create_test_bytes(bpe, 100); - for (i, c) in input_string.iter().enumerate() { - assert_eq!(enc.token_count(), bpe.count(&input_string[0..i])); - enc.push(*c); - } - } -} diff --git a/crates/bpe/src/byte_pair_encoding.rs b/crates/bpe/src/byte_pair_encoding.rs index f18468e..9efbb0e 100644 --- a/crates/bpe/src/byte_pair_encoding.rs +++ b/crates/bpe/src/byte_pair_encoding.rs @@ -12,26 +12,6 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::backtrack_encoder::BacktrackEncoder; use crate::bitfield::BitField; -#[cfg(test)] -pub(crate) static BPE_CL100K: std::sync::LazyLock = - std::sync::LazyLock::new(|| { - BytePairEncoding::from_tiktoken( - &tiktoken_rs::cl100k_base_singleton().lock(), - 100256, - Some(17846336922010275747), - ) - }); - -#[cfg(test)] -pub(crate) static BPE_O200K: std::sync::LazyLock = - std::sync::LazyLock::new(|| { - BytePairEncoding::from_tiktoken( - &tiktoken_rs::o200k_base_singleton().lock(), - 199998, - Some(17846336922010275747), - ) - }); - /// Representation of the byte pair dictionary. /// This struct provides various conversions. /// We put all of them into a single struct so that they can be reused by different implementations. @@ -175,11 +155,11 @@ fn hash_bytes(bytes: &[u8], factor: u64) -> u32 { ((hasher.finish().wrapping_mul(factor)) >> 32) as u32 } -/// Find a suitable hash factor for the given tiktoken dictionary that prevents collisions -/// when constructing a [`BytePairEncoding`] from those tokens. -#[cfg(all(feature = "tiktoken-rs", feature = "rand"))] -pub fn find_hash_factor_for_tiktoken(bpe: &tiktoken_rs::CoreBPE, len: usize) -> u64 { - find_hash_factor_for_dictionary((0..len).map(|i| bpe._decode_native(&[i]))) +/// Find a suitable hash factor for the given tiktoken data that prevents collisions when +/// constructing a [`BytePairEncoding`] from those tokens. +#[cfg(all(feature = "rand", feature = "tiktoken"))] +pub fn find_hash_factor_for_tiktoken(data: &str) -> Result { + Ok(find_hash_factor_for_dictionary(read_tiktoken(data)?)) } /// Find a suitable hash factor for a set of given tokens that prevents collisions when @@ -220,23 +200,36 @@ fn find_token_by_bytes( } } +/// Read the tokens from a tiktoken data file, which contains base64 encoded tokens at +/// the start of each line, in descending frequency order. +#[cfg(feature = "tiktoken")] +pub fn read_tiktoken(data: &str) -> Result>, base64::DecodeError> { + use base64::prelude::*; + data.lines() + .filter(|line| !line.is_empty()) + .map(|line| { + let encoded_token = line + .split_whitespace() + .next() + .expect("non-empty line has first field"); + BASE64_STANDARD.decode(encoded_token) + }) + .try_collect() +} + impl BytePairEncoding { - /// Construct a BytePairEncoding instance from a tiktoken dictionary. - /// A suitable hash factor may be necessary to prevent hash collisions, - /// which can by found using [`find_hash_factor_for_tiktoken`]. + /// Construct a BytePairEncoding instance from a tiktoken data file. + /// A suitable hash factor may be necessary to prevent hash collisions, which can be + /// found using [`find_hash_factor_for_tiktoken`]. /// /// The recommended approach is to store the serialized value and reuse that, /// to prevent repeating the cost of computing the hash factor and encoding. - #[cfg(feature = "tiktoken-rs")] + #[cfg(feature = "tiktoken")] pub fn from_tiktoken( - tiktoken_bpe: &tiktoken_rs::CoreBPE, - num_tokens: usize, + data: &str, hash_factor: Option, - ) -> Self { - Self::from_dictionary( - (0..num_tokens).map(|i| tiktoken_bpe._decode_native(&[i])), - hash_factor, - ) + ) -> Result { + Ok(Self::from_dictionary(read_tiktoken(data)?, hash_factor)) } /// Construct a BytePairEncoding instance from an iterator that enumerates all tokens. @@ -549,93 +542,3 @@ impl BytePairEncoding { encoded } } - -#[cfg(feature = "rand")] -pub fn create_test_bytes(bpe: &BytePairEncoding, tokens: usize) -> Vec { - use rand::{thread_rng, Rng}; - let mut text = vec![]; - for _ in 0..tokens { - let i = thread_rng().gen_range(0..bpe.num_tokens()); - let s = bpe.token_bytes(i as u32); - text.extend_from_slice(s); - } - text -} - -#[cfg(test)] -mod tests { - - use std::time::Instant; - - use itertools::Itertools; - use tiktoken_rs::{cl100k_base_singleton, o200k_base_singleton}; - - use crate::byte_pair_encoding::{create_test_bytes, BPE_CL100K, BPE_O200K}; - - #[test] - fn test_correctness_cl100k() { - // This is quite a challenging test case... - let test_string = std::str::from_utf8(&[ - 125, 34, 10, 10, 46, 109, 107, 100, 105, 114, 115, 32, 102, 100, 115, 32, 97, 100, 105, - 112, 105, 115, 105, 99, 105, 110, 103, 105, 116, 121, 69, 110, 103, 105, 110, 101, 32, - 69, 67, 105, 114, 105, 101, 32, 111, 112, 116, 105, 109, 97, 108, 95, 68, 65, 32, 111, - 102, 102, 101, 110, 100, - ]) - .unwrap(); - let time = Instant::now(); - let bpe = &BPE_CL100K; - println!("{:?}", time.elapsed()); - let encoded1 = cl100k_base_singleton() - .lock() - .encode_ordinary(test_string) - .into_iter() - .map(|t| t as u32) - .collect_vec(); - let encoded2 = bpe.encode_via_backtracking(test_string.as_bytes()); - assert_eq!(encoded1, encoded2); - let encoded3 = bpe.encode_via_table(test_string.as_bytes()); - assert_eq!(encoded1, encoded3); - let encoded4 = bpe.encode_via_bitfield(test_string.as_bytes()); - assert_eq!(encoded1, encoded4); - } - - #[test] - fn test_correctness_o200k() { - // This is quite a challenging test case... - let test_string = std::str::from_utf8(&[ - 125, 34, 10, 10, 46, 109, 107, 100, 105, 114, 115, 32, 102, 100, 115, 32, 97, 100, 105, - 112, 105, 115, 105, 99, 105, 110, 103, 105, 116, 121, 69, 110, 103, 105, 110, 101, 32, - 69, 67, 105, 114, 105, 101, 32, 111, 112, 116, 105, 109, 97, 108, 95, 68, 65, 32, 111, - 102, 102, 101, 110, 100, - ]) - .unwrap(); - let time = Instant::now(); - let bpe = &BPE_O200K; - println!("{:?}", time.elapsed()); - let encoded1 = o200k_base_singleton() - .lock() - .encode_ordinary(test_string) - .into_iter() - .map(|t| t as u32) - .collect_vec(); - let encoded2 = bpe.encode_via_backtracking(test_string.as_bytes()); - assert_eq!(encoded1, encoded2); - let encoded3 = bpe.encode_via_table(test_string.as_bytes()); - assert_eq!(encoded1, encoded3); - let encoded4 = bpe.encode_via_bitfield(test_string.as_bytes()); - assert_eq!(encoded1, encoded4); - } - - #[test] - fn test_bpe_equivalence() { - let bpe = &BPE_CL100K; - for tokens in [10, 1000, 10000] { - for _ in 0..5 { - let test_input = create_test_bytes(bpe, tokens); - let encoded1 = bpe.encode_via_backtracking(&test_input); - let encoded2 = bpe.encode_via_bitfield(&test_input); - assert_eq!(encoded1, encoded2, "{} {}", encoded1.len(), encoded2.len()); - } - } - } -} diff --git a/crates/bpe/src/interval_encoding.rs b/crates/bpe/src/interval_encoding.rs index 05bf79f..422ec45 100644 --- a/crates/bpe/src/interval_encoding.rs +++ b/crates/bpe/src/interval_encoding.rs @@ -81,28 +81,3 @@ impl<'a> IntervalEncoding<'a> { encoder.count() } } - -#[cfg(test)] -mod tests { - use rand::{thread_rng, Rng}; - - use crate::byte_pair_encoding::{create_test_bytes, BPE_CL100K}; - - use super::IntervalEncoding; - - #[test] - fn test_interval_count() { - let bpe = &BPE_CL100K; - let text = create_test_bytes(bpe, 10000); - let intervals = IntervalEncoding::new(bpe, &text); - for _ in 0..1000 { - let start = thread_rng().gen_range(0..text.len()); - let end = thread_rng().gen_range(0..text.len()); - let range = start.min(end)..start.max(end); - assert_eq!( - intervals.count(range.clone()), - bpe.encode_via_backtracking(&text[range]).len() - ); - } - } -} diff --git a/crates/bpe/src/lib.rs b/crates/bpe/src/lib.rs index 2c7ab43..452024e 100644 --- a/crates/bpe/src/lib.rs +++ b/crates/bpe/src/lib.rs @@ -4,64 +4,3 @@ mod bitfield; pub mod byte_pair_encoding; pub mod interval_encoding; pub mod prependable_encoder; - -#[cfg(test)] -mod tests { - use itertools::Itertools; - - use crate::byte_pair_encoding::BytePairEncoding; - - /// This test produces the output for the encoding example in the README. - #[test] - fn readme_example() { - let tokens = ["a", "b", "c", "ab", "cb", "ac"].map(|t| t.as_bytes().to_vec()); - let bpe = BytePairEncoding::from_dictionary(tokens, None); - let text = "abacb"; - let prefixes = (1..=text.len()).map(|end| &text[..end]).collect_vec(); - let all_prefix_tokens = prefixes - .iter() - .map(|prefix| { - bpe.encode_via_backtracking(prefix.as_bytes()) - .into_iter() - .map(|t| unsafe { String::from_utf8_unchecked(bpe.decode_tokens(&[t])) }) - .collect_vec() - }) - .collect_vec(); - let last_prefix_tokens = all_prefix_tokens - .iter() - .map(|tokens| tokens.last().unwrap()) - .collect_vec(); - - println!("All tokens for each prefix of `{text}`:\n"); - for (prefix, tokens) in prefixes.iter().zip(&all_prefix_tokens) { - println!( - "- `{prefix}` {}> `{}`", - "-".repeat(text.len() + 2 - prefix.len()), - tokens.join(" ") - ); - } - println!(); - - println!("Last token for each prefix of `{text}`:\n"); - for (prefix, token) in prefixes.iter().zip(&last_prefix_tokens) { - println!( - "- `{prefix}` {}> `{token}`", - "-".repeat(text.len() + 2 - prefix.len()), - ); - } - println!(); - - println!("Tokenization of `{text}`:\n"); - let mut remaining = text.len(); - while remaining > 0 { - let prefix = &text[..remaining]; - let token = last_prefix_tokens[remaining - 1]; - println!( - "- `{prefix}` {}> `{token}`", - "-".repeat(text.len() + 2 - prefix.len()), - ); - remaining -= token.len(); - } - println!("- ``"); - } -} diff --git a/crates/bpe/src/prependable_encoder.rs b/crates/bpe/src/prependable_encoder.rs index ce13e40..264c6a5 100644 --- a/crates/bpe/src/prependable_encoder.rs +++ b/crates/bpe/src/prependable_encoder.rs @@ -87,21 +87,3 @@ impl<'a> PrependableEncoder<'a> { self.states.is_empty() } } - -#[cfg(test)] -mod tests { - use crate::byte_pair_encoding::{create_test_bytes, BPE_CL100K}; - - use super::PrependableEncoder; - - #[test] - fn test_prependable_encoder() { - let bpe = &BPE_CL100K; - let mut enc = PrependableEncoder::new(bpe); - let input_string = create_test_bytes(bpe, 100); - for (i, c) in input_string.iter().enumerate().rev() { - enc.push(*c); - assert_eq!(enc.token_count(), bpe.count(&input_string[i..])); - } - } -} diff --git a/crates/bpe/tests/Cargo.toml b/crates/bpe/tests/Cargo.toml new file mode 100644 index 0000000..7c6ce69 --- /dev/null +++ b/crates/bpe/tests/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "bpe-tests" +edition = "2021" + +[dependencies] +bpe = { path = "../../bpe" } +bpe-openai = { path = "../../bpe-openai" } +itertools = "0.13" +rand = "0.8" +tiktoken-rs = "0.6" diff --git a/crates/bpe/tests/src/lib.rs b/crates/bpe/tests/src/lib.rs new file mode 100644 index 0000000..ed2ab81 --- /dev/null +++ b/crates/bpe/tests/src/lib.rs @@ -0,0 +1,186 @@ +use bpe::byte_pair_encoding::BytePairEncoding; +use rand::{thread_rng, Rng}; + +pub fn create_test_bytes(bpe: &BytePairEncoding, tokens: usize) -> Vec { + let mut text = vec![]; + for _ in 0..tokens { + let i = thread_rng().gen_range(0..bpe.num_tokens()); + let s = bpe.token_bytes(i as u32); + text.extend_from_slice(s); + } + text +} + +#[cfg(test)] +mod tests { + use std::time::Instant; + + use itertools::Itertools; + use rand::{thread_rng, Rng}; + use tiktoken_rs::{cl100k_base_singleton, o200k_base_singleton}; + + use bpe::appendable_encoder::AppendableEncoder; + use bpe::byte_pair_encoding::BytePairEncoding; + use bpe::interval_encoding::IntervalEncoding; + use bpe::prependable_encoder::PrependableEncoder; + use bpe_openai::{cl100k_base, o200k_base}; + + use super::*; + + /// This test produces the output for the encoding example in the README. + #[test] + fn readme_example() { + let tokens = ["a", "b", "c", "ab", "cb", "ac"].map(|t| t.as_bytes().to_vec()); + let bpe = BytePairEncoding::from_dictionary(tokens, None); + let text = "abacb"; + let prefixes = (1..=text.len()).map(|end| &text[..end]).collect_vec(); + let all_prefix_tokens = prefixes + .iter() + .map(|prefix| { + bpe.encode_via_backtracking(prefix.as_bytes()) + .into_iter() + .map(|t| unsafe { String::from_utf8_unchecked(bpe.decode_tokens(&[t])) }) + .collect_vec() + }) + .collect_vec(); + let last_prefix_tokens = all_prefix_tokens + .iter() + .map(|tokens| tokens.last().unwrap()) + .collect_vec(); + + println!("All tokens for each prefix of `{text}`:\n"); + for (prefix, tokens) in prefixes.iter().zip(&all_prefix_tokens) { + println!( + "- `{prefix}` {}> `{}`", + "-".repeat(text.len() + 2 - prefix.len()), + tokens.join(" ") + ); + } + println!(); + + println!("Last token for each prefix of `{text}`:\n"); + for (prefix, token) in prefixes.iter().zip(&last_prefix_tokens) { + println!( + "- `{prefix}` {}> `{token}`", + "-".repeat(text.len() + 2 - prefix.len()), + ); + } + println!(); + + println!("Tokenization of `{text}`:\n"); + let mut remaining = text.len(); + while remaining > 0 { + let prefix = &text[..remaining]; + let token = last_prefix_tokens[remaining - 1]; + println!( + "- `{prefix}` {}> `{token}`", + "-".repeat(text.len() + 2 - prefix.len()), + ); + remaining -= token.len(); + } + println!("- ``"); + } + + #[test] + fn test_appendable_encoder() { + let bpe = &cl100k_base().bpe; + let mut enc = AppendableEncoder::new(bpe); + let input_string = create_test_bytes(bpe, 100); + for (i, c) in input_string.iter().enumerate() { + assert_eq!(enc.token_count(), bpe.count(&input_string[0..i])); + enc.push(*c); + } + } + + #[test] + fn test_correctness_cl100k() { + // This is quite a challenging test case... + let test_string = std::str::from_utf8(&[ + 125, 34, 10, 10, 46, 109, 107, 100, 105, 114, 115, 32, 102, 100, 115, 32, 97, 100, 105, + 112, 105, 115, 105, 99, 105, 110, 103, 105, 116, 121, 69, 110, 103, 105, 110, 101, 32, + 69, 67, 105, 114, 105, 101, 32, 111, 112, 116, 105, 109, 97, 108, 95, 68, 65, 32, 111, + 102, 102, 101, 110, 100, + ]) + .unwrap(); + let time = Instant::now(); + let bpe = &cl100k_base().bpe; + println!("{:?}", time.elapsed()); + let encoded1 = cl100k_base_singleton() + .lock() + .encode_ordinary(test_string) + .into_iter() + .collect_vec(); + let encoded2 = bpe.encode_via_backtracking(test_string.as_bytes()); + assert_eq!(encoded1, encoded2); + let encoded3 = bpe.encode_via_table(test_string.as_bytes()); + assert_eq!(encoded1, encoded3); + let encoded4 = bpe.encode_via_bitfield(test_string.as_bytes()); + assert_eq!(encoded1, encoded4); + } + + #[test] + fn test_correctness_o200k() { + // This is quite a challenging test case... + let test_string = std::str::from_utf8(&[ + 125, 34, 10, 10, 46, 109, 107, 100, 105, 114, 115, 32, 102, 100, 115, 32, 97, 100, 105, + 112, 105, 115, 105, 99, 105, 110, 103, 105, 116, 121, 69, 110, 103, 105, 110, 101, 32, + 69, 67, 105, 114, 105, 101, 32, 111, 112, 116, 105, 109, 97, 108, 95, 68, 65, 32, 111, + 102, 102, 101, 110, 100, + ]) + .unwrap(); + let time = Instant::now(); + let bpe = &o200k_base().bpe; + println!("{:?}", time.elapsed()); + let encoded1 = o200k_base_singleton() + .lock() + .encode_ordinary(test_string) + .into_iter() + .collect_vec(); + let encoded2 = bpe.encode_via_backtracking(test_string.as_bytes()); + assert_eq!(encoded1, encoded2); + let encoded3 = bpe.encode_via_table(test_string.as_bytes()); + assert_eq!(encoded1, encoded3); + let encoded4 = bpe.encode_via_bitfield(test_string.as_bytes()); + assert_eq!(encoded1, encoded4); + } + + #[test] + fn test_bpe_equivalence() { + let bpe = &cl100k_base().bpe; + for tokens in [10, 1000, 10000] { + for _ in 0..5 { + let test_input = create_test_bytes(bpe, tokens); + let encoded1 = bpe.encode_via_backtracking(&test_input); + let encoded2 = bpe.encode_via_bitfield(&test_input); + assert_eq!(encoded1, encoded2, "{} {}", encoded1.len(), encoded2.len()); + } + } + } + + #[test] + fn test_interval_count() { + let bpe = &cl100k_base().bpe; + let text = create_test_bytes(bpe, 10000); + let intervals = IntervalEncoding::new(bpe, &text); + for _ in 0..1000 { + let start = thread_rng().gen_range(0..text.len()); + let end = thread_rng().gen_range(0..text.len()); + let range = start.min(end)..start.max(end); + assert_eq!( + intervals.count(range.clone()), + bpe.encode_via_backtracking(&text[range]).len() + ); + } + } + + #[test] + fn test_prependable_encoder() { + let bpe = &cl100k_base().bpe; + let mut enc = PrependableEncoder::new(bpe); + let input_string = create_test_bytes(bpe, 100); + for (i, c) in input_string.iter().enumerate().rev() { + enc.push(*c); + assert_eq!(enc.token_count(), bpe.count(&input_string[i..])); + } + } +}