Skip to content

Race condition during update_subtomo_missing_wedges during multi-GPU training. #40

Description

@uermel

Similar to the hparams-file issue from #27 I encountered a second issue during the update_subtomo_missing_wedges-step when training with multiple GPUs.

Full Trace
  Computing model-input normalization statistics:  60%|█████▉    | 37/62 [04:31<00:34,  1.37s/it][rank1]: ╭───────────────────── Traceback (most recent call last) ──────────────────────╮
[rank1]: │ /usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/call.py:38 │
[rank1]: │ in _call_and_handle_interrupt                                                │
[rank1]: │                                                                              │
[rank1]: │   35 │   │   if trainer.strategy.launcher is not None:                       │
[rank1]: │   36 │   │   │   return trainer.strategy.launcher.launch(trainer_fn, *args,  │
[rank1]: │      trainer=trainer, **kwargs)                                              │
[rank1]: │   37 │   │   else:                                                           │
[rank1]: │ ❱ 38 │   │   │   return trainer_fn(*args, **kwargs)                          │
[rank1]: │   39 │                                                                       │
[rank1]: │   40 │   except _TunerExitException:                                         │
[rank1]: │   41 │   │   trainer._call_teardown_hook()                                   │
[rank1]: │                                                                              │
[rank1]: │ /usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/trainer.py │
[rank1]: │ :650 in _fit_impl                                                            │
[rank1]: │                                                                              │
[rank1]: │    647 │   │   │   model_provided=True,                                      │
[rank1]: │    648 │   │   │   model_connected=self.lightning_module is not None,        │
[rank1]: │    649 │   │   )                                                             │
[rank1]: │ ❱  650 │   │   self._run(model, ckpt_path=self.ckpt_path)                    │
[rank1]: │    651 │   │                                                                 │
[rank1]: │    652 │   │   assert self.state.stopped                                     │
[rank1]: │    653 │   │   self.training = False                                         │
[rank1]: │                                                                              │
[rank1]: │ /usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/trainer.py │
[rank1]: │ :1112 in _run                                                                │
[rank1]: │                                                                              │
[rank1]: │   1109 │   │                                                                 │
[rank1]: │   1110 │   │   self._checkpoint_connector.resume_end()                       │
[rank1]: │   1111 │   │                                                                 │
[rank1]: │ ❱ 1112 │   │   results = self._run_stage()                                   │
[rank1]: │   1113 │   │                                                                 │
[rank1]: │   1114 │   │   log.detail(f"{self.__class__.__name__}: trainer tearing       │
[rank1]: │        down")                                                                │
[rank1]: │   1115 │   │   self._teardown()                                              │
[rank1]: │                                                                              │
[rank1]: │ /usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/trainer.py │
[rank1]: │ :1191 in _run_stage                                                          │
[rank1]: │                                                                              │
[rank1]: │   1188 │   │   │   return self._run_evaluate()                               │
[rank1]: │   1189 │   │   if self.predicting:                                           │
[rank1]: │   1190 │   │   │   return self._run_predict()                                │
[rank1]: │ ❱ 1191 │   │   self._run_train()                                             │
[rank1]: │   1192 │                                                                     │
[rank1]: │   1193 │   def _pre_training_routine(self) -> None:                          │
[rank1]: │   1194 │   │   # wait for all to join if on distributed                      │
[rank1]: │                                                                              │
[rank1]: │ /usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/trainer.py │
[rank1]: │ :1214 in _run_train                                                          │
[rank1]: │                                                                              │
[rank1]: │   1211 │   │   self.fit_loop.trainer = self                                  │
[rank1]: │   1212 │   │                                                                 │
[rank1]: │   1213 │   │   with torch.autograd.set_detect_anomaly(self._detect_anomaly): │
[rank1]: │ ❱ 1214 │   │   │   self.fit_loop.run()                                       │
[rank1]: │   1215 │                                                                     │
[rank1]: │   1216 │   def _run_evaluate(self) -> _EVALUATE_OUTPUT:                      │
[rank1]: │   1217 │   │   assert self.evaluating                                        │
[rank1]: │                                                                              │
[rank1]: │ /usr/local/lib/python3.11/dist-packages/pytorch_lightning/loops/loop.py:194  │
[rank1]: │ in run                                                                       │
[rank1]: │                                                                              │
[rank1]: │   191 │   │                                                                  │
[rank1]: │   192 │   │   self.reset()                                                   │
[rank1]: │   193 │   │                                                                  │
[rank1]: │ ❱ 194 │   │   self.on_run_start(*args, **kwargs)                             │
[rank1]: │   195 │   │                                                                  │
[rank1]: │   196 │   │   while not self.done:                                           │
[rank1]: │   197 │   │   │   try:                                                       │
[rank1]: │                                                                              │
[rank1]: │ /usr/local/lib/python3.11/dist-packages/pytorch_lightning/loops/fit_loop.py: │
[rank1]: │ 218 in on_run_start                                                          │
[rank1]: │                                                                              │
[rank1]: │   215 │   │   self._results.to(device=self.trainer.lightning_module.device)  │
[rank1]: │   216 │   │                                                                  │
[rank1]: │   217 │   │   self.trainer._call_callback_hooks("on_train_start")            │
[rank1]: │ ❱ 218 │   │   self.trainer._call_lightning_module_hook("on_train_start")     │
[rank1]: │   219 │   │   self.trainer._call_strategy_hook("on_train_start")             │
[rank1]: │   220 │                                                                      │
[rank1]: │   221 │   def on_advance_start(self) -> None:                                │
[rank1]: │                                                                              │
[rank1]: │ /usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/trainer.py │
[rank1]: │ :1356 in _call_lightning_module_hook                                         │
[rank1]: │                                                                              │
[rank1]: │   1353 │   │   pl_module._current_fx_name = hook_name                        │
[rank1]: │   1354 │   │                                                                 │
[rank1]: │   1355 │   │   with                                                          │
[rank1]: │        self.profiler.profile(f"[LightningModule]{pl_module.__class__.__name_ │
[rank1]: │        _}.{hook_name}"):                                                     │
[rank1]: │ ❱ 1356 │   │   │   output = fn(*args, **kwargs)                              │
[rank1]: │   1357 │   │                                                                 │
[rank1]: │   1358 │   │   # restore current_fx when nested context                      │
[rank1]: │   1359 │   │   pl_module._current_fx_name = prev_fx_name                     │
[rank1]: │                                                                              │
[rank1]: │ /usr/local/lib/python3.11/dist-packages/ddw/utils/unet.py:78 in              │
[rank1]: │ on_train_start                                                               │
[rank1]: │                                                                              │
[rank1]: │    75 │                                                                      │
[rank1]: │    76 │   def on_train_start(self) -> None:                                  │
[rank1]: │    77 │   │   if self.current_epoch == 0:                                    │
[rank1]: │ ❱  78 │   │   │   self.update_normalization()                                │
[rank1]: │    79 │                                                                      │
[rank1]: │    80 │   def on_train_epoch_end(self) -> None:                              │
[rank1]: │    81 │   │   if (                                                           │
[rank1]: │                                                                              │
[rank1]: │ /usr/local/lib/python3.11/dist-packages/ddw/utils/unet.py:156 in             │
[rank1]: │ update_normalization                                                         │
[rank1]: │                                                                              │
[rank1]: │   153 │   │   """
[rank1]: │   154 │   │   Updates the average model input mean and standard deviation    │
[rank1]: │       used to normalize the sub-tomograms.                                   │
[rank1]: │   155 │   │   """                                                            │
[rank1]: │ ❱ 156 │   │   loc, scale = get_avg_model_input_mean_and_std_from_dataloader( │
[rank1]: │   157 │   │   │   dataloader=self.trainer.train_dataloader, verbose=True     │
[rank1]: │   158 │   │   )                                                              │
[rank1]: │   159                                                                        │
[rank1]: │                                                                              │
[rank1]: │ /usr/local/lib/python3.11/dist-packages/ddw/utils/normalization.py:60 in     │
[rank1]: │ get_avg_model_input_mean_and_std_from_dataloader                             │
[rank1]: │                                                                              │
[rank1]: │   57 │   iter_loader = iter(dataloader)                                      │
[rank1]: │   58 │   for _ in bar:                                                       │
[rank1]: │   59 │   │   try:                                                            │
[rank1]: │ ❱ 60 │   │   │   batch = next(iter_loader)                                   │
[rank1]: │   61 │   │   except StopIteration:                                           │
[rank1]: │   62 │   │   │   iter_loader = iter(dataloader)                              │
[rank1]: │   63 │   │   │   batch = next(iter_loader)                                   │
[rank1]: │                                                                              │
[rank1]: │ /usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/supporters │
[rank1]: │ .py:571 in __next__                                                          │
[rank1]: │                                                                              │
[rank1]: │   568 │   │   Returns:                                                       │
[rank1]: │   569 │   │   │   a collections of batch data                                │
[rank1]: │   570 │   │   """
[rank1]: │ ❱ 571 │   │   return self.request_next_batch(self.loader_iters)              │
[rank1]: │   572 │                                                                      │
[rank1]: │   573 │   @staticmethod                                                      │
[rank1]: │   574 │   def request_next_batch(loader_iters: Union[Iterator, Sequence,     │
[rank1]: │       Mapping]) -> Any:                                                      │
[rank1]: │                                                                              │
[rank1]: │ /usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/supporters │
[rank1]: │ .py:583 in request_next_batch                                                │
[rank1]: │                                                                              │
[rank1]: │   580 │   │   Returns                                                        │
[rank1]: │   581 │   │   │   Any: a collections of batch data                           │
[rank1]: │   582 │   │   """                                                            │
[rank1]: │ ❱ 583 │   │   return apply_to_collection(loader_iters, Iterator, next)       │
[rank1]: │   584 │                                                                      │
[rank1]: │   585 │   @staticmethod                                                      │
[rank1]: │   586 │   def create_loader_iters(                                           │
[rank1]: │                                                                              │
[rank1]: │ /usr/local/lib/python3.11/dist-packages/lightning_utilities/core/apply_func. │
[rank1]: │ py:70 in apply_to_collection                                                 │
[rank1]: │                                                                              │
[rank1]: │    67 │   │   )                                                              │
[rank1]: │    68 │   # fast path for the most common cases:                             │
[rank1]: │    69 │   if isinstance(data, dtype):  # single element                      │
[rank1]: │ ❱  70 │   │   return function(data, *args, **kwargs)                         │
[rank1]: │    71 │   if data.__class__ is list and all(isinstance(x, dtype) for x in    │
[rank1]: │       data):  # 1d homogeneous list                                          │
[rank1]: │    72 │   │   return [function(x, *args, **kwargs) for x in data]            │
[rank1]: │    73 │   if data.__class__ is tuple and all(isinstance(x, dtype) for x in   │
[rank1]: │       data):  # 1d homogeneous tuple                                         │
[rank1]: │                                                                              │
[rank1]: │ /usr/local/lib/python3.11/dist-packages/ddw/utils/dataloader.py:27 in        │
[rank1]: │ __iter__                                                                     │
[rank1]: │                                                                              │
[rank1]: │   24 │                                                                       │
[rank1]: │   25 │   def __iter__(self):                                                 │
[rank1]: │   26 │   │   for i in range(len(self)):                                      │
[rank1]: │ ❱ 27 │   │   │   yield next(self.iterator)                                   │
[rank1]: │   28                                                                         │
[rank1]: │   29                                                                         │
[rank1]: │   30 class _RepeatSampler(BatchSampler):                                     │
[rank1]: │                                                                              │
[rank1]: │ /usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py:718   │
[rank1]: │ in __next__                                                                  │
[rank1]: │                                                                              │
[rank1]: │    715 │   │   │   if self._sampler_iter is None:                            │
[rank1]: │    716 │   │   │   │   #
[rank1]: │        TODO(https://github.com/pytorch/pytorch/issues/76750)                 │
[rank1]: │    717 │   │   │   │   self._reset()  # type: ignore[call-arg]               │
[rank1]: │ ❱  718 │   │   │   data = self._next_data()                                  │
[rank1]: │    719 │   │   │   self._num_yielded += 1                                    │
[rank1]: │    720 │   │   │   if (                                                      │
[rank1]: │    721 │   │   │   │   self._dataset_kind == _DatasetKind.Iterable           │
[rank1]: │                                                                              │
[rank1]: │ /usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py:1495  │
[rank1]: │ in _next_data                                                                │
[rank1]: │                                                                              │
[rank1]: │   1492 │   │   │   if len(self._task_info[self._rcvd_idx]) == 2:             │
[rank1]: │   1493 │   │   │   │   worker_id, data = self._task_info.pop(self._rcvd_idx) │
[rank1]: │   1494 │   │   │   │   self._rcvd_idx += 1                                   │
[rank1]: │ ❱ 1495 │   │   │   │   return self._process_data(data, worker_id)            │
[rank1]: │   1496 │   │   │                                                             │
[rank1]: │   1497 │   │   │   if self._shutdown or self._tasks_outstanding <= 0:        │
[rank1]: │   1498 │   │   │   │   raise AssertionError(                                 │
[rank1]: │                                                                              │
[rank1]: │ /usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py:1563  │
[rank1]: │ in _process_data                                                             │
[rank1]: │                                                                              │
[rank1]: │   1560 │   │   self._workers_num_tasks[worker_idx] -= 1                      │
[rank1]: │   1561 │   │   self._try_put_index()                                         │
[rank1]: │   1562 │   │   if isinstance(data, ExceptionWrapper):                        │
[rank1]: │ ❱ 1563 │   │   │   data.reraise()                                            │
[rank1]: │   1564 │   │   return data                                                   │
[rank1]: │   1565 │                                                                     │
[rank1]: │   1566 │   def _mark_worker_as_unavailable(self, worker_id, shutdown=False)  │
[rank1]: │        -> None:                                                              │
[rank1]: │                                                                              │
[rank1]: │ /usr/local/lib/python3.11/dist-packages/torch/_utils.py:774 in reraise       │
[rank1]: │                                                                              │
[rank1]: │    771 │   │   │   # If the exception takes multiple arguments or otherwise  │
[rank1]: │        can't                                                                 │
[rank1]: │    772 │   │   │   # be constructed, don't try to instantiate since we don't │
[rank1]: │        know how to                                                           │
[rank1]: │    773 │   │   │   raise RuntimeError(msg) from None                         │
[rank1]: │ ❱  774 │   │   raise exception                                               │
[rank1]: │    775                                                                       │
[rank1]: │    776                                                                       │
[rank1]: │    777 def cpu_count() -> int | None:                                        │
[rank1]: ╰──────────────────────────────────────────────────────────────────────────────╯
[rank1]: RuntimeError: Caught RuntimeError in DataLoader worker process 2.
[rank1]: Original Traceback (most recent call last):
[rank1]:   File 
[rank1]: "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/worker.py", 
[rank1]: line 374, in _worker_loop
[rank1]:     data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
[rank1]:            ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File 
[rank1]: "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py", line
[rank1]: 54, in fetch
[rank1]:     data = [self.dataset[idx] for idx in possibly_batched_index]
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File 
[rank1]: "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py", line
[rank1]: 54, in <listcomp>
[rank1]:     data = [self.dataset[idx] for idx in possibly_batched_index]
[rank1]:             ~~~~~~~~~~~~^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/ddw/utils/subtomo_dataset.py", 
[rank1]: line 76, in __getitem__
[rank1]:     subtomo0 = safe_load(subtomo0_file)
[rank1]:                ^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/ddw/utils/subtomo_dataset.py", 
[rank1]: line 24, in safe_load
[rank1]:     raise e  # Reraise if it's the last attempt
[rank1]:     ^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/ddw/utils/subtomo_dataset.py", 
[rank1]: line 19, in safe_load
[rank1]:     return torch.load(file_path)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/serialization.py", line 
[rank1]: 1537, in load
[rank1]:     with _open_zipfile_reader(opened_file) as opened_zipfile:
[rank1]:          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/serialization.py", line 
[rank1]: 807, in __init__
[rank1]:     super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
[rank1]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: PytorchStreamReader failed reading zip archive: failed finding 
[rank1]: central directory. This is an internal miniz error. If you are seeing this 
[rank1]: error, there is a high likelihood that your checkpoint file is corrupted. This 
[rank1]: can happen if the checkpoint was not saved properly, was transferred 
[rank1]: incorrectly, or the file was modified after saving.


[rank1]: During handling of the above exception, another exception occurred:

[rank1]: ╭───────────────────── Traceback (most recent call last) ──────────────────────╮
[rank1]: │ /usr/local/lib/python3.11/dist-packages/ddw/fit_model.py:273 in fit_model    │
[rank1]: │                                                                              │
[rank1]: │   270 │   # fit the model                                                    │
[rank1]: │   271 │   if val_data_exists and resume_from_checkpoint is None:             │
[rank1]: │   272 │   │   trainer.validate(lit_unet, val_dataloader)                     │
[rank1]: │ ❱ 273 │   trainer.fit(                                                       │
[rank1]: │   274 │   │   #ckpt_path=resume_from_checkpoint,  # for pytorch-lightning >= │
[rank1]: │       2.0                                                                    │
[rank1]: │   275 │   │   model=lit_unet,                                                │
[rank1]: │   276 │   │   train_dataloaders=fitting_dataloader,                          │
[rank1]: │                                                                              │
[rank1]: │ /usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/trainer.py │
[rank1]: │ :608 in fit                                                                  │
[rank1]: │                                                                              │
[rank1]: │    605 │   │   """
[rank1]: │    606 │   │   model = self._maybe_unwrap_optimized(model)                   │
[rank1]: │    607 │   │   self.strategy._lightning_module = model                       │
[rank1]: │ ❱  608 │   │   call._call_and_handle_interrupt(                              │
[rank1]: │    609 │   │   │   self, self._fit_impl, model, train_dataloaders,           │
[rank1]: │        val_dataloaders, datamodule, ckpt_path                                │
[rank1]: │    610 │   │   )                                                             │
[rank1]: │    611                                                                       │
[rank1]: │                                                                              │
[rank1]: │ /usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/call.py:59 │
[rank1]: │ in _call_and_handle_interrupt                                                │
[rank1]: │                                                                              │
[rank1]: │   56 │   │   trainer.state.status = TrainerStatus.INTERRUPTED                │
[rank1]: │   57 │   │   if _distributed_available() and trainer.world_size > 1:         │
[rank1]: │   58 │   │   │   # try syncing remaining processes, kill otherwise           │
[rank1]: │ ❱ 59 │   │   │                                                               │
[rank1]: │      trainer.strategy.reconciliate_processes(traceback.format_exc())         │
[rank1]: │   60 │   │   trainer._call_callback_hooks("on_exception", exception)         │
[rank1]: │   61 │   │   for logger in trainer.loggers:                                  │
[rank1]: │   62 │   │   │   logger.finalize("failed")                                   │
[rank1]: │                                                                              │
[rank1]: │ /usr/local/lib/python3.11/dist-packages/pytorch_lightning/strategies/ddp.py: │
[rank1]: │ 460 in reconciliate_processes                                                │
[rank1]: │                                                                              │
[rank1]: │   457 │   │   │   if pid != os.getpid():                                     │
[rank1]: │   458 │   │   │   │   os.kill(pid, signal.SIGKILL)                           │
[rank1]: │   459 │   │   shutil.rmtree(sync_dir)                                        │
[rank1]: │ ❱ 460 │   │   raise DeadlockDetectedException(f"DeadLock detected from rank: │
[rank1]: │       {self.global_rank} \n {trace}")                                        │
[rank1]: │   461 │                                                                      │
[rank1]: │   462 │   def teardown(self) -> None:                                        │
[rank1]: │   463 │   │   log.detail(f"{self.__class__.__name__}: tearing down           │
[rank1]: │       strategy")                                                             │
[rank1]: ╰──────────────────────────────────────────────────────────────────────────────╯
[rank1]: DeadlockDetectedException: DeadLock detected from rank: 1 
[rank1]:  Traceback (most recent call last):
[rank1]:   File 
[rank1]: "/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/call.py", 
[rank1]: line 38, in _call_and_handle_interrupt
[rank1]:     return trainer_fn(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File 
[rank1]: "/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/trainer.py", 
[rank1]: line 650, in _fit_impl
[rank1]:     self._run(model, ckpt_path=self.ckpt_path)
[rank1]:   File 
[rank1]: "/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/trainer.py", 
[rank1]: line 1112, in _run
[rank1]:     results = self._run_stage()
[rank1]:               ^^^^^^^^^^^^^^^^^
[rank1]:   File 
[rank1]: "/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/trainer.py", 
[rank1]: line 1191, in _run_stage
[rank1]:     self._run_train()
[rank1]:   File 
[rank1]: "/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/trainer.py", 
[rank1]: line 1214, in _run_train
[rank1]:     self.fit_loop.run()
[rank1]:   File 
[rank1]: "/usr/local/lib/python3.11/dist-packages/pytorch_lightning/loops/loop.py", line 
[rank1]: 194, in run
[rank1]:     self.on_run_start(*args, **kwargs)
[rank1]:   File 
[rank1]: "/usr/local/lib/python3.11/dist-packages/pytorch_lightning/loops/fit_loop.py", 
[rank1]: line 218, in on_run_start
[rank1]:     self.trainer._call_lightning_module_hook("on_train_start")
[rank1]:   File 
[rank1]: "/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/trainer.py", 
[rank1]: line 1356, in _call_lightning_module_hook
[rank1]:     output = fn(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/ddw/utils/unet.py", line 78, in 
[rank1]: on_train_start
[rank1]:     self.update_normalization()
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/ddw/utils/unet.py", line 156, in
[rank1]: update_normalization
[rank1]:     loc, scale = get_avg_model_input_mean_and_std_from_dataloader(
[rank1]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/ddw/utils/normalization.py", 
[rank1]: line 60, in get_avg_model_input_mean_and_std_from_dataloader
[rank1]:     batch = next(iter_loader)
[rank1]:             ^^^^^^^^^^^^^^^^^
[rank1]:   File 
[rank1]: "/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/supporters.py
[rank1]: ", line 571, in __next__
[rank1]:     return self.request_next_batch(self.loader_iters)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File 
[rank1]: "/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/supporters.py
[rank1]: ", line 583, in request_next_batch
[rank1]:     return apply_to_collection(loader_iters, Iterator, next)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File 
[rank1]: "/usr/local/lib/python3.11/dist-packages/lightning_utilities/core/apply_func.py"
[rank1]: , line 70, in apply_to_collection
[rank1]:     return function(data, *args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/ddw/utils/dataloader.py", line 
[rank1]: 27, in __iter__
[rank1]:     yield next(self.iterator)
[rank1]:           ^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py",
[rank1]: line 718, in __next__
[rank1]:     data = self._next_data()
[rank1]:            ^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py",
[rank1]: line 1495, in _next_data
[rank1]:     return self._process_data(data, worker_id)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py",
[rank1]: line 1563, in _process_data
[rank1]:     data.reraise()
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/_utils.py", line 774, in 
[rank1]: reraise
[rank1]:     raise exception
[rank1]: RuntimeError: Caught RuntimeError in DataLoader worker process 2.
[rank1]: Original Traceback (most recent call last):
[rank1]:   File 
[rank1]: "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/worker.py", 
[rank1]: line 374, in _worker_loop
[rank1]:     data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
[rank1]:            ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File 
[rank1]: "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py", line
[rank1]: 54, in fetch
[rank1]:     data = [self.dataset[idx] for idx in possibly_batched_index]
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File 
[rank1]: "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py", line
[rank1]: 54, in <listcomp>
[rank1]:     data = [self.dataset[idx] for idx in possibly_batched_index]
[rank1]:             ~~~~~~~~~~~~^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/ddw/utils/subtomo_dataset.py", 
[rank1]: line 76, in __getitem__
[rank1]:     subtomo0 = safe_load(subtomo0_file)
[rank1]:                ^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/ddw/utils/subtomo_dataset.py", 
[rank1]: line 24, in safe_load
[rank1]:     raise e  # Reraise if it's the last attempt
[rank1]:     ^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/ddw/utils/subtomo_dataset.py", 
[rank1]: line 19, in safe_load
[rank1]:     return torch.load(file_path)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/serialization.py", line 
[rank1]: 1537, in load
[rank1]:     with _open_zipfile_reader(opened_file) as opened_zipfile:
[rank1]:          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.11/dist-packages/torch/serialization.py", line 
[rank1]: 807, in __init__
[rank1]:     super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
[rank1]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: PytorchStreamReader failed reading zip archive: failed finding 
[rank1]: central directory. This is an internal miniz error. If you are seeing this 
[rank1]: error, there is a high likelihood that your checkpoint file is corrupted. This 
[rank1]: can happen if the checkpoint was not saved properly, was transferred 
[rank1]: incorrectly, or the file was modified after saving.

When ddw fit-model is run on more than one GPU (PyTorch Lightning DDP),
the periodic missing-wedge update corrupts subtomo .pt files on disk. The run
later dies when a corrupted file is loaded:

PytorchStreamReader failed reading zip archive: failed finding central directory

The corrupted file is a few hundred bytes short of its siblings — a truncated
torch.save zip whose central directory was never written.

Encountered when:

  • ddw fit-model with --gpu specifying 2+ devices (i.e. len(devices) > 1,
    which activates DDPStrategy in ddw/fit_model.py).
  • Triggered every update_subtomo_missing_wedges_every_n_epochs epochs (default
    10), i.e. at the first wedge-update milestone and onward.
  • Single-GPU training is not affected.

Reason:
LitUnet3D.on_train_epoch_end runs on every DDP rank and, at each milestone,
calls update_subtomo_missing_wedges(). That method builds a non-distributed
DataLoader (no DistributedSampler) over the full subtomo set and writes each
result back to disk:

# ddw/utils/unet.py  (update_subtomo_missing_wedges)
train_set = train_loader.dataset                      # full dataset (sampler stripped)
dataset   = torch.utils.data.ConcatDataset(datasets)  # full set
loader    = torch.utils.data.DataLoader(dataset, ...)  # no sampler -> every rank sees all items
...
for subtomo, file in zip(subtomo_batch, batch["subtomo0_file"]):
    torch.save(subtomo.cpu().clone(), file)            # all ranks write the SAME paths

So with N ranks there are N processes writing the same file path concurrently.
torch.save truncates the target in place before re-writing the zip, so a second
writer (or a reader) sees a half-written / truncated archive → the corruption
above.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions