Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
27d06e1
Make any loss able to filter based on variables
OpheliaMiralles Jan 14, 2026
3f3828f
Revert
OpheliaMiralles Jan 14, 2026
1360904
Make filtering and combination compatible
OpheliaMiralles Jan 15, 2026
3b8f669
Fix
OpheliaMiralles Jan 15, 2026
4eaa568
Fix print_variable_scaling
OpheliaMiralles Jan 15, 2026
8478612
Add test for non-filtered loss
OpheliaMiralles Jan 15, 2026
19381dc
Make integ tests pass
OpheliaMiralles Jan 16, 2026
dded475
feat: drop python 3.10 (#795)
floriankrb Jan 15, 2026
8051183
Merge branch 'main' into fix/filteringlosswrapper
OpheliaMiralles Jan 16, 2026
be451b9
Merge branch 'main' into fix/filteringlosswrapper
OpheliaMiralles Jan 20, 2026
f6488ca
loss stuff
OpheliaMiralles Jan 16, 2026
55087cf
Add small test for printing variables
OpheliaMiralles Jan 21, 2026
d4dfc32
Merge branch 'main' into fix/filteringlosswrapper
OpheliaMiralles Jan 21, 2026
36c5fb1
Merge branch 'main' into fix/filteringlosswrapper
OpheliaMiralles Jan 21, 2026
376dbd6
Fix combined loss
OpheliaMiralles Jan 21, 2026
d5059a1
Adapt scalers to FLW
OpheliaMiralles Jan 23, 2026
d76c2af
Merge remote-tracking branch 'origin/main' into fix/filteringlosswrapper
OpheliaMiralles Jan 23, 2026
f29f4f9
Merge multi dataset
OpheliaMiralles Jan 23, 2026
bb448cf
Merge branch 'main' into fix/filteringlosswrapper
OpheliaMiralles Jan 23, 2026
725140e
Merge branch 'main' into fix/filteringlosswrapper
OpheliaMiralles Jan 23, 2026
f724955
restrict scaler filtering (#826)
gabrieloks Jan 23, 2026
dc6454d
Merge branch 'main' into fix/filteringlosswrapper
OpheliaMiralles Jan 23, 2026
ebeea2a
Filter every loss, apply scalers only if pred var
OpheliaMiralles Jan 23, 2026
7f3c141
zip
OpheliaMiralles Jan 23, 2026
5047118
added second combined loss with filtering test
gabrieloks Jan 26, 2026
71e390c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 26, 2026
bde5b0e
added tests
gabrieloks Jan 29, 2026
b7c84b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2026
0c485a1
Merge branch 'main' into fix/filteringlosswrapper
OpheliaMiralles Feb 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions models/src/anemoi/models/data_indices/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,15 @@ class IndexCollection:
def __init__(self, data_config, name_to_index) -> None:
self.config = OmegaConf.to_container(data_config, resolve=True)
self.name_to_index = dict(sorted(name_to_index.items(), key=operator.itemgetter(1)))
self.forcing = [] if data_config.forcing is None else OmegaConf.to_container(data_config.forcing, resolve=True)
self.forcing = (
[]
if data_config.get("forcing", None) is None
else OmegaConf.to_container(data_config.forcing, resolve=True)
)
self.diagnostic = (
[] if data_config.diagnostic is None else OmegaConf.to_container(data_config.diagnostic, resolve=True)
[]
if data_config.get("diagnostic", None) is None
else OmegaConf.to_container(data_config.diagnostic, resolve=True)
)
self.target = (
[] if data_config.get("target", None) is None else OmegaConf.to_container(data_config.target, resolve=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -787,9 +787,9 @@ def compute_tendency(
Parameters
----------
x_t1 : torch.Tensor
The state at time t1 with full input variables.
The state at time t1.
x_t0 : torch.Tensor
The state at time t0 with prognostic input variables.
The state at time t0.
pre_processors_state : callable
Function to pre-process the state variables.
pre_processors_tendencies : callable
Expand Down
19 changes: 11 additions & 8 deletions training/src/anemoi/training/losses/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from torch import nn
from torch.distributed.distributed_c10d import ProcessGroup

from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.distributed.graph import reduce_tensor
from anemoi.training.losses.scaler_tensor import ScaleTensor
from anemoi.training.utils.enums import TensorDim
Expand All @@ -30,7 +29,10 @@ class BaseLoss(nn.Module, ABC):

scaler: ScaleTensor

def __init__(self, ignore_nans: bool = False) -> None:
def __init__(
self,
ignore_nans: bool = False,
) -> None:
"""Node- and feature_weighted Loss.

Exposes:
Expand Down Expand Up @@ -71,8 +73,9 @@ def add_scaler(self, dimension: int | tuple[int], scaler: torch.Tensor, *, name:
def update_scaler(self, name: str, scaler: torch.Tensor, *, override: bool = False) -> None:
self.scaler.update_scaler(name=name, scaler=scaler, override=override)

def set_data_indices(self, data_indices: IndexCollection) -> None:
"""Hook to set the data indices for the loss."""
@functools.wraps(ScaleTensor.has_scaler_for_dim)
def has_scaler_for_dim(self, dim: TensorDim) -> bool:
return self.scaler.has_scaler_for_dim(dim=dim)

def scale(
self,
Expand Down Expand Up @@ -112,7 +115,7 @@ def scale(
"Scaler tensor must be at least applied to the GRID dimension. "
"Please add a scaler here, use `UniformWeights` for simple uniform scaling.",
)
raise RuntimeError(error_msg)
LOGGER.warning(error_msg)

scale_tensor = self.scaler
if without_scalers is not None and len(without_scalers) > 0:
Expand Down Expand Up @@ -257,7 +260,7 @@ def forward(
without_scalers: list[str] | list[int] | None = None,
grid_shard_slice: slice | None = None,
group: ProcessGroup | None = None,
**kwargs, # noqa: ARG002
**kwargs,
) -> torch.Tensor:
"""Calculates the area-weighted scaled loss.

Expand Down Expand Up @@ -287,5 +290,5 @@ def forward(
is_sharded = grid_shard_slice is not None
out = self.calculate_difference(pred, target)
out = self.scale(out, scaler_indices, without_scalers=without_scalers, grid_shard_slice=grid_shard_slice)

return self.reduce(out, squash, group=group if is_sharded else None)
squash_mode = kwargs.get("squash_mode", "avg")
return self.reduce(out, squash, group=group if is_sharded else None, squash_mode=squash_mode)
7 changes: 6 additions & 1 deletion training/src/anemoi/training/losses/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,18 @@ def __init__(
if loss_weights is None:
loss_weights = (1.0,) * len(losses)

data_indices = kwargs.pop("data_indices", None)
scalers = kwargs.pop("scalers", {})

assert len(losses) == len(loss_weights), "Number of losses and weights must match"
assert len(losses) > 0, "At least one loss must be provided"

for i, loss in enumerate(losses):
if isinstance(loss, DictConfig | dict):
if "scalers" not in loss:
loss.update({"scalers": ["*"]})
self.losses.append(get_loss_function(loss, scalers=scalers, data_indices=data_indices, **dict(kwargs)))
self._loss_scaler_specification[i] = loss.pop("scalers", ["*"])
self.losses.append(get_loss_function(loss, scalers={}, **dict(kwargs)))
elif isinstance(loss, type):
self._loss_scaler_specification[i] = ["*"]
self.losses.append(loss(**kwargs))
Expand Down
29 changes: 9 additions & 20 deletions training/src/anemoi/training/losses/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
from typing import Any

import torch
from omegaconf import DictConfig

from anemoi.models.data_indices.collection import IndexCollection
from anemoi.training.losses.base import BaseLoss
from anemoi.training.losses.loss import get_loss_function


# TODO(Harrison): Consider renaming and reworking to a RemappingLossWrapper or similar, as it remaps variables
Expand All @@ -28,7 +26,6 @@ def __init__(
loss: dict[str, Any] | Callable | BaseLoss,
predicted_variables: list[str] | None = None,
target_variables: list[str] | None = None,
**kwargs,
):
"""Loss wrapper to filter variables to compute the loss on.

Expand All @@ -49,26 +46,15 @@ def __init__(
super().__init__()

self._loss_scaler_specification = {}
if isinstance(loss, str):
self._loss_scaler_specification = ["*"]
self.loss = get_loss_function(DictConfig({"_target_": loss}), scalers={}, **dict(kwargs))
elif isinstance(loss, DictConfig | dict):
self._loss_scaler_specification = loss.pop("scalers", ["*"])
self.loss = get_loss_function(loss, scalers={}, **dict(kwargs))
elif isinstance(loss, type):
self._loss_scaler_specification = ["*"]
self.loss = loss(**kwargs)
elif isinstance(loss, BaseLoss):
self._loss_scaler_specification = loss.scaler
self.loss = loss
else:
msg = f"Invalid loss type provided: {type(loss)}. Expected a str or dict or BaseLoss."
raise TypeError(msg)

assert isinstance(
loss,
BaseLoss,
), f"Invalid loss type provided: {type(loss)}. Expected a str or dict or BaseLoss."
self.loss = loss
self.predicted_variables = predicted_variables
self.target_variables = target_variables

def set_data_indices(self, data_indices: IndexCollection) -> None:
def set_data_indices(self, data_indices: IndexCollection) -> BaseLoss:
"""Hook to set the data indices for the loss."""
self.data_indices = data_indices
name_to_index = data_indices.data.output.name_to_index
Expand All @@ -79,17 +65,20 @@ def set_data_indices(self, data_indices: IndexCollection) -> None:
predicted_indices = [model_output.name_to_index[name] for name in self.predicted_variables]
else:
predicted_indices = output_indices
self.predicted_variables = list(name_to_index.keys())
if self.target_variables is not None:
target_indices = [name_to_index[name] for name in self.target_variables]
else:
target_indices = output_indices
self.target_variables = list(name_to_index.keys())

assert len(predicted_indices) == len(
target_indices,
), "predicted and target variables must have the same length for loss computation"

self.predicted_indices = predicted_indices
self.target_indices = target_indices
return self

def forward(self, pred: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor:
squash = kwargs.get("squash", True)
Expand Down
57 changes: 51 additions & 6 deletions training/src/anemoi/training/losses/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
from omegaconf import DictConfig
from omegaconf import OmegaConf

from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.data_indices.tensor import OutputTensorIndex
from anemoi.training.losses.base import BaseLoss
from anemoi.training.losses.filtering import FilteringLossWrapper
from anemoi.training.losses.scaler_tensor import TENSOR_SPEC
from anemoi.training.utils.enums import TensorDim
from anemoi.training.utils.variables_metadata import ExtractVariableGroupAndLevel

METRIC_RANGE_DTYPE = dict[str, list[int]]
Expand All @@ -30,7 +33,7 @@
def get_loss_function(
config: DictConfig,
scalers: dict[str, TENSOR_SPEC] | None = None,
data_indices: dict | None = None,
data_indices: IndexCollection | None = None,
**kwargs,
) -> BaseLoss:
"""Get loss functions from config.
Expand All @@ -46,7 +49,7 @@ def get_loss_function(
If a scaler is to be added to the loss, ensure it is in `scalers` in the loss config.
For instance, if `scalers: ['variable']` is set in the config, and `variable` in `scalers`
`variable` will be added to the scaler of the loss function.
data_indices : dict, optional
data_indices : IndexCollection, optional
Indices of the training data
kwargs : Any
Additional arguments to pass to the loss function
Expand All @@ -65,6 +68,8 @@ def get_loss_function(
"""
loss_config = OmegaConf.to_container(config, resolve=True)
scalers_to_include = loss_config.pop("scalers", [])
predicted_variables = loss_config.pop("predicted_variables", None)
target_variables = loss_config.pop("target_variables", None)

if "_target_" in loss_config and loss_config["_target_"] in NESTED_LOSSES:
per_scale_loss_config = loss_config.pop("per_scale_loss")
Expand All @@ -77,21 +82,64 @@ def get_loss_function(
if "*" in scalers_to_include:
scalers_to_include = [s for s in list(scalers.keys()) if f"!{s}" not in scalers_to_include]

if "CombinedLoss" in loss_config.get("_target_", ""):
if data_indices is not None:
loss_config.update(
{"data_indices": data_indices},
)
data_indices = None # for combined loss we want the individual losses to handle data indices
loss_config.update(
{"scalers": scalers},
)
scalers_to_include = []
scalers = {} # for combined loss we want the individual losses to handle scalers
loss_function = instantiate(loss_config, **kwargs, _recursive_=False)

if not isinstance(loss_function, BaseLoss):
error_msg = f"Loss must be a subclass of 'BaseLoss', not {type(loss_function)}"
raise TypeError(error_msg)
_apply_scalers(loss_function, scalers_to_include, scalers, data_indices)
if data_indices is not None:
loss_function = _wrap_loss_with_filtering(
loss_function,
predicted_variables,
target_variables,
data_indices,
)
return loss_function


def _wrap_loss_with_filtering(
loss_function: BaseLoss,
predicted_variables: list[str] | None,
target_variables: list[str] | None,
data_indices: IndexCollection,
) -> BaseLoss:
"""Wrap loss function with FilteringLossWrapper if predicted or target variables are specified."""
loss_function = FilteringLossWrapper(
loss=loss_function,
predicted_variables=predicted_variables,
target_variables=target_variables,
).set_data_indices(data_indices)
subloss = loss_function.loss
if subloss.has_scaler_for_dim(TensorDim.VARIABLE) and predicted_variables is not None:
# filter scaler to only predicted variables
n_variables = len(data_indices.model.output.full)
for key, (dims, tens) in subloss.scaler.subset_by_dim(TensorDim.VARIABLE).tensors.items():
dims = (dims,) if isinstance(dims, int) else tuple(dims) if not isinstance(dims, tuple) else dims
var_dim_pos = list(dims).index(TensorDim.VARIABLE)
# Only filter if the scaler has the full number of variables
if tens.shape[var_dim_pos] == n_variables:
scaling = tens[loss_function.predicted_indices]
loss_function.loss.update_scaler(name=key, scaler=scaling, override=True)
return loss_function


def _apply_scalers(
loss_function: BaseLoss,
scalers_to_include: list,
scalers: dict[str, TENSOR_SPEC] | None,
data_indices: dict | None,
data_indices: IndexCollection | None,
) -> None:
"""Attach scalers to a loss function and set data indices if needed."""
for key in scalers_to_include:
Expand All @@ -107,9 +155,6 @@ def _apply_scalers(
LOGGER.info("Parameter %s is being scaled by statistic_tendencies by %.2f", var_key, scaling)
loss_function.add_scaler(*scalers[key], name=key)

if hasattr(loss_function, "set_data_indices"):
loss_function.set_data_indices(data_indices)


def get_metric_ranges(
extract_variable_group_and_level: ExtractVariableGroupAndLevel,
Expand Down
4 changes: 4 additions & 0 deletions training/src/anemoi/training/losses/scaler_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ def get_dim_shape(dimension: int) -> int:

return Shape(get_dim_shape)

def has_scaler_for_dim(self, dim: TensorDim) -> bool:
"""Check if there is a scaler for the given dimension."""
return len(self.subset_by_dim(dim.value).tensors) > 0

def validate_scaler(self, dimension: int | tuple[int], scaler: torch.Tensor) -> None:
"""Check if the scaler is compatible with the given dimension.

Expand Down
13 changes: 4 additions & 9 deletions training/src/anemoi/training/losses/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ def __init__(
self,
transform: Literal["fft2d", "sht"] = "fft2d",
*,
x_dim: int | None = None,
y_dim: int | None = None,
ignore_nans: bool = False,
scalers: list | None = None,
**kwargs,
Expand All @@ -93,17 +91,14 @@ def __init__(
# Backwards-compatibility: older configs pass scalers to the loss ctor.
_ = scalers # intentionally unused
kwargs.pop("scalers", None)

if x_dim is not None:
kwargs.setdefault("x_dim", x_dim)
if y_dim is not None:
kwargs.setdefault("y_dim", y_dim)
x_dim = kwargs.get("x_dim")
y_dim = kwargs.get("y_dim")

if transform == "fft2d":
self.transform = FFT2D(**kwargs)
# expose dims on the loss (legacy API + tests)
self.x_dim = int(kwargs.get("x_dim"))
self.y_dim = int(kwargs.get("y_dim"))
self.x_dim = int(x_dim) if x_dim is not None else None
self.y_dim = int(y_dim) if y_dim is not None else None
elif transform == "sht":
self.transform = SHT()
else:
Expand Down
21 changes: 16 additions & 5 deletions training/src/anemoi/training/losses/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from typing import TYPE_CHECKING

from anemoi.training.losses.combined import CombinedLoss
from anemoi.training.losses.filtering import FilteringLossWrapper
from anemoi.training.losses.multiscale import MultiscaleLossWrapper
from anemoi.training.utils.enums import TensorDim

if TYPE_CHECKING:
Expand Down Expand Up @@ -55,18 +57,27 @@ def print_variable_scaling(loss: BaseLoss, data_indices: IndexCollection) -> dic
variable_scaling[sub_loss.__class__.__name__] = print_variable_scaling(sub_loss, data_indices)
return variable_scaling

variable_scaling = loss.scaler.subset_by_dim(TensorDim.VARIABLE.value).get_scaler(len(TensorDim)).reshape(-1)
if isinstance(loss, MultiscaleLossWrapper):
return print_variable_scaling(loss.loss, data_indices)

if isinstance(loss, FilteringLossWrapper):
subloss = loss.loss
subset_vars = zip(loss.predicted_indices, loss.predicted_variables, strict=False)
else:
subloss = loss
subset_vars = enumerate(data_indices.model.output.name_to_index.keys())

variable_scaling = subloss.scaler.subset_by_dim(TensorDim.VARIABLE.value).get_scaler(len(TensorDim)).reshape(-1)
log_text = f"Final Variable Scaling in {loss.__class__.__name__}: "
scaling_values, scaling_sum = {}, 0.0

for idx, name in enumerate(data_indices.model.output.name_to_index.keys()):
for idx, name in subset_vars:
value = float(variable_scaling[idx])
log_text += f"{name}: {value:.4g}, "
scaling_values[name] = value
scaling_sum += value

log_text += f"Total scaling sum: {scaling_sum:.4g}, "
scaling_values["total_sum"] = scaling_sum
log_text += f"Total scaling sum: {scaling_sum:.4g}, "
scaling_values["total_sum"] = scaling_sum
LOGGER.debug(log_text)

return scaling_values
Loading
Loading