-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Description
I'm encountering a TypeError when training a Cell2location model (which uses Pyro underneath) on Apple Silicon MPS backend. The error occurs during the model.to(device) operation when Lightning tries to move the model to MPS device.
Processing sample: S82_ca
Max value in X: 24116.0
Anndata setup with scvi-tools version 1.3.1.post1.
Setup via Cell2location.setup_anndata with arguments:
{
│ 'layer': None,
│ 'batch_key': None,
│ 'labels_key': None,
│ 'categorical_covariate_keys': None,
│ 'continuous_covariate_keys': None
}
Summary Statistics
┏━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━┓
┃ Summary Stat Key ┃ Value ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━┩
│ n_batch │ 1 │
│ n_cells │ 2913 │
│ n_extra_categorical_covs │ 0 │
│ n_extra_continuous_covs │ 0 │
│ n_labels │ 1 │
│ n_vars │ 13438 │
└──────────────────────────┴───────┘
Data Registry
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Registry Key ┃ scvi-tools Location ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ X │ adata.X │
│ batch │ adata.obs['_scvi_batch'] │
│ ind_x │ adata.obs['_indices'] │
│ labels │ adata.obs['_scvi_labels'] │
└──────────────┴───────────────────────────┘
batch State Registry
┏━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Source Location ┃ Categories ┃ scvi-tools Encoding ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ adata.obs['_scvi_batch'] │ 0 │ 0 │
└──────────────────────────┴────────────┴─────────────────────┘
labels State Registry
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Source Location ┃ Categories ┃ scvi-tools Encoding ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ adata.obs['_scvi_labels'] │ 0 │ 0 │
└───────────────────────────┴────────────┴─────────────────────┘
/opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/scvi/dataloaders/_data_splitting.py:631: UserWarning: accelerator has been set to mps. Please note that not all PyTorch/Jax operations are supported with this backend. as a result, some models might be slower and less accurate than usuall. Please verify your analysis!Refer to pytorch/pytorch#77764 for more details.
_, _, self.device = parse_device_args(
/opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/scvi/train/_trainrunner.py:69: UserWarning: accelerator has been set to mps. Please note that not all PyTorch/Jax operations are supported with this backend. as a result, some models might be slower and less accurate than usuall. Please verify your analysis!Refer to pytorch/pytorch#77764 for more details.
accelerator, lightning_devices, device = parse_device_args(
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/lightning/pytorch/trainer/configuration_validator.py:69: UserWarning: You passed in a val_dataloader but have no validation_step. Skipping val loop.
rank_zero_warn("You passed in a val_dataloader but have no validation_step. Skipping val loop.")
TypeError Traceback (most recent call last)
Cell In[17], line 80
77 if sample_key in ['S89_ca','S86_ca']:
78 continue
---> 80 results[sample_key] = process_sample(sample_key, adata, inf_aver)
82 print("All samples processed successfully!")
Cell In[17], line 47, in process_sample(sample_key, adata, inf_aver, output_dir)
44 mod.module.model.cell_state = mod.module.model.cell_state.to(torch.float32)
46 # 训练模型
---> 47 mod.train(max_epochs=30000,
48 # 使用全部数据进行训练 (batch_size=None)
49 batch_size=None,
50 # 使用所有数据点进行训练,因为
51 # 我们需要估计所有位置的细胞丰度
52 train_size=1,
53 accelerator='mps' # lr=1e-4
54 )
55 #use_gpu=True,
56
57 # 导出估计的细胞丰度(后验分布的摘要)
58 adata_vis = mod.export_posterior(
59 adata_vis, sample_kwargs={'num_samples': 1000, 'batch_size': mod.adata.n_obs}
60 )
File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/cell2location/models/_cell2location_model.py:216, in Cell2location.train(self, max_epochs, batch_size, train_size, lr, num_particles, scale_elbo, **kwargs)
213 scale_elbo = 1.0 / (self.summary_stats["n_cells"] * self.summary_stats["n_vars"])
214 kwargs["plan_kwargs"]["scale_elbo"] = scale_elbo
--> 216 super().train(**kwargs)
File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/scvi/model/base/_pyromixin.py:199, in PyroSviTrainMixin.train(self, max_epochs, accelerator, device, train_size, validation_size, shuffle_set_split, batch_size, early_stopping, lr, training_plan, datasplitter_kwargs, plan_kwargs, **trainer_kwargs)
188 trainer_kwargs["callbacks"].append(PyroJitGuideWarmup())
190 runner = self._train_runner_cls(
191 self,
192 training_plan=training_plan,
(...)
197 **trainer_kwargs,
198 )
--> 199 return runner()
File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/scvi/train/_trainrunner.py:113, in TrainRunner.call(self)
110 self.training_plan.n_obs_validation = self.data_splitter.n_val
112 try:
--> 113 self.trainer.fit(self.training_plan, self.data_splitter)
114 except NameError:
115 import gc
File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/scvi/train/_trainer.py:215, in Trainer.fit(self, *args, **kwargs)
209 warnings.filterwarnings(
210 action="ignore",
211 category=UserWarning,
212 message="LightningModule.configure_optimizers returned None",
213 )
214 try:
--> 215 super().fit(*args, **kwargs)
216 except NameError:
217 import gc
File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:532, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
530 self.strategy._lightning_module = model
531 _verify_strategy_supports_compile(model, self.strategy)
--> 532 call._call_and_handle_interrupt(
533 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
534 )
File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:43, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
41 if trainer.strategy.launcher is not None:
42 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 43 return trainer_fn(*args, **kwargs)
45 except _TunerExitException:
46 _call_teardown_hook(trainer)
File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:571, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
561 self._data_connector.attach_data(
562 model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
563 )
565 ckpt_path = self._checkpoint_connector._select_ckpt_path(
566 self.state.fn,
567 ckpt_path,
568 model_provided=True,
569 model_connected=self.lightning_module is not None,
570 )
--> 571 self._run(model, ckpt_path=ckpt_path)
573 assert self.state.stopped
574 self.training = False
File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:956, in Trainer._run(self, model, ckpt_path)
953 self._logger_connector.reset_metrics()
955 # strategy will configure model and move it to the device
--> 956 self.strategy.setup(self)
958 # hook
959 if self.state.fn == TrainerFn.FITTING:
File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/lightning/pytorch/strategies/single_device.py:75, in SingleDeviceStrategy.setup(self, trainer)
74 def setup(self, trainer: pl.Trainer) -> None:
---> 75 self.model_to_device()
76 super().setup(trainer)
File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/lightning/pytorch/strategies/single_device.py:72, in SingleDeviceStrategy.model_to_device(self)
70 def model_to_device(self) -> None:
71 assert self.model is not None, "self.model must be set before self.model.to()"
---> 72 self.model.to(self.root_device)
File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/lightning/fabric/utilities/device_dtype_mixin.py:54, in _DeviceDtypeModuleMixin.to(self, *args, **kwargs)
52 device, dtype = torch._C._nn._parse_to(*args, **kwargs)[:2]
53 self.__update_properties(device=device, dtype=dtype)
---> 54 return super().to(*args, **kwargs)
Cell In[2], line 72, in patched_to(self, *args, **kwargs)
69 cpu_module._buffers[name] = buf.to(torch.float32)
71 # 现在移动到MPS
---> 72 return original_module_to(cpu_module, device_arg)
73 return original_module_to(self, *args, **kwargs)
File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/torch/nn/modules/module.py:1343, in Module.to(self, *args, **kwargs)
1340 else:
1341 raise
-> 1343 return self._apply(convert)
File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/torch/nn/modules/module.py:903, in Module._apply(self, fn, recurse)
901 if recurse:
902 for module in self.children():
--> 903 module._apply(fn)
905 def compute_should_use_set_data(tensor, tensor_applied):
906 if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
907 # If the new tensor has compatible tensor type as the existing tensor,
908 # the current behavior is to change the tensor in-place using .data =,
(...)
913 # global flag to let the user control whether they want the future
914 # behavior of overwriting the existing tensor or not.
File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/torch/nn/modules/module.py:903, in Module._apply(self, fn, recurse)
901 if recurse:
902 for module in self.children():
--> 903 module._apply(fn)
905 def compute_should_use_set_data(tensor, tensor_applied):
906 if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
907 # If the new tensor has compatible tensor type as the existing tensor,
908 # the current behavior is to change the tensor in-place using .data =,
(...)
913 # global flag to let the user control whether they want the future
914 # behavior of overwriting the existing tensor or not.
File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/torch/nn/modules/module.py:991, in Module._apply(self, fn, recurse)
989 for key, buf in self._buffers.items():
990 if buf is not None:
--> 991 self._buffers[key] = fn(buf)
993 return self
File /opt/anaconda3/envs/scanpy/lib/python3.10/site-packages/torch/nn/modules/module.py:1329, in Module.to..convert(t)
1322 if convert_to_format is not None and t.dim() in (4, 5):
1323 return t.to(
1324 device,
1325 dtype if t.is_floating_point() or t.is_complex() else None,
1326 non_blocking,
1327 memory_format=convert_to_format,
1328 )
-> 1329 return t.to(
1330 device,
1331 dtype if t.is_floating_point() or t.is_complex() else None,
1332 non_blocking,
1333 )
1334 except NotImplementedError as e:
1335 if str(e) == "Cannot copy out of meta tensor; no data!":
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
Questions
Is there a recommended way to ensure all Pyro model parameters and buffers use float32 when working with MPS backend?
Should this be handled at the Pyro level, or is this a Cell2location-specific issue?
Are there any known workarounds for using Pyro models with MPS acceleration?