Skip to content

Checkpointer memory leak #2090

@sirfz

Description

@sirfz

TLDR; version 0.11.8 works, latest is leaking (haven't tested with versions in-between).

I've originally opened an issue about this at google-deepmind/gemma#354 but nailed it down to the current latest orbax-checkpoint (v0.11.19):

Running lora.py:

python -m kauldron.main --cfg=lora.py --cfg.workdir=/tmp/ckpt_lora

Prints out warnings like

replica_slices.py:419] Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=True
W external/xla/xla/stream_executor/integrations/stream_executor_allocator.cc:66] could not allocate pinned host of size: 4294967296
...

RAM (not VRAM) grows rapidly and indefinitely.

Disabling checkpointer in the config eliminates the leak.

Versions:
gemma: 3.0.2
jax: 0.6.2
orbax-checkpoint: 0.11.19
kauldron: 1.2.2
NVIDIA Driver Version: 570.124.06 CUDA Version: 12.8

After doing a quick search through issues over here, I saw #1713 refers to use_replica_parallel which I then found in the CHANGELOG re-enabled in version 0.11.9. I decided to install 0.11.8 instead and re-ran the experiment, voila! checkpointing works normally now.

I also noticed that when training GEMMA3_1B_IT, the printed message say use_replica_parallel=False, enable_pinned_host_transfer=False (both are False) so seems the issue is related to one/both of these settings.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions