Skip to content

Conversation

@xibinliu
Copy link
Collaborator

@xibinliu xibinliu commented Jan 19, 2026

Description

This PR makes the pure nnx mode and linen node co-exist for a while. While the users should continue use linen for the training, we added code branches to support pure nnx.

  • pure_nnx: a flag to to choose pure NNX logic when NNX and linen models co-exist.
    This flag controls the whole training process to use pure NNX. It is different with the enable_nnx flag which is only used during reinforcement learning to get a NNX model and then convert it to linen.
  • init_state_fn: a function to initialize the model state for the training. It will be set to different function for NNX and Linen.

The major change is in the maxtext_utils.py:

To call maxtext_utils. get_abstract_state(), an init_state_fn must be passed in. Currently this is the linen init_initial_state() func, but will be replaced with an nnx init func for pure_nnx.

The setup_training_state() and setup_decode_state() have been modified accordingly to accept the init_state_fn.

Unit tests have been modified as well.

Also at where the pure NNX model is supposed to be supported, raised the NotImplementedError. The implementation will be added gradually.

Tests

Unit tests pass.
Run training with the change.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Jan 19, 2026

- pure_nnx: a flag to to choose pure NNX logic when NNX and linen models
  co-exist.
- init_state_fn: a function to initialize the model state for the
  training. It will be set to different function for NNX and Linen.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant