Skip to content

Restoring a checkpoint with Orbax CheckpointManager fails #4521

Closed
@bednarikjan

Description

@bednarikjan

I can use a regular ocp.StandardCheckpointer() to save and load the model state, but using the ocp.CheckpointManager() fails. I believe I proceed in a standard way as described in the minimum example below:

# Define a very simple NNX model.
class OneLayerMLP(nnx.Module):
  def __init__(self, dim, rngs: nnx.Rngs):
    self.linear = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)    

  def __call__(self, x):
    return self.linear(x)

# Create the model
model = OneLayerMLP(4, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(42), (3, 4))
assert model(x).shape == (3, 4)

# Retrieve the state
_, state = nnx.split(model)
ckpt_dir = '/some/path/...'

# Create the checkpoint manager and save the state.
handler = ocp.StandardCheckpointHandler()
checkpointer = ocp.Checkpointer(handler)
checkpoint_manager = ocp.CheckpointManager(
    str(ckpt_dir / 'state'),
    checkpointer,
    ocp.CheckpointManagerOptions(
        save_interval_steps=1, max_to_keep=5
    ),
)
checkpoint_manager.save(123, state)

# Restore the state.
abstract_model = nnx.eval_shape(lambda: OneLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)

handler = ocp.StandardCheckpointHandler()
checkpointer = ocp.Checkpointer(handler)
checkpoint_manager = ocp.CheckpointManager(
    str(ckpt_dir / 'state'),
    checkpointer)
state_restored = checkpoint_manager.restore(checkpoint_manager.latest_step())
# jax.tree.map(np.testing.assert_array_equal, state, state_restored)  # This call would already fail, see below [2]

# Run the restored model.
model = nnx.merge(graphdef, state_restored)
assert model(x).shape == (3, 4)  # This call fails on error [1], see below

The error [1] which I see when trying to feed-forward through the loaded model is:

TypeError: Unexpected input type for array: <class 'dict'>

When uncommenting the line above which compares the original and loaded state, I get the error [2] already indicating a problem:

ValueError: Custom node type mismatch: expected type: <class 'flax.nnx.statelib.State'>, ...

What am I missing? How should one use the CheckpointManager correctly together with the nnx models?

System information

flax version: 0.10.2
orbax.checkpoint version: 0.11.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions