Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions training/src/anemoi/training/config/dataloader/multi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,6 @@ limit_batches:
validation: null
test: 20

# Dataset-specific grid indices configurations
# If not specified, all datasets will use the default grid_indices above
grid_indices:
datasets:
era5:
_target_: anemoi.training.data.grid_indices.FullGrid
nodes_name: ${graph.data}
cerra:
_target_: anemoi.training.data.grid_indices.FullGrid # Could be different
nodes_name: ${graph.data}

# ============
# Multi-Dataset Configuration
# Define multiple datasets that will be synchronized during training
Expand All @@ -61,13 +50,15 @@ training:
frequency: ${data.frequency}
drop: []
trajectory: null
lam_mask_radius_km: null
cerra:
dataset: ${system.input.dataset_b} # Using same dataset as duplicate for testing
start: 1985
end: 2020
frequency: ${data.frequency}
drop: []
trajectory: null
lam_mask_radius_km: null

validation_rollout: 1 # number of rollouts to use for validation

Expand All @@ -81,13 +72,15 @@ validation:
frequency: ${data.frequency}
drop: []
trajectory: null
lam_mask_radius_km: null
cerra:
dataset: ${system.input.dataset_b}
start: 2021
end: 2021
frequency: ${data.frequency}
drop: []
trajectory: null
lam_mask_radius_km: null

# Multi-dataset test with same datasets, different time period
test:
Expand All @@ -99,10 +92,12 @@ test:
frequency: ${data.frequency}
drop: []
trajectory: null
lam_mask_radius_km: null
cerra:
dataset: ${system.input.dataset_b}
start: 2022
end: null
frequency: ${data.frequency}
drop: []
trajectory: null
lam_mask_radius_km: null
13 changes: 4 additions & 9 deletions training/src/anemoi/training/config/dataloader/native_grid.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,15 @@ limit_batches:
validation: null
test: 20

# set a custom mask for grid points.
# Useful for LAM (dropping unconnected nodes from forcing dataset)
grid_indices:
datasets:
data: # user-defined name for the dataset
_target_: anemoi.training.data.grid_indices.FullGrid
nodes_name: ${graph.data}

# ============
# Dataloader definitions
# These follow the anemoi-datasets patterns
# You can make these as complicated for merging as you like
# See https://anemoi-datasets.readthedocs.io
# ============

# Pointers to datasets and model run info
dataset: ${system.input.dataset}

model_run_info: null # Add for non-analysis training

training:
Expand All @@ -65,6 +57,7 @@ training:
frequency: ${data.frequency}
drop: []
trajectory: ${dataloader.model_run_info}
lam_mask_radius_km: null

validation_rollout: 1 # number of rollouts to use for validation, must be equal or greater than rollout expected by callbacks

Expand All @@ -77,6 +70,7 @@ validation:
frequency: ${data.frequency}
drop: []
trajectory: ${dataloader.model_run_info}
lam_mask_radius_km: null

test:
datasets:
Expand All @@ -87,3 +81,4 @@ test:
frequency: ${data.frequency}
drop: []
trajectory: ${dataloader.model_run_info}
lam_mask_radius_km: null
6 changes: 1 addition & 5 deletions training/src/anemoi/training/config/graph/limited_area.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,7 @@ edges:
target_mask_attr_name: cutout_mask
attributes: ${graph.attributes.edges}

post_processors:
- _target_: anemoi.graphs.processors.RemoveUnconnectedNodes
nodes_name: data
ignore: cutout_mask # optional
save_mask_indices_to_attr: indices_connected_nodes # optional
post_processors: []

attributes:
nodes:
Expand Down
10 changes: 6 additions & 4 deletions training/src/anemoi/training/config/lam.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ dataloader:
- dataset: ${system.input.forcing_dataset}
adjust: all
min_distance_km: 0
grid_indices:
training:
datasets:
data:
_target_: anemoi.training.data.grid_indices.MaskedGrid
nodes_name: data
node_attribute_name: indices_connected_nodes
lam_mask_radius_km: 300
validation:
datasets:
data:
lam_mask_radius_km: 300
model:
output_mask:
_target_: anemoi.training.utils.masks.Boolean1DMask
Expand Down
31 changes: 2 additions & 29 deletions training/src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@
from functools import cached_property

import pytorch_lightning as pl
from hydra.utils import instantiate
from torch.utils.data import DataLoader
from torch_geometric.data import HeteroData

from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.utils.config import get_multiple_datasets_config
from anemoi.training.data.grid_indices import BaseGridIndices
from anemoi.training.data.multidataset import MultiDataset
from anemoi.training.schemas.base_schema import BaseSchema
from anemoi.training.utils.worker_init import worker_init_func
Expand All @@ -29,20 +26,17 @@
class AnemoiDatasetsDataModule(pl.LightningDataModule):
"""Anemoi Datasets data module for PyTorch Lightning."""

def __init__(self, config: BaseSchema, graph_data: HeteroData) -> None:
def __init__(self, config: BaseSchema) -> None:
"""Initialize Multi-dataset data module.

Parameters
----------
config : BaseSchema
Job configuration with multi-dataset specification
graph_data : HeteroData
Graph data for the model
"""
super().__init__()

self.config = config
self.graph_data = graph_data
self.train_dataloader_config = get_multiple_datasets_config(self.config.dataloader.training)
self.valid_dataloader_config = get_multiple_datasets_config(self.config.dataloader.validation)
self.test_dataloader_config = get_multiple_datasets_config(self.config.dataloader.test)
Expand Down Expand Up @@ -77,13 +71,7 @@ def metadata(self) -> dict:
@cached_property
def supporting_arrays(self) -> dict:
"""Return supporting arrays from all training datasets."""
supporting_arrays = self.ds_train.supporting_arrays
for dataset_name, grid_indices in self.grid_indices.items():
if dataset_name in supporting_arrays:
supporting_arrays[dataset_name] = supporting_arrays[dataset_name] | grid_indices.supporting_arrays
else:
supporting_arrays[dataset_name] = grid_indices.supporting_arrays
return supporting_arrays
return self.ds_train.supporting_arrays

@cached_property
def data_indices(self) -> dict[str, IndexCollection]:
Expand Down Expand Up @@ -118,20 +106,6 @@ def relative_date_indices(self, val_rollout: int = 1) -> list:
multi_step = self.config.training.multistep_input
return list(range(multi_step + rollout))

@cached_property
def grid_indices(self) -> dict[str, type[BaseGridIndices]]:
"""Initialize grid indices for spatial sharding for each dataset."""
grid_indices_dict = {}

# Each dataset can have its own grid indices configuration
grid_indices_config = get_multiple_datasets_config(self.config.dataloader.grid_indices)
for dataset_name, grid_config in grid_indices_config.items():
grid_indices = instantiate(grid_config, reader_group_size=self.config.dataloader.read_group_size)
grid_indices.setup(self.graph_data[dataset_name])
grid_indices_dict[dataset_name] = grid_indices

return grid_indices_dict

@cached_property
def ds_train(self) -> MultiDataset:
"""Create multi-dataset for training."""
Expand Down Expand Up @@ -164,7 +138,6 @@ def _get_dataset(
relative_date_indices=self.relative_date_indices(val_rollout),
timestep=self.config.data.timestep,
shuffle=shuffle,
grid_indices=self.grid_indices,
label=label,
)

Expand Down
62 changes: 59 additions & 3 deletions training/src/anemoi/training/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import datetime
import logging
from abc import abstractmethod
from functools import cached_property

import numpy as np
import torch
Expand All @@ -19,6 +20,9 @@
from rich.tree import Tree

from anemoi.datasets import open_dataset
from anemoi.training.data.grid_indices import BaseIndices
from anemoi.training.data.grid_indices import FullGrid
from anemoi.training.data.grid_indices import MaskedGrid
from anemoi.utils.dates import frequency_to_seconds

LOGGER = logging.getLogger(__name__)
Expand All @@ -34,14 +38,41 @@ def __init__(
end: datetime.datetime | int | None = None,
frequency: str | None = None,
drop: list[str] | None = None,
lam_mask_radius_km: int | None = None,
):
self.data = open_dataset(dataset, start=start, end=end, frequency=frequency, drop=drop)
"""Initialize Anemoi data reader."""
ds_kwargs = {}
if drop is not None:
ds_kwargs["drop"] = drop

if frequency is not None:
ds_kwargs["frequency"] = frequency

self.data = open_dataset(dataset, start=start, end=end, **ds_kwargs)
self.lam_mask_radius_km = lam_mask_radius_km

@cached_property
def grid_indices(self) -> BaseIndices:
if self.lam_mask_radius_km is None:
return FullGrid()

return MaskedGrid(
latitudes=self.data.latitudes,
longitudes=self.data.longitudes,
mask=self.cutout_mask,
mask_radius_km=self.lam_mask_radius_km,
)

@property
def dates(self) -> list[datetime.datetime]:
"""Return dataset dates."""
return self.data.dates

@property
def grid_size(self) -> int:
"""Return dataset grid size."""
return sum(self.data.grids)

@property
def statistics(self) -> dict:
"""Return dataset statistics."""
Expand Down Expand Up @@ -77,7 +108,7 @@ def frequency(self) -> datetime.timedelta:
@property
def supporting_arrays(self) -> dict:
"""Return dataset supporting_arrays."""
return self.data.supporting_arrays()
return self.data.supporting_arrays() | self.grid_indices.supporting_arrays

@property
def name_to_index(self) -> dict[str, int]:
Expand All @@ -89,6 +120,21 @@ def resolution(self) -> str:
"""Return dataset resolution."""
return self.data.resolution

@cached_property
def cutout_mask(self) -> np.ndarray:
"""Return cutout mask."""
cutout_mask = np.zeros(self.grid_size, dtype=bool)
if len(self.data.grids) <= 1:
err_msg = "Dataset `cutout_mask` property requires a cutout grid but does not have one."
raise ValueError(err_msg)
cutout_mask[: self.data.grids[0]] = True
return cutout_mask

@cached_property
def boundary_mask(self) -> np.ndarray:
"""Return boundary mask."""
return ~self.cutout_mask

@property
@abstractmethod
def has_trajectories(self) -> bool:
Expand All @@ -100,6 +146,7 @@ def get_sample(
grid_shard_indices: np.ndarray | None = None,
) -> torch.Tensor:
"""Get a sample from the dataset."""
grid_shard_indices = self.grid_indices.get_shard_indices(grid_shard_indices)
if isinstance(grid_shard_indices, slice):
# Load only shards into CPU memory
x = self.data[time_indices, :, :, grid_shard_indices]
Expand Down Expand Up @@ -150,8 +197,16 @@ def __init__(
end: datetime.datetime | int | None = None,
frequency: str | None = None,
drop: list[str] | None = None,
lam_mask_radius_km: int | None = None,
):
super().__init__(dataset, start=start, end=end, frequency=frequency, drop=drop)
super().__init__(
dataset,
start=start,
end=end,
frequency=frequency,
drop=drop,
lam_mask_radius_km=lam_mask_radius_km,
)
self.trajectory_start = trajectory_start
self.trajectory_length = trajectory_length

Expand All @@ -176,6 +231,7 @@ def create_dataset(dataset_config: dict) -> BaseAnemoiReader:
"""Factory function to create dataset based on dataset configuration."""
if isinstance(dataset_config, DictConfig):
dataset_config = dict(dataset_config)

trajectory_config = dataset_config.pop("trajectory", {})
if trajectory_config is not None and hasattr(trajectory_config, "start") and hasattr(trajectory_config, "length"):
LOGGER.info("Creating TrajectoryDataset...")
Expand Down
Loading
Loading