Skip to content

Make Detectron2 Work Better with Torch 2.6.0 torch.load semantics #5456

Open
@mritterfigma

Description

@mritterfigma

🚀 Feature

Torch 2.6.0 makes a change to torch.load where weights_only=True by default (https://pytorch.org/docs/stable/notes/serialization.html#torch-load-with-weights-only-true). This causes any torch.load calls to fail if they don't either (a) set weights_only=False or (b) call torch.serialization.add_safe_globals([{__name__}]) for all of the classes that they want to allowlist in torch.

Unfortunately Detectron2 wraps the call to torch.load inside DetectronCheckpointer.load. So, there is no easy way to pass weights_only=True. You can workaround this by getting all of the globals in the checkpoint and manually adding them, but it's annoying and takes some fiddling.

Instead, Detectron should support this weights_only param and pass it through to torch.

Here is an example of the workaround code:

  safe_globals = torch.serialization.get_unsafe_globals_in_checkpoint(path_to_checkpoint)

  imported_objects = []
  for item in safe_globals:
      module_name, _, class_name = item.rpartition('.')
      module = importlib.import_module(module_name)
      imported_object = getattr(module, class_name)
      imported_objects.append(imported_object)

  torch.serialization.add_safe_globals(imported_objects)

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementImprovements or good new features

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions