You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I can use a regular ocp.StandardCheckpointer() to save and load the model state, but using the ocp.CheckpointManager() fails. I believe I proceed in a standard way as described in the minimum example below:
# Define a very simple NNX model.
class OneLayerMLP(nnx.Module):
def __init__(self, dim, rngs: nnx.Rngs):
self.linear = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)
def __call__(self, x):
return self.linear(x)
# Create the model
model = OneLayerMLP(4, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(42), (3, 4))
assert model(x).shape == (3, 4)
# Retrieve the state
_, state = nnx.split(model)
ckpt_dir = '/some/path/...'
# Create the checkpoint manager and save the state.
handler = ocp.StandardCheckpointHandler()
checkpointer = ocp.Checkpointer(handler)
checkpoint_manager = ocp.CheckpointManager(
str(ckpt_dir / 'state'),
checkpointer,
ocp.CheckpointManagerOptions(
save_interval_steps=1, max_to_keep=5
),
)
checkpoint_manager.save(123, state)
# Restore the state.
abstract_model = nnx.eval_shape(lambda: OneLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
handler = ocp.StandardCheckpointHandler()
checkpointer = ocp.Checkpointer(handler)
checkpoint_manager = ocp.CheckpointManager(
str(ckpt_dir / 'state'),
checkpointer)
state_restored = checkpoint_manager.restore(checkpoint_manager.latest_step())
# jax.tree.map(np.testing.assert_array_equal, state, state_restored) # This call would already fail, see below [2]
# Run the restored model.
model = nnx.merge(graphdef, state_restored)
assert model(x).shape == (3, 4) # This call fails on error [1], see below
The error [1] which I see when trying to feed-forward through the loaded model is:
TypeError: Unexpected input type for array: <class 'dict'>
When uncommenting the line above which compares the original and loaded state, I get the error [2] already indicating a problem:
ValueError: Custom node type mismatch: expected type: <class 'flax.nnx.statelib.State'>, ...
What am I missing? How should one use the CheckpointManager correctly together with the nnx models?
I can use a regular
ocp.StandardCheckpointer()
to save and load the model state, but using theocp.CheckpointManager()
fails. I believe I proceed in a standard way as described in the minimum example below:The error [1] which I see when trying to feed-forward through the loaded model is:
TypeError: Unexpected input type for array: <class 'dict'>
When uncommenting the line above which compares the original and loaded state, I get the error [2] already indicating a problem:
ValueError: Custom node type mismatch: expected type: <class 'flax.nnx.statelib.State'>, ...
What am I missing? How should one use the
CheckpointManager
correctly together with the nnx models?System information
flax version: 0.10.2
orbax.checkpoint version: 0.11.0
The text was updated successfully, but these errors were encountered: