|
1 | 1 | use anyhow::{bail, Result}; |
2 | 2 | use cervo::asset::AssetData; |
3 | | -use cervo::core::prelude::{Batcher, Inferer, InfererExt, State}; |
4 | | -use cervo::core::recurrent::{RecurrentInfo, RecurrentTracker}; |
| 3 | +use cervo::core::epsilon::EpsilonInjectorWrapper; |
| 4 | +use cervo::core::prelude::{Batcher, Inferer, State}; |
| 5 | +use cervo::core::recurrent::{RecurrentInfo, RecurrentTrackerWrapper}; |
| 6 | +use cervo::core::wrapper::{BaseWrapper, InfererWrapper, InfererWrapperExt}; |
5 | 7 | use clap::Parser; |
6 | 8 | use clap::ValueEnum; |
7 | 9 | use serde::Serialize; |
@@ -167,7 +169,7 @@ struct Record { |
167 | 169 | total: f64, |
168 | 170 | } |
169 | 171 |
|
170 | | -fn execute_load_metrics<I: Inferer>( |
| 172 | +fn execute_load_metrics<I: Inferer + 'static>( |
171 | 173 | batch_size: usize, |
172 | 174 | data: HashMap<u64, State<'_>>, |
173 | 175 | count: usize, |
@@ -222,85 +224,112 @@ pub fn build_inputs_from_desc( |
222 | 224 | .collect() |
223 | 225 | } |
224 | 226 |
|
225 | | -fn do_run(mut inferer: impl Inferer, batch_size: usize, config: &Args) -> Result<Record> { |
226 | | - let shapes = inferer.input_shapes().to_vec(); |
227 | | - let observations = build_inputs_from_desc(batch_size as u64, &shapes); |
228 | | - for id in 0..batch_size { |
229 | | - inferer.begin_agent(id as u64); |
230 | | - } |
231 | | - let res = execute_load_metrics(batch_size, observations, config.count, &mut inferer)?; |
232 | | - for id in 0..batch_size { |
233 | | - inferer.end_agent(id as u64); |
| 227 | +fn do_run( |
| 228 | + wrapper: impl InfererWrapper + 'static, |
| 229 | + inferer: impl Inferer + 'static, |
| 230 | + config: &Args, |
| 231 | +) -> Result<Vec<Record>> { |
| 232 | + let mut model = wrapper.wrap(Box::new(inferer) as Box<dyn Inferer>); |
| 233 | + |
| 234 | + let mut records = Vec::with_capacity(config.batch_sizes.len()); |
| 235 | + for batch_size in config.batch_sizes.clone() { |
| 236 | + let mut reader = File::open(&config.file)?; |
| 237 | + let inferer = if cervo::nnef::is_nnef_tar(&config.file) { |
| 238 | + cervo::nnef::builder(&mut reader).build_fixed(&[batch_size])? |
| 239 | + } else { |
| 240 | + match config.file.extension().and_then(|ext| ext.to_str()) { |
| 241 | + Some("onnx") => cervo::onnx::builder(&mut reader).build_fixed(&[batch_size])?, |
| 242 | + Some("crvo") => AssetData::deserialize(&mut reader)?.load_fixed(&[batch_size])?, |
| 243 | + Some(other) => bail!("unknown file type {:?}", other), |
| 244 | + None => bail!("missing file extension {:?}", config.file), |
| 245 | + } |
| 246 | + }; |
| 247 | + |
| 248 | + model = model |
| 249 | + .with_new_inferer(Box::new(inferer) as Box<dyn Inferer>) |
| 250 | + .map_err(|(_, e)| e)?; |
| 251 | + |
| 252 | + let shapes = model.input_shapes().to_vec(); |
| 253 | + let observations = build_inputs_from_desc(batch_size as u64, &shapes); |
| 254 | + for id in 0..batch_size { |
| 255 | + model.begin_agent(id as u64); |
| 256 | + } |
| 257 | + let res = execute_load_metrics(batch_size, observations, config.count, &mut model)?; |
| 258 | + |
| 259 | + // Print Text |
| 260 | + if matches!(config.output, OutputFormat::Text) { |
| 261 | + println!( |
| 262 | + "Batch Size {}: {:.2} ms ± {:.2} per element, {:.2} ms total", |
| 263 | + res.batch_size, res.mean, res.stddev, res.total, |
| 264 | + ); |
| 265 | + } |
| 266 | + |
| 267 | + records.push(res); |
| 268 | + for id in 0..batch_size { |
| 269 | + model.end_agent(id as u64); |
| 270 | + } |
234 | 271 | } |
235 | 272 |
|
236 | | - Ok(res) |
| 273 | + Ok(records) |
237 | 274 | } |
238 | 275 |
|
239 | 276 | fn run_apply_epsilon_config( |
240 | | - inferer: impl Inferer, |
241 | | - batch_size: usize, |
| 277 | + wrapper: impl InfererWrapper + 'static, |
| 278 | + inferer: impl Inferer + 'static, |
242 | 279 | config: &Args, |
243 | | -) -> Result<Record> { |
| 280 | +) -> Result<Vec<Record>> { |
244 | 281 | if let Some(epsilon) = config.with_epsilon.as_ref() { |
245 | | - let inferer = inferer.with_default_epsilon(epsilon)?; |
246 | | - do_run(inferer, batch_size, config) |
| 282 | + let wrapper = EpsilonInjectorWrapper::wrap(wrapper, &inferer, epsilon)?; |
| 283 | + do_run(wrapper, inferer, config) |
247 | 284 | } else { |
248 | | - do_run(inferer, batch_size, config) |
| 285 | + do_run(wrapper, inferer, config) |
249 | 286 | } |
250 | 287 | } |
251 | 288 |
|
252 | | -fn run_apply_recurrent(inferer: impl Inferer, batch_size: usize, config: &Args) -> Result<Record> { |
| 289 | +fn run_apply_recurrent( |
| 290 | + wrapper: impl InfererWrapper + 'static, |
| 291 | + inferer: impl Inferer + 'static, |
| 292 | + config: &Args, |
| 293 | +) -> Result<Vec<Record>> { |
253 | 294 | if let Some(recurrent) = config.recurrent.as_ref() { |
254 | 295 | if matches!(recurrent, RecurrentConfig::None) { |
255 | | - run_apply_epsilon_config(inferer, batch_size, config) |
| 296 | + run_apply_epsilon_config(wrapper, inferer, config) |
256 | 297 | } else { |
257 | | - let inferer = match recurrent { |
| 298 | + let wrapper = match recurrent { |
258 | 299 | RecurrentConfig::None => unreachable!(), |
259 | | - RecurrentConfig::Auto => RecurrentTracker::wrap(inferer), |
| 300 | + RecurrentConfig::Auto => RecurrentTrackerWrapper::wrap(wrapper, &inferer), |
260 | 301 | RecurrentConfig::Mapped(map) => { |
261 | 302 | let infos = map |
262 | 303 | .iter() |
263 | 304 | .cloned() |
264 | 305 | .map(|(inkey, outkey)| RecurrentInfo { inkey, outkey }) |
265 | 306 | .collect::<Vec<_>>(); |
266 | | - RecurrentTracker::new(inferer, infos) |
| 307 | + RecurrentTrackerWrapper::new(wrapper, &inferer, infos) |
267 | 308 | } |
268 | 309 | }?; |
269 | 310 |
|
270 | | - run_apply_epsilon_config(inferer, batch_size, config) |
| 311 | + run_apply_epsilon_config(wrapper, inferer, config) |
271 | 312 | } |
272 | 313 | } else { |
273 | | - run_apply_epsilon_config(inferer, batch_size, config) |
| 314 | + run_apply_epsilon_config(wrapper, inferer, config) |
274 | 315 | } |
275 | 316 | } |
276 | 317 |
|
277 | 318 | pub(super) fn run(config: Args) -> Result<()> { |
278 | | - let mut records: Vec<Record> = Vec::new(); |
279 | | - for batch_size in config.batch_sizes.clone() { |
280 | | - let mut reader = File::open(&config.file)?; |
281 | | - let inferer = if cervo::nnef::is_nnef_tar(&config.file) { |
282 | | - cervo::nnef::builder(&mut reader).build_fixed(&[batch_size])? |
283 | | - } else { |
284 | | - match config.file.extension().and_then(|ext| ext.to_str()) { |
285 | | - Some("onnx") => cervo::onnx::builder(&mut reader).build_fixed(&[batch_size])?, |
286 | | - Some("crvo") => AssetData::deserialize(&mut reader)?.load_fixed(&[batch_size])?, |
287 | | - Some(other) => bail!("unknown file type {:?}", other), |
288 | | - None => bail!("missing file extension {:?}", config.file), |
289 | | - } |
290 | | - }; |
291 | | - |
292 | | - let record = run_apply_recurrent(inferer, batch_size, &config)?; |
293 | | - |
294 | | - // Print Text |
295 | | - if matches!(config.output, OutputFormat::Text) { |
296 | | - println!( |
297 | | - "Batch Size {}: {:.2} ms ± {:.2} per element, {:.2} ms total", |
298 | | - record.batch_size, record.mean, record.stddev, record.total, |
299 | | - ); |
| 319 | + let mut reader = File::open(&config.file)?; |
| 320 | + let inferer = if cervo::nnef::is_nnef_tar(&config.file) { |
| 321 | + cervo::nnef::builder(&mut reader).build_basic()? |
| 322 | + } else { |
| 323 | + match config.file.extension().and_then(|ext| ext.to_str()) { |
| 324 | + Some("onnx") => cervo::onnx::builder(&mut reader).build_basic()?, |
| 325 | + Some("crvo") => AssetData::deserialize(&mut reader)?.load_basic()?, |
| 326 | + Some(other) => bail!("unknown file type {:?}", other), |
| 327 | + None => bail!("missing file extension {:?}", config.file), |
300 | 328 | } |
| 329 | + }; |
| 330 | + |
| 331 | + let records = run_apply_recurrent(BaseWrapper, inferer, &config)?; |
301 | 332 |
|
302 | | - records.push(record); |
303 | | - } |
304 | 333 | // Print JSON |
305 | 334 | if matches!(config.output, OutputFormat::Json) { |
306 | 335 | let json = serde_json::to_string_pretty(&records)?; |
|
0 commit comments