-
Notifications
You must be signed in to change notification settings - Fork 78
Description
What happened?
I could not use the FilteringLossWrapper in a fairly easy setup.
What are the steps to reproduce the bug?
I have a fairly easy setup where the only loss I compute during training is an MSE loss comparing the tp output of my model against tp data from IMERG. FilteredLossWrapper is designed for these types of scenarii.
training_loss:
# loss class to initialise
_target_: anemoi.training.losses.combined.CombinedLoss
losses:
- _target_: anemoi.training.losses.filtering.FilteringLossWrapper
predicted_variables: ["tp"]
target_variables: ["imerg"]
scalers: ['pressure_level', 'general_variable', 'nan_mask_weights', 'node_weights']
loss:
_target_: anemoi.training.losses.MSELoss
scalers: ['pressure_level', 'general_variable', 'nan_mask_weights', 'node_weights']
ignore_nans: FalseFirst, FilteredLossWrapper breaks because it needs the set_data_indices here
method to be performed. We have a set_data_indices hook in the BaseLoss here
, but this hook is not initialized anywhere else, so when we use FilteredLossWrapper within CombinedLoss, the method is just empty.
We could define the method inside CombinedLoss, but this is to be discussed:
def set_data_indices(self, data_indices: IndexCollection) -> None:
for loss in self.losses:
if hasattr(loss, "set_data_indices"):
loss.set_data_indices(data_indices)Second, the mechanics behind the set_data_indices coded in the FilteredLossWrapper might not be correct. The method is currently defined like this:
def set_data_indices(self, data_indices: IndexCollection) -> None:
"""Hook to set the data indices for the loss."""
self.data_indices = data_indices
name_to_index = data_indices.data.output.name_to_index
model_output = data_indices.model.output
output_indices = model_output.full
if self.predicted_variables is not None:
predicted_indices = [model_output.name_to_index[name] for name in self.predicted_variables]
else:
predicted_indices = output_indices
if self.target_variables is not None:
target_indices = [name_to_index[name] for name in self.target_variables]
else:
target_indices = output_indices
assert len(predicted_indices) == len(
target_indices,
), "predicted and target variables must have the same length for loss computation"
self.predicted_indices = predicted_indices
self.target_indices = target_indicesTo understand what is going on it's helpful to print the target_indices, predicted indices and shapes of the prediction and the target in the very simply example I mentioned before:
(inside the forward method of FilteringLossWrapper)
print(pred.shape, target.shape, self.predicted_indices, self.target_indices)
torch.Size([1, 1, 542080, 103]) torch.Size([1, 1, 542080, 103]) [35] [115]The predicted indices seem to be correct, [35], which corresponds to tp, but the target indices are completely wrong. [115] is out of bounds in the target and this produces a very messy cuda error.
The issue seems to come from obtaining the indices from data_indices.data.output.name_to_index. In the data space, the index for imerg, the target variable, is [115], but the target tensor is already transformed with these indices well before the computation of losses is done, here.
y = batch[:, self.multi_step + rollout_step, ..., self.data_indices.data.output.full]
We cannot use self.data_indices.data.output on "y", the target, again since it's already been used. The problem gets a little bit more messy when you consider that in the case I described, imerg could be defined as a "target" variable.
data:
target:
- imerg
in which case
print(pred.shape, target.shape, self.predicted_indices, self.target_indices)
torch.Size([1, 1, 542080, 102]) torch.Size([1, 1, 542080, 103]) [35] [115]so we cannot use the model output space for target (102 vs 103). In conclusion, we need to work on the logic of set_data_indices.
Version
recent main version
Platform (OS and architecture)
linux
Relevant log output
Accompanying data
No response
Organisation
ecmwf
Metadata
Metadata
Labels
Type
Projects
Status