-
Notifications
You must be signed in to change notification settings - Fork 20
Description
We've been unable to run a single kauldron training job on standard TPU VM v4-64 with neither /tmp, /nfs_share, or GCS workdirs. This has led us to believe that perhaps kauldron didn't support v4 after all? We have followed all up-to-date recommendations from fellow TRX program members, such as using tpux to set up the pod, and using podrun -iwc for running the code on all machines from the same working directory, and cleaning up TPU state after the fact. We used gemma/examples/dpo.py with batch_size modified according to allowed minimum (in our case, 64) and HF_TOKEN environment variable set to allow downloading model weights from Huggingface.
How to reproduce:
- Create a TPU VM v4-64 from
tpu-ubuntu-2204-baseimage - Follow tpux recommended installation procedure
- All operations as follows are executed on every worker:
- Confirm slice operational via
python -c 'import jax; jax.distributed.initialize(); print(jax.distributed.is_initialized())' - Install
sudo apt install libgl1which is apparently required - Install kauldron
cd kauldron; pip install -e . - Install gemma
cd gemma; pip install -e .
Potentially relevant
The output from kauldron.main in all cases is full of stuff like these:
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.We made some additional effort to keep the distributed system initialised at all times:
diff --git a/examples/dpo.py b/examples/dpo.py
index ec6c6fb..9160cd5 100644
--- a/examples/dpo.py
+++ b/examples/dpo.py
@@ -35,6 +35,9 @@ python -m kauldron.main \
"""
+import jax
+if not jax.distributed.is_initialized():
+ jax.distributed.initialize()
from kauldron import konfig
@@ -43,10 +46,17 @@ with konfig.imports():
from gemma import gm
from kauldron import kd
import optax
+ import jax
+ if not jax.distributed.is_initialized():
+ jax.distributed.initialize()
# pylint: enable=g-import-not-at-top
def get_config():
+ import jax
+ if not jax.distributed.is_initialized():
+ jax.distributed.initialize()
+
"""Get the default hyperparameter configuration."""
return kd.train.Trainer(
seed=42,
@@ -78,10 +88,10 @@ def get_config():
),
# Evaluation
evals={
- # "test": kd.evals.Evaluator(
- # run=kd.evals.EveryNSteps(1000),
- # ds=_make_dataset(training=False),
- # ),
+ "test": kd.evals.Evaluator(
+ run=kd.evals.EveryNSteps(1000),
+ ds=_make_dataset(training=False),
+ ),
},
)
@@ -89,7 +99,7 @@ def get_config():
def _make_dataset(training: bool) -> kd.data.Pipeline:
# TODO(epot): !!!!
max_length = 512
- batch_size = 16
+ batch_size = 64
tokenizer = gm.text.Gemma3Tokenizer()Working directory in either /tmp or /nfs_share
The command executed on all machines in the slice:
cd gemma
python -m kauldron.main \
--cfg=examples/dpo.py \
--cfg.workdir=/tmp/kdShortly after the computation would begin, output hangs:
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Core halted unexpectedly: INTERNAL: Accelerator device halted prematurely, perhaps due to an on-device check-failure. Node 0 halted unexpectedly at tag:pc TensorCoreSequencer:1:0x175 (from TensorCoreSequencer:1:0x21f): scheckne:
***************
An unexpected peer shows up in the launch group with a different launch id than the current group leader. If using single controller backends, this signals a bug in the runtime scheduling system or above. If using multi-controller backends, non-determinism of 1) model code or 2) XLA compiler may also cause this, enable HLO dump for all workers and check: 1) if before_optimizations.txt are all the same or 2) if after_optimizations.txt are all the same. no HLO mapping
=== Source Location Trace: ===
learning/45eac/tpu/runtime/hal/internal/tpu_program_termination_validation.cc:180
train: 0%| | 18/10001 [07:37<9:26:05, 3.40s/it] I0428 15:20:29.934811 139621177185856 grain_pool.py:520] Grain pool is exiting.
I0428 15:20:29.934961 139621177185856 grain_pool.py:525] Shutting down multiprocessing system.
train: 0%| | 30/10001 [07:38<1:13:27, 2.26it/s]I0428 15:20:31.514525 139621177185856 grain_pool.py:525] Shutting down multiprocessing system.
train: 0%| | 32/10001 [07:50<54:33, 3.05it/s]/usr/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 54 leaked shared_memory objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
Working directory in GCS bucket:
The command executed on all machines in the slice:
cd gemma
python -m kauldron.main \
--cfg=examples/dpo.py \
--cfg.workdir=gs://bucketOutput hangs on:
checkpoint_manager.py:1400] [process=2] Saving checkpoint at step 0
checkpoint_manager.py:1400] [process=0] Saving checkpoint at step 0
checkpoint_manager.py:1400] [process=6] Saving checkpoint at step 0
checkpoint_manager.py:1400] [process=7] Saving checkpoint at step 0
checkpoint_manager.py:1400] [process=5] Saving checkpoint at step 0
checkpoint_manager.py:1400] [process=4] Saving checkpoint at step 0
checkpoint_manager.py:1400] [process=3] Saving checkpoint at step 0
checkpoint_manager.py:1400] [process=1] Saving checkpoint at step 0
async_checkpointer.py:433] [process=0] Started async saving checkpoint to gs://bucket/checkpoints/ckpt_0.
async_checkpointer.py:433] [process=3] Started async saving checkpoint to gs://bucket/checkpoints/ckpt_0.
async_checkpointer.py:433] [process=6] Started async saving checkpoint to gs://bucket/checkpoints/ckpt_0.
async_checkpointer.py:433] [process=1] Started async saving checkpoint to gs://bucket/checkpoints/ckpt_0.
async_checkpointer.py:433] [process=2] Started async saving checkpoint to gs://bucket/checkpoints/ckpt_0.
async_checkpointer.py:433] [process=7] Started async saving checkpoint to gs://bucket/checkpoints/ckpt_0.
async_checkpointer.py:433] [process=4] Started async saving checkpoint to gs://bucket/checkpoints/ckpt_0.
async_checkpointer.py:433] [process=5] Started async saving checkpoint to gs://bucket/checkpoints/ckpt_0.
signaling_client.py:257] Using JaxDistributedSignalingClient
signaling_client.py:257] Using JaxDistributedSignalingClient
signaling_client.py:257] Using JaxDistributedSignalingClient
signaling_client.py:257] Using JaxDistributedSignalingClient
signaling_client.py:257] Using JaxDistributedSignalingClient
signaling_client.py:257] Using JaxDistributedSignalingClient
signaling_client.py:257] Using JaxDistributedSignalingClient
signaling_client.py:257] Using JaxDistributedSignalingClient
atomicity.py:144] Creating tmp directory gs://bucket/checkpoints/ckpt_0