Description
I'm trying to migrate from an older version of orbax (v0.6.4
) to a newer one (v0.11.10
), and I have a problem with restoring checkpoints.
I have two similar checkpoints, and I need to be able to load both of them.
Newer checkpoint version
$ tree -a /tmp/physmodjax/checkpoints_11po8hwr:v0/
/tmp/physmodjax/checkpoints_11po8hwr:v0/
├── checkpoints
│ └── 2351
│ ├── _CHECKPOINT_METADATA
│ ├── metrics
│ │ └── metrics
│ └── state
│ ├── d
│ │ └── 28136664d2dc46fa105f6e6d9bb416fa
│ ├── manifest.ocdbt
│ ├── _METADATA
│ ├── ocdbt.process_0
│ │ ├── d
│ │ │ ├── 23a6136b0d9ee643613c3b4a996bd1f8
│ │ │ ├── 4cf19a245d689524d873048019107e95
│ │ │ ├── 800aa1a9a609496e5ebb84e79682f962
│ │ │ ├── 870d7e0a616a05e2f894a49b684cb51e
│ │ │ ├── a275135cc9f930113aae8d2b20689de1
│ │ │ ├── e7cf94bd5e9366bbdbb30556eb89136d
│ │ │ └── ee2661f94896574e87a14149c0980612
│ │ └── manifest.ocdbt
│ └── _sharding
└── .hydra
└── config.yaml
Older checkpoint version
$ tree -a /tmp/physmodjax/checkpoints_4sa4dawx:v0/
/tmp/physmodjax/checkpoints_4sa4dawx:v0/
├── checkpoints
│ └── 851
│ ├── _CHECKPOINT_METADATA
│ ├── default
│ │ ├── d
│ │ │ └── b45c3d9920fabadd4c2813f6219571b0
│ │ ├── manifest.ocdbt
│ │ ├── _METADATA
│ │ ├── ocdbt.process_0
│ │ │ ├── d
│ │ │ │ ├── 6b4ac6e43f39060f80d607021694e22a
│ │ │ │ ├── c01a0dffd52b512886457d45d527c5ac
│ │ │ │ ├── cd854566ac3aeb55292d17b25f6b6ed6
│ │ │ │ ├── cfe86c22634cbd3abb8d270f9dfc9d8d
│ │ │ │ ├── d448eb8462c257f7522a708129dcd17b
│ │ │ │ └── d7d7efe7b1a749042abf00025dbd7233
│ │ │ └── manifest.ocdbt
│ │ └── _sharding
│ └── metrics
│ └── metrics
└── .hydra
└── config.yaml
The current code, corresponding to the newer checkpoint version, to save the checkpoint is:
options = hydra.utils.instantiate(cfg.checkpoint_manager_options)
with obc.CheckpointManager(
directory=Path(output_dir) / "checkpoints",
options=options,
item_handlers={"state": obc.PyTreeCheckpointHandler()},
) as checkpoint_manager:
_ = train(
model_cls=model_cls,
datamodule=datamodule,
cfg=cfg,
checkpoint_manager=checkpoint_manager,
)
checkpoint_manager.wait_until_finished()
Inside the train
function, in the training loop, I have the following code to save the checkpoint:
checkpoint_manager.save(
step=epoch,
args=obc.args.Composite(
state=obc.args.PyTreeSave(state),
),
metrics=metrics_to_log,
)
With the following code I was able to restore both types of checkpoints with v0.6.4:
import jax
jax.config.update("jax_platforms", "cuda")
print(jax.config.jax_platforms)
import hydra
from omegaconf import OmegaConf
from pathlib import Path
from flax.training import train_state
import flax.linen as nn
import orbax.checkpoint as obc
from typing import Any
def restore_experiment_state(
run_path: Path, # Path to the run directory (e.g. "outputs/2024-01-23/22-15-11")
best: bool = True, # If True, restore the best checkpoint instead of the latest
step_to_restore: int = None, # If not None, restore the checkpoint at this step
kwargs: dict = {}, # Additional arguments to pass to the model
) -> tuple[train_state.TrainState, nn.Module, obc.CheckpointManager]:
"""
Restores the train state from a run.
Args:
run_path (Path): Path to the run directory (e.g. "outputs/2024-01-23/22-15-11")
Returns:
-------
train_state.TrainState: The train state of the experiment
nn.Module: The model used in the experiment
CheckpointManager: The checkpoint manager
"""
# Make sure the path is a Path object
run_path = Path(run_path)
# These are hardcoded, do not change
ckpt_path = run_path / "checkpoints"
config_path = run_path / ".hydra" / "config.yaml"
cfg = OmegaConf.load(config_path)
options = obc.CheckpointManagerOptions(
max_to_keep=1,
create=True,
best_fn=lambda x: float(
x["val/mae_rel"]
), # Shouldn't be hardcoded here, not a problem atm because we only save one step, best
best_mode="min",
)
with obc.CheckpointManager(
ckpt_path,
options=options,
item_handlers={
"state": obc.PyTreeCheckpointHandler(),
"default": obc.PyTreeCheckpointHandler(),
},
) as checkpoint_manager:
model_cls: nn.Module = hydra.utils.instantiate(cfg.model)
model = model_cls(training=False, **kwargs)
# Get checkpoint metadata
step = (
checkpoint_manager.latest_step()
if not best
else checkpoint_manager.best_step()
)
step = step_to_restore if step_to_restore is not None else step
metadatas = checkpoint_manager.item_metadata(step)
print(f"Restoring checkpoint from step {step}...")
# Backwards compatibility for older checkpoints
if "state" in metadatas and metadatas.state is not None:
metadata_state = metadatas.state
ckpt_type = "state"
print("This is a checkpoint with new formatting")
elif "default" in metadatas and metadatas.default is not None:
print("This is a checkpoint with old formatting")
assert "model" in metadatas.default, "No model found in the checkpoint"
metadata_state = metadatas.default["model"]
ckpt_type = "default"
else:
raise ValueError("No state found in the checkpoint")
# Check if the checkpoint has batch_stats
if "batch_stats" in metadata_state:
# Define TrainState with optional batch_stats
class TrainState(train_state.TrainState):
key: jax.Array
batch_stats: Any = None # Optional field
# Initialize the empty state
empty_state = TrainState(
key={},
step=0,
apply_fn=model.apply,
params=metadata_state["params"],
tx={},
opt_state=metadata_state["opt_state"],
batch_stats=metadata_state["batch_stats"],
)
else:
# Define TrainState with optional batch_stats
class TrainState(train_state.TrainState):
key: jax.Array
empty_state = TrainState(
key={},
step=0,
apply_fn=model.apply,
params=metadata_state["params"],
tx={},
opt_state=metadata_state["opt_state"],
)
old_ckpt = {"model": empty_state}
restored_checkpoint = checkpoint_manager.restore(
step=step,
args=obc.args.Composite(
default=obc.args.PyTreeRestore(
item=old_ckpt,
),
state=obc.args.PyTreeRestore(
item=empty_state,
),
),
)
if ckpt_type == "state":
state = restored_checkpoint.state
elif ckpt_type == "default":
state = restored_checkpoint.default["model"]
return state, model, checkpoint_manager
checkpoint_path = Path("/tmp/physmodjax/checkpoints_11po8hwr:v0/") # Newer checkpoint version
checkpoint_path = Path("/tmp/physmodjax/checkpoints_4sa4dawx:v0/") # Older checkpoint version
conf = OmegaConf.load(checkpoint_path / ".hydra" / "config.yaml")
kwargs = {"n_steps": conf.datamodule.num_steps_train[1]}
state, model, ckpt_manager = restore_experiment_state(
checkpoint_path,
kwargs=kwargs,
)
print("Restored model!!")
Although it works only for restoring to the same backend (cuda
), only. It gives me a warning regarding this
$ python simple_restore_issue.py
cuda
Restoring checkpoint from step 851...
This is a checkpoint with old formatting
/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:1330: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
warnings.warn(
Restored model!!
I was trying to get it to load correctly to both cpu
and cuda
backends doing shenanigans like
def apply_sharding(array, sharding):
if isinstance(array, obc.metadata.value.ArrayMetadata):
array.sharding = sharding
default_sharding = obc.metadata.sharding.SingleDeviceShardingMetadata(device_str=str(jax.local_devices()[0]))
jax.tree_util.tree_map(
lambda x: apply_sharding(x, default_sharding),
metadata_state,
)
metadata_state = apply_default_sharding(metadata_state)
Inside the restore_experiment_state
function, but after bashing my head into a wall for hours, I decided to update orbax to a newer version (v0.11.10
) to use the new API and not code against a deprecated version.
But now I'm back at square one, because with v0.11.10
I can't restore the existing checkpoints, not even on the same backend.
Newer checkpoint version:
$ python simple_restore_issue.py
jax.config.jax_platforms: cuda
Restoring checkpoint from step 2351...
This is a checkpoint with new formatting
Traceback (most recent call last):
File "/home/carlos/projects/physmodjax_private/examples/evaluation/mlsp25/simple_restore_issue.py", line 144, in <module>
state, model, ckpt_manager = restore_experiment_state(
File "/home/carlos/projects/physmodjax_private/examples/evaluation/mlsp25/simple_restore_issue.py", line 116, in restore_experiment_state
restored_checkpoint = checkpoint_manager.restore(
File "/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1566, in restore
restored = self._checkpointer.restore(restore_directory, args=args)
File "/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 545, in restore
return super().restore(directory, *args, **kwargs)
File "/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/checkpointer.py", line 289, in restore
restored = self._restore(directory, args=ckpt_args)
File "/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/checkpointer.py", line 308, in _restore
return self._handler.restore(directory, args=args)
File "/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py", line 831, in restore
raise KeyError(
KeyError: 'Item "default" was not found in the checkpoint. Available items: [\'metrics\', \'state\']'
Older checkpoint version:
$ python simple_restore_issue.py
jax.config.jax_platforms: cuda
Restoring checkpoint from step 851...
This is a checkpoint with old formatting
/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py:1250: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
warnings.warn(
WARNING:absl:[process=0][thread=MainThread] No metadata found for any process_index, checkpoint_dir=/tmp/physmodjax/checkpoints_4sa4dawx:v0/checkpoints/851/default. time elapsed=0.00021982192993164062 seconds. If the checkpoint does not contain jax.Array then it is expected. If checkpoint contains jax.Array then it should lead to an error eventually; if no error is raised then it is a bug.
Traceback (most recent call last):
File "/home/carlos/projects/physmodjax_private/examples/evaluation/mlsp25/simple_restore_issue.py", line 144, in <module>
state, model, ckpt_manager = restore_experiment_state(
File "/home/carlos/projects/physmodjax_private/examples/evaluation/mlsp25/simple_restore_issue.py", line 116, in restore_experiment_state
restored_checkpoint = checkpoint_manager.restore(
File "/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1566, in restore
restored = self._checkpointer.restore(restore_directory, args=args)
File "/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 545, in restore
return super().restore(directory, *args, **kwargs)
File "/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/checkpointer.py", line 289, in restore
restored = self._restore(directory, args=ckpt_args)
File "/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/checkpointer.py", line 308, in _restore
return self._handler.restore(directory, args=args)
File "/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py", line 831, in restore
raise KeyError(
KeyError: 'Item "state" was not found in the checkpoint. Available items: [\'default\', \'metrics\']'
What is the best way to do this? I feel like I'm going crazy, thanks a lot for any help you can provide.
Bonus points if you can help me restore the checkpoints to both cpu
and cuda
backends as well, if not I will have to possibly open an issue later for that specifically if I get this working.