-
Notifications
You must be signed in to change notification settings - Fork 68
Open
Description
I’m trying to load a checkpoint saved on a Linux CUDA box into a CPU-only machine. I keep hitting version-dependent errors around PyTreeRestore vs ArrayRestoreArgs and sharding requirements. Looking for a minimal, version-correct snippet to restore everything onto CPU (either jax.Array on CPU or plain np.ndarray), ignoring original GPU sharding.
This is the barebones code:
import os
import orbax.checkpoint as ocp
orbax_restore = ocp.CheckpointManager(f"{os.getcwd()}/checkpoints/2025-10-09-17-20-33")
restored = orbax_restore.restore(orbax_restore.latest_step(), args=ocp.args.PyTreeRestore())
I've tried various iterations of this code based on past issues posted, but no luck.
The exact error on running this:
ERROR:root:Device cuda:0 was not found in jax.local_devices().
Any help would be much appreciated.
Metadata
Metadata
Assignees
Labels
No labels