From 47fd13e0d6e2d1a167ec5baf0f9dfce9d998e981 Mon Sep 17 00:00:00 2001 From: BloodAxe Date: Thu, 6 Feb 2025 20:31:58 +0200 Subject: [PATCH] Reimplement average_checkpoints which uses average_state_dicts now --- pytorch_toolbelt/inference/ensembling.py | 127 +++++++++++++++-------- 1 file changed, 83 insertions(+), 44 deletions(-) diff --git a/pytorch_toolbelt/inference/ensembling.py b/pytorch_toolbelt/inference/ensembling.py index 30d9b6287..9a8548f93 100644 --- a/pytorch_toolbelt/inference/ensembling.py +++ b/pytorch_toolbelt/inference/ensembling.py @@ -4,7 +4,15 @@ from torch import nn, Tensor from typing import List, Union, Iterable, Optional, Dict, Tuple -__all__ = ["ApplySoftmaxTo", "ApplySigmoidTo", "Ensembler", "PickModelOutput", "SelectByIndex", "average_checkpoints"] +__all__ = [ + "ApplySoftmaxTo", + "ApplySigmoidTo", + "Ensembler", + "PickModelOutput", + "SelectByIndex", + "average_checkpoints", + "average_state_dicts", +] from pytorch_toolbelt.inference.tta import _deaugment_averaging @@ -163,53 +171,84 @@ def forward(self, outputs: Dict[str, Tensor]) -> Tensor: return outputs[self.target_key] -def average_checkpoints(inputs: List[str]) -> collections.OrderedDict: - """Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from: - https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16 +def average_state_dicts(state_dicts: List[Mapping[str, Tensor]]) -> Mapping[str, Tensor]: + """ + Averages multiple 'state_dict' + + """ + + keys = state_dicts[0].keys() + final_state_dict = collections.OrderedDict() + + for key in keys: + # Collect the values (tensors) for this key from all checkpoints + values = [sd[key] for sd in state_dicts] + + # Check the dtype of the first value (assuming all dtypes match) + first_val = values[0] + + if not all(v.shape == first_val.shape for v in values): + raise ValueError(f"Tensor shapes for key '{key}' are not consistent across checkpoints.") + + if first_val.dtype == torch.bool: + # For bool, ensure all are identical + for val in values[1:]: + if not torch.equal(val, first_val): + raise ValueError(f"Boolean values for key '{key}' differ between checkpoints.") + final_state_dict[key] = first_val # Use the first if all identical + + elif torch.is_floating_point(first_val): + # Average float values + stacked = torch.stack(values, dim=0) + target_dtype = stacked.dtype + accum_dtype = torch.promote_types(target_dtype, torch.float32) # Upcast to float32 if needed + averaged = stacked.to(accum_dtype).mean(dim=0).to(target_dtype) + final_state_dict[key] = averaged + + elif first_val.dtype in { + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + }: + # Average integer values (using integer division) + stacked = torch.stack(values, dim=0) + summed = stacked.sum(dim=0, dtype=torch.int64) + averaged = summed // len(values) + final_state_dict[key] = averaged.to(first_val.dtype) + + else: + # If you have other special dtypes to handle, add logic here + # or simply copy the first value if that is your intended behavior. + raise TypeError(f"Unsupported dtype '{first_val.dtype}' encountered for key '{key}'.") + + return final_state_dict + + +def average_checkpoints(inputs: List[str], key=None, map_location="cpu", weights_only=True) -> collections.OrderedDict: + """Loads checkpoints from inputs and returns a model with averaged weights. + Args: inputs (List[str]): An iterable of string paths of checkpoints to load from. + key (str): An optional key to select a sub-dictionary from the checkpoint. + map_location (str): A string describing how to remap storage locations when loading the model. + weights_only (bool): If True, will only load the weights of the model. + Returns: A dict of string keys mapping to various values. The 'model' key from the returned dict should correspond to an OrderedDict mapping string parameter names to torch Tensors. """ - params_dict = collections.OrderedDict() - params_keys = None - new_state = None - num_models = len(inputs) - for fpath in inputs: - with open(fpath, "rb") as f: - state = torch.load( - f, - map_location="cpu", - ) - # Copies over the settings from the first checkpoint - if new_state is None: - new_state = state - model_params = state["model_state_dict"] - model_params_keys = list(model_params.keys()) - if params_keys is None: - params_keys = model_params_keys - elif params_keys != model_params_keys: - raise KeyError( - "For checkpoint {}, expected list of params: {}, " - "but found: {}".format(f, params_keys, model_params_keys) - ) - for k in params_keys: - p = model_params[k] - if isinstance(p, torch.HalfTensor): - p = p.float() - if k not in params_dict: - params_dict[k] = p.clone() - # NOTE: clone() is needed in case of p is a shared parameter - else: - params_dict[k] += p - averaged_params = collections.OrderedDict() - for k, v in params_dict.items(): - averaged_params[k] = v - if averaged_params[k].is_floating_point(): - averaged_params[k].div_(num_models) - else: - averaged_params[k] //= num_models - new_state["model_state_dict"] = averaged_params - return new_state + state_dicts = [torch.load(path, map_location="cpu", weights_only=weights_only) for path in inputs] + if key is not None: + state_dicts = [sd[key] for sd in state_dicts] + + avg_state_dict = average_state_dicts(state_dicts) + if key is not None: + avg_state_dict = {key: avg_state_dict} + + return avg_state_dict