Skip to content

Commit 5e2af3e

Browse files
committed
docs,changelog, cleanup
1 parent 3070ac0 commit 5e2af3e

File tree

14 files changed

+184
-83
lines changed

14 files changed

+184
-83
lines changed

CHANGELOG.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
<!-- next-header -->
1111
## [Unreleased] - ReleaseDate
12+
13+
- Breaking: `Inferer.begin_agent` and `Inferer.end_agent` now take
14+
`&self`, changed from a mutable reference.
15+
16+
### Wrapper rework
17+
18+
To support a wider variety of uses, we have implemented a new category
19+
of wrappers that do not require ownership of the inferer. This allows
20+
for more flexible usage patterns, where the inferer policy can be
21+
replaced in a live application without losing any state kept in
22+
wrappers.
23+
24+
This change is currently non-breaking and is implemented separately
25+
from the old wrapper system.
26+
1227
## [0.8.0] - 2025-05-28
1328
- Added a new `RecurrentTracker` wrapper to handle recurrent
1429
inputs/outputs if the recurrent data is only needed durign network

crates/cervo-cli/src/commands/benchmark.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
use anyhow::{bail, Result};
22
use cervo::asset::AssetData;
33
use cervo::core::epsilon::EpsilonInjectorWrapper;
4-
use cervo::core::model::{BaseCase, Model, ModelWrapper};
54
use cervo::core::prelude::{Batcher, Inferer, State};
65
use cervo::core::recurrent::{RecurrentInfo, RecurrentTrackerWrapper};
6+
use cervo::core::wrapper::{BaseWrapper, InfererWrapper, InfererWrapperExt};
77
use clap::Parser;
88
use clap::ValueEnum;
99
use serde::Serialize;
@@ -169,7 +169,7 @@ struct Record {
169169
total: f64,
170170
}
171171

