Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update sanity checks for training data consistency #120

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
Open
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@ Keep it human-readable, your future self will thank you!
- Sub-hour datasets [#63](https://github.com/ecmwf/anemoi-training/pull/63)
- Add synchronisation workflow [#92](https://github.com/ecmwf/anemoi-training/pull/92)
- Feat: Anemoi Profiler compatible with mlflow and using Pytorch (Kineto) Profiler for memory report [38](https://github.com/ecmwf/anemoi-training/pull/38/)

- Added a check for the variable sorting on pre-trained/finetuned models [#120](https://github.com/ecmwf/anemoi-training/pull/120)

### Changed
- Modified training configuration to support max_steps and tied lr iterations to max_steps by default [#67](https://github.com/ecmwf/anemoi-training/pull/67)

### Removed
- Removed the resolution config entry [#120](https://github.com/ecmwf/anemoi-training/pull/120)

## [0.2.2 - Maintenance: pin python <3.13](https://github.com/ecmwf/anemoi-training/compare/0.2.1...0.2.2) - 2024-10-28


Expand Down
3 changes: 1 addition & 2 deletions src/anemoi/training/config/data/zarr.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
format: zarr
resolution: o96
# Time frequency requested from dataset
frequency: 6h
# Time step of model (must be multiple of frequency)
Expand Down Expand Up @@ -82,5 +81,5 @@ processors:
# _convert_: all
# config: ${data.remapper}

# Values set in the code
# Values set in the code
num_features: null # number of features in the forecast state
9 changes: 1 addition & 8 deletions src/anemoi/training/data/datamodule.py
JesperDramsch marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,6 @@ def __init__(self, config: DictConfig) -> None:
if not self.config.dataloader.get("pin_memory", True):
LOGGER.info("Data loader memory pinning disabled.")

def _check_resolution(self, resolution: str) -> None:
assert (
self.config.data.resolution.lower() == resolution.lower()
), f"Network resolution {self.config.data.resolution=} does not match dataset resolution {resolution=}"

@cached_property
def statistics(self) -> dict:
return self.ds_train.statistics
Expand Down Expand Up @@ -178,7 +173,7 @@ def _get_dataset(
label: str = "generic",
) -> NativeGridDataset:
r = max(rollout, self.rollout)
data = NativeGridDataset(
return NativeGridDataset(
data_reader=data_reader,
rollout=r,
multistep=self.config.training.multistep_input,
Expand All @@ -189,8 +184,6 @@ def _get_dataset(
shuffle=shuffle,
label=label,
)
self._check_resolution(data.resolution)
return data

def _get_dataloader(self, ds: NativeGridDataset, stage: str) -> DataLoader:
assert stage in {"training", "validation", "test"}
Expand Down
4 changes: 4 additions & 0 deletions src/anemoi/training/diagnostics/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from anemoi.training.diagnostics.callbacks.optimiser import LearningRateMonitor
from anemoi.training.diagnostics.callbacks.optimiser import StochasticWeightAveraging
from anemoi.training.diagnostics.callbacks.provenance import ParentUUIDCallback
from anemoi.training.diagnostics.callbacks.sanity import CheckVariableOrder

if TYPE_CHECKING:
from pytorch_lightning.callbacks import Callback
Expand Down Expand Up @@ -204,6 +205,9 @@ def get_callbacks(config: DictConfig) -> list[Callback]: # noqa: C901
# Parent UUID callback
trainer_callbacks.append(ParentUUIDCallback(config))

# Check variable order callback
trainer_callbacks.append(CheckVariableOrder())

return trainer_callbacks


Expand Down
169 changes: 169 additions & 0 deletions src/anemoi/training/diagnostics/callbacks/sanity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import logging

import pytorch_lightning as pl

LOGGER = logging.getLogger(__name__)


class CheckVariableOrder(pl.callbacks.Callback):
"""Check the order of the variables in a pre-trained / fine-tuning model."""

def __init__(self) -> None:
super().__init__()
self._model_name_to_index = None

def on_load_checkpoint(self, trainer: pl.Trainer, _: pl.LightningModule, checkpoint: dict) -> None:
"""Cache the model mapping from the checkpoint.

Parameters
----------
trainer : pl.Trainer
Pytorch Lightning trainer
_ : pl.LightningModule
Not used
checkpoint : dict
Pytorch Lightning checkpoint
"""
self._model_name_to_index = checkpoint["hyper_parameters"]["data_indices"].name_to_index
data_name_to_index = trainer.datamodule.data_indices.name_to_index

self._compare_variables(data_name_to_index)

def on_sanity_check_start(self, trainer: pl.Trainer, _: pl.LightningModule) -> None:
"""Cache the model mapping from the datamodule if not loaded from checkpoint.

Parameters
----------
trainer : pl.Trainer
Pytorch Lightning trainer
_ : pl.LightningModule
Not used
"""
if self._model_name_to_index is None:
self._model_name_to_index = trainer.datamodule.data_indices.name_to_index

def on_train_epoch_start(self, trainer: pl.Trainer, _: pl.LightningModule) -> None:
"""Check the order of the variables in the model from checkpoint and the training data.

Parameters
----------
trainer : pl.Trainer
Pytorch Lightning trainer
_ : pl.LightningModule
Not used
"""
data_name_to_index = trainer.datamodule.ds_train.name_to_index

self._compare_variables(data_name_to_index)

def on_validation_epoch_start(self, trainer: pl.Trainer, _: pl.LightningModule) -> None:
"""Check the order of the variables in the model from checkpoint and the validation data.

Parameters
----------
trainer : pl.Trainer
Pytorch Lightning trainer
_ : pl.LightningModule
Not used
"""
data_name_to_index = trainer.datamodule.ds_valid.name_to_index

self._compare_variables(data_name_to_index)

def on_test_epoch_start(self, trainer: pl.Trainer, _: pl.LightningModule) -> None:
"""Check the order of the variables in the model from checkpoint and the test data.

Parameters
----------
trainer : pl.Trainer
Pytorch Lightning trainer
_ : pl.LightningModule
Not used
"""
data_name_to_index = trainer.datamodule.ds_test.name_to_index

self._compare_variables(data_name_to_index)

def _compare_variables(self, data_name_to_index: dict[str, int]) -> None:
"""Compare the order of the variables in the model from checkpoint and the data.

Parameters
----------
data_name_to_index : dict[str, int]
The dictionary mapping variable names to their indices in the data.

Raises
------
ValueError
If the variable order in the model and data is verifiably different.
"""
if self._model_name_to_index is None:
LOGGER.info("No variable order to compare. Skipping variable order check.")
return

if self._model_name_to_index == data_name_to_index:
LOGGER.info("The order of the variables in the model matches the order in the data.")
LOGGER.debug("%s, %s", self._model_name_to_index, data_name_to_index)
return

keys1 = set(self._model_name_to_index.keys())
keys2 = set(data_name_to_index.keys())

error_msg = ""

# Find keys unique to each dictionary
only_in_model = {key: self._model_name_to_index[key] for key in (keys1 - keys2)}
only_in_data = {key: data_name_to_index[key] for key in (keys2 - keys1)}

# Find common keys
common_keys = keys1 & keys2

# Compare values for common keys
different_values = {
k: (self._model_name_to_index[k], data_name_to_index[k])
for k in common_keys
if self._model_name_to_index[k] != data_name_to_index[k]
}

LOGGER.warning(
"The variables in the model do not match the variables in the data. "
"If you're fine-tuning or pre-training, you may have to adjust the "
"variable order and naming in your config.",
)
if only_in_model:
LOGGER.warning("Variables only in model: %s", only_in_model)
if only_in_data:
LOGGER.warning("Variables only in data: %s", only_in_data)
if set(only_in_model.values()) == set(only_in_data.values()):
# This checks if the order is the same, but the naming is different. This is not be treated as an error.
LOGGER.warning(
"The variable naming is different, but the order appears to be the same. Continuing with training.",
)
else:
# If the renamed variables are not in the same index locations, raise an error.
error_msg += (
"The variable order in the model and data is different.\n"
"Please adjust the variable order in your config, you may need to "
"use the 'reorder' and 'rename' key in the dataloader config.\n"
"Refer to the Anemoi Datasets documentation for more information.\n"
)
if different_values:
# If the variables are named the same but in different order, raise an error.
error_msg += (
f"Detected a different sort order of the same variables: {different_values}.\n"
"Please adjust the variable order in your config, you may need to use the "
f"'reorder' key in the dataloader config. With:\n `reorder: {self._model_name_to_index}`\n"
)

if error_msg:
LOGGER.error(error_msg)
raise ValueError(error_msg)
Loading
Loading