Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kaggle TPU Multi-core Training Crashes with debug_single_process=False #8569

Open
mohamedamara7 opened this issue Jan 14, 2025 · 0 comments
Open

Comments

@mohamedamara7
Copy link

TPU Multi-core Training Crashes with debug_single_process=False

Description

I encountered an issue when using PyTorch XLA to train a model on Kaggle TPU. The code works correctly when debug_single_process=True (single core), but crashes when debug_single_process=false with the error:

Multi-Processing Function for Distributed Training

def _mp_fn(rank, flags):
    # Initialize the distributed process group
    if not dist.is_initialized():
        dist.init_process_group(backend='xla', init_method='xla://')

    # Setup distributed data samplers
    train_sampler, val_sampler, orig_sampler, test_sampler = None, None, None, None
    if xr.world_size() > 1:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            flags['TRAIN_DS'], num_replicas=xr.world_size(), rank=xr.global_ordinal(), shuffle=True
        )
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            flags['VAL_DS'], num_replicas=xr.world_size(), rank=xr.global_ordinal(), shuffle=False
        )
        orig_sampler = torch.utils.data.distributed.DistributedSampler(
            flags['ORIG_DS'], num_replicas=xr.world_size(), rank=xr.global_ordinal(), shuffle=False
        )
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            flags['TEST_DS'], num_replicas=xr.world_size(), rank=xr.global_ordinal(), shuffle=False
        )

    # Create data loaders
    train_loader = torch.utils.data.DataLoader(
        flags['TRAIN_DS'], batch_size=flags['BATCH_SIZE'], sampler=train_sampler, drop_last=True,
        shuffle=not train_sampler, num_workers=flags['num_workers']
    )
    val_loader = torch.utils.data.DataLoader(
        flags['VAL_DS'], batch_size=flags['BATCH_SIZE'], sampler=val_sampler, drop_last=False,
        shuffle=False, num_workers=flags['num_workers']
    )
    orig_loader = torch.utils.data.DataLoader(
        flags['ORIG_DS'], batch_size=flags['BATCH_SIZE'], sampler=orig_sampler, drop_last=False,
        shuffle=False, num_workers=flags['num_workers']
    )
    test_loader = torch.utils.data.DataLoader(
        flags['TEST_DS'], batch_size=flags['BATCH_SIZE'], sampler=test_sampler, drop_last=False,
        shuffle=False, num_workers=flags['num_workers']
    )

    del train_sampler, val_sampler, orig_sampler, test_sampler
    gc.collect()

    # Setup device, model, and optimizer
    device = xm.xla_device()
    model = flags['FOLD_MODEL']
    model.to(device)

    xm.broadcast_master_param(model)
    if flags['DDP']:
        model = DDP(model, broadcast_buffers=False, gradient_as_bucket_view=True)

    optimizer = optim.SGD(model.parameters(), lr=flags['LR'], momentum=0.9, weight_decay=1e-4)
    lr_scheduler = None  # Optionally add learning rate scheduler

    # Initialize results dictionary
    results = {
        "train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []
    }

    # Define training loop
    def train_loop_fn(loader):
        model.train()
        train_loss, train_acc = 0, 0
        for data, target in tqdm(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            train_loss += xm.mesh_reduce('train_loss', loss.item(), np.mean)
            loss.backward()
            if flags['DDP']:
                optimizer.step()
                xm.mark_step()
            else:
                xm.optimizer_step(optimizer)
            if lr_scheduler:
                lr_scheduler.step()
            acc_metric = (output.argmax(dim=1) == target.argmax(dim=1)).sum().item() / len(output) \
                if one_hot else (output.argmax(dim=1) == target).sum() / len(output)
            train_acc += xm.mesh_reduce('train_acc', acc_metric, np.mean)

        return train_loss / len(loader), train_acc / len(loader)

    # Define testing loop
    def test_loop_fn(loader):
        model.eval()
        val_loss, val_acc = 0, 0
        for data, target in tqdm(loader):
            output = model(data)
            loss = loss_fn(output, target)
            val_loss += xm.mesh_reduce('val_loss', loss.item(), np.mean)
            acc_metric = (output.argmax(dim=1) == target.argmax(dim=1)).sum().item() / len(output) \
                if one_hot else (output.argmax(dim=1) == target).sum() / len(output)
            val_acc += xm.mesh_reduce('val_acc', acc_metric, np.mean)

        return val_loss / len(loader), val_acc / len(loader)

    # Training and validation
    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    val_device_loader = pl.MpDeviceLoader(val_loader, device)

    for epoch in range(1, flags['EPOCHS'] + 1):
        train_loss, train_acc = train_loop_fn(train_device_loader)
        val_loss, val_acc = test_loop_fn(val_device_loader)
        xm.master_print(
            f"Epoch: {epoch} | "
            f"train_loss: {train_loss:.4f} | "
            f"train_acc: {train_acc:.4f} | "
            f"val_loss: {val_loss:.4f} | "
            f"val_acc: {val_acc:.4f}"
        )

    # Save results and predictions
    xm.save(model.state_dict(), f"fold_model_{flags['FOLD_NO']}.pth")

# Iterate through the folds using k-fold cross-validation
for fold_no, (train_index, val_index) in enumerate(
        kf.split(processed_images_paths, np.argmax(processed_images_labels, axis=1))):
    print(f"Training on fold {fold_no}...")

    # Split data into training and validation sets
    train_paths, val_paths = processed_images_paths[train_index], processed_images_paths[val_index]
    train_labels, val_labels = processed_images_labels[train_index], processed_images_labels[val_index]

    # Create datasets
    train_dataset = CustomDS(train_paths, train_labels, transform=train_transforms)
    val_dataset = CustomDS(val_paths, val_labels, transform=train_transforms)

    # Initialize the model with appropriate parameters
    model = Model(
        backbone='efficientnet_b0',
        embedding_size=512,
        num_classes=4,
        margin=0.5,
        scale=64.0,
        one_hot=one_hot
    )

    # Define flags to pass to the `_mp_fn`
    FLAGS = {
        'FOLD_NO': fold_no,
        'TRAIN_DS': train_dataset,
        'VAL_DS': val_dataset,
        'ORIG_DS': orig_ds,
        'TEST_DS': test_ds,
        'FOLD_MODEL': model,
        'LR': 0.001,
        'BATCH_SIZE': 64,
        'EPOCHS': 2,
        'num_workers': 8,
        'DDP': False
    }

    # Launch the training process
    torch_xla.launch(_mp_fn, args=(FLAGS,), debug_single_process=False)

Error Message

Here’s the error I receive:

Training on fold 0...
Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth
100%|██████████| 20.5M/20.5M [00:00<00:00, 188MB/s]
Process SpawnProcess-3:
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/local/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 240, in _process_worker
    call_item = call_queue.get(block=True)
  File "/usr/local/lib/python3.10/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
AttributeError: Can't get attribute '_mp_fn' on <module '__main__' (built-in)>
Process SpawnProcess-4:
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/local/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 240, in _process_worker
    call_item = call_queue.get(block=True)
  File "/usr/local/lib/python3.10/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
AttributeError: Can't get attribute '_mp_fn' on <module '__main__' (built-in)>
Process SpawnProcess-2:
Process SpawnProcess-1:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/local/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 240, in _process_worker
    call_item = call_queue.get(block=True)
  File "/usr/local/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/local/lib/python3.10/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
  File "/usr/local/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 240, in _process_worker
    call_item = call_queue.get(block=True)
  File "/usr/local/lib/python3.10/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
AttributeError: Can't get attribute '_mp_fn' on <module '__main__' (built-in)>
AttributeError: Can't get attribute '_mp_fn' on <module '__main__' (built-in)>
---------------------------------------------------------------------------
BrokenProcessPool                         Traceback (most recent call last)
Cell In[20], line 27
     12 model = Model('efficientnet_b0', embedding_size=512, num_classes=4, margin=0.5, s=64.0, one_hot=one_hot)    
     15 FLAGS = {'FOLD_NO': fold_no,
     16          'TRAIN_DS': train_dataset,
     17          'VAL_DS': val_dataset,
   (...)
     24          'num_workers' : 8,
     25          'DDP' : False}
---> 27 torch_xla.launch(
     28   _mp_fn, args=(FLAGS,), debug_single_process=False)

File /usr/local/lib/python3.10/site-packages/torch_xla/torch_xla.py:233, in launch(fn, args, start_method, debug_single_process)
    231   fn(xu.getenv_as(xenv.LOCAL_RANK, int), *args)
    232 else:
--> 233   xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method)

