Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions training/docs/user-guide/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -570,3 +570,50 @@ frozen and only the encoder and decoder will be trained:
Freezing can be particularly beneficial in scenarios such as fine-tuning
when only specific components (e.g., the encoder, the decoder) need to
adapt to a new task while keeping others (e.g., the processor) fixed.

******************
Weight Averaging
******************

Weight averaging is a technique to improve model generalization by
averaging model weights during training. Anemoi Training supports weight
averaging methods through PyTorch Lightning callbacks:

- **Exponential Moving Average (EMA)**: Maintains an exponential moving
average of model weights, which can lead to smoother convergence
and better generalization.

.. code:: yaml
weight_averaging:
_target_: pytorch_lightning.callbacks.EMAWeightAveraging
decay: 0.999
The ``decay`` parameter (typically between 0.99 and 0.9999)
controls the smoothing factor. Higher values give more weight to
historical weights, resulting in a more stable average. By
default, the decay is set to 0.999.

- **Stochastic Weight Averaging (SWA)**: Averages weights from multiple
points along the training trajectory, typically resulting in wider
optima and improved generalization.

.. code:: yaml
weight_averaging:
_target_: pytorch_lightning.callbacks.StochasticWeightAveraging
swa_lrs: 1.e-4
The ``swa_lrs`` parameter specifies the learning rate to use
during the SWA phase. By default, the learning rate is set to
1e-4. Additional parameters can be configured as described in the
[PyTorch Lightning
documentation](https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.StochasticWeightAveraging.html#lightning.pytorch.callbacks.StochasticWeightAveraging)

By default, weight averaging is disabled. To explicitly disable it or to
override a parent configuration, set ``weight_averaging`` to null.

.. note::

Weight averaging is only supported in PyTorch Lightning 2.6 and later
versions.
13 changes: 8 additions & 5 deletions training/src/anemoi/training/config/training/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ gradient_clip:
val: 32.
algorithm: value

# stochastic weight averaging
# https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
swa:
enabled: False
lr: 1.e-4
# weight averaging only supported in Lightning 2.6+
weight_averaging: null
# For EMA:
# _target_: pytorch_lightning.callbacks.EMAWeightAveraging
# decay: 0.999
# For SWA:
# _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
# swa_lrs: 1.e-4

# =====================================================================
# Optimizer configuration
Expand Down
13 changes: 8 additions & 5 deletions training/src/anemoi/training/config/training/diffusion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ gradient_clip:
val: 1.
algorithm: norm

# stochastic weight averaging
# https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
swa:
enabled: False
lr: 1.e-4
# weight averaging only supported in Lightning 2.6+
weight_averaging: null
# For EMA:
# _target_: pytorch_lightning.callbacks.EMAWeightAveraging
# decay: 0.999
# For SWA:
# _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
# swa_lrs: 1.e-4

# Optimizer settings
optimizer:
Expand Down
13 changes: 8 additions & 5 deletions training/src/anemoi/training/config/training/ensemble.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ gradient_clip:
val: 32.
algorithm: value

# stochastic weight averaging
# https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
swa:
enabled: False
lr: 1.e-4
# weight averaging only supported in Lightning 2.6+
weight_averaging: null
# For EMA:
# _target_: pytorch_lightning.callbacks.EMAWeightAveraging
# decay: 0.999
# For SWA:
# _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
# swa_lrs: 1.e-4

# Optimizer settings
optimizer:
Expand Down
13 changes: 8 additions & 5 deletions training/src/anemoi/training/config/training/interpolator.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ gradient_clip:
val: 32.
algorithm: value

# stochastic weight averaging
# https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
swa:
enabled: False
lr: 1.e-4
# weight averaging only supported in Lightning 2.6+
weight_averaging: null
# For EMA:
# _target_: pytorch_lightning.callbacks.EMAWeightAveraging
# decay: 0.999
# For SWA:
# _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
# swa_lrs: 1.e-4

# Optimizer settings
optimizer:
Expand Down
13 changes: 8 additions & 5 deletions training/src/anemoi/training/config/training/lam.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ gradient_clip:
val: 32.
algorithm: value

# stochastic weight averaging
# https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
swa:
enabled: False
lr: 1.e-4
# weight averaging only supported in Lightning 2.6+
weight_averaging: null
# For EMA:
# _target_: pytorch_lightning.callbacks.EMAWeightAveraging
# decay: 0.999
# For SWA:
# _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
# swa_lrs: 1.e-4

# Optimizer settings
optimizer:
Expand Down
13 changes: 8 additions & 5 deletions training/src/anemoi/training/config/training/stretched.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ gradient_clip:
val: 32.
algorithm: value

# stochastic weight averaging
# https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
swa:
enabled: False
lr: 1.e-4
# weight averaging only supported in Lightning 2.6+
weight_averaging: null
# For EMA:
# _target_: pytorch_lightning.callbacks.EMAWeightAveraging
# decay: 0.999
# For SWA:
# _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
# swa_lrs: 1.e-4

# Optimizer settings
optimizer:
Expand Down
38 changes: 36 additions & 2 deletions training/src/anemoi/training/diagnostics/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from anemoi.training.diagnostics.callbacks.checkpoint import AnemoiCheckpoint
from anemoi.training.diagnostics.callbacks.optimiser import LearningRateMonitor
from anemoi.training.diagnostics.callbacks.optimiser import StochasticWeightAveraging
from anemoi.training.diagnostics.callbacks.provenance import ParentUUIDCallback
from anemoi.training.diagnostics.callbacks.sanity import CheckVariableOrder
from anemoi.training.schemas.base_schema import BaseSchema
Expand All @@ -49,7 +48,6 @@ def nestedget(config: DictConfig, key: str, default: Any) -> Any:
# Callbacks to add according to flags in the config
# Can be function to check status from config
CONFIG_ENABLED_CALLBACKS: list[tuple[list[str] | str | Callable[[DictConfig], bool], type[Callback]]] = [
("training.swa.enabled", StochasticWeightAveraging),
(
lambda config: nestedget(config, "diagnostics.log.wandb.enabled", False)
or nestedget(config, "diagnostics.log.mlflow.enabled", False),
Expand All @@ -58,6 +56,39 @@ def nestedget(config: DictConfig, key: str, default: Any) -> Any:
]


def _get_weight_averaging_callback(config: DictConfig) -> list[Callback]:
"""Get weight averaging callback.
Supported are ExponentialMovingAverage and StochasticWeightAveraging.
Example config:
weight_averaging:
_target_: anemoi.training.diagnostics.callbacks.optimiser.ExponentialMovingAverage
decay: 0.999
Parameters
----------
config : DictConfig
Job configuration
Returns
-------
list[Callback]
List containing the weight averaging callback, or empty list if not configured.
"""
weight_averaging_config = nestedget(config, "training.weight_averaging", None)

if weight_averaging_config is not None:
try:
weight_averaging = instantiate(weight_averaging_config)
LOGGER.info("Using weight averaging: %s", type(weight_averaging))
except InstantiationException:
LOGGER.warning("Failed to instantiate weight averaging callback from config: %s", weight_averaging_config)
else:
return [weight_averaging]
return []


def _get_checkpoint_callback(config: BaseSchema) -> list[AnemoiCheckpoint]:
"""Get checkpointing callbacks."""
if not config.diagnostics.enable_checkpointing:
Expand Down Expand Up @@ -226,6 +257,9 @@ def get_callbacks(config: DictConfig) -> list[Callback]:
# Plotting callbacks
trainer_callbacks.extend(instantiate(callback, config) for callback in config.diagnostics.plot.callbacks)

# Weight averaging callback (SWA or EMA)
trainer_callbacks.extend(_get_weight_averaging_callback(config))

# Extend with config enabled callbacks
trainer_callbacks.extend(_get_config_enabled_callbacks(config))

Expand Down
44 changes: 0 additions & 44 deletions training/src/anemoi/training/diagnostics/callbacks/optimiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from omegaconf import DictConfig
from pytorch_lightning.callbacks import LearningRateMonitor as pl_LearningRateMonitor
from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging as pl_StochasticWeightAveraging

LOGGER = logging.getLogger(__name__)

Expand All @@ -28,46 +27,3 @@ def __init__(
) -> None:
super().__init__(logging_interval=logging_interval, log_momentum=log_momentum)
self.config = config


class StochasticWeightAveraging(pl_StochasticWeightAveraging):
"""Provide StochasticWeightAveraging from pytorch_lightning as a callback."""

def __init__(
self,
config: DictConfig,
swa_lrs: int | None = None,
swa_epoch_start: int | None = None,
annealing_epochs: int | None = None,
annealing_strategy: str | None = None,
device: str | None = None,
**kwargs,
) -> None:
"""Stochastic Weight Averaging Callback.
Parameters
----------
config : OmegaConf
Full configuration object
swa_lrs : int, optional
Stochastic Weight Averaging Learning Rate, by default None
swa_epoch_start : int, optional
Epoch start, by default 0.75 * config.training.max_epochs
annealing_epochs : int, optional
Annealing Epoch, by default 0.25 * config.training.max_epochs
annealing_strategy : str, optional
Annealing Strategy, by default 'cos'
device : str, optional
Device to use, by default None
"""
kwargs["swa_lrs"] = swa_lrs or config.training.swa.lr
kwargs["swa_epoch_start"] = swa_epoch_start or min(
int(0.75 * config.training.max_epochs),
config.training.max_epochs - 1,
)
kwargs["annealing_epochs"] = annealing_epochs or max(int(0.25 * config.training.max_epochs), 1)
kwargs["annealing_strategy"] = annealing_strategy or "cos"
kwargs["device"] = device

super().__init__(**kwargs)
self.config = config
42 changes: 34 additions & 8 deletions training/src/anemoi/training/schemas/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,42 @@ class GradientClip(BaseModel):
"The gradient clipping algorithm to use"


class SWA(BaseModel):
"""Stochastic weight averaging configuration.
class WeightAveragingSchema(BaseModel):
"""Weight averaging configuration (SWA or EMA).
Uses Hydra instantiate pattern with _target_ to specify the callback class.
See https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
Example:
weight_averaging:
_target_: pytorch_lightning.callbacks.EMAWeightAveraging
decay: 0.999
update_starting_at_step: 1000
"""

enabled: bool = Field(example=False)
"Enable stochastic weight averaging."
lr: NonNegativeFloat = Field(example=1.0e-4)
"Learning rate for SWA."
target_: Literal[
"pytorch_lightning.callbacks.EMAWeightAveraging",
"pytorch_lightning.callbacks.StochasticWeightAveraging",
] = Field(..., alias="_target_")
"Target callback class for weight averaging. Either EMAWeightAveraging or StochasticWeightAveraging."
# EMA specific
decay: NonNegativeFloat | None = Field(default=None)
"EMA decay rate (only used for EMAWeightAveraging)."
update_every_n_steps: PositiveInt = Field(default=1)
"Update every n steps (only used for EMAWeightAveraging)."
update_starting_at_step: PositiveInt | None = Field(default=None)
"Update starting at step (only used for EMAWeightAveraging)."
update_starting_at_epoch: PositiveInt | None = Field(default=None)
"Update starting at epoch (only used for EMAWeightAveraging)."
# SWA specific
swa_lrs: NonNegativeFloat | list[NonNegativeFloat] = Field(default=0.8)
"SWA learning rate (only used for StochasticWeightAveraging)."
swa_epoch_start: NonNegativeFloat = Field(default=0.8)
"SWA epoch start (only used for StochasticWeightAveraging)."
annealing_epochs: PositiveInt = Field(default=10)
"Annealing epochs (only used for StochasticWeightAveraging)."
annealing_strategy: Literal["cos", "linear"] = Field(default="cos")
"Annealing strategy (only used for StochasticWeightAveraging)."


class Rollout(BaseModel):
Expand Down Expand Up @@ -334,8 +360,8 @@ class BaseTrainingSchema(BaseModel):
"Config for gradient clipping."
strategy: StrategySchemas
"Strategy to use."
swa: SWA = Field(default_factory=SWA)
"Config for stochastic weight averaging."
weight_averaging: WeightAveragingSchema | None = Field(default=None)
"Config for weight averaging (SWA or EMA). Set to null to disable."
training_loss: LossSchemas
"Training loss configuration."
loss_gradient_scaling: bool = False
Expand Down
47 changes: 47 additions & 0 deletions training/tests/unit/diagnostics/callbacks/test_weight_averaging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""Unit tests for weight averaging callback functionality."""

import omegaconf
import pytest
import yaml

from anemoi.training.diagnostics.callbacks import _get_weight_averaging_callback

default_config = """
training:
weight_averaging: null
"""


def test_weight_averaging_disabled_when_null() -> None:
"""Test that weight averaging is disabled when set to null."""
config = omegaconf.OmegaConf.create(yaml.safe_load(default_config))
callbacks = _get_weight_averaging_callback(config)
assert callbacks == []


def test_ema_callback_available() -> None:
"""Test that EMA weight averaging callback can be instantiated."""
pytest.importorskip("pytorch_lightning.callbacks", reason="EMA requires PyTorch Lightning 2.6+")

try:
from pytorch_lightning.callbacks import EMAWeightAveraging
except ImportError:
pytest.skip("EMAWeightAveraging not available in this PyTorch Lightning version")

config = omegaconf.OmegaConf.create(yaml.safe_load(default_config))
config.training.weight_averaging = {
"_target_": "pytorch_lightning.callbacks.EMAWeightAveraging",
"decay": 0.999,
}
callbacks = _get_weight_averaging_callback(config)
assert len(callbacks) == 1
assert isinstance(callbacks[0], EMAWeightAveraging)
Loading