Skip to content

Outdated and Broken Documentation: Module & Pytree Guide #5100

@wtfnukee

Description

@wtfnukee

Description

The Module & Pytree guide in the Flax documentation is severely outdated and non-functional. Approximately half of the code cells in this guide fail to execute, making it impossible for users to learn from the examples.

Environment

  • Flax version: [latest stable as of the documentation]
  • JAX version: [current]
  • Python version: 3.12

Issues Found

Running through the notebook reveals multiple breaking issues:

1. First Example Fails Immediately

class Linear(nnx.Module):
  def __init__(self, din, dout, rngs: nnx.Rngs):
    self.din, self.dout = din, dout
    self.kernel = nnx.Param(rngs.normal((din, dout)))

rngs = nnx.Rngs(0)
weights = Linear(2, 3, rngs=rngs)

Error:

TypeError: RngStream.__call__() takes 1 positional argument but 2 were given

2. Missing nnx.Pytree Class

Multiple examples reference nnx.Pytree which doesn't exist in the current API:

class Linear(nnx.Pytree):  # AttributeError: module 'flax.nnx' has no attribute 'Pytree'

3. Missing nnx.List Container

self.layers = nnx.List([...])  # AttributeError: module 'flax.nnx' has no attribute 'List'

4. Missing Utility Functions

  • nnx.is_data() - doesn't exist
  • nnx.find_duplicates() - doesn't exist
  • Various other API mismatches

5. API Inconsistencies Throughout

The guide references an API surface that appears to be from an older or planned version of NNX that doesn't match the current implementation.

Impact

This is a critical documentation issue because:

  1. First impressions matter: Users trying to learn Flax NNX hit immediate failures
  2. Wastes developer time: Hours spent debugging what turns out to be doc issues
  3. Erodes trust: When core documentation doesn't work, it raises questions about library stability
  4. Blocks adoption: Potential users will simply move to alternatives with working docs

Expected Behavior

Documentation examples should:

  • Execute without errors
  • Use current API patterns
  • Match the installed version of Flax
  • Include version compatibility notes if APIs changed

Additional Context

I understand that JAX/Flax is evolving rapidly and hasn't reached 1.0 yet. However, having non-functional core documentation creates a significant barrier to adoption. Even if the API is unstable, the docs should accurately reflect the current state.

Would appreciate if this could be prioritized - happy to help test updated examples if needed. I'd love to write updated parts myself, but it's hard with incomplete docs like this :)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions