Skip to content

Commit

Permalink
average_checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Jan 30, 2024
1 parent caab062 commit 7c66d8a
Showing 1 changed file with 55 additions and 1 deletion.
56 changes: 55 additions & 1 deletion pytorch_toolbelt/inference/ensembling.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import collections

import torch
from torch import nn, Tensor
from typing import List, Union, Iterable, Optional, Dict, Tuple

__all__ = ["ApplySoftmaxTo", "ApplySigmoidTo", "Ensembler", "PickModelOutput", "SelectByIndex"]
__all__ = ["ApplySoftmaxTo", "ApplySigmoidTo", "Ensembler", "PickModelOutput", "SelectByIndex", "average_checkpoints"]

from pytorch_toolbelt.inference.tta import _deaugment_averaging

Expand Down Expand Up @@ -159,3 +161,55 @@ def __init__(self, key: Union[str, int]):

def forward(self, outputs: Dict[str, Tensor]) -> Tensor:
return outputs[self.target_key]


def average_checkpoints(inputs: List[str]) -> collections.OrderedDict:
"""Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from:
https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16
Args:
inputs (List[str]): An iterable of string paths of checkpoints to load from.
Returns:
A dict of string keys mapping to various values. The 'model' key
from the returned dict should correspond to an OrderedDict mapping
string parameter names to torch Tensors.
"""
params_dict = collections.OrderedDict()
params_keys = None
new_state = None
num_models = len(inputs)
for fpath in inputs:
with open(fpath, "rb") as f:
state = torch.load(
f,
map_location="cpu",
)
# Copies over the settings from the first checkpoint
if new_state is None:
new_state = state
model_params = state["model_state_dict"]
model_params_keys = list(model_params.keys())
if params_keys is None:
params_keys = model_params_keys
elif params_keys != model_params_keys:
raise KeyError(
"For checkpoint {}, expected list of params: {}, "
"but found: {}".format(f, params_keys, model_params_keys)
)
for k in params_keys:
p = model_params[k]
if isinstance(p, torch.HalfTensor):
p = p.float()
if k not in params_dict:
params_dict[k] = p.clone()
# NOTE: clone() is needed in case of p is a shared parameter
else:
params_dict[k] += p
averaged_params = collections.OrderedDict()
for k, v in params_dict.items():
averaged_params[k] = v
if averaged_params[k].is_floating_point():
averaged_params[k].div_(num_models)
else:
averaged_params[k] //= num_models
new_state["model_state_dict"] = averaged_params
return new_state

0 comments on commit 7c66d8a

Please sign in to comment.