-
Notifications
You must be signed in to change notification settings - Fork 7
Closed
Description
Expected Behavior
When setting up config.taml for starts training from a customized pretrained checkpoint trained from sleap-nn, sleap-nn should load the checkpoint correctly wihtout an error.
Actual Behavior
When running sleap-nn train --config-dir $(pwd) --config-name config, it throws an RuntimeError:
File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/sleap_nn/cli.py", line 98, in train
run_training(cfg)
File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/sleap_nn/train.py", line 29, in run_training
trainer = ModelTrainer.get_model_trainer_from_config(config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/sleap_nn/training/model_trainer.py", line 134, in get_model_trainer_from_config
model_trainer.setup_config()
File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/sleap_nn/training/model_trainer.py", line 505, in setup_config
self._verify_model_input_channels()
File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/sleap_nn/training/model_trainer.py", line 389, in _verify_model_input_channels
pretrained_backbone_ckpt = torch.load(
^^^^^^^^^^^
File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/torch/serialization.py", line 1530, in load
return _load(
^^^^^^
File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/torch/serialization.py", line 2119, in _load
result = unpickler.load()
^^^^^^^^^^^^^^^^
File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/torch/serialization.py", line 2083, in persistent_load
typed_storage = load_tensor(
^^^^^^^^^^^^
File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/torch/serialization.py", line 2049, in load_tensor
wrap_storage = restore_location(storage, location)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/torch/serialization.py", line 1859, in restore_location
return default_restore_location(storage, map_location)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/torch/serialization.py", line 701, in default_restore_location
raise RuntimeError(
RuntimeError: don't know how to restore data location of torch.storage.UntypedStorage (tagged with gpu)Hypothesis
torch.load() does not accept gpu as an input, it should accept cuda, cpu, mps, etc. In sleap-nn.training.model_trainer.py:L389:
pretrained_backbone_ckpt = torch.load(
self.config.model_config.pretrained_backbone_weights,
map_location=(
self.config.trainer_config.trainer_accelerator
if self.config.trainer_config.trainer_accelerator is not None # <== In config it's `gpu`.
or self.config.trainer_config.trainer_accelerator != "auto"
else "cpu"
),
weights_only=False,
)
Suggested Fix
Change to cuda or a casade (try gpu backends like cuda and mps first), then cpu.
Metadata
Metadata
Assignees
Labels
No labels