NNX migration preparation: pure_nnx flag and init_state_fn #2965
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR makes the pure nnx mode and linen node co-exist for a while. While the users should continue use linen for the training, we added code branches to support pure nnx.
This flag controls the whole training process to use pure NNX. It is different with the
enable_nnxflag which is only used during reinforcement learning to get a NNX model and then convert it to linen.The major change is in the maxtext_utils.py:
To call
maxtext_utils. get_abstract_state(), aninit_state_fnmust be passed in. Currently this is the lineninit_initial_state()func, but will be replaced with an nnx init func for pure_nnx.The
setup_training_state()andsetup_decode_state()have been modified accordingly to accept theinit_state_fn.Unit tests have been modified as well.
Also at where the pure NNX model is supposed to be supported, raised the NotImplementedError. The implementation will be added gradually.
Tests
Unit tests pass.
Run training with the change.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.