|
| 1 | +import collections |
| 2 | + |
1 | 3 | import torch
|
2 | 4 | from torch import nn, Tensor
|
3 | 5 | from typing import List, Union, Iterable, Optional, Dict, Tuple
|
4 | 6 |
|
5 |
| -__all__ = ["ApplySoftmaxTo", "ApplySigmoidTo", "Ensembler", "PickModelOutput", "SelectByIndex"] |
| 7 | +__all__ = ["ApplySoftmaxTo", "ApplySigmoidTo", "Ensembler", "PickModelOutput", "SelectByIndex", "average_checkpoints"] |
6 | 8 |
|
7 | 9 | from pytorch_toolbelt.inference.tta import _deaugment_averaging
|
8 | 10 |
|
@@ -159,3 +161,55 @@ def __init__(self, key: Union[str, int]):
|
159 | 161 |
|
160 | 162 | def forward(self, outputs: Dict[str, Tensor]) -> Tensor:
|
161 | 163 | return outputs[self.target_key]
|
| 164 | + |
| 165 | + |
| 166 | +def average_checkpoints(inputs: List[str]) -> collections.OrderedDict: |
| 167 | + """Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from: |
| 168 | + https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16 |
| 169 | + Args: |
| 170 | + inputs (List[str]): An iterable of string paths of checkpoints to load from. |
| 171 | + Returns: |
| 172 | + A dict of string keys mapping to various values. The 'model' key |
| 173 | + from the returned dict should correspond to an OrderedDict mapping |
| 174 | + string parameter names to torch Tensors. |
| 175 | + """ |
| 176 | + params_dict = collections.OrderedDict() |
| 177 | + params_keys = None |
| 178 | + new_state = None |
| 179 | + num_models = len(inputs) |
| 180 | + for fpath in inputs: |
| 181 | + with open(fpath, "rb") as f: |
| 182 | + state = torch.load( |
| 183 | + f, |
| 184 | + map_location="cpu", |
| 185 | + ) |
| 186 | + # Copies over the settings from the first checkpoint |
| 187 | + if new_state is None: |
| 188 | + new_state = state |
| 189 | + model_params = state["model_state_dict"] |
| 190 | + model_params_keys = list(model_params.keys()) |
| 191 | + if params_keys is None: |
| 192 | + params_keys = model_params_keys |
| 193 | + elif params_keys != model_params_keys: |
| 194 | + raise KeyError( |
| 195 | + "For checkpoint {}, expected list of params: {}, " |
| 196 | + "but found: {}".format(f, params_keys, model_params_keys) |
| 197 | + ) |
| 198 | + for k in params_keys: |
| 199 | + p = model_params[k] |
| 200 | + if isinstance(p, torch.HalfTensor): |
| 201 | + p = p.float() |
| 202 | + if k not in params_dict: |
| 203 | + params_dict[k] = p.clone() |
| 204 | + # NOTE: clone() is needed in case of p is a shared parameter |
| 205 | + else: |
| 206 | + params_dict[k] += p |
| 207 | + averaged_params = collections.OrderedDict() |
| 208 | + for k, v in params_dict.items(): |
| 209 | + averaged_params[k] = v |
| 210 | + if averaged_params[k].is_floating_point(): |
| 211 | + averaged_params[k].div_(num_models) |
| 212 | + else: |
| 213 | + averaged_params[k] //= num_models |
| 214 | + new_state["model_state_dict"] = averaged_params |
| 215 | + return new_state |
0 commit comments