You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm building a system using flax.nnx and orbax.checkpointing. However, it is overly complicated on how to save and restore models due to the new jax.random.key() being used in flax.nnx rather than jax.random.PRNGkey().
I have had to create a workaround where all layers with rng and key in their path are changed from dtype=key<fry> to a format appropriate for saving. Then, upon restoration, they need to be shanged back.
I am attaching a link to a notebook explaining what I've done but I would be keen to hear if there are simpler workarounds? Or, preferably, if there is a way to simple save and restore models?
Hi, just checking if google/flax#4383 suggested solution is sufficient enough. Currently, Orbax only support dtype=key<fry> via the JaxRamdomKeyCheckpointHandler. To support it combining with other types, a feature request would be needed.
I'm building a system using
flax.nnx
andorbax.checkpointing
. However, it is overly complicated on how to save and restore models due to the newjax.random.key()
being used inflax.nnx
rather thanjax.random.PRNGkey()
.I have had to create a workaround where all layers with
rng
andkey
in their path are changed fromdtype=key<fry>
to a format appropriate for saving. Then, upon restoration, they need to be shanged back.I am attaching a link to a notebook explaining what I've done but I would be keen to hear if there are simpler workarounds? Or, preferably, if there is a way to simple save and restore models?
https://colab.research.google.com/drive/1ozln9ejG7eRtxvbkqHYU3K6OyPvveH9w?usp=sharing
Note: I am also adding an issue to flax to see if there is a fix their side (#4383).
The text was updated successfully, but these errors were encountered: