Skip to content

Training jobs stall in TPU VM v4-64 environment regardless of workdir #1120

@tucnak

Description

@tucnak

We've been unable to run a single kauldron training job on standard TPU VM v4-64 with neither /tmp, /nfs_share, or GCS workdirs. This has led us to believe that perhaps kauldron didn't support v4 after all? We have followed all up-to-date recommendations from fellow TRX program members, such as using tpux to set up the pod, and using podrun -iwc for running the code on all machines from the same working directory, and cleaning up TPU state after the fact. We used gemma/examples/dpo.py with batch_size modified according to allowed minimum (in our case, 64) and HF_TOKEN environment variable set to allow downloading model weights from Huggingface.

How to reproduce:

  1. Create a TPU VM v4-64 from tpu-ubuntu-2204-base image
  2. Follow tpux recommended installation procedure
  3. All operations as follows are executed on every worker:
  4. Confirm slice operational via python -c 'import jax; jax.distributed.initialize(); print(jax.distributed.is_initialized())'
  5. Install sudo apt install libgl1 which is apparently required
  6. Install kauldron cd kauldron; pip install -e .
  7. Install gemma cd gemma; pip install -e .

Potentially relevant

The output from kauldron.main in all cases is full of stuff like these:

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.

We made some additional effort to keep the distributed system initialised at all times:

diff --git a/examples/dpo.py b/examples/dpo.py
index ec6c6fb..9160cd5 100644
--- a/examples/dpo.py
+++ b/examples/dpo.py
@@ -35,6 +35,9 @@ python -m kauldron.main \

 """
+import jax
+if not jax.distributed.is_initialized():
+    jax.distributed.initialize()

 from kauldron import konfig

@@ -43,10 +46,17 @@ with konfig.imports():
   from gemma import gm
   from kauldron import kd
   import optax
+  import jax
+  if not jax.distributed.is_initialized():
+      jax.distributed.initialize()
 # pylint: enable=g-import-not-at-top


 def get_config():
+  import jax
+  if not jax.distributed.is_initialized():
+      jax.distributed.initialize()
+
   """Get the default hyperparameter configuration."""
   return kd.train.Trainer(
       seed=42,
@@ -78,10 +88,10 @@ def get_config():
       ),
       # Evaluation
       evals={
-          # "test": kd.evals.Evaluator(
-          #     run=kd.evals.EveryNSteps(1000),
-          #     ds=_make_dataset(training=False),
-          # ),
+          "test": kd.evals.Evaluator(
+              run=kd.evals.EveryNSteps(1000),
+              ds=_make_dataset(training=False),
+          ),
       },
   )

@@ -89,7 +99,7 @@ def get_config():
 def _make_dataset(training: bool) -> kd.data.Pipeline:
   # TODO(epot): !!!!
   max_length = 512
-  batch_size = 16
+  batch_size = 64

   tokenizer = gm.text.Gemma3Tokenizer()

Working directory in either /tmp or /nfs_share

The command executed on all machines in the slice:

cd gemma
python -m kauldron.main \
    --cfg=examples/dpo.py \
    --cfg.workdir=/tmp/kd

Shortly after the computation would begin, output hangs:

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Core halted unexpectedly: INTERNAL: Accelerator device halted prematurely, perhaps due to an on-device check-failure. Node 0 halted unexpectedly at tag:pc TensorCoreSequencer:1:0x175 (from TensorCoreSequencer:1:0x21f): scheckne:
***************
An unexpected peer shows up in the launch group with a different launch id than the current group leader. If using single controller backends, this signals a bug in the runtime scheduling system or above. If using multi-controller backends, non-determinism of 1) model code or 2) XLA compiler may also cause this, enable HLO dump for all workers and check: 1) if before_optimizations.txt are all the same or 2) if after_optimizations.txt are all the same. no HLO mapping
=== Source Location Trace: ===
learning/45eac/tpu/runtime/hal/internal/tpu_program_termination_validation.cc:180

train:   0%|          | 18/10001 [07:37<9:26:05,  3.40s/it] I0428 15:20:29.934811 139621177185856 grain_pool.py:520] Grain pool is exiting.
I0428 15:20:29.934961 139621177185856 grain_pool.py:525] Shutting down multiprocessing system.
train:   0%|          | 30/10001 [07:38<1:13:27,  2.26it/s]I0428 15:20:31.514525 139621177185856 grain_pool.py:525] Shutting down multiprocessing system.
train:   0%|          | 32/10001 [07:50<54:33,  3.05it/s]/usr/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 54 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

Working directory in GCS bucket:

The command executed on all machines in the slice:

cd gemma
python -m kauldron.main \
    --cfg=examples/dpo.py \
    --cfg.workdir=gs://bucket

Output hangs on:

checkpoint_manager.py:1400] [process=2] Saving checkpoint at step 0
checkpoint_manager.py:1400] [process=0] Saving checkpoint at step 0
checkpoint_manager.py:1400] [process=6] Saving checkpoint at step 0
checkpoint_manager.py:1400] [process=7] Saving checkpoint at step 0
checkpoint_manager.py:1400] [process=5] Saving checkpoint at step 0
checkpoint_manager.py:1400] [process=4] Saving checkpoint at step 0
checkpoint_manager.py:1400] [process=3] Saving checkpoint at step 0
checkpoint_manager.py:1400] [process=1] Saving checkpoint at step 0
async_checkpointer.py:433] [process=0] Started async saving checkpoint to gs://bucket/checkpoints/ckpt_0.
async_checkpointer.py:433] [process=3] Started async saving checkpoint to gs://bucket/checkpoints/ckpt_0.
async_checkpointer.py:433] [process=6] Started async saving checkpoint to gs://bucket/checkpoints/ckpt_0.
async_checkpointer.py:433] [process=1] Started async saving checkpoint to gs://bucket/checkpoints/ckpt_0.
async_checkpointer.py:433] [process=2] Started async saving checkpoint to gs://bucket/checkpoints/ckpt_0.
async_checkpointer.py:433] [process=7] Started async saving checkpoint to gs://bucket/checkpoints/ckpt_0.
async_checkpointer.py:433] [process=4] Started async saving checkpoint to gs://bucket/checkpoints/ckpt_0.
async_checkpointer.py:433] [process=5] Started async saving checkpoint to gs://bucket/checkpoints/ckpt_0.
signaling_client.py:257] Using JaxDistributedSignalingClient
signaling_client.py:257] Using JaxDistributedSignalingClient
signaling_client.py:257] Using JaxDistributedSignalingClient
signaling_client.py:257] Using JaxDistributedSignalingClient
signaling_client.py:257] Using JaxDistributedSignalingClient
signaling_client.py:257] Using JaxDistributedSignalingClient
signaling_client.py:257] Using JaxDistributedSignalingClient
signaling_client.py:257] Using JaxDistributedSignalingClient
atomicity.py:144] Creating tmp directory gs://bucket/checkpoints/ckpt_0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions