|
| 1 | +#[cfg(feature = "mkl")] |
| 2 | +extern crate intel_mkl_src; |
| 3 | + |
| 4 | +#[cfg(feature = "accelerate")] |
| 5 | +extern crate accelerate_src; |
| 6 | + |
| 7 | +use anyhow::Result; |
| 8 | +use clap::{Parser, ValueEnum}; |
| 9 | +use std::io::Write; |
| 10 | +use std::path::{Path, PathBuf}; |
| 11 | +use tokenizers::Tokenizer; |
| 12 | + |
| 13 | +use candle::quantized::gguf_file; |
| 14 | +use candle::Tensor; |
| 15 | +use candle_transformers::generation::{LogitsProcessor, Sampling}; |
| 16 | + |
| 17 | +use candle_examples::token_output_stream::TokenOutputStream; |
| 18 | +use candle_transformers::models::quantized_lfm2::ModelWeights; |
| 19 | + |
| 20 | +const DEFAULT_PROMPT: &str = "Explain how Rotary Position Embeddings work in transformers."; |
| 21 | + |
| 22 | +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] |
| 23 | +enum Which { |
| 24 | + /// 350M base model, Q4_K_M quantization. |
| 25 | + #[value(name = "lfm2-350m-q4_k_m")] |
| 26 | + Lfm2_350MQ4KM, |
| 27 | + /// 350M base model, Q8_0 quantization. |
| 28 | + #[value(name = "lfm2-350m-q8_0")] |
| 29 | + Lfm2_350MQ8_0, |
| 30 | + /// 2.6B model, Q4_K_M quantization. |
| 31 | + #[value(name = "lfm2-2.6b-q4_k_m")] |
| 32 | + Lfm2_2_6BQ4KM, |
| 33 | + /// 2.6B model, Q8_0 quantization. |
| 34 | + #[value(name = "lfm2-2.6b-q8_0")] |
| 35 | + Lfm2_2_6BQ8_0, |
| 36 | +} |
| 37 | + |
| 38 | +#[derive(Parser, Debug)] |
| 39 | +#[command(author, version, about, long_about = None)] |
| 40 | +struct Args { |
| 41 | + /// GGUF file to load, typically a .gguf file generated by llama.cpp. |
| 42 | + #[arg(long)] |
| 43 | + model: Option<String>, |
| 44 | + |
| 45 | + /// Hugging Face repo id (eg `user/model`) to download the weights from when --model is not set. |
| 46 | + #[arg(long, default_value = "lfm2-2.6b-q4_k_m")] |
| 47 | + which: Which, |
| 48 | + |
| 49 | + /// Repo revision to download from when using --which. |
| 50 | + #[arg(long, default_value = "main")] |
| 51 | + revision: String, |
| 52 | + |
| 53 | + /// Path to tokenizer.json. Defaults to the same folder as the model or is fetched from Hugging Face. |
| 54 | + #[arg(long)] |
| 55 | + tokenizer: Option<String>, |
| 56 | + |
| 57 | + /// The initial prompt to feed to the model. |
| 58 | + #[arg(long)] |
| 59 | + prompt: Option<String>, |
| 60 | + |
| 61 | + /// The number of tokens to sample (including the first token after the prompt). |
| 62 | + #[arg(short = 'n', long, default_value_t = 512)] |
| 63 | + sample_len: usize, |
| 64 | + |
| 65 | + /// The temperature used to generate samples, use 0 for greedy sampling. |
| 66 | + #[arg(long, default_value_t = 0.8)] |
| 67 | + temperature: f64, |
| 68 | + |
| 69 | + /// Nucleus sampling probability cutoff. |
| 70 | + #[arg(long)] |
| 71 | + top_p: Option<f64>, |
| 72 | + |
| 73 | + /// Only sample among the top K samples. |
| 74 | + #[arg(long)] |
| 75 | + top_k: Option<usize>, |
| 76 | + |
| 77 | + /// The seed to use when generating random samples. |
| 78 | + #[arg(long, default_value_t = 299792458)] |
| 79 | + seed: u64, |
| 80 | + |
| 81 | + /// Enable tracing (generates a trace-timestamp.json file). |
| 82 | + #[arg(long)] |
| 83 | + tracing: bool, |
| 84 | + |
| 85 | + /// Process prompt elements separately. |
| 86 | + #[arg(long)] |
| 87 | + split_prompt: bool, |
| 88 | + |
| 89 | + /// Run on CPU rather than GPU even if a GPU is available. |
| 90 | + #[arg(long)] |
| 91 | + cpu: bool, |
| 92 | + |
| 93 | + /// Penalty to be applied for repeating tokens, 1. means no penalty. |
| 94 | + #[arg(long, default_value_t = 1.1)] |
| 95 | + repeat_penalty: f32, |
| 96 | + |
| 97 | + /// The context size to consider for the repeat penalty. |
| 98 | + #[arg(long, default_value_t = 64)] |
| 99 | + repeat_last_n: usize, |
| 100 | +} |
| 101 | + |
| 102 | +impl Args { |
| 103 | + fn model_path(&self) -> Result<PathBuf> { |
| 104 | + if let Some(model) = &self.model { |
| 105 | + return Ok(PathBuf::from(model)); |
| 106 | + } |
| 107 | + let (repo, filename) = match self.which { |
| 108 | + Which::Lfm2_350MQ4KM => ("LiquidAI/LFM2-350M-GGUF", "LFM2-350M-Q4_K_M.gguf"), |
| 109 | + Which::Lfm2_350MQ8_0 => ("LiquidAI/LFM2-350M-GGUF", "LFM2-350M-Q8_0.gguf"), |
| 110 | + Which::Lfm2_2_6BQ4KM => ("LiquidAI/LFM2-2.6B-GGUF", "LFM2-2.6B-Q4_K_M.gguf"), |
| 111 | + Which::Lfm2_2_6BQ8_0 => ("LiquidAI/LFM2-2.6B-GGUF", "LFM2-2.6B-Q8_0.gguf"), |
| 112 | + }; |
| 113 | + let api = hf_hub::api::sync::Api::new()?; |
| 114 | + api.repo(hf_hub::Repo::with_revision( |
| 115 | + repo.to_string(), |
| 116 | + hf_hub::RepoType::Model, |
| 117 | + self.revision.clone(), |
| 118 | + )) |
| 119 | + .get(filename) |
| 120 | + .map_err(Into::into) |
| 121 | + } |
| 122 | + |
| 123 | + fn tokenizer(&self, model_path: &Path) -> Result<Tokenizer> { |
| 124 | + if let Some(path) = &self.tokenizer { |
| 125 | + return Tokenizer::from_file(path).map_err(anyhow::Error::msg); |
| 126 | + } |
| 127 | + |
| 128 | + if let Some(dir) = model_path.parent() { |
| 129 | + let candidate = dir.join("tokenizer.json"); |
| 130 | + if candidate.exists() { |
| 131 | + return Tokenizer::from_file(candidate).map_err(anyhow::Error::msg); |
| 132 | + } |
| 133 | + } |
| 134 | + |
| 135 | + let tokenizer_repo = match self.which { |
| 136 | + Which::Lfm2_350MQ4KM | Which::Lfm2_350MQ8_0 => "LiquidAI/LFM2-350M", |
| 137 | + Which::Lfm2_2_6BQ4KM | Which::Lfm2_2_6BQ8_0 => "LiquidAI/LFM2-2.6B", |
| 138 | + }; |
| 139 | + let api = hf_hub::api::sync::Api::new()?; |
| 140 | + let tokenizer_path = api |
| 141 | + .repo(hf_hub::Repo::with_revision( |
| 142 | + tokenizer_repo.to_string(), |
| 143 | + hf_hub::RepoType::Model, |
| 144 | + self.revision.clone(), |
| 145 | + )) |
| 146 | + .get("tokenizer.json")?; |
| 147 | + Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg) |
| 148 | + } |
| 149 | +} |
| 150 | + |
| 151 | +fn format_size(size_in_bytes: usize) -> String { |
| 152 | + if size_in_bytes < 1_000 { |
| 153 | + format!("{size_in_bytes}B") |
| 154 | + } else if size_in_bytes < 1_000_000 { |
| 155 | + format!("{:.2}KB", size_in_bytes as f64 / 1e3) |
| 156 | + } else if size_in_bytes < 1_000_000_000 { |
| 157 | + format!("{:.2}MB", size_in_bytes as f64 / 1e6) |
| 158 | + } else { |
| 159 | + format!("{:.2}GB", size_in_bytes as f64 / 1e9) |
| 160 | + } |
| 161 | +} |
| 162 | + |
| 163 | +fn guess_eos_id(tokenizer: &Tokenizer) -> Option<u32> { |
| 164 | + let vocab = tokenizer.get_vocab(true); |
| 165 | + let candidates = [ |
| 166 | + "</s>", |
| 167 | + "<|im_end|>", |
| 168 | + "<|eot_id|>", |
| 169 | + "<|end|>", |
| 170 | + "<|end_of_text|>", |
| 171 | + "<|endoftext|>", |
| 172 | + ]; |
| 173 | + candidates |
| 174 | + .iter() |
| 175 | + .find_map(|token| vocab.get(*token).copied()) |
| 176 | +} |
| 177 | + |
| 178 | +fn main() -> Result<()> { |
| 179 | + use tracing_chrome::ChromeLayerBuilder; |
| 180 | + use tracing_subscriber::prelude::*; |
| 181 | + |
| 182 | + let args = Args::parse(); |
| 183 | + let _guard = if args.tracing { |
| 184 | + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); |
| 185 | + tracing_subscriber::registry().with(chrome_layer).init(); |
| 186 | + Some(guard) |
| 187 | + } else { |
| 188 | + None |
| 189 | + }; |
| 190 | + |
| 191 | + println!( |
| 192 | + "avx: {}, neon: {}, simd128: {}, f16c: {}", |
| 193 | + candle::utils::with_avx(), |
| 194 | + candle::utils::with_neon(), |
| 195 | + candle::utils::with_simd128(), |
| 196 | + candle::utils::with_f16c() |
| 197 | + ); |
| 198 | + println!( |
| 199 | + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", |
| 200 | + args.temperature, args.repeat_penalty, args.repeat_last_n |
| 201 | + ); |
| 202 | + |
| 203 | + let model_path = args.model_path()?; |
| 204 | + let mut file = std::fs::File::open(&model_path)?; |
| 205 | + let start = std::time::Instant::now(); |
| 206 | + let device = candle_examples::device(args.cpu)?; |
| 207 | + |
| 208 | + let gguf = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path.clone()))?; |
| 209 | + let mut total_size_in_bytes = 0; |
| 210 | + for (_, tensor) in gguf.tensor_infos.iter() { |
| 211 | + let elem_count = tensor.shape.elem_count(); |
| 212 | + total_size_in_bytes += |
| 213 | + elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size(); |
| 214 | + } |
| 215 | + |
| 216 | + let context_length = gguf |
| 217 | + .metadata |
| 218 | + .get("lfm2.context_length") |
| 219 | + .and_then(|v| v.to_u32().ok().map(|v| v as usize)); |
| 220 | + |
| 221 | + println!( |
| 222 | + "loaded {:?} tensors ({}) in {:.2}s", |
| 223 | + gguf.tensor_infos.len(), |
| 224 | + format_size(total_size_in_bytes), |
| 225 | + start.elapsed().as_secs_f32() |
| 226 | + ); |
| 227 | + |
| 228 | + let mut model = ModelWeights::from_gguf(gguf, &mut file, &device)?; |
| 229 | + println!("model ready"); |
| 230 | + |
| 231 | + let tokenizer = args.tokenizer(&model_path)?; |
| 232 | + let mut tos = TokenOutputStream::new(tokenizer); |
| 233 | + let mut tokens = tos |
| 234 | + .tokenizer() |
| 235 | + .encode(args.prompt.as_deref().unwrap_or(DEFAULT_PROMPT), true) |
| 236 | + .map_err(anyhow::Error::msg)? |
| 237 | + .get_ids() |
| 238 | + .to_vec(); |
| 239 | + |
| 240 | + if let Some(max_ctx) = context_length { |
| 241 | + if tokens.len() >= max_ctx { |
| 242 | + let trim = tokens.len() - max_ctx + 1; |
| 243 | + tokens.drain(0..trim); |
| 244 | + println!("prompt trimmed to last {max_ctx} tokens to fit context"); |
| 245 | + } |
| 246 | + } |
| 247 | + |
| 248 | + let mut all_tokens = tokens.clone(); |
| 249 | + let to_sample = args.sample_len.saturating_sub(1); |
| 250 | + |
| 251 | + let mut logits_processor = { |
| 252 | + let temperature = args.temperature; |
| 253 | + let sampling = if temperature <= 0. { |
| 254 | + Sampling::ArgMax |
| 255 | + } else { |
| 256 | + match (args.top_k, args.top_p) { |
| 257 | + (None, None) => Sampling::All { temperature }, |
| 258 | + (Some(k), None) => Sampling::TopK { k, temperature }, |
| 259 | + (None, Some(p)) => Sampling::TopP { p, temperature }, |
| 260 | + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, |
| 261 | + } |
| 262 | + }; |
| 263 | + LogitsProcessor::from_sampling(args.seed, sampling) |
| 264 | + }; |
| 265 | + |
| 266 | + println!("Starting the inference loop:"); |
| 267 | + let prompt_str = args.prompt.as_deref().unwrap_or(DEFAULT_PROMPT); |
| 268 | + print!("{prompt_str}"); |
| 269 | + std::io::stdout().flush()?; |
| 270 | + |
| 271 | + let start_prompt_processing = std::time::Instant::now(); |
| 272 | + let mut next_token = if !args.split_prompt { |
| 273 | + let input = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; |
| 274 | + let logits = model.forward(&input, 0)?; |
| 275 | + let logits = logits.squeeze(0)?; |
| 276 | + logits_processor.sample(&logits)? |
| 277 | + } else { |
| 278 | + let mut next_token = 0; |
| 279 | + for (pos, token) in tokens.iter().enumerate() { |
| 280 | + let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?; |
| 281 | + let logits = model.forward(&input, pos)?; |
| 282 | + let logits = logits.squeeze(0)?; |
| 283 | + next_token = logits_processor.sample(&logits)? |
| 284 | + } |
| 285 | + next_token |
| 286 | + }; |
| 287 | + |
| 288 | + let mut index_pos = tokens.len(); |
| 289 | + let prompt_dt = start_prompt_processing.elapsed(); |
| 290 | + |
| 291 | + all_tokens.push(next_token); |
| 292 | + if let Some(t) = tos.next_token(next_token)? { |
| 293 | + print!("{t}"); |
| 294 | + std::io::stdout().flush()?; |
| 295 | + } |
| 296 | + |
| 297 | + let eos_token = guess_eos_id(tos.tokenizer()); |
| 298 | + let mut sampled = 0; |
| 299 | + let start_post_prompt = std::time::Instant::now(); |
| 300 | + for _ in 0..to_sample { |
| 301 | + if let Some(max_ctx) = context_length { |
| 302 | + if index_pos + 1 > max_ctx { |
| 303 | + println!("\n\ncontext window of {max_ctx} reached, stopping generation"); |
| 304 | + break; |
| 305 | + } |
| 306 | + } |
| 307 | + |
| 308 | + let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; |
| 309 | + let logits = model.forward(&input, index_pos)?; |
| 310 | + let logits = logits.squeeze(0)?; |
| 311 | + let logits = if args.repeat_penalty == 1. { |
| 312 | + logits |
| 313 | + } else { |
| 314 | + let start_at = all_tokens.len().saturating_sub(args.repeat_last_n); |
| 315 | + candle_transformers::utils::apply_repeat_penalty( |
| 316 | + &logits, |
| 317 | + args.repeat_penalty, |
| 318 | + &all_tokens[start_at..], |
| 319 | + )? |
| 320 | + }; |
| 321 | + next_token = logits_processor.sample(&logits)?; |
| 322 | + index_pos += 1; |
| 323 | + all_tokens.push(next_token); |
| 324 | + if let Some(t) = tos.next_token(next_token)? { |
| 325 | + print!("{t}"); |
| 326 | + std::io::stdout().flush()?; |
| 327 | + } |
| 328 | + sampled += 1; |
| 329 | + if let Some(eos) = eos_token { |
| 330 | + if next_token == eos { |
| 331 | + break; |
| 332 | + } |
| 333 | + } |
| 334 | + } |
| 335 | + |
| 336 | + if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? { |
| 337 | + print!("{rest}"); |
| 338 | + } |
| 339 | + std::io::stdout().flush()?; |
| 340 | + |
| 341 | + let dt = start_post_prompt.elapsed(); |
| 342 | + println!( |
| 343 | + "\n\n{:4} prompt tokens processed: {:.2} token/s", |
| 344 | + tokens.len(), |
| 345 | + tokens.len() as f64 / prompt_dt.as_secs_f64(), |
| 346 | + ); |
| 347 | + println!( |
| 348 | + "{sampled:4} tokens generated: {:.2} token/s", |
| 349 | + sampled as f64 / dt.as_secs_f64(), |
| 350 | + ); |
| 351 | + Ok(()) |
| 352 | +} |
0 commit comments