File /usr/local/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py:37, in spawn(fn, args, nprocs, join, daemon, start_method)
      6 def spawn(fn,
      7           args=(),
      8           nprocs=None,
      9           join=True,
     10           daemon=False,
     11           start_method='spawn'):
     12   """Enables multi processing based replication.
     13 
     14   Args:
   (...)
     35     return None.
     36   """
---> 37   return pjrt.spawn(fn, nprocs, start_method, args)

File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:209, in spawn(fn, nprocs, start_method, args)
    206 elif nprocs is not None:
    207   logging.warning('Unsupported nprocs (%d), ignoring...' % nprocs)
--> 209 run_multiprocess(spawn_fn, start_method=start_method)

File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:169, in run_multiprocess(fn, start_method, *args, **kwargs)
    163   mp_fn = functools.partial(
    164       _run_thread_per_device,
    165       local_world_size=num_processes,
    166       fn=functools.partial(fn, *args, **kwargs),
    167       initializer_fn=initialize_multiprocess)
    168   process_results = executor.map(mp_fn, range(num_processes))
--> 169   replica_results = list(
    170       itertools.chain.from_iterable(
    171           result.items() for result in process_results))
    173 return _merge_replica_results(replica_results)

File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:170, in <genexpr>(.0)
    163   mp_fn = functools.partial(
    164       _run_thread_per_device,
    165       local_world_size=num_processes,
    166       fn=functools.partial(fn, *args, **kwargs),
    167       initializer_fn=initialize_multiprocess)
    168   process_results = executor.map(mp_fn, range(num_processes))
    169   replica_results = list(
--> 170       itertools.chain.from_iterable(
    171           result.items() for result in process_results))
    173 return _merge_replica_results(replica_results)

