Skip to content

Restore GPU-saved Orbax checkpoint on CPU #2473

@prekshan

Description

@prekshan

I’m trying to load a checkpoint saved on a Linux CUDA box into a CPU-only machine. I keep hitting version-dependent errors around PyTreeRestore vs ArrayRestoreArgs and sharding requirements. Looking for a minimal, version-correct snippet to restore everything onto CPU (either jax.Array on CPU or plain np.ndarray), ignoring original GPU sharding.

This is the barebones code:

import os
import orbax.checkpoint as ocp

orbax_restore = ocp.CheckpointManager(f"{os.getcwd()}/checkpoints/2025-10-09-17-20-33")
restored = orbax_restore.restore(orbax_restore.latest_step(), args=ocp.args.PyTreeRestore())

I've tried various iterations of this code based on past issues posted, but no luck.

The exact error on running this:

ERROR:root:Device cuda:0 was not found in jax.local_devices().

Any help would be much appreciated.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions