Skip to content

Issues with FilteringLossWrapper (CombinedLoss) #803

@gabrieloks

Description

@gabrieloks

What happened?

I could not use the FilteringLossWrapper in a fairly easy setup.

What are the steps to reproduce the bug?

I have a fairly easy setup where the only loss I compute during training is an MSE loss comparing the tp output of my model against tp data from IMERG. FilteredLossWrapper is designed for these types of scenarii.

  training_loss:
    # loss class to initialise

    _target_: anemoi.training.losses.combined.CombinedLoss
    losses:
      - _target_: anemoi.training.losses.filtering.FilteringLossWrapper
        predicted_variables: ["tp"]
        target_variables: ["imerg"]
        scalers: ['pressure_level', 'general_variable', 'nan_mask_weights', 'node_weights']
        loss:
          _target_: anemoi.training.losses.MSELoss
          scalers: ['pressure_level', 'general_variable', 'nan_mask_weights', 'node_weights']
          ignore_nans: False

First, FilteredLossWrapper breaks because it needs the set_data_indices here
method to be performed. We have a set_data_indices hook in the BaseLoss here
, but this hook is not initialized anywhere else, so when we use FilteredLossWrapper within CombinedLoss, the method is just empty.

We could define the method inside CombinedLoss, but this is to be discussed:

    def set_data_indices(self, data_indices: IndexCollection) -> None:

        for loss in self.losses:
            if hasattr(loss, "set_data_indices"):
                loss.set_data_indices(data_indices)

Second, the mechanics behind the set_data_indices coded in the FilteredLossWrapper might not be correct. The method is currently defined like this:

    def set_data_indices(self, data_indices: IndexCollection) -> None:
        """Hook to set the data indices for the loss."""
        self.data_indices = data_indices
        name_to_index = data_indices.data.output.name_to_index
        model_output = data_indices.model.output
        output_indices = model_output.full

        if self.predicted_variables is not None:
            predicted_indices = [model_output.name_to_index[name] for name in self.predicted_variables]
        else:
            predicted_indices = output_indices
        if self.target_variables is not None:
            target_indices = [name_to_index[name] for name in self.target_variables]
        else:
            target_indices = output_indices

        assert len(predicted_indices) == len(
            target_indices,
        ), "predicted and target variables must have the same length for loss computation"

        self.predicted_indices = predicted_indices
        self.target_indices = target_indices

To understand what is going on it's helpful to print the target_indices, predicted indices and shapes of the prediction and the target in the very simply example I mentioned before:

(inside the forward method of FilteringLossWrapper)

        
print(pred.shape, target.shape, self.predicted_indices, self.target_indices)

torch.Size([1, 1, 542080, 103]) torch.Size([1, 1, 542080, 103]) [35] [115]

The predicted indices seem to be correct, [35], which corresponds to tp, but the target indices are completely wrong. [115] is out of bounds in the target and this produces a very messy cuda error.

The issue seems to come from obtaining the indices from data_indices.data.output.name_to_index. In the data space, the index for imerg, the target variable, is [115], but the target tensor is already transformed with these indices well before the computation of losses is done, here.

            y = batch[:, self.multi_step + rollout_step, ..., self.data_indices.data.output.full]

We cannot use self.data_indices.data.output on "y", the target, again since it's already been used. The problem gets a little bit more messy when you consider that in the case I described, imerg could be defined as a "target" variable.

data:

  target:
  - imerg

in which case

        
print(pred.shape, target.shape, self.predicted_indices, self.target_indices)

torch.Size([1, 1, 542080, 102]) torch.Size([1, 1, 542080, 103]) [35] [115]

so we cannot use the model output space for target (102 vs 103). In conclusion, we need to work on the logic of set_data_indices.

Version

recent main version

Platform (OS and architecture)

linux

Relevant log output

Accompanying data

No response

Organisation

ecmwf

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

Status

To be triaged

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions