Skip to content

Commit 990b059

Browse files
committed
feat: add quantized lfm2 model support
1 parent 4b46187 commit 990b059

File tree

4 files changed

+1007
-0
lines changed

4 files changed

+1007
-0
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# candle-quantized-lfm2
2+
3+
Candle implementation of various quantized lfm2 models.
4+
5+
## Running an example
6+
7+
```bash
8+
$ cargo run --example quantized-lfm2 --release -- --prompt "Tell me a story in 100 words."
9+
avx: false, neon: true, simd128: false, f16c: false
10+
temp: 0.80 repeat-penalty: 1.10 repeat-last-n: 64
11+
Running on CPU, to run on GPU(metal), build this example with `--features metal`
12+
loaded 266 tensors (1.56GB) in 0.13s
13+
model ready
14+
Starting the inference loop:
15+
Tell me a story in 100 words.
16+
17+
A quiet town nestled between rolling hills, where every springtime arrives with laughter and blossoms. Clara, the town’s beloved baker, opens her shop at dawn—cinnamon swirling into warm air, fresh pastries glowing on wooden racks. Each customer greets her with a smile, sharing tales while savoring sweet treats. One day, an old man hands her a faded photo: him and Clara, decades ago, when she’d kneaded dough for his wedding cake. Now he waits in silence, unseen. Clara bakes him another batch—hope rising from the oven, turning cold hearts into laughter again.
18+
19+
10 prompt tokens processed: 39.28 token/s
20+
133 tokens generated: 43.34 token/s
21+
```
Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
#[cfg(feature = "mkl")]
2+
extern crate intel_mkl_src;
3+
4+
#[cfg(feature = "accelerate")]
5+
extern crate accelerate_src;
6+
7+
use anyhow::Result;
8+
use clap::{Parser, ValueEnum};
9+
use std::io::Write;
10+
use std::path::{Path, PathBuf};
11+
use tokenizers::Tokenizer;
12+
13+
use candle::quantized::gguf_file;
14+
use candle::Tensor;
15+
use candle_transformers::generation::{LogitsProcessor, Sampling};
16+
17+
use candle_examples::token_output_stream::TokenOutputStream;
18+
use candle_transformers::models::quantized_lfm2::ModelWeights;
19+
20+
const DEFAULT_PROMPT: &str = "Explain how Rotary Position Embeddings work in transformers.";
21+
22+
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
23+
enum Which {
24+
/// 350M base model, Q4_K_M quantization.
25+
#[value(name = "lfm2-350m-q4_k_m")]
26+
Lfm2_350MQ4KM,
27+
/// 350M base model, Q8_0 quantization.
28+
#[value(name = "lfm2-350m-q8_0")]
29+
Lfm2_350MQ8_0,
30+
/// 2.6B model, Q4_K_M quantization.
31+
#[value(name = "lfm2-2.6b-q4_k_m")]
32+
Lfm2_2_6BQ4KM,
33+
/// 2.6B model, Q8_0 quantization.
34+
#[value(name = "lfm2-2.6b-q8_0")]
35+
Lfm2_2_6BQ8_0,
36+
}
37+
38+
#[derive(Parser, Debug)]
39+
#[command(author, version, about, long_about = None)]
40+
struct Args {
41+
/// GGUF file to load, typically a .gguf file generated by llama.cpp.
42+
#[arg(long)]
43+
model: Option<String>,
44+
45+
/// Hugging Face repo id (eg `user/model`) to download the weights from when --model is not set.
46+
#[arg(long, default_value = "lfm2-2.6b-q4_k_m")]
47+
which: Which,
48+
49+
/// Repo revision to download from when using --which.
50+
#[arg(long, default_value = "main")]
51+
revision: String,
52+
53+
/// Path to tokenizer.json. Defaults to the same folder as the model or is fetched from Hugging Face.
54+
#[arg(long)]
55+
tokenizer: Option<String>,
56+
57+
/// The initial prompt to feed to the model.
58+
#[arg(long)]
59+
prompt: Option<String>,
60+
61+
/// The number of tokens to sample (including the first token after the prompt).
62+
#[arg(short = 'n', long, default_value_t = 512)]
63+
sample_len: usize,
64+
65+
/// The temperature used to generate samples, use 0 for greedy sampling.
66+
#[arg(long, default_value_t = 0.8)]
67+
temperature: f64,
68+
69+
/// Nucleus sampling probability cutoff.
70+
#[arg(long)]
71+
top_p: Option<f64>,
72+
73+
/// Only sample among the top K samples.
74+
#[arg(long)]
75+
top_k: Option<usize>,
76+
77+
/// The seed to use when generating random samples.
78+
#[arg(long, default_value_t = 299792458)]
79+
seed: u64,
80+
81+
/// Enable tracing (generates a trace-timestamp.json file).
82+
#[arg(long)]
83+
tracing: bool,
84+
85+
/// Process prompt elements separately.
86+
#[arg(long)]
87+
split_prompt: bool,
88+
89+
/// Run on CPU rather than GPU even if a GPU is available.
90+
#[arg(long)]
91+
cpu: bool,
92+
93+
/// Penalty to be applied for repeating tokens, 1. means no penalty.
94+
#[arg(long, default_value_t = 1.1)]
95+
repeat_penalty: f32,
96+
97+
/// The context size to consider for the repeat penalty.
98+
#[arg(long, default_value_t = 64)]
99+
repeat_last_n: usize,
100+
}
101+
102+
impl Args {
103+
fn model_path(&self) -> Result<PathBuf> {
104+
if let Some(model) = &self.model {
105+
return Ok(PathBuf::from(model));
106+
}
107+
let (repo, filename) = match self.which {
108+
Which::Lfm2_350MQ4KM => ("LiquidAI/LFM2-350M-GGUF", "LFM2-350M-Q4_K_M.gguf"),
109+
Which::Lfm2_350MQ8_0 => ("LiquidAI/LFM2-350M-GGUF", "LFM2-350M-Q8_0.gguf"),
110+
Which::Lfm2_2_6BQ4KM => ("LiquidAI/LFM2-2.6B-GGUF", "LFM2-2.6B-Q4_K_M.gguf"),
111+
Which::Lfm2_2_6BQ8_0 => ("LiquidAI/LFM2-2.6B-GGUF", "LFM2-2.6B-Q8_0.gguf"),
112+
};
113+
let api = hf_hub::api::sync::Api::new()?;
114+
api.repo(hf_hub::Repo::with_revision(
115+
repo.to_string(),
116+
hf_hub::RepoType::Model,
117+
self.revision.clone(),
118+
))
119+
.get(filename)
120+
.map_err(Into::into)
121+
}
122+
123+
fn tokenizer(&self, model_path: &Path) -> Result<Tokenizer> {
124+
if let Some(path) = &self.tokenizer {
125+
return Tokenizer::from_file(path).map_err(anyhow::Error::msg);
126+
}
127+
128+
if let Some(dir) = model_path.parent() {
129+
let candidate = dir.join("tokenizer.json");
130+
if candidate.exists() {
131+
return Tokenizer::from_file(candidate).map_err(anyhow::Error::msg);
132+
}
133+
}
134+
135+
let tokenizer_repo = match self.which {
136+
Which::Lfm2_350MQ4KM | Which::Lfm2_350MQ8_0 => "LiquidAI/LFM2-350M",
137+
Which::Lfm2_2_6BQ4KM | Which::Lfm2_2_6BQ8_0 => "LiquidAI/LFM2-2.6B",
138+
};
139+
let api = hf_hub::api::sync::Api::new()?;
140+
let tokenizer_path = api
141+
.repo(hf_hub::Repo::with_revision(
142+
tokenizer_repo.to_string(),
143+
hf_hub::RepoType::Model,
144+
self.revision.clone(),
145+
))
146+
.get("tokenizer.json")?;
147+
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
148+
}
149+
}
150+
151+
fn format_size(size_in_bytes: usize) -> String {
152+
if size_in_bytes < 1_000 {
153+
format!("{size_in_bytes}B")
154+
} else if size_in_bytes < 1_000_000 {
155+
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
156+
} else if size_in_bytes < 1_000_000_000 {
157+
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
158+
} else {
159+
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
160+
}
161+
}
162+
163+
fn guess_eos_id(tokenizer: &Tokenizer) -> Option<u32> {
164+
let vocab = tokenizer.get_vocab(true);
165+
let candidates = [
166+
"</s>",
167+
"<|im_end|>",
168+
"<|eot_id|>",
169+
"<|end|>",
170+
"<|end_of_text|>",
171+
"<|endoftext|>",
172+
];
173+
candidates
174+
.iter()
175+
.find_map(|token| vocab.get(*token).copied())
176+
}
177+
178+
fn main() -> Result<()> {
179+
use tracing_chrome::ChromeLayerBuilder;
180+
use tracing_subscriber::prelude::*;
181+
182+
let args = Args::parse();
183+
let _guard = if args.tracing {
184+
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
185+
tracing_subscriber::registry().with(chrome_layer).init();
186+
Some(guard)
187+
} else {
188+
None
189+
};
190+
191+
println!(
192+
"avx: {}, neon: {}, simd128: {}, f16c: {}",
193+
candle::utils::with_avx(),
194+
candle::utils::with_neon(),
195+
candle::utils::with_simd128(),
196+
candle::utils::with_f16c()
197+
);
198+
println!(
199+
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
200+
args.temperature, args.repeat_penalty, args.repeat_last_n
201+
);
202+
203+
let model_path = args.model_path()?;
204+
let mut file = std::fs::File::open(&model_path)?;
205+
let start = std::time::Instant::now();
206+
let device = candle_examples::device(args.cpu)?;
207+
208+
let gguf = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path.clone()))?;
209+
let mut total_size_in_bytes = 0;
210+
for (_, tensor) in gguf.tensor_infos.iter() {
211+
let elem_count = tensor.shape.elem_count();
212+
total_size_in_bytes +=
213+
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
214+
}
215+
216+
let context_length = gguf
217+
.metadata
218+
.get("lfm2.context_length")
219+
.and_then(|v| v.to_u32().ok().map(|v| v as usize));
220+
221+
println!(
222+
"loaded {:?} tensors ({}) in {:.2}s",
223+
gguf.tensor_infos.len(),
224+
format_size(total_size_in_bytes),
225+
start.elapsed().as_secs_f32()
226+
);
227+
228+
let mut model = ModelWeights::from_gguf(gguf, &mut file, &device)?;
229+
println!("model ready");
230+
231+
let tokenizer = args.tokenizer(&model_path)?;
232+
let mut tos = TokenOutputStream::new(tokenizer);
233+
let mut tokens = tos
234+
.tokenizer()
235+
.encode(args.prompt.as_deref().unwrap_or(DEFAULT_PROMPT), true)
236+
.map_err(anyhow::Error::msg)?
237+
.get_ids()
238+
.to_vec();
239+
240+
if let Some(max_ctx) = context_length {
241+
if tokens.len() >= max_ctx {
242+
let trim = tokens.len() - max_ctx + 1;
243+
tokens.drain(0..trim);
244+
println!("prompt trimmed to last {max_ctx} tokens to fit context");
245+
}
246+
}
247+
248+
let mut all_tokens = tokens.clone();
249+
let to_sample = args.sample_len.saturating_sub(1);
250+
251+
let mut logits_processor = {
252+
let temperature = args.temperature;
253+
let sampling = if temperature <= 0. {
254+
Sampling::ArgMax
255+
} else {
256+
match (args.top_k, args.top_p) {
257+
(None, None) => Sampling::All { temperature },
258+
(Some(k), None) => Sampling::TopK { k, temperature },
259+
(None, Some(p)) => Sampling::TopP { p, temperature },
260+
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
261+
}
262+
};
263+
LogitsProcessor::from_sampling(args.seed, sampling)
264+
};
265+
266+
println!("Starting the inference loop:");
267+
let prompt_str = args.prompt.as_deref().unwrap_or(DEFAULT_PROMPT);
268+
print!("{prompt_str}");
269+
std::io::stdout().flush()?;
270+
271+
let start_prompt_processing = std::time::Instant::now();
272+
let mut next_token = if !args.split_prompt {
273+
let input = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
274+
let logits = model.forward(&input, 0)?;
275+
let logits = logits.squeeze(0)?;
276+
logits_processor.sample(&logits)?
277+
} else {
278+
let mut next_token = 0;
279+
for (pos, token) in tokens.iter().enumerate() {
280+
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
281+
let logits = model.forward(&input, pos)?;
282+
let logits = logits.squeeze(0)?;
283+
next_token = logits_processor.sample(&logits)?
284+
}
285+
next_token
286+
};
287+
288+
let mut index_pos = tokens.len();
289+
let prompt_dt = start_prompt_processing.elapsed();
290+
291+
all_tokens.push(next_token);
292+
if let Some(t) = tos.next_token(next_token)? {
293+
print!("{t}");
294+
std::io::stdout().flush()?;
295+
}
296+
297+
let eos_token = guess_eos_id(tos.tokenizer());
298+
let mut sampled = 0;
299+
let start_post_prompt = std::time::Instant::now();
300+
for _ in 0..to_sample {
301+
if let Some(max_ctx) = context_length {
302+
if index_pos + 1 > max_ctx {
303+
println!("\n\ncontext window of {max_ctx} reached, stopping generation");
304+
break;
305+
}
306+
}
307+
308+
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
309+
let logits = model.forward(&input, index_pos)?;
310+
let logits = logits.squeeze(0)?;
311+
let logits = if args.repeat_penalty == 1. {
312+
logits
313+
} else {
314+
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
315+
candle_transformers::utils::apply_repeat_penalty(
316+
&logits,
317+
args.repeat_penalty,
318+
&all_tokens[start_at..],
319+
)?
320+
};
321+
next_token = logits_processor.sample(&logits)?;
322+
index_pos += 1;
323+
all_tokens.push(next_token);
324+
if let Some(t) = tos.next_token(next_token)? {
325+
print!("{t}");
326+
std::io::stdout().flush()?;
327+
}
328+
sampled += 1;
329+
if let Some(eos) = eos_token {
330+
if next_token == eos {
331+
break;
332+
}
333+
}
334+
}
335+
336+
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
337+
print!("{rest}");
338+
}
339+
std::io::stdout().flush()?;
340+
341+
let dt = start_post_prompt.elapsed();
342+
println!(
343+
"\n\n{:4} prompt tokens processed: {:.2} token/s",
344+
tokens.len(),
345+
tokens.len() as f64 / prompt_dt.as_secs_f64(),
346+
);
347+
println!(
348+
"{sampled:4} tokens generated: {:.2} token/s",
349+
sampled as f64 / dt.as_secs_f64(),
350+
);
351+
Ok(())
352+
}

candle-transformers/src/models/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ pub mod pixtral;
8383
pub mod quantized_blip;
8484
pub mod quantized_blip_text;
8585
pub mod quantized_gemma3;
86+
pub mod quantized_lfm2;
8687
pub mod quantized_llama;
8788
pub mod quantized_llama2_c;
8889
pub mod quantized_metavoice;

0 commit comments

Comments
 (0)