Commit f73aea5
fix RNN
Implements Mapping for `StateAxes` and uses `StateAxes` in place of `dict` to fix RNN, this avoids some JAX pytree errors when scanning attributes for data in `nnx.Pytree`.
PiperOrigin-RevId: 8006532461 parent db06488 commit f73aea5
2 files changed
+11
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
685 | 685 | | |
686 | 686 | | |
687 | 687 | | |
688 | | - | |
| 688 | + | |
689 | 689 | | |
690 | 690 | | |
691 | 691 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
56 | 56 | | |
57 | 57 | | |
58 | 58 | | |
59 | | - | |
| 59 | + | |
60 | 60 | | |
61 | 61 | | |
62 | 62 | | |
| |||
101 | 101 | | |
102 | 102 | | |
103 | 103 | | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
104 | 113 | | |
105 | 114 | | |
106 | 115 | | |
| |||
0 commit comments