Skip to content

Loading customized pretrained_backbone_weights throws PyTorch error #321

@tom21100227

Description

@tom21100227

Expected Behavior

When setting up config.taml for starts training from a customized pretrained checkpoint trained from sleap-nn, sleap-nn should load the checkpoint correctly wihtout an error.

Actual Behavior

When running sleap-nn train --config-dir $(pwd) --config-name config, it throws an RuntimeError:

File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/sleap_nn/cli.py", line 98, in train
    run_training(cfg)
  File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/sleap_nn/train.py", line 29, in run_training
    trainer = ModelTrainer.get_model_trainer_from_config(config)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/sleap_nn/training/model_trainer.py", line 134, in get_model_trainer_from_config
    model_trainer.setup_config()
  File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/sleap_nn/training/model_trainer.py", line 505, in setup_config
    self._verify_model_input_channels()
  File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/sleap_nn/training/model_trainer.py", line 389, in _verify_model_input_channels
    pretrained_backbone_ckpt = torch.load(
                               ^^^^^^^^^^^
  File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/torch/serialization.py", line 1530, in load
    return _load(
           ^^^^^^
  File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/torch/serialization.py", line 2119, in _load
    result = unpickler.load()
             ^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/torch/serialization.py", line 2083, in persistent_load
    typed_storage = load_tensor(
                    ^^^^^^^^^^^^
  File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/torch/serialization.py", line 2049, in load_tensor
    wrap_storage = restore_location(storage, location)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/torch/serialization.py", line 1859, in restore_location
    return default_restore_location(storage, map_location)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/sleap-nn/lib/python3.12/site-packages/torch/serialization.py", line 701, in default_restore_location
    raise RuntimeError(
RuntimeError: don't know how to restore data location of torch.storage.UntypedStorage (tagged with gpu)

Hypothesis

torch.load() does not accept gpu as an input, it should accept cuda, cpu, mps, etc. In sleap-nn.training.model_trainer.py:L389:

pretrained_backbone_ckpt = torch.load(
                    self.config.model_config.pretrained_backbone_weights,
                    map_location=(
                        self.config.trainer_config.trainer_accelerator
                        if self.config.trainer_config.trainer_accelerator is not None # <== In config it's `gpu`. 
                        or self.config.trainer_config.trainer_accelerator != "auto"
                        else "cpu"
                    ),
                    weights_only=False,
                )

Suggested Fix

Change to cuda or a casade (try gpu backends like cuda and mps first), then cpu.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions