@@ -109,23 +109,23 @@ where
109109 /// Create a new recurrency tracker for the model.
110110 ///
111111 pub fn new ( inferer : T , info : Vec < RecurrentInfo > ) -> Result < Self > {
112- let inputs = inferer. raw_input_shapes ( ) ;
113- let outputs = inferer. raw_output_shapes ( ) ;
112+ let raw_inputs = inferer. raw_input_shapes ( ) ;
113+ let raw_outputs = inferer. raw_output_shapes ( ) ;
114114
115115 let mut offset = 0 ;
116116 let keys = info
117117 . iter ( )
118118 . map ( |info| {
119- let inslot = inputs
119+ let inslot = raw_inputs
120120 . iter ( )
121121 . position ( |input| info. inkey == input. 0 )
122122 . with_context ( || format ! ( "no input named {}" , info. inkey) ) ?;
123- let outslot = outputs
123+ let outslot = raw_outputs
124124 . iter ( )
125125 . position ( |output| info. outkey == output. 0 )
126126 . with_context ( || format ! ( "no output named {}" , info. outkey) ) ?;
127127
128- let numels = inputs [ inslot] . 1 . iter ( ) . product ( ) ;
128+ let numels = raw_inputs [ inslot] . 1 . iter ( ) . product ( ) ;
129129 offset += numels;
130130 Ok ( RecurrentPair {
131131 inslot,
@@ -136,16 +136,21 @@ where
136136 } )
137137 . collect :: < Result < TVec < RecurrentPair > > > ( ) ?;
138138
139+ let inputs = inferer. input_shapes ( ) ;
140+ let outputs = inferer. output_shapes ( ) ;
141+
139142 let inputs = inputs
140143 . iter ( )
141144 . filter ( |( k, _) | !info. iter ( ) . any ( |info| & info. inkey == k) )
142145 . cloned ( )
143146 . collect :: < Vec < _ > > ( ) ;
147+
144148 let outputs = outputs
145149 . iter ( )
146150 . filter ( |( k, _) | !info. iter ( ) . any ( |info| & info. outkey == k) )
147151 . cloned ( )
148152 . collect :: < Vec < _ > > ( ) ;
153+
149154 Ok ( Self {
150155 inner : inferer,
151156 state : RecurrentState {
@@ -248,23 +253,23 @@ impl<Inner: InfererWrapper> RecurrentTrackerWrapper<Inner> {
248253 /// Create a new recurrency tracker for the model.
249254 ///
250255 pub fn new < T : Inferer > ( inner : Inner , inferer : & T , info : Vec < RecurrentInfo > ) -> Result < Self > {
251- let inputs = inner . input_shapes ( inferer ) ;
252- let outputs = inner . output_shapes ( inferer ) ;
256+ let raw_inputs = inferer . raw_input_shapes ( ) ;
257+ let raw_outputs = inferer . raw_output_shapes ( ) ;
253258
254259 let mut offset = 0 ;
255260 let keys = info
256261 . iter ( )
257262 . map ( |info| {
258- let inslot = inputs
263+ let inslot = raw_inputs
259264 . iter ( )
260265 . position ( |input| info. inkey == input. 0 )
261266 . with_context ( || format ! ( "no input named {}" , info. inkey) ) ?;
262- let outslot = outputs
267+ let outslot = raw_outputs
263268 . iter ( )
264269 . position ( |output| info. outkey == output. 0 )
265270 . with_context ( || format ! ( "no output named {}" , info. outkey) ) ?;
266271
267- let numels = inputs [ inslot] . 1 . iter ( ) . product ( ) ;
272+ let numels = raw_inputs [ inslot] . 1 . iter ( ) . product ( ) ;
268273 offset += numels;
269274 Ok ( RecurrentPair {
270275 inslot,
@@ -275,16 +280,21 @@ impl<Inner: InfererWrapper> RecurrentTrackerWrapper<Inner> {
275280 } )
276281 . collect :: < Result < TVec < RecurrentPair > > > ( ) ?;
277282
283+ let inputs = inner. input_shapes ( inferer) ;
284+ let outputs = inner. output_shapes ( inferer) ;
285+
278286 let inputs = inputs
279287 . iter ( )
280288 . filter ( |( k, _) | !info. iter ( ) . any ( |info| & info. inkey == k) )
281289 . cloned ( )
282290 . collect :: < Vec < _ > > ( ) ;
291+
283292 let outputs = outputs
284293 . iter ( )
285294 . filter ( |( k, _) | !info. iter ( ) . any ( |info| & info. outkey == k) )
286295 . cloned ( )
287296 . collect :: < Vec < _ > > ( ) ;
297+
288298 Ok ( Self {
289299 inner,
290300 state : RecurrentState {
@@ -338,6 +348,8 @@ mod tests {
338348 batcher:: ScratchPadView ,
339349 inferer:: State ,
340350 prelude:: { Batcher , Inferer } ,
351+ recurrent:: RecurrentTrackerWrapper ,
352+ wrapper:: InfererWrapper ,
341353 } ;
342354
343355 use super :: RecurrentTracker ;
@@ -371,6 +383,7 @@ mod tests {
371383 end_called : false . into ( ) ,
372384 begin_called : false . into ( ) ,
373385 inputs : vec ! [
386+ ( "epsilon" . to_owned( ) , vec![ 2 ] ) ,
374387 ( hidden_name_in. to_owned( ) , vec![ 2 , 1 ] ) ,
375388 ( cell_name_in. to_owned( ) , vec![ 2 , 3 ] ) ,
376389 ] ,
@@ -390,15 +403,15 @@ mod tests {
390403 }
391404
392405 fn infer_raw ( & self , batch : & mut ScratchPadView < ' _ > ) -> anyhow:: Result < ( ) , anyhow:: Error > {
393- assert_eq ! ( batch. inner( ) . input_name( 0 ) , "lstm_hidden_state" ) ;
394- let hidden_value = batch. input_slot ( 0 ) ;
406+ assert_eq ! ( batch. inner( ) . input_name( 1 ) , "lstm_hidden_state" ) ;
407+ let hidden_value = batch. input_slot ( 1 ) ;
395408 let hidden_new = hidden_value. iter ( ) . map ( |v| * v + 1.0 ) . collect :: < Vec < _ > > ( ) ;
396409
397410 assert_eq ! ( batch. inner( ) . output_name( 0 ) , "lstm_hidden_state" ) ;
398411 batch. output_slot_mut ( 0 ) . copy_from_slice ( & hidden_new) ;
399412
400- assert_eq ! ( batch. inner( ) . input_name( 1 ) , "lstm_cell_state" ) ;
401- let cell_value = batch. input_slot ( 1 ) ;
413+ assert_eq ! ( batch. inner( ) . input_name( 2 ) , "lstm_cell_state" ) ;
414+ let cell_value = batch. input_slot ( 2 ) ;
402415 let cell_new = cell_value. iter ( ) . map ( |v| * v + 2.0 ) . collect :: < Vec < _ > > ( ) ;
403416
404417 assert_eq ! ( batch. inner( ) . output_name( 1 ) , "lstm_cell_state" ) ;
@@ -585,4 +598,65 @@ mod tests {
585598 assert ! ( agent_data. data[ "hidden_output" ] . iter( ) . all( |v| * v == 1.0 ) ) ;
586599 assert ! ( agent_data. data[ "cell_output" ] . iter( ) . all( |v| * v == 2.0 ) ) ;
587600 }
601+
602+ #[ test]
603+ fn test_wrapper_does_not_expose_inner_hidden ( ) {
604+ // Imagine Recurrent<Epsilon<...>>. We want to assert that
605+ // Recurrent hides its own fields while also not exposing any
606+ // fields from the inner epsilon wrapper.
607+
608+ struct DummyEpsilonWrapper {
609+ inputs : Vec < ( String , Vec < usize > ) > ,
610+ }
611+
612+ impl InfererWrapper for DummyEpsilonWrapper {
613+ fn invoke (
614+ & self ,
615+ _inferer : & dyn Inferer ,
616+ _batch : & mut ScratchPadView < ' _ > ,
617+ ) -> anyhow:: Result < ( ) , anyhow:: Error > {
618+ Ok ( ( ) )
619+ }
620+ fn input_shapes < ' a > ( & ' a self , _inferer : & ' a dyn Inferer ) -> & ' a [ ( String , Vec < usize > ) ] {
621+ & self . inputs
622+ }
623+ fn output_shapes < ' a > (
624+ & ' a self ,
625+ _inferer : & ' a dyn Inferer ,
626+ ) -> & ' a [ ( String , Vec < usize > ) ] {
627+ _inferer. output_shapes ( )
628+ }
629+ fn begin_agent ( & self , _inferer : & dyn Inferer , _id : u64 ) { }
630+ fn end_agent ( & self , _inferer : & dyn Inferer , _id : u64 ) { }
631+ }
632+
633+ let inferer = DummyInferer :: default ( ) ;
634+ let wrapper = DummyEpsilonWrapper {
635+ inputs : vec ! [
636+ ( "lstm_hidden_state" . to_owned( ) , vec![ 2 , 1 ] ) ,
637+ ( "lstm_cell_state" . to_owned( ) , vec![ 2 , 3 ] ) ,
638+ ] ,
639+ } ;
640+
641+ let recurrent = RecurrentTrackerWrapper :: wrap ( wrapper, & inferer) . unwrap ( ) ;
642+
643+ assert_eq ! ( recurrent. input_shapes( & inferer) . len( ) , 0 ) ;
644+ assert_eq ! (
645+ recurrent. output_shapes( & inferer) . len( ) ,
646+ 2 ,
647+ "only hidden and cell state are recurrent: {:?}" ,
648+ recurrent. output_shapes( & inferer)
649+ ) ;
650+
651+ assert_eq ! ( recurrent. output_shapes( & inferer) [ 0 ] . 0 , "hidden_output" ) ;
652+ assert_eq ! ( recurrent. output_shapes( & inferer) [ 1 ] . 0 , "cell_output" ) ;
653+
654+ assert_eq ! ( recurrent. state. inputs. len( ) , 0 ) ;
655+ assert_eq ! ( recurrent. state. outputs. len( ) , 2 ) ;
656+
657+ assert_eq ! ( recurrent. state. keys. len( ) , 2 ) ;
658+ // slots are still correct despite epsilon being hidden
659+ assert_eq ! ( recurrent. state. keys[ 0 ] . inslot, 1 ) ;
660+ assert_eq ! ( recurrent. state. keys[ 1 ] . inslot, 2 ) ;
661+ }
588662}
0 commit comments