Skip to content

Commit df29b9c

Browse files
Merge pull request #35 from github/count-till-limit
Add count_till_limit method on Tokenizer
2 parents 17d5c3e + c28e428 commit df29b9c

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

crates/bpe-openai/src/lib.rs

+21-1
Original file line numberDiff line numberDiff line change
@@ -80,22 +80,41 @@ impl Tokenizer {
8080
Ok(Self { bpe, pre })
8181
}
8282

83+
/// Count the number of tokens produced when encoding the text. Applies pre-tokenization
84+
/// before counting.
8385
pub fn count(&self, text: &str) -> usize {
8486
self.split(text)
8587
.map(|piece| self.bpe.count(piece.as_bytes()))
8688
.sum()
8789
}
8890

91+
/// Returns the token count iff the total token count stays below the specified token_limit.
92+
/// Otherwise, it returns none. This function can be faster than [`Self::count`]` when the
93+
/// token limit is much smaller than the provided text. Applies pre-tokenization before counting.
94+
pub fn count_till_limit(&self, text: &str, token_limit: usize) -> Option<usize> {
95+
self.split(text)
96+
.try_fold(token_limit, |token_limit, piece| {
97+
self.bpe
98+
.count_till_limit(piece.as_bytes(), token_limit)
99+
.map(|piece_count| token_limit - piece_count)
100+
})
101+
}
102+
103+
/// Returns the tokens for the encoding of the given text. Applies pre-tokenization before
104+
/// encoding.
89105
pub fn encode(&self, text: &str) -> Vec<u32> {
90106
self.split(text)
91107
.flat_map(|piece| self.bpe.encode_via_backtracking(piece.as_bytes()))
92108
.collect()
93109
}
94-
110+
/// Returns the text corresponding to the given encoding if it is valid UTF-8. Otherwise,
111+
/// returns none.
95112
pub fn decode(&self, tokens: &[u32]) -> Option<String> {
96113
String::from_utf8(self.bpe.decode_tokens(tokens)).ok()
97114
}
98115

116+
/// Returns an iterator with the text pieces resulting from pre-tokenization. If this
117+
/// tokenizer does not have pre-tokenization, the iterator returns the full text.
99118
pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator<Item = &str> + 'a {
100119
match &self.pre {
101120
Some(pre) => Either::Left(pre.split(text)),
@@ -124,6 +143,7 @@ impl Pretokenizer {
124143
Ok(Self { pat, lookahead })
125144
}
126145

146+
/// Returns an iterator with the text pieces after splitting with the regular expression.
127147
pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator<Item = &str> + 'a {
128148
Splits {
129149
pat: &self.pat,

crates/bpe/benchmarks/Cargo.toml

-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ test = true
2020
[dependencies]
2121
bpe = { path = "../../bpe" }
2222
bpe-openai = { path = "../../bpe-openai" }
23-
bpe-tests = { path = "../tests" }
2423
criterion = "0.5"
2524
rand = "0.8"
2625
tiktoken-rs = "0.6"

crates/bpe/src/byte_pair_encoding.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ impl BytePairEncoding {
417417
}
418418

419419
/// Returns the token count iff the total token count stays below the specified `token_limit`.
420-
/// Otherwise, it returns false.
420+
/// Otherwise, it returns none.
421421
/// This function can be faster than `count` when the token_limit is much smaller than the provided text.
422422
pub fn count_till_limit(&self, text: &[u8], token_limit: usize) -> Option<usize> {
423423
let mut enc = BacktrackEncoder::new(self, text);

0 commit comments

Comments
 (0)