Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Orbax Checkpointing and Flax.NNX require hacking to work together #1337

Closed
hdrwilkinson opened this issue Nov 15, 2024 · 2 comments
Closed
Labels
duplicate This issue or pull request already exists

Comments

@hdrwilkinson
Copy link

hdrwilkinson commented Nov 15, 2024

I'm building a system using flax.nnx and orbax.checkpointing. However, it is overly complicated on how to save and restore models due to the new jax.random.key() being used in flax.nnx rather than jax.random.PRNGkey().

I have had to create a workaround where all layers with rng and key in their path are changed from dtype=key<fry> to a format appropriate for saving. Then, upon restoration, they need to be shanged back.

I am attaching a link to a notebook explaining what I've done but I would be keen to hear if there are simpler workarounds? Or, preferably, if there is a way to simple save and restore models?

https://colab.research.google.com/drive/1ozln9ejG7eRtxvbkqHYU3K6OyPvveH9w?usp=sharing

Note: I am also adding an issue to flax to see if there is a fix their side (#4383).

@ChromeHearts
Copy link
Collaborator

Hi, just checking if google/flax#4383 suggested solution is sufficient enough. Currently, Orbax only support dtype=key<fry> via the JaxRamdomKeyCheckpointHandler. To support it combining with other types, a feature request would be needed.

@ChromeHearts ChromeHearts added the duplicate This issue or pull request already exists label Dec 6, 2024
@ChromeHearts
Copy link
Collaborator

Duplicate of #1105 . Orbax team is working on supporting RNG as part of a PyTree.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
duplicate This issue or pull request already exists
Projects
None yet
Development

No branches or pull requests

2 participants