Skip to content

Commit 89d10ca

Browse files
committed
Configurable weight initialization with lightning>=2.6.
1 parent e7453ae commit 89d10ca

File tree

11 files changed

+212
-84
lines changed

11 files changed

+212
-84
lines changed

training/docs/user-guide/training.rst

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,3 +570,50 @@ frozen and only the encoder and decoder will be trained:
570570
Freezing can be particularly beneficial in scenarios such as fine-tuning
571571
when only specific components (e.g., the encoder, the decoder) need to
572572
adapt to a new task while keeping others (e.g., the processor) fixed.
573+
574+
******************
575+
Weight Averaging
576+
******************
577+
578+
Weight averaging is a technique to improve model generalization by
579+
averaging model weights during training. Anemoi Training supports weight
580+
averaging methods through PyTorch Lightning callbacks:
581+
582+
- **Exponential Moving Average (EMA)**: Maintains an exponential moving
583+
average of model weights, which can lead to smoother convergence
584+
and better generalization.
585+
586+
.. code:: yaml
587+
588+
weight_averaging:
589+
_target_: pytorch_lightning.callbacks.EMAWeightAveraging
590+
decay: 0.999
591+
592+
The ``decay`` parameter (typically between 0.99 and 0.9999)
593+
controls the smoothing factor. Higher values give more weight to
594+
historical weights, resulting in a more stable average. By
595+
default, the decay is set to 0.999.
596+
597+
- **Stochastic Weight Averaging (SWA)**: Averages weights from multiple
598+
points along the training trajectory, typically resulting in wider
599+
optima and improved generalization.
600+
601+
.. code:: yaml
602+
603+
weight_averaging:
604+
_target_: pytorch_lightning.callbacks.StochasticWeightAveraging
605+
swa_lrs: 1.e-4
606+
607+
The ``swa_lrs`` parameter specifies the learning rate to use
608+
during the SWA phase. By default, the learning rate is set to
609+
1e-4. Additional parameters can be configured as described in the
610+
[PyTorch Lightning
611+
documentation](https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.StochasticWeightAveraging.html#lightning.pytorch.callbacks.StochasticWeightAveraging)
612+
613+
By default, weight averaging is disabled. To explicitly disable it or to
614+
override a parent configuration, set ``weight_averaging`` to null.
615+
616+
.. note::
617+
618+
Weight averaging is only supported in PyTorch Lightning 2.6 and later
619+
versions.

training/src/anemoi/training/config/training/default.yaml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,14 @@ gradient_clip:
3131
val: 32.
3232
algorithm: value
3333

34-
# stochastic weight averaging
35-
# https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
36-
swa:
37-
enabled: False
38-
lr: 1.e-4
34+
# weight averaging only supported in Lightning 2.6+
35+
weight_averaging: null
36+
# For EMA:
37+
# _target_: pytorch_lightning.callbacks.EMAWeightAveraging
38+
# decay: 0.999
39+
# For SWA:
40+
# _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
41+
# swa_lrs: 1.e-4
3942

4043
# =====================================================================
4144
# Optimizer configuration

training/src/anemoi/training/config/training/diffusion.yaml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,14 @@ gradient_clip:
3131
val: 1.
3232
algorithm: norm
3333

34-
# stochastic weight averaging
35-
# https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
36-
swa:
37-
enabled: False
38-
lr: 1.e-4
34+
# weight averaging only supported in Lightning 2.6+
35+
weight_averaging: null
36+
# For EMA:
37+
# _target_: pytorch_lightning.callbacks.EMAWeightAveraging
38+
# decay: 0.999
39+
# For SWA:
40+
# _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
41+
# swa_lrs: 1.e-4
3942

4043
# Optimizer settings
4144
optimizer:

training/src/anemoi/training/config/training/ensemble.yaml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,14 @@ gradient_clip:
3131
val: 32.
3232
algorithm: value
3333

34-
# stochastic weight averaging
35-
# https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
36-
swa:
37-
enabled: False
38-
lr: 1.e-4
34+
# weight averaging only supported in Lightning 2.6+
35+
weight_averaging: null
36+
# For EMA:
37+
# _target_: pytorch_lightning.callbacks.EMAWeightAveraging
38+
# decay: 0.999
39+
# For SWA:
40+
# _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
41+
# swa_lrs: 1.e-4
3942

4043
# Optimizer settings
4144
optimizer:

training/src/anemoi/training/config/training/interpolator.yaml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,14 @@ gradient_clip:
3131
val: 32.
3232
algorithm: value
3333

34-
# stochastic weight averaging
35-
# https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
36-
swa:
37-
enabled: False
38-
lr: 1.e-4
34+
# weight averaging only supported in Lightning 2.6+
35+
weight_averaging: null
36+
# For EMA:
37+
# _target_: pytorch_lightning.callbacks.EMAWeightAveraging
38+
# decay: 0.999
39+
# For SWA:
40+
# _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
41+
# swa_lrs: 1.e-4
3942

4043
# Optimizer settings
4144
optimizer:

training/src/anemoi/training/config/training/lam.yaml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,14 @@ gradient_clip:
3131
val: 32.
3232
algorithm: value
3333

34-
# stochastic weight averaging
35-
# https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
36-
swa:
37-
enabled: False
38-
lr: 1.e-4
34+
# weight averaging only supported in Lightning 2.6+
35+
weight_averaging: null
36+
# For EMA:
37+
# _target_: pytorch_lightning.callbacks.EMAWeightAveraging
38+
# decay: 0.999
39+
# For SWA:
40+
# _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
41+
# swa_lrs: 1.e-4
3942

4043
# Optimizer settings
4144
optimizer:

training/src/anemoi/training/config/training/stretched.yaml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,14 @@ gradient_clip:
3131
val: 32.
3232
algorithm: value
3333

34-
# stochastic weight averaging
35-
# https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
36-
swa:
37-
enabled: False
38-
lr: 1.e-4
34+
# weight averaging only supported in Lightning 2.6+
35+
weight_averaging: null
36+
# For EMA:
37+
# _target_: pytorch_lightning.callbacks.EMAWeightAveraging
38+
# decay: 0.999
39+
# For SWA:
40+
# _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
41+
# swa_lrs: 1.e-4
3942

4043
# Optimizer settings
4144
optimizer:

training/src/anemoi/training/diagnostics/callbacks/__init__.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
from anemoi.training.diagnostics.callbacks.checkpoint import AnemoiCheckpoint
2525
from anemoi.training.diagnostics.callbacks.optimiser import LearningRateMonitor
26-
from anemoi.training.diagnostics.callbacks.optimiser import StochasticWeightAveraging
2726
from anemoi.training.diagnostics.callbacks.provenance import ParentUUIDCallback
2827
from anemoi.training.diagnostics.callbacks.sanity import CheckVariableOrder
2928
from anemoi.training.schemas.base_schema import BaseSchema
@@ -49,7 +48,6 @@ def nestedget(config: DictConfig, key: str, default: Any) -> Any:
4948
# Callbacks to add according to flags in the config
5049
# Can be function to check status from config
5150
CONFIG_ENABLED_CALLBACKS: list[tuple[list[str] | str | Callable[[DictConfig], bool], type[Callback]]] = [
52-
("training.swa.enabled", StochasticWeightAveraging),
5351
(
5452
lambda config: nestedget(config, "diagnostics.log.wandb.enabled", False)
5553
or nestedget(config, "diagnostics.log.mlflow.enabled", False),
@@ -58,6 +56,39 @@ def nestedget(config: DictConfig, key: str, default: Any) -> Any:
5856
]
5957

6058

59+
def _get_weight_averaging_callback(config: DictConfig) -> list[Callback]:
60+
"""Get weight averaging callback.
61+
62+
Supported are ExponentialMovingAverage and StochasticWeightAveraging.
63+
64+
Example config:
65+
weight_averaging:
66+
_target_: anemoi.training.diagnostics.callbacks.optimiser.ExponentialMovingAverage
67+
decay: 0.999
68+
69+
Parameters
70+
----------
71+
config : DictConfig
72+
Job configuration
73+
74+
Returns
75+
-------
76+
list[Callback]
77+
List containing the weight averaging callback, or empty list if not configured.
78+
"""
79+
weight_averaging_config = nestedget(config, "training.weight_averaging", None)
80+
81+
if weight_averaging_config is not None:
82+
try:
83+
weight_averaging = instantiate(weight_averaging_config)
84+
LOGGER.info("Using weight averaging: %s", type(weight_averaging))
85+
except InstantiationException:
86+
LOGGER.warning("Failed to instantiate weight averaging callback from config: %s", weight_averaging_config)
87+
else:
88+
return [weight_averaging]
89+
return []
90+
91+
6192
def _get_checkpoint_callback(config: BaseSchema) -> list[AnemoiCheckpoint]:
6293
"""Get checkpointing callbacks."""
6394
if not config.diagnostics.enable_checkpointing:
@@ -226,6 +257,9 @@ def get_callbacks(config: DictConfig) -> list[Callback]:
226257
# Plotting callbacks
227258
trainer_callbacks.extend(instantiate(callback, config) for callback in config.diagnostics.plot.callbacks)
228259

260+
# Weight averaging callback (SWA or EMA)
261+
trainer_callbacks.extend(_get_weight_averaging_callback(config))
262+
229263
# Extend with config enabled callbacks
230264
trainer_callbacks.extend(_get_config_enabled_callbacks(config))
231265

training/src/anemoi/training/diagnostics/callbacks/optimiser.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from omegaconf import DictConfig
1414
from pytorch_lightning.callbacks import LearningRateMonitor as pl_LearningRateMonitor
15-
from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging as pl_StochasticWeightAveraging
1615

1716
LOGGER = logging.getLogger(__name__)
1817

@@ -28,46 +27,3 @@ def __init__(
2827
) -> None:
2928
super().__init__(logging_interval=logging_interval, log_momentum=log_momentum)
3029
self.config = config
31-
32-
33-
class StochasticWeightAveraging(pl_StochasticWeightAveraging):
34-
"""Provide StochasticWeightAveraging from pytorch_lightning as a callback."""
35-
36-
def __init__(
37-
self,
38-
config: DictConfig,
39-
swa_lrs: int | None = None,
40-
swa_epoch_start: int | None = None,
41-
annealing_epochs: int | None = None,
42-
annealing_strategy: str | None = None,
43-
device: str | None = None,
44-
**kwargs,
45-
) -> None:
46-
"""Stochastic Weight Averaging Callback.
47-
48-
Parameters
49-
----------
50-
config : OmegaConf
51-
Full configuration object
52-
swa_lrs : int, optional
53-
Stochastic Weight Averaging Learning Rate, by default None
54-
swa_epoch_start : int, optional
55-
Epoch start, by default 0.75 * config.training.max_epochs
56-
annealing_epochs : int, optional
57-
Annealing Epoch, by default 0.25 * config.training.max_epochs
58-
annealing_strategy : str, optional
59-
Annealing Strategy, by default 'cos'
60-
device : str, optional
61-
Device to use, by default None
62-
"""
63-
kwargs["swa_lrs"] = swa_lrs or config.training.swa.lr
64-
kwargs["swa_epoch_start"] = swa_epoch_start or min(
65-
int(0.75 * config.training.max_epochs),
66-
config.training.max_epochs - 1,
67-
)
68-
kwargs["annealing_epochs"] = annealing_epochs or max(int(0.25 * config.training.max_epochs), 1)
69-
kwargs["annealing_strategy"] = annealing_strategy or "cos"
70-
kwargs["device"] = device
71-
72-
super().__init__(**kwargs)
73-
self.config = config

training/src/anemoi/training/schemas/training.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,42 @@ class GradientClip(BaseModel):
3939
"The gradient clipping algorithm to use"
4040

4141

42-
class SWA(BaseModel):
43-
"""Stochastic weight averaging configuration.
42+
class WeightAveragingSchema(BaseModel):
43+
"""Weight averaging configuration (SWA or EMA).
4444
45+
Uses Hydra instantiate pattern with _target_ to specify the callback class.
4546
See https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
47+
48+
Example:
49+
weight_averaging:
50+
_target_: pytorch_lightning.callbacks.EMAWeightAveraging
51+
decay: 0.999
52+
update_starting_at_step: 1000
4653
"""
4754

48-
enabled: bool = Field(example=False)
49-
"Enable stochastic weight averaging."
50-
lr: NonNegativeFloat = Field(example=1.0e-4)
51-
"Learning rate for SWA."
55+
target_: Literal[
56+
"pytorch_lightning.callbacks.EMAWeightAveraging",
57+
"pytorch_lightning.callbacks.StochasticWeightAveraging",
58+
] = Field(..., alias="_target_")
59+
"Target callback class for weight averaging. Either EMAWeightAveraging or StochasticWeightAveraging."
60+
# EMA specific
61+
decay: NonNegativeFloat | None = Field(default=None)
62+
"EMA decay rate (only used for EMAWeightAveraging)."
63+
update_every_n_steps: PositiveInt = Field(default=1)
64+
"Update every n steps (only used for EMAWeightAveraging)."
65+
update_starting_at_step: PositiveInt | None = Field(default=None)
66+
"Update starting at step (only used for EMAWeightAveraging)."
67+
update_starting_at_epoch: PositiveInt | None = Field(default=None)
68+
"Update starting at epoch (only used for EMAWeightAveraging)."
69+
# SWA specific
70+
swa_lrs: NonNegativeFloat | list[NonNegativeFloat] = Field(default=0.8)
71+
"SWA learning rate (only used for StochasticWeightAveraging)."
72+
swa_epoch_start: NonNegativeFloat = Field(default=0.8)
73+
"SWA epoch start (only used for StochasticWeightAveraging)."
74+
annealing_epochs: PositiveInt = Field(default=10)
75+
"Annealing epochs (only used for StochasticWeightAveraging)."
76+
annealing_strategy: Literal["cos", "linear"] = Field(default="cos")
77+
"Annealing strategy (only used for StochasticWeightAveraging)."
5278

5379

5480
class Rollout(BaseModel):
@@ -334,8 +360,8 @@ class BaseTrainingSchema(BaseModel):
334360
"Config for gradient clipping."
335361
strategy: StrategySchemas
336362
"Strategy to use."
337-
swa: SWA = Field(default_factory=SWA)
338-
"Config for stochastic weight averaging."
363+
weight_averaging: WeightAveragingSchema | None = Field(default=None)
364+
"Config for weight averaging (SWA or EMA). Set to null to disable."
339365
training_loss: LossSchemas
340366
"Training loss configuration."
341367
loss_gradient_scaling: bool = False

0 commit comments

Comments
 (0)