@@ -110,6 +110,8 @@ pub struct StatefulInferer<WrapStack: InfererWrapper, Inf: Inferer> {
110110}
111111
112112impl < WrapStack : InfererWrapper , Inf : Inferer > StatefulInferer < WrapStack , Inf > {
113+ /// Construct a new [`StatefulInferer`] by wrapping the given
114+ /// inferer with the given wrapper stack.
113115 pub fn new ( wrapper_stack : WrapStack , inferer : Inf ) -> Self {
114116 Self {
115117 wrapper_stack,
@@ -121,8 +123,8 @@ impl<WrapStack: InfererWrapper, Inf: Inferer> StatefulInferer<WrapStack, Inf> {
121123 /// any state in wrappers.
122124 ///
123125 /// Requires that the shapes of the policies are compatible, but
124- /// they may be different concrete inferer implementations . If
125- /// this check fails, will return self unchanged.
126+ /// they may be different inferer types . If this check fails, will
127+ /// return self unchanged.
126128 pub fn with_new_inferer < NewInf : Inferer > (
127129 self ,
128130 new_inferer : NewInf ,
@@ -136,6 +138,22 @@ impl<WrapStack: InfererWrapper, Inf: Inferer> StatefulInferer<WrapStack, Inf> {
136138 } )
137139 }
138140
141+ /// Replace the inner inferer with a new inferer while maintaining
142+ /// any state in wrappers.
143+ ///
144+ /// Requires that the shapes of the policies are compatible If
145+ /// this check fails, will not change self. Compared to
146+ /// [`with_new_inferer`], also requires that the new inferer has
147+ /// the same type as the old one.
148+ pub fn replace_inferer ( & mut self , new_inferer : Inf ) -> anyhow:: Result < ( ) > {
149+ if let Err ( e) = Self :: check_compatible_shapes ( & self . inferer , & new_inferer) {
150+ Err ( e)
151+ } else {
152+ self . inferer = new_inferer;
153+ Ok ( ( ) )
154+ }
155+ }
156+
139157 /// Validate that [`Old`] and [`New`] are compatible with each
140158 /// other.
141159 pub fn check_compatible_shapes < Old : Inferer , New : Inferer > (
0 commit comments