-
System informationOS Platform and Distribution: Linux Ubuntu 22.04 Problem you have encountered:
Steps to reproduce:I saw a similar discussion for |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
|
Hey! I'm guessing you want to replicate the weights but have different RNGs, to do this you can use the state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None})
@nnx.split_rngs(splits=1)
@nnx.pmap(in_axes=(state_axes, 0))
def forward(model, x):
return model(x)
out = forward(model, jnp.ones((1, 16, 2)))For more info, check out the Filters guide. |
Beta Was this translation helpful? Give feedback.
-
|
thank you, keep up the amazing work! |
Beta Was this translation helpful? Give feedback.
Hey! I'm guessing you want to replicate the weights but have different RNGs, to do this you can use the
nnx.split_rngsdecorator to split the RNGs before enteringpmapand, and useStateAxesto specify the parallelization axes for substates of your Module, in this case mapRngStateto0and the rest (...) toNone:For more info, check out the Filters guide.