-
Notifications
You must be signed in to change notification settings - Fork 763
Open
Description
Description
The current VAE example uses the flax.linen API for model definition and training.
As Flax continues to develop the nnx module as its next-generation neural network API, it would be valuable to provide an updated version of this example using flax.nnx.
This migration will help users:
- Learn how to implement a VAE using
nnx's new modular and explicit state-handling paradigm. - Compare differences between
nnandnnxAPIs in real-world use cases. - Encourage adoption of
nnxin research and production examples.
Proposed Changes
- Reimplement the model (
Encoder,Decoder, andVAEwrapper) usingflax.nnx.Module. - Replace
flax.training.train_state.TrainStatewithnnx.Optimizerfor parameter management. - Update training and evaluation loops to use
nnx.jitand direct method calls instead ofapply(). - Ensure reproducibility and equivalence with the original
nn-based example.
Contribution
I'd be happy to implement this migration and submit a PR. Please let me know if there are any specific guidelines or preferences for the implementation approach.
Motivation
The VAE example is a widely understood benchmark that involves both deterministic and stochastic components, making it ideal to showcase nnx's design strengths:
- Explicit randomness (
nnx.Rngs) - Parameter/state separation
- Compositional design
- Compatibility with Optax and other JAX tools
Having this example available in nnx would significantly benefit users exploring or transitioning to the new API.
Metadata
Metadata
Assignees
Labels
No labels