Skip to content

Commit 7c66d8a

Browse files
committed
average_checkpoints
1 parent caab062 commit 7c66d8a

File tree

1 file changed

+55
-1
lines changed

1 file changed

+55
-1
lines changed

pytorch_toolbelt/inference/ensembling.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import collections
2+
13
import torch
24
from torch import nn, Tensor
35
from typing import List, Union, Iterable, Optional, Dict, Tuple
46

5-
__all__ = ["ApplySoftmaxTo", "ApplySigmoidTo", "Ensembler", "PickModelOutput", "SelectByIndex"]
7+
__all__ = ["ApplySoftmaxTo", "ApplySigmoidTo", "Ensembler", "PickModelOutput", "SelectByIndex", "average_checkpoints"]
68

79
from pytorch_toolbelt.inference.tta import _deaugment_averaging
810

@@ -159,3 +161,55 @@ def __init__(self, key: Union[str, int]):
159161

160162
def forward(self, outputs: Dict[str, Tensor]) -> Tensor:
161163
return outputs[self.target_key]
164+
165+
166+
def average_checkpoints(inputs: List[str]) -> collections.OrderedDict:
167+
"""Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from:
168+
https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16
169+
Args:
170+
inputs (List[str]): An iterable of string paths of checkpoints to load from.
171+
Returns:
172+
A dict of string keys mapping to various values. The 'model' key
173+
from the returned dict should correspond to an OrderedDict mapping
174+
string parameter names to torch Tensors.
175+
"""
176+
params_dict = collections.OrderedDict()
177+
params_keys = None
178+
new_state = None
179+
num_models = len(inputs)
180+
for fpath in inputs:
181+
with open(fpath, "rb") as f:
182+
state = torch.load(
183+
f,
184+
map_location="cpu",
185+
)
186+
# Copies over the settings from the first checkpoint
187+
if new_state is None:
188+
new_state = state
189+
model_params = state["model_state_dict"]
190+
model_params_keys = list(model_params.keys())
191+
if params_keys is None:
192+
params_keys = model_params_keys
193+
elif params_keys != model_params_keys:
194+
raise KeyError(
195+
"For checkpoint {}, expected list of params: {}, "
196+
"but found: {}".format(f, params_keys, model_params_keys)
197+
)
198+
for k in params_keys:
199+
p = model_params[k]
200+
if isinstance(p, torch.HalfTensor):
201+
p = p.float()
202+
if k not in params_dict:
203+
params_dict[k] = p.clone()
204+
# NOTE: clone() is needed in case of p is a shared parameter
205+
else:
206+
params_dict[k] += p
207+
averaged_params = collections.OrderedDict()
208+
for k, v in params_dict.items():
209+
averaged_params[k] = v
210+
if averaged_params[k].is_floating_point():
211+
averaged_params[k].div_(num_models)
212+
else:
213+
averaged_params[k] //= num_models
214+
new_state["model_state_dict"] = averaged_params
215+
return new_state

0 commit comments

Comments
 (0)