Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restoring a checkpoint with Orbax CheckpointManager fails #4521

Open
bednarikjan opened this issue Feb 1, 2025 · 0 comments
Open

Restoring a checkpoint with Orbax CheckpointManager fails #4521

bednarikjan opened this issue Feb 1, 2025 · 0 comments

Comments

@bednarikjan
Copy link

I can use a regular ocp.StandardCheckpointer() to save and load the model state, but using the ocp.CheckpointManager() fails. I believe I proceed in a standard way as described in the minimum example below:

# Define a very simple NNX model.
class OneLayerMLP(nnx.Module):
  def __init__(self, dim, rngs: nnx.Rngs):
    self.linear = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)    

  def __call__(self, x):
    return self.linear(x)

# Create the model
model = OneLayerMLP(4, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(42), (3, 4))
assert model(x).shape == (3, 4)

# Retrieve the state
_, state = nnx.split(model)
ckpt_dir = '/some/path/...'

# Create the checkpoint manager and save the state.
handler = ocp.StandardCheckpointHandler()
checkpointer = ocp.Checkpointer(handler)
checkpoint_manager = ocp.CheckpointManager(
    str(ckpt_dir / 'state'),
    checkpointer,
    ocp.CheckpointManagerOptions(
        save_interval_steps=1, max_to_keep=5
    ),
)
checkpoint_manager.save(123, state)

# Restore the state.
abstract_model = nnx.eval_shape(lambda: OneLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)

handler = ocp.StandardCheckpointHandler()
checkpointer = ocp.Checkpointer(handler)
checkpoint_manager = ocp.CheckpointManager(
    str(ckpt_dir / 'state'),
    checkpointer)
state_restored = checkpoint_manager.restore(checkpoint_manager.latest_step())
# jax.tree.map(np.testing.assert_array_equal, state, state_restored)  # This call would already fail, see below [2]

# Run the restored model.
model = nnx.merge(graphdef, state_restored)
assert model(x).shape == (3, 4)  # This call fails on error [1], see below

The error [1] which I see when trying to feed-forward through the loaded model is:

TypeError: Unexpected input type for array: <class 'dict'>

When uncommenting the line above which compares the original and loaded state, I get the error [2] already indicating a problem:

ValueError: Custom node type mismatch: expected type: <class 'flax.nnx.statelib.State'>, ...

What am I missing? How should one use the CheckpointManager correctly together with the nnx models?

System information

flax version: 0.10.2
orbax.checkpoint version: 0.11.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant