-
Notifications
You must be signed in to change notification settings - Fork 763
Description
Description
The Module & Pytree guide in the Flax documentation is severely outdated and non-functional. Approximately half of the code cells in this guide fail to execute, making it impossible for users to learn from the examples.
Environment
- Flax version: [latest stable as of the documentation]
- JAX version: [current]
- Python version: 3.12
Issues Found
Running through the notebook reveals multiple breaking issues:
1. First Example Fails Immediately
class Linear(nnx.Module):
def __init__(self, din, dout, rngs: nnx.Rngs):
self.din, self.dout = din, dout
self.kernel = nnx.Param(rngs.normal((din, dout)))
rngs = nnx.Rngs(0)
weights = Linear(2, 3, rngs=rngs)Error:
TypeError: RngStream.__call__() takes 1 positional argument but 2 were given
2. Missing nnx.Pytree Class
Multiple examples reference nnx.Pytree which doesn't exist in the current API:
class Linear(nnx.Pytree): # AttributeError: module 'flax.nnx' has no attribute 'Pytree'3. Missing nnx.List Container
self.layers = nnx.List([...]) # AttributeError: module 'flax.nnx' has no attribute 'List'4. Missing Utility Functions
nnx.is_data()- doesn't existnnx.find_duplicates()- doesn't exist- Various other API mismatches
5. API Inconsistencies Throughout
The guide references an API surface that appears to be from an older or planned version of NNX that doesn't match the current implementation.
Impact
This is a critical documentation issue because:
- First impressions matter: Users trying to learn Flax NNX hit immediate failures
- Wastes developer time: Hours spent debugging what turns out to be doc issues
- Erodes trust: When core documentation doesn't work, it raises questions about library stability
- Blocks adoption: Potential users will simply move to alternatives with working docs
Expected Behavior
Documentation examples should:
- Execute without errors
- Use current API patterns
- Match the installed version of Flax
- Include version compatibility notes if APIs changed
Additional Context
I understand that JAX/Flax is evolving rapidly and hasn't reached 1.0 yet. However, having non-functional core documentation creates a significant barrier to adoption. Even if the API is unstable, the docs should accurately reflect the current state.
Would appreciate if this could be prioritized - happy to help test updated examples if needed. I'd love to write updated parts myself, but it's hard with incomplete docs like this :)