Description
🚀 Feature
Torch 2.6.0 makes a change to torch.load
where weights_only=True
by default (https://pytorch.org/docs/stable/notes/serialization.html#torch-load-with-weights-only-true). This causes any torch.load
calls to fail if they don't either (a) set weights_only=False
or (b) call torch.serialization.add_safe_globals([{__name__}])
for all of the classes that they want to allowlist in torch.
Unfortunately Detectron2 wraps the call to torch.load
inside DetectronCheckpointer.load. So, there is no easy way to pass weights_only=True
. You can workaround this by getting all of the globals in the checkpoint and manually adding them, but it's annoying and takes some fiddling.
Instead, Detectron should support this weights_only
param and pass it through to torch.
Here is an example of the workaround code:
safe_globals = torch.serialization.get_unsafe_globals_in_checkpoint(path_to_checkpoint)
imported_objects = []
for item in safe_globals:
module_name, _, class_name = item.rpartition('.')
module = importlib.import_module(module_name)
imported_object = getattr(module, class_name)
imported_objects.append(imported_object)
torch.serialization.add_safe_globals(imported_objects)