Skip to content

Migrate vae example from flax.linen to flax.nnx #5068

@sanepunk

Description

@sanepunk

Description

The current VAE example uses the flax.linen API for model definition and training.
As Flax continues to develop the nnx module as its next-generation neural network API, it would be valuable to provide an updated version of this example using flax.nnx.

This migration will help users:

  • Learn how to implement a VAE using nnx's new modular and explicit state-handling paradigm.
  • Compare differences between nn and nnx APIs in real-world use cases.
  • Encourage adoption of nnx in research and production examples.

Proposed Changes

  • Reimplement the model (Encoder, Decoder, and VAE wrapper) using flax.nnx.Module.
  • Replace flax.training.train_state.TrainState with nnx.Optimizer for parameter management.
  • Update training and evaluation loops to use nnx.jit and direct method calls instead of apply().
  • Ensure reproducibility and equivalence with the original nn-based example.

Contribution

I'd be happy to implement this migration and submit a PR. Please let me know if there are any specific guidelines or preferences for the implementation approach.


Motivation

The VAE example is a widely understood benchmark that involves both deterministic and stochastic components, making it ideal to showcase nnx's design strengths:

  • Explicit randomness (nnx.Rngs)
  • Parameter/state separation
  • Compositional design
  • Compatibility with Optax and other JAX tools

Having this example available in nnx would significantly benefit users exploring or transitioning to the new API.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions