Skip to content

Commit c593989

Browse files
authored
do not use root input shapes when wrapping (#66)
1 parent fc3ae7b commit c593989

File tree

4 files changed

+6
-3
lines changed

4 files changed

+6
-3
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313
- Add `StatefulInferer::replace_inferer` which works with a `&mut
1414
StatefulInferer`, at the cost of requiring the inferer to be of the
1515
same type.
16+
- Fix bugs in the new wrapper setup where consumed and modified shapes
17+
weren't respected during wrapper construction.
1618

1719
## [0.9.0] - 2025-09-04
1820

crates/cervo-core/src/epsilon.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ where
259259
generator: NG,
260260
key: &str,
261261
) -> Result<Self> {
262-
let inputs = inferer.input_shapes();
262+
let inputs = inner.input_shapes(inferer);
263263

264264
let (index, count) = match inputs.iter().enumerate().find(|(_, (k, _))| k == key) {
265265
Some((index, (_, shape))) => (index, shape.iter().product()),

crates/cervo-core/src/recurrent.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,8 @@ impl<Inner: InfererWrapper> RecurrentTrackerWrapper<Inner> {
248248
/// Create a new recurrency tracker for the model.
249249
///
250250
pub fn new<T: Inferer>(inner: Inner, inferer: &T, info: Vec<RecurrentInfo>) -> Result<Self> {
251-
let inputs = inferer.raw_input_shapes();
252-
let outputs = inferer.raw_output_shapes();
251+
let inputs = inner.input_shapes(inferer);
252+
let outputs = inner.output_shapes(inferer);
253253

254254
let mut offset = 0;
255255
let keys = info

deny.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ db-urls = ["https://github.com/rustsec/advisory-db"]
4545
ignore = [
4646
#"RUSTSEC-0000-0000",
4747
"RUSTSEC-2024-0436",
48+
"RUSTSEC-2025-0056",
4849
]
4950
# Threshold for security vulnerabilities, any vulnerability with a CVSS score
5051
# lower than the range specified will be ignored. Note that ignored advisories

0 commit comments

Comments
 (0)