Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions models/src/anemoi/models/preprocessing/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,26 @@ def __init__(
self.register_buffer("_input_idx", data_indices.data.input.full, persistent=True)
self.register_buffer("_output_idx", self.data_indices.data.output.full, persistent=True)

# We need some special handling when target variables are defined.
model_output_names = set(self.data_indices.model.output.name_to_index.keys())
data_output_names = list(self.data_indices.data.output.name_to_index.keys())

# Create a boolean mask with same length as _output_idx
model_mask = torch.zeros(len(self._output_idx), dtype=torch.bool)

for i, var_name in enumerate(data_output_names):
if var_name in model_output_names:
# Get the index value for this variable in data.output
index_in_data_output = self.data_indices.data.output.name_to_index[var_name]
# Find which position in _output_idx has this index value
position_in_output_idx = (self._output_idx == index_in_data_output).nonzero(as_tuple=True)[0].item()
# Mark this position as True (keep it for model output)
model_mask[position_in_output_idx] = True

_model_output_idx = self._output_idx[model_mask]

self.register_buffer("_model_output_idx", _model_output_idx, persistent=True)

def _validate_normalization_inputs(self, name_to_index_training_input: dict, minimum, maximum, mean, stdev):
assert len(self.methods) == sum(len(v) for v in self.method_config.values()), (
f"Error parsing methods in InputNormalizer methods ({len(self.methods)}) "
Expand Down Expand Up @@ -201,6 +221,8 @@ def inverse_transform(
x.subtract_(self._norm_add[data_index]).div_(self._norm_mul[data_index])
elif x.shape[-1] == len(self._output_idx):
x.subtract_(self._norm_add[self._output_idx]).div_(self._norm_mul[self._output_idx])
elif x.shape[-1] == len(self._model_output_idx):
x.subtract_(self._norm_add[self._model_output_idx]).div_(self._norm_mul[self._model_output_idx])
else:
x.subtract_(self._norm_add).div_(self._norm_mul)
return x
128 changes: 128 additions & 0 deletions models/tests/preprocessing/test_preprocessor_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,131 @@ def test_normalize_remap(remap_normalizer) -> None:
x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]])
expected_output = torch.Tensor([[0.0, 2 / 11, 3.0, -0.5, 1 / 7], [5 / 9, 7 / 11, 8.0, 4.5, 0.5]])
assert torch.allclose(remap_normalizer.transform(x), expected_output)


# ============================================================================
# Tests for target-only variables (e.g., satellite observations like 'imerg')
# ============================================================================


@pytest.fixture()
def normalizer_with_target_only():
"""Create normalizer with target-only variable 'imerg'.

Setup:
- 4 regular variables (x, y, z, q) in both model.output and data.output
- 1 target-only variable (imerg) only in data.output
- data.output has 5 variables, model.output has 4
"""
config = DictConfig(
{
"diagnostics": {"log": {"code": {"level": "DEBUG"}}},
"data": {
"normalizer": {"default": "mean-std"},
"forcing": [],
"diagnostic": [],
"target": ["imerg"], # target-only variable
},
},
)
statistics = {
"mean": np.array([1.0, 2.0, 3.0, 4.0, 5.0]),
"stdev": np.array([0.5, 0.5, 0.5, 0.5, 0.5]),
"minimum": np.array([0.0, 0.0, 0.0, 0.0, 0.0]),
"maximum": np.array([10.0, 10.0, 10.0, 10.0, 10.0]),
}
# 5 variables: x, y, z, q are regular; imerg is target-only
name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "imerg": 4}
data_indices = IndexCollection(data_config=config.data, name_to_index=name_to_index)
return InputNormalizer(config=config.data.normalizer, data_indices=data_indices, statistics=statistics)


def test_model_output_idx_buffer_created(normalizer_with_target_only) -> None:
"""Test that _model_output_idx buffer is created and has correct size."""
assert hasattr(normalizer_with_target_only, "_model_output_idx")
# model.output has 4 variables (excludes imerg)
assert len(normalizer_with_target_only._model_output_idx) == 4
# data.output has 5 variables (includes imerg)
assert len(normalizer_with_target_only._output_idx) == 5


def test_model_output_idx_excludes_target_only(normalizer_with_target_only) -> None:
"""Test that _model_output_idx excludes target-only variables."""
model_idx = normalizer_with_target_only._model_output_idx
output_idx = normalizer_with_target_only._output_idx

# imerg is at index 4, should not be in model_output_idx
assert 4 not in model_idx.tolist()
# But should be in output_idx
assert 4 in output_idx.tolist()


def test_inverse_transform_model_output_size(normalizer_with_target_only) -> None:
"""Test inverse_transform with model output size (excludes target-only vars).

This tests the scenario where the model predicts 4 variables but
data.output has 5 variables (including target-only 'imerg').
"""
# Normalized tensor with 4 variables (model output, no imerg)
# Normalized values: (x - mean) / stdev
normalized = torch.Tensor(
[
[0.0, 0.0, 0.0, 0.0], # All at mean
[2.0, 2.0, 2.0, 2.0], # All at mean + stdev
]
)

# After inverse transform: x * stdev + mean = x * 0.5 + [1,2,3,4]
expected = torch.Tensor(
[
[1.0, 2.0, 3.0, 4.0], # means
[2.0, 3.0, 4.0, 5.0], # mean + stdev
]
)

result = normalizer_with_target_only.inverse_transform(normalized, in_place=False)
assert torch.allclose(result, expected)


def test_inverse_transform_data_output_size(normalizer_with_target_only) -> None:
"""Test inverse_transform with data output size (includes target-only vars).

This tests the scenario where we have the full data.output with 5 variables
including the target-only 'imerg'.
"""
# Normalized tensor with 5 variables (data output, includes imerg)
normalized = torch.Tensor(
[
[0.0, 0.0, 0.0, 0.0, 0.0],
[2.0, 2.0, 2.0, 2.0, 2.0],
]
)

# After inverse transform
expected = torch.Tensor(
[
[1.0, 2.0, 3.0, 4.0, 5.0], # means
[2.0, 3.0, 4.0, 5.0, 6.0], # mean + stdev
]
)

result = normalizer_with_target_only.inverse_transform(normalized, in_place=False)
assert torch.allclose(result, expected)


def test_inverse_transform_different_tensor_sizes(normalizer_with_target_only) -> None:
"""Test that inverse_transform handles both model and data output sizes correctly."""
# Model output size (4 variables)
model_output = torch.zeros(2, 4)
result_model = normalizer_with_target_only.inverse_transform(model_output, in_place=False)
assert result_model.shape[-1] == 4

# Data output size (5 variables)
data_output = torch.zeros(2, 5)
result_data = normalizer_with_target_only.inverse_transform(data_output, in_place=False)
assert result_data.shape[-1] == 5

# Values should be correct for each
# Model output gets indices [0,1,2,3], data output gets indices [0,1,2,3,4]
assert torch.allclose(result_model, torch.Tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]]))
assert torch.allclose(result_data, torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0, 5.0]]))
8 changes: 7 additions & 1 deletion training/src/anemoi/training/losses/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ def set_data_indices(self, data_indices: IndexCollection) -> BaseLoss:
"""Hook to set the data indices for the loss."""
self.data_indices = data_indices
name_to_index = data_indices.data.output.name_to_index
reindexed = {
k: i
for i, (k, _) in enumerate(
kv for kv in name_to_index.items() if kv[1] in set(data_indices.data.output.full.tolist())
)
}
model_output = data_indices.model.output
output_indices = model_output.full

Expand All @@ -67,7 +73,7 @@ def set_data_indices(self, data_indices: IndexCollection) -> BaseLoss:
predicted_indices = output_indices
self.predicted_variables = list(name_to_index.keys())
if self.target_variables is not None:
target_indices = [name_to_index[name] for name in self.target_variables]
target_indices = [reindexed[name] for name in self.target_variables]
else:
target_indices = output_indices
self.target_variables = list(name_to_index.keys())
Expand Down
16 changes: 12 additions & 4 deletions training/src/anemoi/training/losses/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def get_loss_function(
error_msg = f"Loss must be a subclass of 'BaseLoss', not {type(loss_function)}"
raise TypeError(error_msg)
_apply_scalers(loss_function, scalers_to_include, scalers, data_indices)
if data_indices is not None:
if data_indices is not None and (predicted_variables is not None or target_variables is not None):
loss_function = _wrap_loss_with_filtering(
loss_function,
predicted_variables,
Expand All @@ -124,13 +124,21 @@ def _wrap_loss_with_filtering(
subloss = loss_function.loss
if subloss.has_scaler_for_dim(TensorDim.VARIABLE) and predicted_variables is not None:
# filter scaler to only predicted variables
n_variables = len(data_indices.model.output.full)
# Map predicted variables to data output indices for scaler filtering
data_indices_for_vars = [
(data_indices.data.output.full == data_indices.data.output.name_to_index[var])
.nonzero(as_tuple=True)[0]
.item()
for var in loss_function.predicted_variables
]

for key, (dims, tens) in subloss.scaler.subset_by_dim(TensorDim.VARIABLE).tensors.items():
dims = (dims,) if isinstance(dims, int) else tuple(dims) if not isinstance(dims, tuple) else dims
var_dim_pos = list(dims).index(TensorDim.VARIABLE)
# Only filter if the scaler has the full number of variables
if tens.shape[var_dim_pos] == n_variables:
scaling = tens[loss_function.predicted_indices]

if tens.shape[var_dim_pos] in [len(data_indices.model.output.full), len(data_indices.data.output.full)]:
scaling = tens[data_indices_for_vars]
loss_function.loss.update_scaler(name=key, scaler=scaling, override=True)
return loss_function

Expand Down
19 changes: 11 additions & 8 deletions training/src/anemoi/training/losses/scaler_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,15 +335,18 @@ def update_scaler(self, name: str, scaler: torch.Tensor, *, override: bool = Fal
original_scaler = self._tensors.pop(name)
original_scaler_buffer = self._buffers.pop(name, None)

if not override:
if override:
# Directly replace without validation
self._tensors[name] = (dimension, None)
self.register_buffer(name, scaler, persistent=False)
else:
self.validate_scaler(dimension, scaler)

try:
self.add_scaler(dimension, scaler, name=name)
except ValueError:
self._tensors[name] = original_scaler
self.register_buffer(name, original_scaler_buffer, persistent=False)
raise
try:
self.add_scaler(dimension, scaler, name=name)
except ValueError:
self._tensors[name] = original_scaler
self.register_buffer(name, original_scaler_buffer, persistent=False)
raise

def add(self, new_scalers: dict[str, TENSOR_SPEC] | list[TENSOR_SPEC] | None = None, **kwargs) -> None:
"""Add multiple scalers to the existing scalers.
Expand Down
16 changes: 8 additions & 8 deletions training/src/anemoi/training/losses/scalers/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,13 @@ def get_scaling_values(self, **_kwargs) -> torch.Tensor:
variable_ref,
self.weights.get("default", 1.0),
)
if variable_ref != variable_name:
assert (
self.weights.get(
variable_name,
None,
)
is None
), f"Variable {variable_name} is not allowed to have a separate scaling besides {variable_ref}."
# if variable_ref != variable_name:
# assert (
# self.weights.get(
# variable_name,
# None,
# )
# is None
# ), f"Variable {variable_name} is not allowed to have a separate scaling besides {variable_ref}."

return variable_loss_scaling
2 changes: 1 addition & 1 deletion training/src/anemoi/training/losses/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def print_variable_scaling(loss: BaseLoss, data_indices: IndexCollection) -> dic

if isinstance(loss, FilteringLossWrapper):
subloss = loss.loss
subset_vars = zip(loss.predicted_indices, loss.predicted_variables, strict=False)
subset_vars = enumerate(loss.predicted_variables)
else:
subloss = loss
subset_vars = enumerate(data_indices.model.output.name_to_index.keys())
Expand Down
20 changes: 20 additions & 0 deletions training/src/anemoi/training/train/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,24 @@ def __init__(
self.grid_indices[dataset_name].setup(graph_data[dataset_name])
self.grid_dim = -2

# We need some special handling when target variables are defined.
model_output_names = set(self.data_indices[dataset_name].model.output.name_to_index.keys())
data_output_names = list(self.data_indices[dataset_name].data.output.name_to_index.keys())

# Create a boolean mask with same length as output size (nvars)
model_mask = torch.zeros(len(self.data_indices[dataset_name].data.output.full), dtype=torch.bool)
for i, var_name in enumerate(data_output_names):
if var_name in model_output_names:
# Get the index value for this variable in data.output
full_output_indices = self.data_indices[dataset_name].data.output.full
index = self.data_indices[dataset_name].data.output.name_to_index[var_name]
# Find which position in full_output_indices has this index value
position = (full_output_indices == index).nonzero(as_tuple=True)[0].item()
# Mark this position as True (keep it for model output)
model_mask[position] = True

self.mask_target = model_mask

# check sharding support
self.keep_batch_sharded = self.config.model.keep_batch_sharded
read_group_supports_sharding = reader_group_size == self.config.system.hardware.num_gpus_per_model
Expand Down Expand Up @@ -849,7 +867,9 @@ def calculate_val_metrics(

y_postprocessed = post_processor(y, in_place=False)
y_pred_postprocessed = post_processor(y_pred, in_place=False)
if y_postprocessed.shape[3] != y_pred_postprocessed.shape[3]: # maybe 3 shouldn't be hardcoded

y_postprocessed = y_postprocessed[..., self.mask_target]
Comment on lines +870 to +872
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely not, is it to index the variable dimension? If so we have an enum for this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Var dim could also just be -1, this one rarely changes

suffix = "" if step is None else f"/{step + 1}"
for metric_name, metric in metrics_dict.items():
if not isinstance(metric, BaseLoss):
Expand Down
7 changes: 6 additions & 1 deletion training/src/anemoi/training/train/tasks/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,12 @@ def _rollout_step(

y = {}
for dataset_name, dataset_batch in batch.items():
y[dataset_name] = dataset_batch[:, self.multi_step + rollout_step]
y[dataset_name] = dataset_batch[
:,
self.multi_step + rollout_step,
...,
self.data_indices[dataset_name].data.output.full,
]
# y includes the auxiliary variables, so we must leave those out when computing the loss
# Compute loss for each dataset and sum them up
loss, metrics_next, y_pred = checkpoint(
Expand Down
4 changes: 3 additions & 1 deletion training/tests/unit/losses/test_combined_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,4 +209,6 @@ def test_combined_loss_filtered_and_unfiltered_with_scalers() -> None:

# Should not raise IndexError - print_variable_scaling should work with filtered losses
scaling_values = print_variable_scaling(loss, data_indices)
assert "tp" in scaling_values # The filtered variable should be in the output
# scaling_values is a dict with loss names as keys
assert "MSELoss" in scaling_values
assert "tp" in scaling_values["MSELoss"]
Loading