Skip to content

Commit 5008d7b

Browse files
committed
Improve the safety of torch.load with weights_only=True
1 parent 32428a2 commit 5008d7b

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

hanlp/common/torch_component.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,10 @@ def load_weights(self, save_dir, filename='model.pt', **kwargs):
9797
save_dir = get_resource(save_dir)
9898
filename = os.path.join(save_dir, filename)
9999
# flash(f'Loading model: {filename} [blink]...[/blink][/yellow]')
100-
self.model_.load_state_dict(torch.load(filename, map_location='cpu'), strict=False)
100+
try:
101+
self.model_.load_state_dict(torch.load(filename, map_location='cpu', weights_only=True), strict=False)
102+
except TypeError:
103+
self.model_.load_state_dict(torch.load(filename, map_location='cpu'), strict=False)
101104
# flash('')
102105

103106
def save_config(self, save_dir, filename='config.json'):

0 commit comments

Comments
 (0)