Error: vmapped nnx.Module initialization with selective Variable broadcasting
#4526
-
|
I'm trying to use Two commented lines provide modifications that succeed, but are not quite what I want.
import jax
import jax.numpy as jnp
from flax import nnx
class MyVar(nnx.Variable):
pass
state_axes = nnx.StateAxes({(nnx.RngState, MyVar): 0, ...: None})
# state_axes = nnx.StateAxes({(nnx.RngState, nnx.Param, MyVar): 0, ...: None})
class MyModule(nnx.Module):
@nnx.split_rngs(splits=2)
@nnx.vmap(in_axes=(state_axes, 0))
def __init__(self, rngs):
self.param = nnx.Param(jax.random.uniform(rngs()))
# self.param = nnx.Param(jnp.float32(0.0))
self.var = MyVar(jax.random.uniform(rngs()))
rngs = nnx.Rngs(123)
model = MyModule(rngs)The error is Any suggestions, anyone? @cgarciae Shameless ping. Seems like You ~= NNX 😄 |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
|
I'm guessing I don't need that |
Beta Was this translation helpful? Give feedback.
-
|
Hey @rademacher-p, you need two different RNG keys, one that initializes import jax
import jax.numpy as jnp
from flax import nnx
class MyVar(nnx.Variable):
pass
state_axes = nnx.StateAxes({(MyVar, 'vars'): 0, ...: None})
class MyModule(nnx.Module):
@nnx.split_rngs(splits=2, only='vars')
@nnx.vmap(in_axes=(state_axes, state_axes))
def __init__(self, rngs):
self.param = nnx.Param(jax.random.uniform(rngs.params()))
# self.param = nnx.Param(jnp.float32(0.0))
self.var = MyVar(jax.random.uniform(rngs.vars()))
rngs = nnx.Rngs(params=1, vars=2)
model = MyModule(rngs)Also updated |
Beta Was this translation helpful? Give feedback.
Hey @rademacher-p, you need two different RNG keys, one that initializes
Params as broadcasts, and one that initializesMyVars vectorized. Easiest way to do this is to use named streams inRngs, create aparamsstream and avarsstream, tellsplit_rngsto only split thevarskeys, and use the stream names to sample some keys. Here's a running sample: