Skip to content

Commit 8021db3

Browse files
authored
Add Model + ModelWrapper abstraction to split wrapper state from policy (#64)
WIP
1 parent 298c685 commit 8021db3

File tree

18 files changed

+718
-190
lines changed

18 files changed

+718
-190
lines changed

CHANGELOG.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
<!-- next-header -->
1111
## [Unreleased] - ReleaseDate
12+
13+
- Breaking: `Inferer.begin_agent` and `Inferer.end_agent` now take
14+
`&self`, changed from a mutable reference.
15+
16+
### Wrapper rework
17+
18+
To support a wider variety of uses, we have implemented a new category
19+
of wrappers that do not require ownership of the inferer. This allows
20+
for more flexible usage patterns, where the inferer policy can be
21+
replaced in a live application without losing any state kept in
22+
wrappers.
23+
24+
This change is currently non-breaking and is implemented separately
25+
from the old wrapper system.
26+
1227
## [0.8.0] - 2025-05-28
1328
- Added a new `RecurrentTracker` wrapper to handle recurrent
1429
inputs/outputs if the recurrent data is only needed durign network

Cargo.lock

Lines changed: 10 additions & 17 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/cervo-cli/src/commands/benchmark.rs

Lines changed: 79 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use anyhow::{bail, Result};
22
use cervo::asset::AssetData;
3-
use cervo::core::prelude::{Batcher, Inferer, InfererExt, State};
4-
use cervo::core::recurrent::{RecurrentInfo, RecurrentTracker};
3+
use cervo::core::epsilon::EpsilonInjectorWrapper;
4+
use cervo::core::prelude::{Batcher, Inferer, State};
5+
use cervo::core::recurrent::{RecurrentInfo, RecurrentTrackerWrapper};
6+
use cervo::core::wrapper::{BaseWrapper, InfererWrapper, InfererWrapperExt};
57
use clap::Parser;
68
use clap::ValueEnum;
79
use serde::Serialize;
@@ -167,7 +169,7 @@ struct Record {
167169
total: f64,
168170
}
169171

170-
fn execute_load_metrics<I: Inferer>(
172+
fn execute_load_metrics<I: Inferer + 'static>(
171173
batch_size: usize,
172174
data: HashMap<u64, State<'_>>,
173175
count: usize,
@@ -222,85 +224,112 @@ pub fn build_inputs_from_desc(
222224
.collect()
223225
}
224226

225-
fn do_run(mut inferer: impl Inferer, batch_size: usize, config: &Args) -> Result<Record> {
226-
let shapes = inferer.input_shapes().to_vec();
227-
let observations = build_inputs_from_desc(batch_size as u64, &shapes);
228-
for id in 0..batch_size {
229-
inferer.begin_agent(id as u64);
230-
}
231-
let res = execute_load_metrics(batch_size, observations, config.count, &mut inferer)?;
232-
for id in 0..batch_size {
233-
inferer.end_agent(id as u64);
227+
fn do_run(
228+
wrapper: impl InfererWrapper + 'static,
229+
inferer: impl Inferer + 'static,
230+
config: &Args,
231+
) -> Result<Vec<Record>> {
232+
let mut model = wrapper.wrap(Box::new(inferer) as Box<dyn Inferer>);
233+
234+
let mut records = Vec::with_capacity(config.batch_sizes.len());
235+
for batch_size in config.batch_sizes.clone() {
236+
let mut reader = File::open(&config.file)?;
237+
let inferer = if cervo::nnef::is_nnef_tar(&config.file) {
238+
cervo::nnef::builder(&mut reader).build_fixed(&[batch_size])?
239+
} else {
240+
match config.file.extension().and_then(|ext| ext.to_str()) {
241+
Some("onnx") => cervo::onnx::builder(&mut reader).build_fixed(&[batch_size])?,
242+
Some("crvo") => AssetData::deserialize(&mut reader)?.load_fixed(&[batch_size])?,
243+
Some(other) => bail!("unknown file type {:?}", other),
244+
None => bail!("missing file extension {:?}", config.file),
245+
}
246+
};
247+
248+
model = model
249+
.with_new_inferer(Box::new(inferer) as Box<dyn Inferer>)
250+
.map_err(|(_, e)| e)?;
251+
252+
let shapes = model.input_shapes().to_vec();
253+
let observations = build_inputs_from_desc(batch_size as u64, &shapes);
254+
for id in 0..batch_size {
255+
model.begin_agent(id as u64);
256+
}
257+
let res = execute_load_metrics(batch_size, observations, config.count, &mut model)?;
258+
259+
// Print Text
260+
if matches!(config.output, OutputFormat::Text) {
261+
println!(
262+
"Batch Size {}: {:.2} ms ± {:.2} per element, {:.2} ms total",
263+
res.batch_size, res.mean, res.stddev, res.total,
264+
);
265+
}
266+
267+
records.push(res);
268+
for id in 0..batch_size {
269+
model.end_agent(id as u64);
270+
}
234271
}
235272

236-
Ok(res)
273+
Ok(records)
237274
}
238275

239276
fn run_apply_epsilon_config(
240-
inferer: impl Inferer,
241-
batch_size: usize,
277+
wrapper: impl InfererWrapper + 'static,
278+
inferer: impl Inferer + 'static,
242279
config: &Args,
243-
) -> Result<Record> {
280+
) -> Result<Vec<Record>> {
244281
if let Some(epsilon) = config.with_epsilon.as_ref() {
245-
let inferer = inferer.with_default_epsilon(epsilon)?;
246-
do_run(inferer, batch_size, config)
282+
let wrapper = EpsilonInjectorWrapper::wrap(wrapper, &inferer, epsilon)?;
283+
do_run(wrapper, inferer, config)
247284
} else {
248-
do_run(inferer, batch_size, config)
285+
do_run(wrapper, inferer, config)
249286
}
250287
}
251288

252-
fn run_apply_recurrent(inferer: impl Inferer, batch_size: usize, config: &Args) -> Result<Record> {
289+
fn run_apply_recurrent(
290+
wrapper: impl InfererWrapper + 'static,
291+
inferer: impl Inferer + 'static,
292+
config: &Args,
293+
) -> Result<Vec<Record>> {
253294
if let Some(recurrent) = config.recurrent.as_ref() {
254295
if matches!(recurrent, RecurrentConfig::None) {
255-
run_apply_epsilon_config(inferer, batch_size, config)
296+
run_apply_epsilon_config(wrapper, inferer, config)
256297
} else {
257-
let inferer = match recurrent {
298+
let wrapper = match recurrent {
258299
RecurrentConfig::None => unreachable!(),
259-
RecurrentConfig::Auto => RecurrentTracker::wrap(inferer),
300+
RecurrentConfig::Auto => RecurrentTrackerWrapper::wrap(wrapper, &inferer),
260301
RecurrentConfig::Mapped(map) => {
261302
let infos = map
262303
.iter()
263304
.cloned()
264305
.map(|(inkey, outkey)| RecurrentInfo { inkey, outkey })
265306
.collect::<Vec<_>>();
266-
RecurrentTracker::new(inferer, infos)
307+
RecurrentTrackerWrapper::new(wrapper, &inferer, infos)
267308
}
268309
}?;
269310

270-
run_apply_epsilon_config(inferer, batch_size, config)
311+
run_apply_epsilon_config(wrapper, inferer, config)
271312
}
272313
} else {
273-
run_apply_epsilon_config(inferer, batch_size, config)
314+
run_apply_epsilon_config(wrapper, inferer, config)
274315
}
275316
}
276317

277318
pub(super) fn run(config: Args) -> Result<()> {
278-
let mut records: Vec<Record> = Vec::new();
279-
for batch_size in config.batch_sizes.clone() {
280-
let mut reader = File::open(&config.file)?;
281-
let inferer = if cervo::nnef::is_nnef_tar(&config.file) {
282-
cervo::nnef::builder(&mut reader).build_fixed(&[batch_size])?
283-
} else {
284-
match config.file.extension().and_then(|ext| ext.to_str()) {
285-
Some("onnx") => cervo::onnx::builder(&mut reader).build_fixed(&[batch_size])?,
286-
Some("crvo") => AssetData::deserialize(&mut reader)?.load_fixed(&[batch_size])?,
287-
Some(other) => bail!("unknown file type {:?}", other),
288-
None => bail!("missing file extension {:?}", config.file),
289-
}
290-
};
291-
292-
let record = run_apply_recurrent(inferer, batch_size, &config)?;
293-
294-
// Print Text
295-
if matches!(config.output, OutputFormat::Text) {
296-
println!(
297-
"Batch Size {}: {:.2} ms ± {:.2} per element, {:.2} ms total",
298-
record.batch_size, record.mean, record.stddev, record.total,
299-
);
319+
let mut reader = File::open(&config.file)?;
320+
let inferer = if cervo::nnef::is_nnef_tar(&config.file) {
321+
cervo::nnef::builder(&mut reader).build_basic()?
322+
} else {
323+
match config.file.extension().and_then(|ext| ext.to_str()) {
324+
Some("onnx") => cervo::onnx::builder(&mut reader).build_basic()?,
325+
Some("crvo") => AssetData::deserialize(&mut reader)?.load_basic()?,
326+
Some(other) => bail!("unknown file type {:?}", other),
327+
None => bail!("missing file extension {:?}", config.file),
300328
}
329+
};
330+
331+
let records = run_apply_recurrent(BaseWrapper, inferer, &config)?;
301332

302-
records.push(record);
303-
}
304333
// Print JSON
305334
if matches!(config.output, OutputFormat::Json) {
306335
let json = serde_json::to_string_pretty(&records)?;

crates/cervo-cli/src/commands/run.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,7 @@ pub(super) fn run(config: Args) -> Result<()> {
8484

8585
let elapsed = if let Some(epsilon) = config.with_epsilon.as_ref() {
8686
let inferer = inferer.with_default_epsilon(epsilon)?;
87-
// TODO[TSolberg]: Issue #31.
88-
let shapes = inferer
89-
.raw_input_shapes()
90-
.iter()
91-
.filter(|(k, _)| k.as_str() != epsilon)
92-
.cloned()
93-
.collect::<Vec<_>>();
94-
95-
let observations = build_inputs_from_desc(config.batch_size as u64, &shapes);
87+
let observations = build_inputs_from_desc(config.batch_size as u64, inferer.input_shapes());
9688

9789
if config.print_input {
9890
print_input(&observations);

0 commit comments

Comments
 (0)