Skip to content

Commit 3fbae92

Browse files
authored
fix issue with slot shifting in recurrent wrapper (#67)
1 parent 7ddc88f commit 3fbae92

File tree

2 files changed

+91
-14
lines changed

2 files changed

+91
-14
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
<!-- next-header -->
1111
## [Unreleased] - ReleaseDate
12+
- Fix bugs in the new wrapper setup where consumed and modified shapes
13+
weren't respected during wrapper construction.
14+
1215
## [0.9.1] - 2025-09-10
1316

1417
- Add `StatefulInferer::replace_inferer` which works with a `&mut

crates/cervo-core/src/recurrent.rs

Lines changed: 88 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -109,23 +109,23 @@ where
109109
/// Create a new recurrency tracker for the model.
110110
///
111111
pub fn new(inferer: T, info: Vec<RecurrentInfo>) -> Result<Self> {
112-
let inputs = inferer.raw_input_shapes();
113-
let outputs = inferer.raw_output_shapes();
112+
let raw_inputs = inferer.raw_input_shapes();
113+
let raw_outputs = inferer.raw_output_shapes();
114114

115115
let mut offset = 0;
116116
let keys = info
117117
.iter()
118118
.map(|info| {
119-
let inslot = inputs
119+
let inslot = raw_inputs
120120
.iter()
121121
.position(|input| info.inkey == input.0)
122122
.with_context(|| format!("no input named {}", info.inkey))?;
123-
let outslot = outputs
123+
let outslot = raw_outputs
124124
.iter()
125125
.position(|output| info.outkey == output.0)
126126
.with_context(|| format!("no output named {}", info.outkey))?;
127127

128-
let numels = inputs[inslot].1.iter().product();
128+
let numels = raw_inputs[inslot].1.iter().product();
129129
offset += numels;
130130
Ok(RecurrentPair {
131131
inslot,
@@ -136,16 +136,21 @@ where
136136
})
137137
.collect::<Result<TVec<RecurrentPair>>>()?;
138138

139+
let inputs = inferer.input_shapes();
140+
let outputs = inferer.output_shapes();
141+
139142
let inputs = inputs
140143
.iter()
141144
.filter(|(k, _)| !info.iter().any(|info| &info.inkey == k))
142145
.cloned()
143146
.collect::<Vec<_>>();
147+
144148
let outputs = outputs
145149
.iter()
146150
.filter(|(k, _)| !info.iter().any(|info| &info.outkey == k))
147151
.cloned()
148152
.collect::<Vec<_>>();
153+
149154
Ok(Self {
150155
inner: inferer,
151156
state: RecurrentState {
@@ -248,23 +253,23 @@ impl<Inner: InfererWrapper> RecurrentTrackerWrapper<Inner> {
248253
/// Create a new recurrency tracker for the model.
249254
///
250255
pub fn new<T: Inferer>(inner: Inner, inferer: &T, info: Vec<RecurrentInfo>) -> Result<Self> {
251-
let inputs = inner.input_shapes(inferer);
252-
let outputs = inner.output_shapes(inferer);
256+
let raw_inputs = inferer.raw_input_shapes();
257+
let raw_outputs = inferer.raw_output_shapes();
253258

254259
let mut offset = 0;
255260
let keys = info
256261
.iter()
257262
.map(|info| {
258-
let inslot = inputs
263+
let inslot = raw_inputs
259264
.iter()
260265
.position(|input| info.inkey == input.0)
261266
.with_context(|| format!("no input named {}", info.inkey))?;
262-
let outslot = outputs
267+
let outslot = raw_outputs
263268
.iter()
264269
.position(|output| info.outkey == output.0)
265270
.with_context(|| format!("no output named {}", info.outkey))?;
266271

267-
let numels = inputs[inslot].1.iter().product();
272+
let numels = raw_inputs[inslot].1.iter().product();
268273
offset += numels;
269274
Ok(RecurrentPair {
270275
inslot,
@@ -275,16 +280,21 @@ impl<Inner: InfererWrapper> RecurrentTrackerWrapper<Inner> {
275280
})
276281
.collect::<Result<TVec<RecurrentPair>>>()?;
277282

283+
let inputs = inner.input_shapes(inferer);
284+
let outputs = inner.output_shapes(inferer);
285+
278286
let inputs = inputs
279287
.iter()
280288
.filter(|(k, _)| !info.iter().any(|info| &info.inkey == k))
281289
.cloned()
282290
.collect::<Vec<_>>();
291+
283292
let outputs = outputs
284293
.iter()
285294
.filter(|(k, _)| !info.iter().any(|info| &info.outkey == k))
286295
.cloned()
287296
.collect::<Vec<_>>();
297+
288298
Ok(Self {
289299
inner,
290300
state: RecurrentState {
@@ -338,6 +348,8 @@ mod tests {
338348
batcher::ScratchPadView,
339349
inferer::State,
340350
prelude::{Batcher, Inferer},
351+
recurrent::RecurrentTrackerWrapper,
352+
wrapper::InfererWrapper,
341353
};
342354

343355
use super::RecurrentTracker;
@@ -371,6 +383,7 @@ mod tests {
371383
end_called: false.into(),
372384
begin_called: false.into(),
373385
inputs: vec![
386+
("epsilon".to_owned(), vec![2]),
374387
(hidden_name_in.to_owned(), vec![2, 1]),
375388
(cell_name_in.to_owned(), vec![2, 3]),
376389
],
@@ -390,15 +403,15 @@ mod tests {
390403
}
391404

392405
fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> anyhow::Result<(), anyhow::Error> {
393-
assert_eq!(batch.inner().input_name(0), "lstm_hidden_state");
394-
let hidden_value = batch.input_slot(0);
406+
assert_eq!(batch.inner().input_name(1), "lstm_hidden_state");
407+
let hidden_value = batch.input_slot(1);
395408
let hidden_new = hidden_value.iter().map(|v| *v + 1.0).collect::<Vec<_>>();
396409

397410
assert_eq!(batch.inner().output_name(0), "lstm_hidden_state");
398411
batch.output_slot_mut(0).copy_from_slice(&hidden_new);
399412

400-
assert_eq!(batch.inner().input_name(1), "lstm_cell_state");
401-
let cell_value = batch.input_slot(1);
413+
assert_eq!(batch.inner().input_name(2), "lstm_cell_state");
414+
let cell_value = batch.input_slot(2);
402415
let cell_new = cell_value.iter().map(|v| *v + 2.0).collect::<Vec<_>>();
403416

404417
assert_eq!(batch.inner().output_name(1), "lstm_cell_state");
@@ -585,4 +598,65 @@ mod tests {
585598
assert!(agent_data.data["hidden_output"].iter().all(|v| *v == 1.0));
586599
assert!(agent_data.data["cell_output"].iter().all(|v| *v == 2.0));
587600
}
601+
602+
#[test]
603+
fn test_wrapper_does_not_expose_inner_hidden() {
604+
// Imagine Recurrent<Epsilon<...>>. We want to assert that
605+
// Recurrent hides its own fields while also not exposing any
606+
// fields from the inner epsilon wrapper.
607+
608+
struct DummyEpsilonWrapper {
609+
inputs: Vec<(String, Vec<usize>)>,
610+
}
611+
612+
impl InfererWrapper for DummyEpsilonWrapper {
613+
fn invoke(
614+
&self,
615+
_inferer: &dyn Inferer,
616+
_batch: &mut ScratchPadView<'_>,
617+
) -> anyhow::Result<(), anyhow::Error> {
618+
Ok(())
619+
}
620+
fn input_shapes<'a>(&'a self, _inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
621+
&self.inputs
622+
}
623+
fn output_shapes<'a>(
624+
&'a self,
625+
_inferer: &'a dyn Inferer,
626+
) -> &'a [(String, Vec<usize>)] {
627+
_inferer.output_shapes()
628+
}
629+
fn begin_agent(&self, _inferer: &dyn Inferer, _id: u64) {}
630+
fn end_agent(&self, _inferer: &dyn Inferer, _id: u64) {}
631+
}
632+
633+
let inferer = DummyInferer::default();
634+
let wrapper = DummyEpsilonWrapper {
635+
inputs: vec![
636+
("lstm_hidden_state".to_owned(), vec![2, 1]),
637+
("lstm_cell_state".to_owned(), vec![2, 3]),
638+
],
639+
};
640+
641+
let recurrent = RecurrentTrackerWrapper::wrap(wrapper, &inferer).unwrap();
642+
643+
assert_eq!(recurrent.input_shapes(&inferer).len(), 0);
644+
assert_eq!(
645+
recurrent.output_shapes(&inferer).len(),
646+
2,
647+
"only hidden and cell state are recurrent: {:?}",
648+
recurrent.output_shapes(&inferer)
649+
);
650+
651+
assert_eq!(recurrent.output_shapes(&inferer)[0].0, "hidden_output");
652+
assert_eq!(recurrent.output_shapes(&inferer)[1].0, "cell_output");
653+
654+
assert_eq!(recurrent.state.inputs.len(), 0);
655+
assert_eq!(recurrent.state.outputs.len(), 2);
656+
657+
assert_eq!(recurrent.state.keys.len(), 2);
658+
// slots are still correct despite epsilon being hidden
659+
assert_eq!(recurrent.state.keys[0].inslot, 1);
660+
assert_eq!(recurrent.state.keys[1].inslot, 2);
661+
}
588662
}

0 commit comments

Comments
 (0)