diff --git a/CHANGELOG.md b/CHANGELOG.md index 287d76ea..cccb15a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ Keep it human-readable, your future self will thank you! - Sub-hour datasets [#63](https://github.com/ecmwf/anemoi-training/pull/63) - Add synchronisation workflow [#92](https://github.com/ecmwf/anemoi-training/pull/92) - Feat: Anemoi Profiler compatible with mlflow and using Pytorch (Kineto) Profiler for memory report [38](https://github.com/ecmwf/anemoi-training/pull/38/) +- Added a check for the variable sorting on pre-trained/finetuned models [#120](https://github.com/ecmwf/anemoi-training/pull/120) - New limited area config file added, limited_area.yaml. [#134](https://github.com/ecmwf/anemoi-training/pull/134/) - New stretched grid config added, stretched_grid.yaml [#133](https://github.com/ecmwf/anemoi-training/pull/133) @@ -36,6 +37,9 @@ Keep it human-readable, your future self will thank you! - Modified training configuration to support max_steps and tied lr iterations to max_steps by default [#67](https://github.com/ecmwf/anemoi-training/pull/67) - Merged node & edge trainable feature callbacks into one. [#135](https://github.com/ecmwf/anemoi-training/pull/135) +### Removed +- Removed the resolution config entry [#120](https://github.com/ecmwf/anemoi-training/pull/120) + ## [0.2.2 - Maintenance: pin python <3.13](https://github.com/ecmwf/anemoi-training/compare/0.2.1...0.2.2) - 2024-10-28 diff --git a/src/anemoi/training/config/data/zarr.yaml b/src/anemoi/training/config/data/zarr.yaml index 3b9a4537..943899da 100644 --- a/src/anemoi/training/config/data/zarr.yaml +++ b/src/anemoi/training/config/data/zarr.yaml @@ -1,5 +1,4 @@ format: zarr -resolution: o96 # Time frequency requested from dataset frequency: 6h # Time step of model (must be multiple of frequency) @@ -82,5 +81,5 @@ processors: # _convert_: all # config: ${data.remapper} -# Values set in the code + # Values set in the code num_features: null # number of features in the forecast state diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index 303266fc..1497d27f 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -86,11 +86,6 @@ def __init__(self, config: DictConfig) -> None: if not self.config.dataloader.get("pin_memory", True): LOGGER.info("Data loader memory pinning disabled.") - def _check_resolution(self, resolution: str) -> None: - assert ( - self.config.data.resolution.lower() == resolution.lower() - ), f"Network resolution {self.config.data.resolution=} does not match dataset resolution {resolution=}" - @cached_property def statistics(self) -> dict: return self.ds_train.statistics @@ -177,7 +172,7 @@ def _get_dataset( label: str = "generic", ) -> NativeGridDataset: r = max(rollout, self.rollout) - data = NativeGridDataset( + return NativeGridDataset( data_reader=data_reader, rollout=r, multistep=self.config.training.multistep_input, @@ -188,8 +183,6 @@ def _get_dataset( shuffle=shuffle, label=label, ) - self._check_resolution(data.resolution) - return data def _get_dataloader(self, ds: NativeGridDataset, stage: str) -> DataLoader: assert stage in {"training", "validation", "test"} diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index f3597843..4297ff2d 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -23,6 +23,7 @@ from anemoi.training.diagnostics.callbacks.optimiser import LearningRateMonitor from anemoi.training.diagnostics.callbacks.optimiser import StochasticWeightAveraging from anemoi.training.diagnostics.callbacks.provenance import ParentUUIDCallback +from anemoi.training.diagnostics.callbacks.sanity import CheckVariableOrder if TYPE_CHECKING: from pytorch_lightning.callbacks import Callback @@ -198,6 +199,9 @@ def get_callbacks(config: DictConfig) -> list[Callback]: # Parent UUID callback trainer_callbacks.append(ParentUUIDCallback(config)) + # Check variable order callback + trainer_callbacks.append(CheckVariableOrder()) + return trainer_callbacks diff --git a/src/anemoi/training/diagnostics/callbacks/sanity.py b/src/anemoi/training/diagnostics/callbacks/sanity.py new file mode 100644 index 00000000..751bf273 --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks/sanity.py @@ -0,0 +1,169 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import pytorch_lightning as pl + +LOGGER = logging.getLogger(__name__) + + +class CheckVariableOrder(pl.callbacks.Callback): + """Check the order of the variables in a pre-trained / fine-tuning model.""" + + def __init__(self) -> None: + super().__init__() + self._model_name_to_index = None + + def on_load_checkpoint(self, trainer: pl.Trainer, _: pl.LightningModule, checkpoint: dict) -> None: + """Cache the model mapping from the checkpoint. + + Parameters + ---------- + trainer : pl.Trainer + Pytorch Lightning trainer + _ : pl.LightningModule + Not used + checkpoint : dict + Pytorch Lightning checkpoint + """ + self._model_name_to_index = checkpoint["hyper_parameters"]["data_indices"].name_to_index + data_name_to_index = trainer.datamodule.data_indices.name_to_index + + self._compare_variables(data_name_to_index) + + def on_sanity_check_start(self, trainer: pl.Trainer, _: pl.LightningModule) -> None: + """Cache the model mapping from the datamodule if not loaded from checkpoint. + + Parameters + ---------- + trainer : pl.Trainer + Pytorch Lightning trainer + _ : pl.LightningModule + Not used + """ + if self._model_name_to_index is None: + self._model_name_to_index = trainer.datamodule.data_indices.name_to_index + + def on_train_epoch_start(self, trainer: pl.Trainer, _: pl.LightningModule) -> None: + """Check the order of the variables in the model from checkpoint and the training data. + + Parameters + ---------- + trainer : pl.Trainer + Pytorch Lightning trainer + _ : pl.LightningModule + Not used + """ + data_name_to_index = trainer.datamodule.ds_train.name_to_index + + self._compare_variables(data_name_to_index) + + def on_validation_epoch_start(self, trainer: pl.Trainer, _: pl.LightningModule) -> None: + """Check the order of the variables in the model from checkpoint and the validation data. + + Parameters + ---------- + trainer : pl.Trainer + Pytorch Lightning trainer + _ : pl.LightningModule + Not used + """ + data_name_to_index = trainer.datamodule.ds_valid.name_to_index + + self._compare_variables(data_name_to_index) + + def on_test_epoch_start(self, trainer: pl.Trainer, _: pl.LightningModule) -> None: + """Check the order of the variables in the model from checkpoint and the test data. + + Parameters + ---------- + trainer : pl.Trainer + Pytorch Lightning trainer + _ : pl.LightningModule + Not used + """ + data_name_to_index = trainer.datamodule.ds_test.name_to_index + + self._compare_variables(data_name_to_index) + + def _compare_variables(self, data_name_to_index: dict[str, int]) -> None: + """Compare the order of the variables in the model from checkpoint and the data. + + Parameters + ---------- + data_name_to_index : dict[str, int] + The dictionary mapping variable names to their indices in the data. + + Raises + ------ + ValueError + If the variable order in the model and data is verifiably different. + """ + if self._model_name_to_index is None: + LOGGER.info("No variable order to compare. Skipping variable order check.") + return + + if self._model_name_to_index == data_name_to_index: + LOGGER.info("The order of the variables in the model matches the order in the data.") + LOGGER.debug("%s, %s", self._model_name_to_index, data_name_to_index) + return + + keys1 = set(self._model_name_to_index.keys()) + keys2 = set(data_name_to_index.keys()) + + error_msg = "" + + # Find keys unique to each dictionary + only_in_model = {key: self._model_name_to_index[key] for key in (keys1 - keys2)} + only_in_data = {key: data_name_to_index[key] for key in (keys2 - keys1)} + + # Find common keys + common_keys = keys1 & keys2 + + # Compare values for common keys + different_values = { + k: (self._model_name_to_index[k], data_name_to_index[k]) + for k in common_keys + if self._model_name_to_index[k] != data_name_to_index[k] + } + + LOGGER.warning( + "The variables in the model do not match the variables in the data. " + "If you're fine-tuning or pre-training, you may have to adjust the " + "variable order and naming in your config.", + ) + if only_in_model: + LOGGER.warning("Variables only in model: %s", only_in_model) + if only_in_data: + LOGGER.warning("Variables only in data: %s", only_in_data) + if set(only_in_model.values()) == set(only_in_data.values()): + # This checks if the order is the same, but the naming is different. This is not be treated as an error. + LOGGER.warning( + "The variable naming is different, but the order appears to be the same. Continuing with training.", + ) + else: + # If the renamed variables are not in the same index locations, raise an error. + error_msg += ( + "The variable order in the model and data is different.\n" + "Please adjust the variable order in your config, you may need to " + "use the 'reorder' and 'rename' key in the dataloader config.\n" + "Refer to the Anemoi Datasets documentation for more information.\n" + ) + if different_values: + # If the variables are named the same but in different order, raise an error. + error_msg += ( + f"Detected a different sort order of the same variables: {different_values}.\n" + "Please adjust the variable order in your config, you may need to use the " + f"'reorder' key in the dataloader config. With:\n `reorder: {self._model_name_to_index}`\n" + ) + + if error_msg: + LOGGER.error(error_msg) + raise ValueError(error_msg) diff --git a/tests/diagnostics/callbacks/test_variable_order.py b/tests/diagnostics/callbacks/test_variable_order.py new file mode 100644 index 00000000..6f91cdc2 --- /dev/null +++ b/tests/diagnostics/callbacks/test_variable_order.py @@ -0,0 +1,289 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from typing import Any + +import pytest +from anemoi.models.data_indices.collection import IndexCollection + +from anemoi.training.diagnostics.callbacks.sanity import CheckVariableOrder +from anemoi.training.train.train import AnemoiTrainer + + +@pytest.fixture +def name_to_index() -> dict: + return {"a": 0, "b": 1, "c": 2} + + +@pytest.fixture +def name_to_index_permute() -> dict: + return {"a": 0, "b": 2, "c": 1} + + +@pytest.fixture +def name_to_index_rename() -> dict: + return {"a": 0, "b": 1, "d": 2} + + +@pytest.fixture +def name_to_index_partial_rename_permute() -> dict: + return {"a": 2, "b": 1, "d": 0} + + +@pytest.fixture +def name_to_index_rename_permute() -> dict: + return {"x": 2, "b": 1, "d": 0} + + +@pytest.fixture +def fake_trainer(mocker: Any, name_to_index: dict) -> AnemoiTrainer: + trainer = mocker.Mock(spec=AnemoiTrainer) + trainer.datamodule.data_indices.name_to_index = name_to_index + return trainer + + +@pytest.fixture +def checkpoint(mocker: Any, name_to_index: dict) -> dict[str, dict[str, IndexCollection]]: + data_index = mocker.Mock(spec=IndexCollection) + data_index.name_to_index = name_to_index + return {"hyper_parameters": {"data_indices": data_index}} + + +@pytest.fixture +def callback() -> CheckVariableOrder: + callback = CheckVariableOrder() + assert callback is not None + assert hasattr(callback, "on_load_checkpoint") + assert hasattr(callback, "on_sanity_check_start") + assert hasattr(callback, "on_train_epoch_start") + assert hasattr(callback, "on_validation_epoch_start") + assert hasattr(callback, "on_test_epoch_start") + + assert callback._model_name_to_index is None + + return callback + + +def test_on_load_checkpoint( + fake_trainer: AnemoiTrainer, + callback: CheckVariableOrder, + checkpoint: dict, + name_to_index: dict, +) -> None: + assert callback._model_name_to_index is None + callback.on_load_checkpoint(fake_trainer, None, checkpoint) + assert callback._model_name_to_index == name_to_index + + assert callback._compare_variables(name_to_index) is None + + +def test_on_sanity(fake_trainer: AnemoiTrainer, callback: CheckVariableOrder, name_to_index: dict) -> None: + assert callback._model_name_to_index is None + callback.on_sanity_check_start(fake_trainer, None) + assert callback._model_name_to_index == name_to_index + + assert callback._compare_variables(name_to_index) is None + + +def test_on_epoch(fake_trainer: AnemoiTrainer, callback: CheckVariableOrder, name_to_index: dict) -> None: + """Test all epoch functions with "working" indices.""" + assert callback._model_name_to_index is None + callback.on_train_epoch_start(fake_trainer, None) + callback.on_validation_epoch_start(fake_trainer, None) + callback.on_test_epoch_start(fake_trainer, None) + assert callback._model_name_to_index is None + + assert callback._compare_variables(name_to_index) is None + + # Test with initialised model_name_to_index + callback.on_sanity_check_start(fake_trainer, None) + assert callback._model_name_to_index == name_to_index + + fake_trainer.datamodule.ds_train.name_to_index = name_to_index + fake_trainer.datamodule.ds_valid.name_to_index = name_to_index + fake_trainer.datamodule.ds_test.name_to_index = name_to_index + callback.on_train_epoch_start(fake_trainer, None) + callback.on_validation_epoch_start(fake_trainer, None) + callback.on_test_epoch_start(fake_trainer, None) + + assert callback._compare_variables(name_to_index) is None + + +def test_on_epoch_permute( + fake_trainer: AnemoiTrainer, + callback: CheckVariableOrder, + name_to_index: dict, + name_to_index_permute: dict, +) -> None: + """Test all epoch functions with permuted indices. + + Expecting errors in all cases. + """ + assert callback._model_name_to_index is None + callback.on_train_epoch_start(fake_trainer, None) + callback.on_validation_epoch_start(fake_trainer, None) + callback.on_test_epoch_start(fake_trainer, None) + assert callback._model_name_to_index is None + + assert callback._compare_variables(name_to_index) is None + + # Test with initialised model_name_to_index + callback.on_sanity_check_start(fake_trainer, None) + assert callback._model_name_to_index == name_to_index + + fake_trainer.datamodule.ds_train.name_to_index = name_to_index_permute + fake_trainer.datamodule.ds_valid.name_to_index = name_to_index_permute + fake_trainer.datamodule.ds_test.name_to_index = name_to_index_permute + with pytest.raises(ValueError, match="Detected a different sort order of the same variables:") as exc_info: + callback.on_train_epoch_start(fake_trainer, None) + assert "{'c': (2, 1), 'b': (1, 2)}" in str(exc_info.value) or "{'b': (1, 2), 'c': (2, 1)}" in str(exc_info.value) + with pytest.raises(ValueError, match="Detected a different sort order of the same variables:") as exc_info: + callback.on_validation_epoch_start(fake_trainer, None) + assert "{'c': (2, 1), 'b': (1, 2)}" in str(exc_info.value) or "{'b': (1, 2), 'c': (2, 1)}" in str(exc_info.value) + with pytest.raises(ValueError, match="Detected a different sort order of the same variables:") as exc_info: + callback.on_test_epoch_start(fake_trainer, None) + assert "{'c': (2, 1), 'b': (1, 2)}" in str(exc_info.value) or "{'b': (1, 2), 'c': (2, 1)}" in str(exc_info.value) + + with pytest.raises(ValueError, match="Detected a different sort order of the same variables:") as exc_info: + callback._compare_variables(name_to_index_permute) + assert "{'c': (2, 1), 'b': (1, 2)}" in str(exc_info.value) or "{'b': (1, 2), 'c': (2, 1)}" in str(exc_info.value) + + +def test_on_epoch_rename( + fake_trainer: AnemoiTrainer, + callback: CheckVariableOrder, + name_to_index: dict, + name_to_index_rename: dict, +) -> None: + """Test all epoch functions with renamed indices. + + Expecting passes in all cases. + """ + assert callback._model_name_to_index is None + callback.on_train_epoch_start(fake_trainer, None) + callback.on_validation_epoch_start(fake_trainer, None) + callback.on_test_epoch_start(fake_trainer, None) + assert callback._model_name_to_index is None + + assert callback._compare_variables(name_to_index) is None + + # Test with initialised model_name_to_index + callback.on_sanity_check_start(fake_trainer, None) + assert callback._model_name_to_index == name_to_index + + fake_trainer.datamodule.ds_train.name_to_index = name_to_index_rename + fake_trainer.datamodule.ds_valid.name_to_index = name_to_index_rename + fake_trainer.datamodule.ds_test.name_to_index = name_to_index_rename + callback.on_train_epoch_start(fake_trainer, None) + callback.on_validation_epoch_start(fake_trainer, None) + callback.on_test_epoch_start(fake_trainer, None) + + callback._compare_variables(name_to_index_rename) + + +def test_on_epoch_rename_permute( + fake_trainer: AnemoiTrainer, + callback: CheckVariableOrder, + name_to_index: dict, + name_to_index_rename_permute: dict, +) -> None: + """Test all epoch functions with renamed and permuted indices. + + Expects all passes (but warnings). + """ + assert callback._model_name_to_index is None + callback.on_train_epoch_start(fake_trainer, None) + callback.on_validation_epoch_start(fake_trainer, None) + callback.on_test_epoch_start(fake_trainer, None) + assert callback._model_name_to_index is None + + assert callback._compare_variables(name_to_index) is None + + # Test with initialised model_name_to_index + callback.on_sanity_check_start(fake_trainer, None) + assert callback._model_name_to_index == name_to_index + + fake_trainer.datamodule.ds_train.name_to_index = name_to_index_rename_permute + fake_trainer.datamodule.ds_valid.name_to_index = name_to_index_rename_permute + fake_trainer.datamodule.ds_test.name_to_index = name_to_index_rename_permute + callback.on_train_epoch_start(fake_trainer, None) + callback.on_validation_epoch_start(fake_trainer, None) + callback.on_test_epoch_start(fake_trainer, None) + + callback._compare_variables(name_to_index_rename_permute) + + +def test_on_epoch_partial_rename_permute( + fake_trainer: AnemoiTrainer, + callback: CheckVariableOrder, + name_to_index: dict, + name_to_index_partial_rename_permute: dict, +) -> None: + """Test all epoch functions with partially renamed and permuted indices. + + Expects all errors. + """ + assert callback._model_name_to_index is None + callback.on_train_epoch_start(fake_trainer, None) + callback.on_validation_epoch_start(fake_trainer, None) + callback.on_test_epoch_start(fake_trainer, None) + assert callback._model_name_to_index is None + + assert callback._compare_variables(name_to_index) is None + + # Test with initialised model_name_to_index + callback.on_sanity_check_start(fake_trainer, None) + assert callback._model_name_to_index == name_to_index + + fake_trainer.datamodule.ds_train.name_to_index = name_to_index_partial_rename_permute + fake_trainer.datamodule.ds_valid.name_to_index = name_to_index_partial_rename_permute + fake_trainer.datamodule.ds_test.name_to_index = name_to_index_partial_rename_permute + with pytest.raises(ValueError, match="The variable order in the model and data is different."): + callback.on_train_epoch_start(fake_trainer, None) + with pytest.raises(ValueError, match="The variable order in the model and data is different."): + callback.on_validation_epoch_start(fake_trainer, None) + with pytest.raises(ValueError, match="The variable order in the model and data is different."): + callback.on_test_epoch_start(fake_trainer, None) + + with pytest.raises(ValueError, match="The variable order in the model and data is different."): + callback._compare_variables(name_to_index_partial_rename_permute) + + +def test_on_epoch_wrong_validation( + fake_trainer: AnemoiTrainer, + callback: CheckVariableOrder, + name_to_index: dict, + name_to_index_permute: dict, + name_to_index_rename: dict, +) -> None: + """Test all epoch functions with "working" indices, but different validation indices.""" + assert callback._model_name_to_index is None + callback.on_train_epoch_start(fake_trainer, None) + callback.on_validation_epoch_start(fake_trainer, None) + callback.on_test_epoch_start(fake_trainer, None) + assert callback._model_name_to_index is None + + assert callback._compare_variables(name_to_index) is None + + # Test with initialised model_name_to_index + callback.on_sanity_check_start(fake_trainer, None) + assert callback._model_name_to_index == name_to_index + + fake_trainer.datamodule.ds_train.name_to_index = name_to_index + fake_trainer.datamodule.ds_valid.name_to_index = name_to_index_permute + fake_trainer.datamodule.ds_test.name_to_index = name_to_index_rename + callback.on_train_epoch_start(fake_trainer, None) + with pytest.raises(ValueError, match="Detected a different sort order of the same variables:") as exc_info: + callback.on_validation_epoch_start(fake_trainer, None) + assert " {'c': (2, 1), 'b': (1, 2)}" in str( + exc_info.value, + ) or "{'b': (1, 2), 'c': (2, 1)}" in str(exc_info.value) + callback.on_test_epoch_start(fake_trainer, None) + + assert callback._compare_variables(name_to_index) is None diff --git a/tests/diagnostics/test_callbacks.py b/tests/diagnostics/test_callbacks.py index a61b19f1..58ea6440 100644 --- a/tests/diagnostics/test_callbacks.py +++ b/tests/diagnostics/test_callbacks.py @@ -14,6 +14,8 @@ from anemoi.training.diagnostics.callbacks import get_callbacks +NUM_FIXED_CALLBACKS = 2 # ParentUUIDCallback, CheckVariableOrder + default_config = """ diagnostics: callbacks: [] @@ -39,7 +41,7 @@ def test_no_extra_callbacks_set(): # No extra callbacks set config = omegaconf.OmegaConf.create(yaml.safe_load(default_config)) callbacks = get_callbacks(config) - assert len(callbacks) == 1 # ParentUUIDCallback + assert len(callbacks) == NUM_FIXED_CALLBACKS # ParentUUIDCallback, CheckVariableOrder, etc def test_add_config_enabled_callback(): @@ -47,7 +49,7 @@ def test_add_config_enabled_callback(): config = omegaconf.OmegaConf.create(default_config) config.diagnostics.callbacks.append({"log": {"mlflow": {"enabled": True}}}) callbacks = get_callbacks(config) - assert len(callbacks) == 2 + assert len(callbacks) == NUM_FIXED_CALLBACKS + 1 def test_add_callback(): @@ -56,7 +58,7 @@ def test_add_callback(): {"_target_": "anemoi.training.diagnostics.callbacks.provenance.ParentUUIDCallback"}, ) callbacks = get_callbacks(config) - assert len(callbacks) == 2 + assert len(callbacks) == NUM_FIXED_CALLBACKS + 1 def test_add_plotting_callback(monkeypatch): @@ -73,4 +75,4 @@ def __init__(self, config: omegaconf.DictConfig): config.diagnostics.plot.enabled = True config.diagnostics.plot.callbacks = [{"_target_": "anemoi.training.diagnostics.callbacks.plot.PlotLoss"}] callbacks = get_callbacks(config) - assert len(callbacks) == 2 + assert len(callbacks) == NUM_FIXED_CALLBACKS + 1