-
Notifications
You must be signed in to change notification settings - Fork 68
Description
TLDR; version 0.11.8 works, latest is leaking (haven't tested with versions in-between).
I've originally opened an issue about this at google-deepmind/gemma#354 but nailed it down to the current latest orbax-checkpoint (v0.11.19):
Running lora.py:
python -m kauldron.main --cfg=lora.py --cfg.workdir=/tmp/ckpt_lora
Prints out warnings like
replica_slices.py:419] Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=True
W external/xla/xla/stream_executor/integrations/stream_executor_allocator.cc:66] could not allocate pinned host of size: 4294967296
...
RAM (not VRAM) grows rapidly and indefinitely.
Disabling checkpointer in the config eliminates the leak.
Versions:
gemma: 3.0.2
jax: 0.6.2
orbax-checkpoint: 0.11.19
kauldron: 1.2.2
NVIDIA Driver Version: 570.124.06 CUDA Version: 12.8
After doing a quick search through issues over here, I saw #1713 refers to use_replica_parallel which I then found in the CHANGELOG re-enabled in version 0.11.9. I decided to install 0.11.8 instead and re-ran the experiment, voila! checkpointing works normally now.
I also noticed that when training GEMMA3_1B_IT, the printed message say use_replica_parallel=False, enable_pinned_host_transfer=False
(both are False) so seems the issue is related to one/both of these settings.