172-
fn execute_load_metrics<I: Inferer>(
172+
fn execute_load_metrics<I: Inferer + 'static>(
173173
batch_size: usize,
174174
data: HashMap<u64, State<'_>>,
175175
count: usize,
@@ -225,11 +225,11 @@ pub fn build_inputs_from_desc(
225225
}
226226

227227
fn do_run(
228-
wrapper: impl ModelWrapper,
228+
wrapper: impl InfererWrapper + 'static,
229229
inferer: impl Inferer + 'static,
230230
config: &Args,
231231
) -> Result<Vec<Record>> {
232-
let mut model = Model::new(wrapper, Box::new(inferer) as Box<dyn Inferer>);
232+
let mut model = wrapper.wrap(Box::new(inferer) as Box<dyn Inferer>);
233233

234234
let mut records = Vec::with_capacity(config.batch_sizes.len());
235235
for batch_size in config.batch_sizes.clone() {
@@ -246,7 +246,7 @@ fn do_run(
246246
};
247247

248248
model = model
249-
.with_new_policy(Box::new(inferer) as Box<dyn Inferer>)
249+
.with_new_inferer(Box::new(inferer) as Box<dyn Inferer>)
250250
.map_err(|(_, e)| e)?;
251251

252252
let shapes = model.input_shapes().to_vec();
@@ -274,7 +274,7 @@ fn do_run(
274274
}
275275

276276
fn run_apply_epsilon_config(
277-
wrapper: impl ModelWrapper,
277+
wrapper: impl InfererWrapper + 'static,
278278
inferer: impl Inferer + 'static,
279279
config: &Args,
280280
) -> Result<Vec<Record>> {
@@ -287,7 +287,7 @@ fn run_apply_epsilon_config(
287287
}
288288

289289
fn run_apply_recurrent(
290-
wrapper: impl ModelWrapper,
290+
wrapper: impl InfererWrapper + 'static,
291291
inferer: impl Inferer + 'static,
292292
config: &Args,
293293
) -> Result<Vec<Record>> {
@@ -328,7 +328,7 @@ pub(super) fn run(config: Args) -> Result<()> {
328328
}
329329
};
330330

331-
let records = run_apply_recurrent(BaseCase, inferer, &config)?;
331+
let records = run_apply_recurrent(BaseWrapper, inferer, &config)?;
332332

333333
// Print JSON
334334
if matches!(config.output, OutputFormat::Json) {

crates/cervo-core/src/epsilon.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Utilities for filling noise inputs for an inference model.
88

99
use std::cell::RefCell;
1010

11-
use crate::{batcher::ScratchPadView, inferer::Inferer, prelude::ModelWrapper};
11+
use crate::{batcher::ScratchPadView, inferer::Inferer, prelude::InfererWrapper};
1212
use anyhow::{bail, Result};
1313
use perchance::PerchanceContext;
1414
use rand::thread_rng;
@@ -212,21 +212,21 @@ where
212212
self.inner.raw_output_shapes()
213213
}
214214

215-
fn begin_agent(&mut self, id: u64) {
215+
fn begin_agent(&self, id: u64) {
216216
self.inner.begin_agent(id);
217217
}
218218

219-
fn end_agent(&mut self, id: u64) {
219+
fn end_agent(&self, id: u64) {
220220
self.inner.end_agent(id);
221221
}
222222
}
223223

224-
pub struct EpsilonInjectorWrapper<Inner: ModelWrapper, NG: NoiseGenerator> {
224+
pub struct EpsilonInjectorWrapper<Inner: InfererWrapper, NG: NoiseGenerator> {
225225
inner: Inner,
226226
state: EpsilonInjectorState<NG>,
227227
}
228228

229-
impl<Inner: ModelWrapper> EpsilonInjectorWrapper<Inner, HighQualityNoiseGenerator> {
229+
impl<Inner: InfererWrapper> EpsilonInjectorWrapper<Inner, HighQualityNoiseGenerator> {
230230
/// Wraps the provided `inferer` to automatically generate noise for the input named by `key`.
231231
///
232232
/// This function will use [`HighQualityNoiseGenerator`] as the noise source.
@@ -245,7 +245,7 @@ impl<Inner: ModelWrapper> EpsilonInjectorWrapper<Inner, HighQualityNoiseGenerato
245245

246246
impl<Inner, NG> EpsilonInjectorWrapper<Inner, NG>
247247
where
248-
Inner: ModelWrapper,
248+
Inner: InfererWrapper,
249249
NG: NoiseGenerator,
250250
{
251251
/// Create a new injector for the provided `key`, using the custom `generator` as the noise source.
@@ -284,12 +284,12 @@ where
284284
}
285285
}
286286

287-
impl<Inner, NG> ModelWrapper for EpsilonInjectorWrapper<Inner, NG>
287+
impl<Inner, NG> InfererWrapper for EpsilonInjectorWrapper<Inner, NG>
288288
where
289-
Inner: ModelWrapper,
289+
Inner: InfererWrapper,
290290
NG: NoiseGenerator,
291291
{
292-
fn invoke(&self, inferer: &impl Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()> {
292+
fn invoke(&self, inferer: &dyn Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()> {
293293
self.inner.invoke(inferer, batch)?;
294294
let total_count = self.state.count * batch.len();
295295
let output = batch.input_slot_mut(self.state.index);
@@ -306,11 +306,11 @@ where
306306
self.inner.output_shapes(inferer)
307307
}
308308

309-
fn begin_agent(&self, id: u64) {
310-
self.inner.begin_agent(id);
309+
fn begin_agent(&self, inferer: &dyn Inferer, id: u64) {
310+
self.inner.begin_agent(inferer, id);
311311
}
312312

313-
fn end_agent(&self, id: u64) {
314-
self.inner.end_agent(id);
313+
fn end_agent(&self, inferer: &dyn Inferer, id: u64) {
314+
self.inner.end_agent(inferer, id);
315315
}
316316
}

crates/cervo-core/src/inferer.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ pub trait Inferer {
112112
/// Retrieve the name and shapes of the model outputs.
113113
fn raw_output_shapes(&self) -> &[(String, Vec<usize>)];
114114

115-
fn begin_agent(&mut self, id: u64);
116-
fn end_agent(&mut self, id: u64);
115+
fn begin_agent(&self, id: u64);
116+
fn end_agent(&self, id: u64);
117117
}
118118

119119
/// Helper trait to provide helper functions for loadable models.
@@ -242,12 +242,12 @@ impl Inferer for Box<dyn Inferer + Send> {
242242
self.as_ref().raw_output_shapes()
243243
}
244244

245-
fn begin_agent(&mut self, id: u64) {
246-
self.as_mut().begin_agent(id);
245+
fn begin_agent(&self, id: u64) {
246+
self.as_ref().begin_agent(id);
247247
}
248248

249-
fn end_agent(&mut self, id: u64) {
250-
self.as_mut().end_agent(id);
249+
fn end_agent(&self, id: u64) {
250+
self.as_ref().end_agent(id);
251251
}
252252
}
253253

@@ -268,11 +268,11 @@ impl Inferer for Box<dyn Inferer> {
268268
self.as_ref().raw_output_shapes()
269269
}
270270

271-
fn begin_agent(&mut self, id: u64) {
272-
self.as_mut().begin_agent(id);
271+
fn begin_agent(&self, id: u64) {
272+
self.as_ref().begin_agent(id);
273273
}
274274

275-
fn end_agent(&mut self, id: u64) {
276-
self.as_mut().end_agent(id);
275+
fn end_agent(&self, id: u64) {
276+
self.as_ref().end_agent(id);
277277
}
278278
}

crates/cervo-core/src/inferer/basic.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,6 @@ impl Inferer for BasicInferer {
9494
&self.model_api.outputs
9595
}
9696

97-
fn begin_agent(&mut self, _id: u64) {}
98-
fn end_agent(&mut self, _id: u64) {}
97+
fn begin_agent(&self, _id: u64) {}
98+
fn end_agent(&self, _id: u64) {}
9999
}

crates/cervo-core/src/inferer/dynamic.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,6 @@ impl Inferer for DynamicInferer {
110110
&self.model_api.outputs
111111
}
112112

113-
fn begin_agent(&mut self, _id: u64) {}
114-
fn end_agent(&mut self, _id: u64) {}
113+
fn begin_agent(&self, _id: u64) {}
114+
fn end_agent(&self, _id: u64) {}
115115
}

crates/cervo-core/src/inferer/fixed.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ impl Inferer for FixedBatchInferer {
111111
&self.model_api.outputs
112112
}
113113

114-
fn begin_agent(&mut self, _id: u64) {}
115-
fn end_agent(&mut self, _id: u64) {}
114+
fn begin_agent(&self, _id: u64) {}
115+
fn end_agent(&self, _id: u64) {}
116116
}
117117

118118
struct BatchedModel {

crates/cervo-core/src/inferer/memoizing.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,6 @@ impl Inferer for MemoizingDynamicInferer {
175175
&self.model_api.outputs
176176
}
177177

178-
fn begin_agent(&mut self, _id: u64) {}
179-
fn end_agent(&mut self, _id: u64) {}
178+
fn begin_agent(&self, _id: u64) {}
179+
fn end_agent(&self, _id: u64) {}
180180
}

crates/cervo-core/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ pub use tract_hir;
1515
pub mod batcher;
1616
pub mod epsilon;
1717
pub mod inferer;
18-
pub mod model;
1918
mod model_api;
2019
pub mod recurrent;
20+
pub mod wrapper;
2121

2222
/// Most core utilities are re-exported here.
2323
pub mod prelude {
@@ -30,7 +30,7 @@ pub mod prelude {
3030
InfererProvider, MemoizingDynamicInferer, Response, State,
3131
};
3232

33-
pub use super::model::ModelWrapper;
3433
pub use super::model_api::ModelApi;
3534
pub use super::recurrent::{RecurrentInfo, RecurrentTracker};
35+
pub use super::wrapper::{InfererWrapper, InfererWrapperExt, IntoStateful, StatefulInferer};
3636
}

0 commit comments

Comments
 (0)