File /usr/local/lib/python3.10/concurrent/futures/process.py:575, in _chain_from_iterable_of_lists(iterable)
    569 def _chain_from_iterable_of_lists(iterable):
    570     """
    571     Specialized implementation of itertools.chain.from_iterable.
    572     Each item in *iterable* should be a list.  This function is
    573     careful not to keep references to yielded objects.
    574     """
--> 575     for element in iterable:
    576         element.reverse()
    577         while element:

File /usr/local/lib/python3.10/concurrent/futures/_base.py:621, in Executor.map.<locals>.result_iterator()
    618 while fs:
    619     # Careful not to keep a reference to the popped future
    620     if timeout is None:
--> 621         yield _result_or_cancel(fs.pop())
    622     else:
    623         yield _result_or_cancel(fs.pop(), end_time - time.monotonic())

File /usr/local/lib/python3.10/concurrent/futures/_base.py:319, in _result_or_cancel(***failed resolving arguments***)
    317 try:
    318     try:
--> 319         return fut.result(timeout)
    320     finally:
    321         fut.cancel()

File /usr/local/lib/python3.10/concurrent/futures/_base.py:458, in Future.result(self, timeout)
    456     raise CancelledError()
    457 elif self._state == FINISHED:
--> 458     return self.__get_result()
    459 else:
    460     raise TimeoutError()

File /usr/local/lib/python3.10/concurrent/futures/_base.py:403, in Future.__get_result(self)
    401 if self._exception:
    402     try:
--> 403         raise self._exception
    404     finally:
    405         # Break a reference cycle with the exception in self._exception
    406         self = None

BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.

I would appreciate any insights or suggestions to resolve this issue. If there’s a specific way to handle _mp_fn for TPU multi-core training, please advise.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant