Skip to content

BUG] MPS Backend: float64 to float32 conversion error and NaN values in distributions #3437

@Kleinzhang99

Description

@Kleinzhang99

I'm encountering a TypeError when training a Cell2location model (which uses Pyro underneath) on Apple Silicon MPS backend. The error occurs during the model.to(device) operation when Lightning tries to move the model to MPS device.

Processing sample: S82_ca
Max value in X: 24116.0
Anndata setup with scvi-tools version 1.3.1.post1.
Setup via Cell2location.setup_anndata with arguments:
{
│ 'layer': None,
│ 'batch_key': None,
│ 'labels_key': None,
│ 'categorical_covariate_keys': None,
│ 'continuous_covariate_keys': None
}
Summary Statistics
┏━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━┓
┃ Summary Stat Key ┃ Value ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━┩
│ n_batch │ 1 │
│ n_cells │ 2913 │
│ n_extra_categorical_covs │ 0 │
│ n_extra_continuous_covs │ 0 │
│ n_labels │ 1 │
│ n_vars │ 13438 │
└──────────────────────────┴───────┘
Data Registry
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Registry Key ┃ scvi-tools Location ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ X │ adata.X │
│ batch │ adata.obs['_scvi_batch'] │
│ ind_x │ adata.obs['_indices'] │
│ labels │ adata.obs['_scvi_labels'] │
└──────────────┴───────────────────────────┘
batch State Registry
┏━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Source Location ┃ Categories ┃ scvi-tools Encoding ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ adata.obs['_scvi_batch'] │ 0 │ 0 │
└──────────────────────────┴────────────┴─────────────────────┘
labels State Registry
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Source Location ┃ Categories ┃ scvi-tools Encoding ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ adata.obs['_scvi_labels'] │ 0 │ 0 │
└───────────────────────────┴────────────┴─────────────────────┘
/opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/scvi/dataloaders/_data_splitting.py:631: UserWarning: accelerator has been set to mps. Please note that not all PyTorch/Jax operations are supported with this backend. as a result, some models might be slower and less accurate than usuall. Please verify your analysis!Refer to pytorch/pytorch#77764 for more details.
_, _, self.device = parse_device_args(
/opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/scvi/train/_trainrunner.py:69: UserWarning: accelerator has been set to mps. Please note that not all PyTorch/Jax operations are supported with this backend. as a result, some models might be slower and less accurate than usuall. Please verify your analysis!Refer to pytorch/pytorch#77764 for more details.
accelerator, lightning_devices, device = parse_device_args(
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/lightning/pytorch/trainer/configuration_validator.py:69: UserWarning: You passed in a val_dataloader but have no validation_step. Skipping val loop.
rank_zero_warn("You passed in a val_dataloader but have no validation_step. Skipping val loop.")

TypeError Traceback (most recent call last)
Cell In[17], line 80
77 if sample_key in ['S89_ca','S86_ca']:
78 continue
---> 80 results[sample_key] = process_sample(sample_key, adata, inf_aver)
82 print("All samples processed successfully!")

Cell In[17], line 47, in process_sample(sample_key, adata, inf_aver, output_dir)
44 mod.module.model.cell_state = mod.module.model.cell_state.to(torch.float32)
46 # 训练模型
---> 47 mod.train(max_epochs=30000,
48 # 使用全部数据进行训练 (batch_size=None)
49 batch_size=None,
50 # 使用所有数据点进行训练,因为
51 # 我们需要估计所有位置的细胞丰度
52 train_size=1,
53 accelerator='mps' # lr=1e-4
54 )
55 #use_gpu=True,
56
57 # 导出估计的细胞丰度(后验分布的摘要)
58 adata_vis = mod.export_posterior(
59 adata_vis, sample_kwargs={'num_samples': 1000, 'batch_size': mod.adata.n_obs}
60 )

File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/cell2location/models/_cell2location_model.py:216, in Cell2location.train(self, max_epochs, batch_size, train_size, lr, num_particles, scale_elbo, **kwargs)
213 scale_elbo = 1.0 / (self.summary_stats["n_cells"] * self.summary_stats["n_vars"])
214 kwargs["plan_kwargs"]["scale_elbo"] = scale_elbo
--> 216 super().train(**kwargs)

File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/scvi/model/base/_pyromixin.py:199, in PyroSviTrainMixin.train(self, max_epochs, accelerator, device, train_size, validation_size, shuffle_set_split, batch_size, early_stopping, lr, training_plan, datasplitter_kwargs, plan_kwargs, **trainer_kwargs)
188 trainer_kwargs["callbacks"].append(PyroJitGuideWarmup())
190 runner = self._train_runner_cls(
191 self,
192 training_plan=training_plan,
(...)
197 **trainer_kwargs,
198 )
--> 199 return runner()

File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/scvi/train/_trainrunner.py:113, in TrainRunner.call(self)
110 self.training_plan.n_obs_validation = self.data_splitter.n_val
112 try:
--> 113 self.trainer.fit(self.training_plan, self.data_splitter)
114 except NameError:
115 import gc

File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/scvi/train/_trainer.py:215, in Trainer.fit(self, *args, **kwargs)
209 warnings.filterwarnings(
210 action="ignore",
211 category=UserWarning,
212 message="LightningModule.configure_optimizers returned None",
213 )
214 try:
--> 215 super().fit(*args, **kwargs)
216 except NameError:
217 import gc

File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:532, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
530 self.strategy._lightning_module = model
531 _verify_strategy_supports_compile(model, self.strategy)
--> 532 call._call_and_handle_interrupt(
533 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
534 )

File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:43, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
41 if trainer.strategy.launcher is not None:
42 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 43 return trainer_fn(*args, **kwargs)
45 except _TunerExitException:
46 _call_teardown_hook(trainer)

File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:571, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
561 self._data_connector.attach_data(
562 model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
563 )
565 ckpt_path = self._checkpoint_connector._select_ckpt_path(
566 self.state.fn,
567 ckpt_path,
568 model_provided=True,
569 model_connected=self.lightning_module is not None,
570 )
--> 571 self._run(model, ckpt_path=ckpt_path)
573 assert self.state.stopped
574 self.training = False

File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:956, in Trainer._run(self, model, ckpt_path)
953 self._logger_connector.reset_metrics()
955 # strategy will configure model and move it to the device
--> 956 self.strategy.setup(self)
958 # hook
959 if self.state.fn == TrainerFn.FITTING:

File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/lightning/pytorch/strategies/single_device.py:75, in SingleDeviceStrategy.setup(self, trainer)
74 def setup(self, trainer: pl.Trainer) -> None:
---> 75 self.model_to_device()
76 super().setup(trainer)

File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/lightning/pytorch/strategies/single_device.py:72, in SingleDeviceStrategy.model_to_device(self)
70 def model_to_device(self) -> None:
71 assert self.model is not None, "self.model must be set before self.model.to()"
---> 72 self.model.to(self.root_device)

File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/lightning/fabric/utilities/device_dtype_mixin.py:54, in _DeviceDtypeModuleMixin.to(self, *args, **kwargs)
52 device, dtype = torch._C._nn._parse_to(*args, **kwargs)[:2]
53 self.__update_properties(device=device, dtype=dtype)
---> 54 return super().to(*args, **kwargs)

Cell In[2], line 72, in patched_to(self, *args, **kwargs)
69 cpu_module._buffers[name] = buf.to(torch.float32)
71 # 现在移动到MPS
---> 72 return original_module_to(cpu_module, device_arg)
73 return original_module_to(self, *args, **kwargs)

File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/torch/nn/modules/module.py:1343, in Module.to(self, *args, **kwargs)
1340 else:
1341 raise
-> 1343 return self._apply(convert)

File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/torch/nn/modules/module.py:903, in Module._apply(self, fn, recurse)
901 if recurse:
902 for module in self.children():
--> 903 module._apply(fn)
905 def compute_should_use_set_data(tensor, tensor_applied):
906 if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
907 # If the new tensor has compatible tensor type as the existing tensor,
908 # the current behavior is to change the tensor in-place using .data =,
(...)
913 # global flag to let the user control whether they want the future
914 # behavior of overwriting the existing tensor or not.

File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/torch/nn/modules/module.py:903, in Module._apply(self, fn, recurse)
901 if recurse:
902 for module in self.children():
--> 903 module._apply(fn)
905 def compute_should_use_set_data(tensor, tensor_applied):
906 if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
907 # If the new tensor has compatible tensor type as the existing tensor,
908 # the current behavior is to change the tensor in-place using .data =,
(...)
913 # global flag to let the user control whether they want the future
914 # behavior of overwriting the existing tensor or not.

File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/torch/nn/modules/module.py:991, in Module._apply(self, fn, recurse)
989 for key, buf in self._buffers.items():
990 if buf is not None:
--> 991 self._buffers[key] = fn(buf)
993 return self

File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/torch/nn/modules/module.py:1329, in Module.to..convert(t)
1322 if convert_to_format is not None and t.dim() in (4, 5):
1323 return t.to(
1324 device,
1325 dtype if t.is_floating_point() or t.is_complex() else None,
1326 non_blocking,
1327 memory_format=convert_to_format,
1328 )
-> 1329 return t.to(
1330 device,
1331 dtype if t.is_floating_point() or t.is_complex() else None,
1332 non_blocking,
1333 )
1334 except NotImplementedError as e:
1335 if str(e) == "Cannot copy out of meta tensor; no data!":

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

Questions
Is there a recommended way to ensure all Pyro model parameters and buffers use float32 when working with MPS backend?
Should this be handled at the Pyro level, or is this a Cell2location-specific issue?
Are there any known workarounds for using Pyro models with MPS acceleration?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions