Skip to content

Commit 5b893bf

Browse files
committed
feat: allow tokenizer to load from GGUF metadata
1 parent 4b46187 commit 5b893bf

File tree

4 files changed

+423
-0
lines changed

4 files changed

+423
-0
lines changed

candle-core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ safetensors = { workspace = true }
3333
thiserror = { workspace = true }
3434
yoke = { workspace = true }
3535
zip = { workspace = true }
36+
tokenizers = { workspace = true, features = ["onig"] }
3637

3738
[target.'cfg(all(not(target_arch = "wasm32"), not(target_os = "ios")))'.dependencies]
3839
ug = { workspace = true }

candle-core/src/quantized/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ pub mod imatrix_file;
1414
pub mod k_quants;
1515
#[cfg(feature = "metal")]
1616
pub mod metal;
17+
pub mod tokenizer;
1718
#[cfg(not(feature = "metal"))]
1819
mod metal {
1920
pub use super::dummy_metal::*;
Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
use crate::quantized::gguf_file;
2+
use crate::{Context, Error, Result};
3+
use std::collections::HashSet;
4+
use tokenizers::{
5+
decoders::{byte_level::ByteLevel as ByteLevelDecoder, DecoderWrapper},
6+
models::bpe::{Vocab, BPE},
7+
normalizers::{unicode::NFC, NormalizerWrapper},
8+
pre_tokenizers::{
9+
byte_level::ByteLevel as ByteLevelPre,
10+
sequence::Sequence,
11+
split::{Split, SplitPattern},
12+
PreTokenizerWrapper,
13+
},
14+
processors::sequence::Sequence as ProcessorSequence,
15+
processors::{byte_level::ByteLevel as ByteLevelProcessor, PostProcessorWrapper},
16+
tokenizer::SplitDelimiterBehavior,
17+
AddedToken, Tokenizer,
18+
};
19+
20+
pub trait TokenizerFromGguf: Sized {
21+
fn from_gguf(ct: &gguf_file::Content) -> Result<Self>;
22+
}
23+
24+
fn metadata_value<'a>(ct: &'a gguf_file::Content, key: &str) -> Result<&'a gguf_file::Value> {
25+
ct.metadata
26+
.get(key)
27+
.with_context(|| format!("missing GGUF metadata key `{key}`"))
28+
}
29+
30+
fn gguf_value_to_u32(v: &gguf_file::Value) -> Result<u32> {
31+
use gguf_file::Value::*;
32+
match v {
33+
U8(v) => Ok(*v as u32),
34+
I8(v) => Ok(*v as u32),
35+
U16(v) => Ok(*v as u32),
36+
I16(v) => Ok(*v as u32),
37+
U32(v) => Ok(*v),
38+
I32(v) => Ok(*v as u32),
39+
U64(v) => Ok(*v as u32),
40+
I64(v) => Ok(*v as u32),
41+
_ => crate::bail!("expected numeric value for token type/id, got {v:?}"),
42+
}
43+
}
44+
45+
fn value_to_string_array(v: &gguf_file::Value, name: &str) -> Result<Vec<String>> {
46+
let arr = v
47+
.to_vec()
48+
.with_context(|| format!("`{name}` is not an array"))?;
49+
arr.iter()
50+
.map(|v| {
51+
v.to_string()
52+
.map(|s| s.to_string())
53+
.with_context(|| format!("`{name}` element is not a string: {v:?}"))
54+
})
55+
.collect()
56+
}
57+
58+
fn merges_from_value(v: &gguf_file::Value) -> Result<Vec<(String, String)>> {
59+
value_to_string_array(v, "tokenizer.ggml.merges")?
60+
.into_iter()
61+
.map(|m| {
62+
m.split_once(' ')
63+
.map(|(a, b)| (a.to_string(), b.to_string()))
64+
.ok_or_else(|| Error::msg(format!("invalid merge entry `{m}`")))
65+
})
66+
.collect()
67+
}
68+
69+
struct Pipeline {
70+
normalizer: Option<NormalizerWrapper>,
71+
pretokenizer: Option<PreTokenizerWrapper>,
72+
decoder: Option<DecoderWrapper>,
73+
post_processor: Option<PostProcessorWrapper>,
74+
}
75+
76+
impl Pipeline {
77+
fn apply(self, tokenizer: &mut Tokenizer) {
78+
if let Some(norm) = self.normalizer {
79+
tokenizer.with_normalizer(Some(norm));
80+
}
81+
if let Some(pt) = self.pretokenizer {
82+
tokenizer.with_pre_tokenizer(Some(pt));
83+
}
84+
if let Some(dec) = self.decoder {
85+
tokenizer.with_decoder(Some(dec));
86+
}
87+
if let Some(pp) = self.post_processor {
88+
tokenizer.with_post_processor(Some(pp));
89+
}
90+
}
91+
}
92+
93+
fn pre_tokenizer_sequence(regex: &str, byte_level: ByteLevelPre) -> Result<PreTokenizerWrapper> {
94+
let split = Split::new(
95+
SplitPattern::Regex(regex.to_string()),
96+
SplitDelimiterBehavior::Isolated,
97+
false,
98+
)
99+
.map_err(Error::wrap)?;
100+
Ok(Sequence::new(vec![split.into(), byte_level.into()]).into())
101+
}
102+
103+
fn pipeline_from_pre(pre: &str) -> Result<Pipeline> {
104+
const REGEX_QWEN2: &str = r"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
105+
const REGEX_LLAMA3: &str = r"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
106+
107+
Ok(match pre {
108+
// Matches Qwen2 tokenizer.json settings
109+
"qwen2" => Pipeline {
110+
normalizer: Some(NFC.into()),
111+
pretokenizer: Some(pre_tokenizer_sequence(
112+
REGEX_QWEN2,
113+
ByteLevelPre::new(false, false, false),
114+
)?),
115+
decoder: Some(ByteLevelDecoder::new(false, false, false).into()),
116+
post_processor: Some(ByteLevelProcessor::new(false, false, false).into()),
117+
},
118+
// Matches Smaug/Llama3 style byte-level BPE
119+
"smaug-bpe" | "lfm2" | "llama3" => Pipeline {
120+
normalizer: None,
121+
pretokenizer: Some(pre_tokenizer_sequence(
122+
REGEX_LLAMA3,
123+
ByteLevelPre::new(false, true, false),
124+
)?),
125+
decoder: Some(ByteLevelDecoder::new(true, true, true).into()),
126+
post_processor: Some(ByteLevelProcessor::new(true, false, true).into()),
127+
},
128+
// Default GPT-2 style BPE
129+
_ => Pipeline {
130+
normalizer: None,
131+
pretokenizer: Some(ByteLevelPre::default().into()),
132+
decoder: Some(ByteLevelDecoder::default().into()),
133+
post_processor: Some(ByteLevelProcessor::default().into()),
134+
},
135+
})
136+
}
137+
138+
fn template_processor(
139+
tokens: &[String],
140+
bos_id: Option<u32>,
141+
eos_id: Option<u32>,
142+
add_bos: bool,
143+
add_eos: bool,
144+
) -> Option<PostProcessorWrapper> {
145+
if (!add_bos && !add_eos) || tokens.is_empty() {
146+
return None;
147+
}
148+
149+
let bos = bos_id.and_then(|id| tokens.get(id as usize)).cloned();
150+
let eos = eos_id.and_then(|id| tokens.get(id as usize)).cloned();
151+
152+
let mut specials = Vec::new();
153+
if add_bos {
154+
let bos_id = bos_id?;
155+
let bos_tok = bos.clone()?;
156+
specials.push((bos_tok.clone(), bos_id));
157+
}
158+
if add_eos {
159+
let eos_id = eos_id?;
160+
let eos_tok = eos.clone()?;
161+
specials.push((eos_tok.clone(), eos_id));
162+
}
163+
164+
let mut single = Vec::new();
165+
if add_bos {
166+
single.push(bos.clone()?);
167+
}
168+
single.push("$0".to_string());
169+
if add_eos {
170+
single.push(eos.clone()?);
171+
}
172+
173+
let mut pair = Vec::new();
174+
if add_bos {
175+
pair.push(format!("{}:0", bos.clone()?));
176+
}
177+
pair.push("$A:0".to_string());
178+
if add_eos {
179+
pair.push(format!("{}:0", eos.clone()?));
180+
}
181+
if add_bos {
182+
pair.push(format!("{}:1", bos.clone()?));
183+
}
184+
pair.push("$B:1".to_string());
185+
if add_eos {
186+
pair.push(format!("{}:1", eos.clone()?));
187+
}
188+
189+
let proc = tokenizers::processors::template::TemplateProcessing::builder()
190+
.try_single(single)
191+
.ok()?
192+
.try_pair(pair)
193+
.ok()?
194+
.special_tokens(specials)
195+
.build()
196+
.ok()?;
197+
198+
Some(PostProcessorWrapper::Template(proc))
199+
}
200+
201+
impl TokenizerFromGguf for Tokenizer {
202+
fn from_gguf(ct: &gguf_file::Content) -> Result<Self> {
203+
let model_kind = metadata_value(ct, "tokenizer.ggml.model")?
204+
.to_string()?
205+
.to_lowercase();
206+
if model_kind != "gpt2" {
207+
crate::bail!("unsupported tokenizer model `{model_kind}`");
208+
}
209+
210+
let tokens = value_to_string_array(
211+
metadata_value(ct, "tokenizer.ggml.tokens")?,
212+
"tokenizer.ggml.tokens",
213+
)?;
214+
let vocab: Vocab = tokens
215+
.iter()
216+
.enumerate()
217+
.map(|(i, t)| (t.clone(), i as u32))
218+
.collect();
219+
let merges = merges_from_value(metadata_value(ct, "tokenizer.ggml.merges")?)?;
220+
221+
let mut builder = BPE::builder().vocab_and_merges(vocab, merges);
222+
223+
if let Ok(val) = metadata_value(ct, "tokenizer.ggml.unk_token_id") {
224+
let token_id = gguf_value_to_u32(val)?;
225+
if let Some(token) = tokens.get(token_id as usize) {
226+
builder = builder.unk_token(token.clone());
227+
}
228+
}
229+
230+
if let Ok(val) = metadata_value(ct, "tokenizer.ggml.byte_fallback") {
231+
builder = builder.byte_fallback(val.to_bool()?);
232+
}
233+
234+
if let Ok(val) = metadata_value(ct, "tokenizer.ggml.ignore_merges") {
235+
builder = builder.ignore_merges(val.to_bool()?);
236+
}
237+
238+
let bpe = builder.build().map_err(Error::wrap)?;
239+
let mut tokenizer = Tokenizer::new(bpe);
240+
241+
let pre = metadata_value(ct, "tokenizer.ggml.pre")
242+
.and_then(|v| v.to_string())
243+
.map(|s| s.to_string())
244+
.unwrap_or_else(|_| "gpt2".to_string());
245+
let pipeline = pipeline_from_pre(pre.as_str())?;
246+
let post_processor_base = pipeline.post_processor.clone();
247+
248+
let add_bos = metadata_value(ct, "tokenizer.ggml.add_bos_token")
249+
.and_then(|v| v.to_bool())
250+
.unwrap_or(false);
251+
let add_eos = metadata_value(ct, "tokenizer.ggml.add_eos_token")
252+
.and_then(|v| v.to_bool())
253+
.unwrap_or(false);
254+
let bos_id = metadata_value(ct, "tokenizer.ggml.bos_token_id")
255+
.and_then(gguf_value_to_u32)
256+
.ok();
257+
let eos_id = metadata_value(ct, "tokenizer.ggml.eos_token_id")
258+
.and_then(gguf_value_to_u32)
259+
.ok();
260+
261+
pipeline.apply(&mut tokenizer);
262+
263+
// Compose existing post-processor with a template-based one if needed
264+
let template_pp = template_processor(&tokens, bos_id, eos_id, add_bos, add_eos);
265+
if template_pp.is_some() || post_processor_base.is_some() {
266+
let mut steps = Vec::new();
267+
if let Some(pp) = post_processor_base {
268+
steps.push(pp);
269+
}
270+
if let Some(tp) = template_pp {
271+
steps.push(tp);
272+
}
273+
let pp = if steps.len() == 1 {
274+
steps.pop().unwrap()
275+
} else {
276+
ProcessorSequence::new(steps).into()
277+
};
278+
tokenizer.with_post_processor(Some(pp));
279+
}
280+
281+
// Mark special tokens so decode(skip_special_tokens = true) behaves as expected
282+
if let Ok(gguf_file::Value::Array(arr)) = metadata_value(ct, "tokenizer.ggml.token_type") {
283+
let mut specials = Vec::new();
284+
for (idx, v) in arr.iter().enumerate() {
285+
let ty = gguf_value_to_u32(v)?;
286+
// Aligns with llama_token_type: treat non-normal/non-byte tokens as special.
287+
let is_special = matches!(ty, 2..=5);
288+
if is_special {
289+
if let Some(tok) = tokens.get(idx) {
290+
specials.push(AddedToken::from(tok.clone(), true));
291+
}
292+
}
293+
}
294+
if !specials.is_empty() {
295+
tokenizer.add_special_tokens(&specials);
296+
}
297+
}
298+
299+
let mut explicit_specials = HashSet::new();
300+
for key in [
301+
"tokenizer.ggml.bos_token_id",
302+
"tokenizer.ggml.eos_token_id",
303+
"tokenizer.ggml.pad_token_id",
304+
"tokenizer.ggml.sep_token_id",
305+
"tokenizer.ggml.unk_token_id",
306+
] {
307+
if let Ok(val) = metadata_value(ct, key) {
308+
explicit_specials.insert(gguf_value_to_u32(val)?);
309+
}
310+
}
311+
if !explicit_specials.is_empty() {
312+
let specials: Vec<_> = explicit_specials
313+
.into_iter()
314+
.filter_map(|id| tokens.get(id as usize))
315+
.map(|tok| AddedToken::from(tok.clone(), true))
316+
.collect();
317+
if !specials.is_empty() {
318+
tokenizer.add_special_tokens(&specials);
319+
}
320+
}
321+
322+
Ok(tokenizer)
323+
}
324+
}

0 commit comments

Comments
 (0)