Closed
Description
I can use a regular ocp.StandardCheckpointer()
to save and load the model state, but using the ocp.CheckpointManager()
fails. I believe I proceed in a standard way as described in the minimum example below:
# Define a very simple NNX model.
class OneLayerMLP(nnx.Module):
def __init__(self, dim, rngs: nnx.Rngs):
self.linear = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)
def __call__(self, x):
return self.linear(x)
# Create the model
model = OneLayerMLP(4, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(42), (3, 4))
assert model(x).shape == (3, 4)
# Retrieve the state
_, state = nnx.split(model)
ckpt_dir = '/some/path/...'
# Create the checkpoint manager and save the state.
handler = ocp.StandardCheckpointHandler()
checkpointer = ocp.Checkpointer(handler)
checkpoint_manager = ocp.CheckpointManager(
str(ckpt_dir / 'state'),
checkpointer,
ocp.CheckpointManagerOptions(
save_interval_steps=1, max_to_keep=5
),
)
checkpoint_manager.save(123, state)
# Restore the state.
abstract_model = nnx.eval_shape(lambda: OneLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
handler = ocp.StandardCheckpointHandler()
checkpointer = ocp.Checkpointer(handler)
checkpoint_manager = ocp.CheckpointManager(
str(ckpt_dir / 'state'),
checkpointer)
state_restored = checkpoint_manager.restore(checkpoint_manager.latest_step())
# jax.tree.map(np.testing.assert_array_equal, state, state_restored) # This call would already fail, see below [2]
# Run the restored model.
model = nnx.merge(graphdef, state_restored)
assert model(x).shape == (3, 4) # This call fails on error [1], see below
The error [1] which I see when trying to feed-forward through the loaded model is:
TypeError: Unexpected input type for array: <class 'dict'>
When uncommenting the line above which compares the original and loaded state, I get the error [2] already indicating a problem:
ValueError: Custom node type mismatch: expected type: <class 'flax.nnx.statelib.State'>, ...
What am I missing? How should one use the CheckpointManager
correctly together with the nnx models?
System information
flax version: 0.10.2
orbax.checkpoint version: 0.11.0
Metadata
Metadata
Assignees
Labels
No labels