Skip to content

Commit f73aea5

Browse files
Cristian GarciaFlax Authors
authored andcommitted
fix RNN
Implements Mapping for `StateAxes` and uses `StateAxes` in place of `dict` to fix RNN, this avoids some JAX pytree errors when scanning attributes for data in `nnx.Pytree`. PiperOrigin-RevId: 800653246
1 parent db06488 commit f73aea5

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

flax/nnx/nn/recurrent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ def __init__(
685685
'Expected rngs to be a jax.Array, int, Rngs, or bool. '
686686
f'Got {type(rngs)}.'
687687
)
688-
self.state_axes = state_axes or {...: iteration.Carry} # type: ignore
688+
self.state_axes = state_axes or nnx.StateAxes({...: iteration.Carry}) # type: ignore
689689
self.broadcast_rngs = broadcast_rngs
690690

691691
def __call__(

flax/nnx/transforms/iteration.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class Carry:
5656
# -------------------------------
5757

5858

59-
class StateAxes(extract.PrefixMapping):
59+
class StateAxes(extract.PrefixMapping, tp.Mapping):
6060

6161
def __init__(
6262
self,
@@ -101,6 +101,15 @@ def __repr__(self):
101101
def items(self):
102102
return zip(self.filters, self.axes)
103103

104+
def __getitem__(self, key):
105+
return self.axes[self.filters.index(key)]
106+
107+
def __iter__(self):
108+
return iter(self.filters)
109+
110+
def __len__(self):
111+
return len(self.filters)
112+
104113
def __eq__(self, other):
105114
return (
106115
isinstance(other, StateAxes)

0 commit comments

Comments
 (0)