@@ -8,7 +8,7 @@ Utilities for filling noise inputs for an inference model.
88
99use std:: cell:: RefCell ;
1010
11- use crate :: { batcher:: ScratchPadView , inferer:: Inferer , prelude:: ModelWrapper } ;
11+ use crate :: { batcher:: ScratchPadView , inferer:: Inferer , prelude:: InfererWrapper } ;
1212use anyhow:: { bail, Result } ;
1313use perchance:: PerchanceContext ;
1414use rand:: thread_rng;
@@ -212,21 +212,21 @@ where
212212 self . inner . raw_output_shapes ( )
213213 }
214214
215- fn begin_agent ( & mut self , id : u64 ) {
215+ fn begin_agent ( & self , id : u64 ) {
216216 self . inner . begin_agent ( id) ;
217217 }
218218
219- fn end_agent ( & mut self , id : u64 ) {
219+ fn end_agent ( & self , id : u64 ) {
220220 self . inner . end_agent ( id) ;
221221 }
222222}
223223
224- pub struct EpsilonInjectorWrapper < Inner : ModelWrapper , NG : NoiseGenerator > {
224+ pub struct EpsilonInjectorWrapper < Inner : InfererWrapper , NG : NoiseGenerator > {
225225 inner : Inner ,
226226 state : EpsilonInjectorState < NG > ,
227227}
228228
229- impl < Inner : ModelWrapper > EpsilonInjectorWrapper < Inner , HighQualityNoiseGenerator > {
229+ impl < Inner : InfererWrapper > EpsilonInjectorWrapper < Inner , HighQualityNoiseGenerator > {
230230 /// Wraps the provided `inferer` to automatically generate noise for the input named by `key`.
231231 ///
232232 /// This function will use [`HighQualityNoiseGenerator`] as the noise source.
@@ -245,7 +245,7 @@ impl<Inner: ModelWrapper> EpsilonInjectorWrapper<Inner, HighQualityNoiseGenerato
245245
246246impl < Inner , NG > EpsilonInjectorWrapper < Inner , NG >
247247where
248- Inner : ModelWrapper ,
248+ Inner : InfererWrapper ,
249249 NG : NoiseGenerator ,
250250{
251251 /// Create a new injector for the provided `key`, using the custom `generator` as the noise source.
@@ -284,12 +284,12 @@ where
284284 }
285285}
286286
287- impl < Inner , NG > ModelWrapper for EpsilonInjectorWrapper < Inner , NG >
287+ impl < Inner , NG > InfererWrapper for EpsilonInjectorWrapper < Inner , NG >
288288where
289- Inner : ModelWrapper ,
289+ Inner : InfererWrapper ,
290290 NG : NoiseGenerator ,
291291{
292- fn invoke ( & self , inferer : & impl Inferer , batch : & mut ScratchPadView < ' _ > ) -> anyhow:: Result < ( ) > {
292+ fn invoke ( & self , inferer : & dyn Inferer , batch : & mut ScratchPadView < ' _ > ) -> anyhow:: Result < ( ) > {
293293 self . inner . invoke ( inferer, batch) ?;
294294 let total_count = self . state . count * batch. len ( ) ;
295295 let output = batch. input_slot_mut ( self . state . index ) ;
@@ -306,11 +306,11 @@ where
306306 self . inner . output_shapes ( inferer)
307307 }
308308
309- fn begin_agent ( & self , id : u64 ) {
310- self . inner . begin_agent ( id) ;
309+ fn begin_agent ( & self , inferer : & dyn Inferer , id : u64 ) {
310+ self . inner . begin_agent ( inferer , id) ;
311311 }
312312
313- fn end_agent ( & self , id : u64 ) {
314- self . inner . end_agent ( id) ;
313+ fn end_agent ( & self , inferer : & dyn Inferer , id : u64 ) {
314+ self . inner . end_agent ( inferer , id) ;
315315 }
316316}
0 commit comments