Skip to content

Commit

Permalink
Reimplement average_checkpoints which uses average_state_dicts now
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Feb 6, 2025
1 parent 05e07ac commit 47fd13e
Showing 1 changed file with 83 additions and 44 deletions.
127 changes: 83 additions & 44 deletions pytorch_toolbelt/inference/ensembling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@
from torch import nn, Tensor
from typing import List, Union, Iterable, Optional, Dict, Tuple

__all__ = ["ApplySoftmaxTo", "ApplySigmoidTo", "Ensembler", "PickModelOutput", "SelectByIndex", "average_checkpoints"]
__all__ = [
"ApplySoftmaxTo",
"ApplySigmoidTo",
"Ensembler",
"PickModelOutput",
"SelectByIndex",
"average_checkpoints",
"average_state_dicts",
]

from pytorch_toolbelt.inference.tta import _deaugment_averaging

Expand Down Expand Up @@ -163,53 +171,84 @@ def forward(self, outputs: Dict[str, Tensor]) -> Tensor:
return outputs[self.target_key]


def average_checkpoints(inputs: List[str]) -> collections.OrderedDict:
"""Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from:
https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16
def average_state_dicts(state_dicts: List[Mapping[str, Tensor]]) -> Mapping[str, Tensor]:
"""
Averages multiple 'state_dict'
"""

keys = state_dicts[0].keys()
final_state_dict = collections.OrderedDict()

for key in keys:
# Collect the values (tensors) for this key from all checkpoints
values = [sd[key] for sd in state_dicts]

# Check the dtype of the first value (assuming all dtypes match)
first_val = values[0]

if not all(v.shape == first_val.shape for v in values):
raise ValueError(f"Tensor shapes for key '{key}' are not consistent across checkpoints.")

if first_val.dtype == torch.bool:
# For bool, ensure all are identical
for val in values[1:]:
if not torch.equal(val, first_val):
raise ValueError(f"Boolean values for key '{key}' differ between checkpoints.")
final_state_dict[key] = first_val # Use the first if all identical

elif torch.is_floating_point(first_val):
# Average float values
stacked = torch.stack(values, dim=0)
target_dtype = stacked.dtype
accum_dtype = torch.promote_types(target_dtype, torch.float32) # Upcast to float32 if needed
averaged = stacked.to(accum_dtype).mean(dim=0).to(target_dtype)
final_state_dict[key] = averaged

elif first_val.dtype in {
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
}:
# Average integer values (using integer division)
stacked = torch.stack(values, dim=0)
summed = stacked.sum(dim=0, dtype=torch.int64)
averaged = summed // len(values)
final_state_dict[key] = averaged.to(first_val.dtype)

else:
# If you have other special dtypes to handle, add logic here
# or simply copy the first value if that is your intended behavior.
raise TypeError(f"Unsupported dtype '{first_val.dtype}' encountered for key '{key}'.")

return final_state_dict


def average_checkpoints(inputs: List[str], key=None, map_location="cpu", weights_only=True) -> collections.OrderedDict:
"""Loads checkpoints from inputs and returns a model with averaged weights.
Args:
inputs (List[str]): An iterable of string paths of checkpoints to load from.
key (str): An optional key to select a sub-dictionary from the checkpoint.
map_location (str): A string describing how to remap storage locations when loading the model.
weights_only (bool): If True, will only load the weights of the model.
Returns:
A dict of string keys mapping to various values. The 'model' key
from the returned dict should correspond to an OrderedDict mapping
string parameter names to torch Tensors.
"""
params_dict = collections.OrderedDict()
params_keys = None
new_state = None
num_models = len(inputs)
for fpath in inputs:
with open(fpath, "rb") as f:
state = torch.load(
f,
map_location="cpu",
)
# Copies over the settings from the first checkpoint
if new_state is None:
new_state = state
model_params = state["model_state_dict"]
model_params_keys = list(model_params.keys())
if params_keys is None:
params_keys = model_params_keys
elif params_keys != model_params_keys:
raise KeyError(
"For checkpoint {}, expected list of params: {}, "
"but found: {}".format(f, params_keys, model_params_keys)
)
for k in params_keys:
p = model_params[k]
if isinstance(p, torch.HalfTensor):
p = p.float()
if k not in params_dict:
params_dict[k] = p.clone()
# NOTE: clone() is needed in case of p is a shared parameter
else:
params_dict[k] += p
averaged_params = collections.OrderedDict()
for k, v in params_dict.items():
averaged_params[k] = v
if averaged_params[k].is_floating_point():
averaged_params[k].div_(num_models)
else:
averaged_params[k] //= num_models
new_state["model_state_dict"] = averaged_params
return new_state
state_dicts = [torch.load(path, map_location="cpu", weights_only=weights_only) for path in inputs]
if key is not None:
state_dicts = [sd[key] for sd in state_dicts]

avg_state_dict = average_state_dicts(state_dicts)
if key is not None:
avg_state_dict = {key: avg_state_dict}

return avg_state_dict

0 comments on commit 47fd13e

Please sign in to comment.