Skip to content

Commit fffe136

Browse files
committed
Improve the safety of torch.load with weights_only=True
1 parent 32428a2 commit fffe136

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

hanlp/common/torch_component.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Date: 2020-05-08 21:20
44
import logging
55
import os
6+
import pickle
67
import re
78
import time
89
from abc import ABC, abstractmethod
@@ -97,7 +98,10 @@ def load_weights(self, save_dir, filename='model.pt', **kwargs):
9798
save_dir = get_resource(save_dir)
9899
filename = os.path.join(save_dir, filename)
99100
# flash(f'Loading model: {filename} [blink]...[/blink][/yellow]')
100-
self.model_.load_state_dict(torch.load(filename, map_location='cpu'), strict=False)
101+
try:
102+
self.model_.load_state_dict(torch.load(filename, map_location='cpu', weights_only=True), strict=False)
103+
except pickle.UnpicklingError:
104+
self.model_.load_state_dict(torch.load(filename, map_location='cpu', weights_only=False), strict=False)
101105
# flash('')
102106

103107
def save_config(self, save_dir, filename='config.json'):

0 commit comments

Comments
 (0)