You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
TPU Multi-core Training Crashes with debug_single_process=False
Description
I encountered an issue when using PyTorch XLA to train a model on Kaggle TPU. The code works correctly when debug_single_process=True (single core), but crashes when debug_single_process=false with the error:
Multi-Processing Function for Distributed Training
Training on fold 0...
Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth
100%|██████████| 20.5M/20.5M [00:00<00:00, 188MB/s]
Process SpawnProcess-3:
Traceback (most recent call last):
File "/usr/local/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/usr/local/lib/python3.10/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 240, in _process_worker
call_item = call_queue.get(block=True)
File "/usr/local/lib/python3.10/multiprocessing/queues.py", line 122, in get
return _ForkingPickler.loads(res)
AttributeError: Can't get attribute '_mp_fn' on <module '__main__' (built-in)>
Process SpawnProcess-4:
Traceback (most recent call last):
File "/usr/local/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/usr/local/lib/python3.10/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 240, in _process_worker
call_item = call_queue.get(block=True)
File "/usr/local/lib/python3.10/multiprocessing/queues.py", line 122, in get
return _ForkingPickler.loads(res)
AttributeError: Can't get attribute '_mp_fn' on <module '__main__' (built-in)>
Process SpawnProcess-2:
Process SpawnProcess-1:
Traceback (most recent call last):
Traceback (most recent call last):
File "/usr/local/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/usr/local/lib/python3.10/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 240, in _process_worker
call_item = call_queue.get(block=True)
File "/usr/local/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/usr/local/lib/python3.10/multiprocessing/queues.py", line 122, in get
return _ForkingPickler.loads(res)
File "/usr/local/lib/python3.10/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 240, in _process_worker
call_item = call_queue.get(block=True)
File "/usr/local/lib/python3.10/multiprocessing/queues.py", line 122, in get
return _ForkingPickler.loads(res)
AttributeError: Can't get attribute '_mp_fn' on <module '__main__' (built-in)>
AttributeError: Can't get attribute '_mp_fn' on <module '__main__' (built-in)>
---------------------------------------------------------------------------
BrokenProcessPool Traceback (most recent call last)
Cell In[20], line 27
12 model = Model('efficientnet_b0', embedding_size=512, num_classes=4, margin=0.5, s=64.0, one_hot=one_hot)
15 FLAGS = {'FOLD_NO': fold_no,
16 'TRAIN_DS': train_dataset,
17 'VAL_DS': val_dataset,
(...)
24 'num_workers' : 8,
25 'DDP' : False}
---> 27 torch_xla.launch(
28 _mp_fn, args=(FLAGS,), debug_single_process=False)
File /usr/local/lib/python3.10/site-packages/torch_xla/torch_xla.py:233, in launch(fn, args, start_method, debug_single_process)
231 fn(xu.getenv_as(xenv.LOCAL_RANK, int), *args)
232 else:
--> 233 xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method)
File /usr/local/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py:37, in spawn(fn, args, nprocs, join, daemon, start_method)
6 def spawn(fn,
7 args=(),
8 nprocs=None,
9 join=True,
10 daemon=False,
11 start_method='spawn'):
12 """Enables multi processing based replication.
13
14 Args:
(...)
35 return None.
36 """
---> 37 return pjrt.spawn(fn, nprocs, start_method, args)
File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:209, in spawn(fn, nprocs, start_method, args)
206 elif nprocs is not None:
207 logging.warning('Unsupported nprocs (%d), ignoring...' % nprocs)
--> 209 run_multiprocess(spawn_fn, start_method=start_method)
File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:169, in run_multiprocess(fn, start_method, *args, **kwargs)
163 mp_fn = functools.partial(
164 _run_thread_per_device,
165 local_world_size=num_processes,
166 fn=functools.partial(fn, *args, **kwargs),
167 initializer_fn=initialize_multiprocess)
168 process_results = executor.map(mp_fn, range(num_processes))
--> 169 replica_results = list(
170 itertools.chain.from_iterable(
171 result.items() for result in process_results))
173 return _merge_replica_results(replica_results)
File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:170, in <genexpr>(.0)
163 mp_fn = functools.partial(
164 _run_thread_per_device,
165 local_world_size=num_processes,
166 fn=functools.partial(fn, *args, **kwargs),
167 initializer_fn=initialize_multiprocess)
168 process_results = executor.map(mp_fn, range(num_processes))
169 replica_results = list(
--> 170 itertools.chain.from_iterable(
171 result.items() for result in process_results))
173 return _merge_replica_results(replica_results)
File /usr/local/lib/python3.10/concurrent/futures/process.py:575, in _chain_from_iterable_of_lists(iterable)
569 def _chain_from_iterable_of_lists(iterable):
570 """
571 Specialized implementation of itertools.chain.from_iterable.
572 Each item in *iterable* should be a list. This function is
573 careful not to keep references to yielded objects.
574 """
--> 575 for element in iterable:
576 element.reverse()
577 while element:
File /usr/local/lib/python3.10/concurrent/futures/_base.py:621, in Executor.map.<locals>.result_iterator()
618 while fs:
619 # Careful not to keep a reference to the popped future
620 if timeout is None:
--> 621 yield _result_or_cancel(fs.pop())
622 else:
623 yield _result_or_cancel(fs.pop(), end_time - time.monotonic())
File /usr/local/lib/python3.10/concurrent/futures/_base.py:319, in _result_or_cancel(***failed resolving arguments***)
317 try:
318 try:
--> 319 return fut.result(timeout)
320 finally:
321 fut.cancel()
File /usr/local/lib/python3.10/concurrent/futures/_base.py:458, in Future.result(self, timeout)
456 raise CancelledError()
457 elif self._state == FINISHED:
--> 458 return self.__get_result()
459 else:
460 raise TimeoutError()
File /usr/local/lib/python3.10/concurrent/futures/_base.py:403, in Future.__get_result(self)
401 if self._exception:
402 try:
--> 403 raise self._exception
404 finally:
405 # Break a reference cycle with the exception in self._exception
406 self = None
BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.
I would appreciate any insights or suggestions to resolve this issue. If there’s a specific way to handle _mp_fn for TPU multi-core training, please advise.
The text was updated successfully, but these errors were encountered:
TPU Multi-core Training Crashes with
debug_single_process=False
Description
I encountered an issue when using PyTorch XLA to train a model on Kaggle TPU. The code works correctly when
debug_single_process=True
(single core), but crashes whendebug_single_process=false
with the error:Multi-Processing Function for Distributed Training
Error Message
Here’s the error I receive:
I would appreciate any insights or suggestions to resolve this issue. If there’s a specific way to handle _mp_fn for TPU multi-core training, please advise.
The text was updated successfully, but these errors were encountered: