Skip to content
Open
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
337 commits
Select commit Hold shift + click to select a range
17678bf
remove unused code
JPXKQX Nov 28, 2025
34f1a0e
add test
JPXKQX Nov 28, 2025
c743ccb
start updating diffusion
JPXKQX Nov 28, 2025
10dfc38
diffusion
JPXKQX Dec 2, 2025
33bbcda
Merge remote-tracking branch 'origin/main' into feat/mapper-refactor
ssmmnn11 Dec 5, 2025
82305d3
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX Dec 11, 2025
6471510
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2025
d1cf2c5
Merge branch 'feature/multi-dataset-integration' of https://github.co…
JPXKQX Dec 11, 2025
45d6ff1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2025
9524934
update after rebase (WIP)
JPXKQX Dec 11, 2025
3401990
Merge branch 'feature/multi-dataset-integration' of https://github.co…
JPXKQX Dec 11, 2025
79059f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2025
7715fcf
debug_multi_eracerra working
JPXKQX Dec 11, 2025
caaa131
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2025
437be12
ensemble_crps working
JPXKQX Dec 11, 2025
599a9f0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2025
5a067b8
update predict_step()
JPXKQX Dec 11, 2025
885dcfa
Merge branch 'feature/multi-dataset-integration' of https://github.co…
JPXKQX Dec 11, 2025
74c13df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2025
a4ffdfa
move relative_date_indices inside MultiDataset
JPXKQX Dec 11, 2025
7b17ce7
Merge branch 'feature/multi-dataset-integration' of https://github.co…
JPXKQX Dec 11, 2025
368054e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2025
1f9d3b0
fix plotting after merge
VeraChristina Dec 11, 2025
7a4f34f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2025
ac8d393
working again
JPXKQX Dec 11, 2025
11038c8
Merge branch 'feature/multi-dataset-integration' of https://github.co…
JPXKQX Dec 11, 2025
de1f45f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2025
6142158
update
JPXKQX Dec 11, 2025
57d4d65
Merge branch 'feature/multi-dataset-integration' of https://github.co…
JPXKQX Dec 11, 2025
7109c50
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2025
89300b7
merged main
ssmmnn11 Dec 15, 2025
81fcecd
modifications to make sparse graph provider work
ssmmnn11 Dec 15, 2025
ab83ef5
revisions
ssmmnn11 Dec 15, 2025
41ea6c9
fix
ssmmnn11 Dec 15, 2025
366442e
fix for creation from graph
ssmmnn11 Dec 15, 2025
dc9e990
remove transpose option
ssmmnn11 Dec 15, 2025
d3e2f95
config fix
ssmmnn11 Dec 15, 2025
d968e2c
Merge branch 'main' into feat/mapper-refactor
ssmmnn11 Dec 16, 2025
8761862
fix
JPXKQX Dec 17, 2025
fc5dc44
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX Dec 17, 2025
b22041f
feat: add ensemble plots for multi datasets (#747)
VeraChristina Dec 17, 2025
0e3acda
Merge branch 'main' into feat/mapper-refactor
ssmmnn11 Dec 17, 2025
19fdb1b
remove predict_step
JPXKQX Dec 18, 2025
4a2798a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 18, 2025
1b6794c
new update
JPXKQX Dec 18, 2025
437c162
Fix predict_step
gmertes Dec 18, 2025
31fd027
revert unnecessary bool
gmertes Dec 18, 2025
14eb3aa
Update type hint
gmertes Dec 18, 2025
478a86d
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX Dec 18, 2025
d6a8eba
fix ensemble plots tests
VeraChristina Dec 18, 2025
5708d8c
Merge branch 'feature/multi-dataset-integration' of github.com:ecmwf/…
VeraChristina Dec 18, 2025
0cb28c4
feat: metadata for multi datasets (#762)
VeraChristina Dec 18, 2025
1894a7c
Merge remote-tracking branch 'origin/main' into feat/mapper-refactor
ssmmnn11 Dec 19, 2025
31b7610
adapt multiscale loss to graph provider
ssmmnn11 Dec 19, 2025
c994178
fix migration script
ssmmnn11 Dec 19, 2025
a801f22
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX Dec 19, 2025
76d83af
fix
JPXKQX Dec 19, 2025
c955060
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 19, 2025
6c42f3b
update crps
JPXKQX Dec 19, 2025
961605b
fix test_variable_order
VeraChristina Dec 19, 2025
eba0162
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX Dec 19, 2025
565161d
diffusion is working!
JPXKQX Dec 19, 2025
ad02aef
clean
JPXKQX Dec 19, 2025
b1ad392
time interpolation working
JPXKQX Dec 22, 2025
a80e061
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX Dec 22, 2025
d755ab1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 22, 2025
4f6635e
style
JPXKQX Dec 22, 2025
72a6f2f
pre-commit passing
JPXKQX Dec 22, 2025
154561e
unit tests passing
JPXKQX Dec 23, 2025
81a95ce
configs
JPXKQX Dec 23, 2025
d33ca46
prepare integration tests
JPXKQX Dec 23, 2025
e3f4a20
prepare schemas
JPXKQX Dec 23, 2025
6b58f4f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 23, 2025
2bdee19
.
JPXKQX Dec 23, 2025
ac762b3
configs
JPXKQX Dec 23, 2025
84351bd
update integration test configs
JPXKQX Dec 29, 2025
0393a82
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 29, 2025
d64fbfd
update integration benchmark configs
JPXKQX Dec 29, 2025
3271327
truncation matrices handling in integration tests
JPXKQX Dec 29, 2025
f822370
fix: ens schemas
JPXKQX Dec 29, 2025
1965535
fix: lam schemas
JPXKQX Dec 29, 2025
c5d2ad9
fix: stretched schema
JPXKQX Dec 29, 2025
d6e5e5b
fix: interpolation schemas
JPXKQX Dec 29, 2025
307693c
remove empty space
JPXKQX Dec 29, 2025
2724a91
fix: aicon integration configs
JPXKQX Dec 29, 2025
df7f88f
fix
JPXKQX Dec 29, 2025
39850b1
fix: data indices tests
JPXKQX Dec 30, 2025
d162f55
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 30, 2025
dbcc4da
fix: tend diffusion
JPXKQX Dec 30, 2025
9e0cee3
fix: interpolator & aicon integrations tests
JPXKQX Dec 30, 2025
61be08c
fix: integration tests
JPXKQX Dec 30, 2025
e44f969
fix: device
JPXKQX Dec 30, 2025
115f200
fix: hierarchical
JPXKQX Dec 30, 2025
c2e4865
fix: diffusion integration test
JPXKQX Dec 30, 2025
a6b2673
fix: interpolation integration test
JPXKQX Dec 30, 2025
047e840
fix: multiscale
JPXKQX Dec 30, 2025
99f608a
remove
JPXKQX Dec 30, 2025
a34b7f1
fix: compute_metrics
JPXKQX Dec 30, 2025
219c959
pre-commit
JPXKQX Dec 30, 2025
566bd04
Merge branch 'main' into feat/mapper-refactor
ssmmnn11 Dec 30, 2025
ad54463
add migration script
JPXKQX Dec 31, 2025
95e16c0
diffusion predict()
JPXKQX Dec 31, 2025
94d0333
diffusion predict
JPXKQX Dec 31, 2025
4b839e8
update unit test:
JPXKQX Dec 31, 2025
fc47f71
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 31, 2025
dccbba2
update unit test
JPXKQX Dec 31, 2025
900880d
diffusion: predict_step()
JPXKQX Dec 31, 2025
54e78fe
update diffusion sampler
JPXKQX Dec 31, 2025
2815f20
pre-commit
JPXKQX Dec 31, 2025
a968c6b
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX Jan 2, 2026
ca2fdfd
fix: Heun sampler test
JPXKQX Jan 2, 2026
217e40a
fix: DPMpp2M sampler tests
JPXKQX Jan 2, 2026
fb10385
fix: samplers comparison
JPXKQX Jan 2, 2026
7a3b9bc
pre-commit
JPXKQX Jan 2, 2026
4516f7f
Merge branch 'main' into feat/mapper-refactor
ssmmnn11 Jan 5, 2026
e4fcc0f
Merge branch 'main' into feat/mapper-refactor
anaprietonem Jan 5, 2026
3404fe2
fix multiscale loss tests
ssmmnn11 Jan 5, 2026
8132b10
Merge remote-tracking branch 'origin/feat/mapper-refactor' into feat/…
ssmmnn11 Jan 5, 2026
a3a1cdc
update documentation
ssmmnn11 Jan 5, 2026
96e5838
add edge sharding and checkpointing in graph provider
ssmmnn11 Jan 7, 2026
ee65fc8
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX Jan 7, 2026
2509061
fix for splitting uneven seq
ssmmnn11 Jan 7, 2026
300ee5f
fix sharding for uneven head dimensions
ssmmnn11 Jan 7, 2026
f200956
fix for validation loss computation
ssmmnn11 Jan 7, 2026
5d1088b
fix hanging callback
ssmmnn11 Jan 7, 2026
c7ff82a
clear optimiser states in migration script
ssmmnn11 Jan 7, 2026
1d88dd6
remove checkpoint opt state reset in migration script
ssmmnn11 Jan 8, 2026
df36dc6
refactor of edge sharding
ssmmnn11 Jan 8, 2026
3a52cae
bug fix: do not persist edge_inc buffer
ssmmnn11 Jan 8, 2026
d2ce519
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX Jan 9, 2026
664c2c8
Merge branch 'main' into feat/mapper-refactor
JPXKQX Jan 9, 2026
6f81a43
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2026
fea4055
update new checkpoint for gnn
anaprietonem Jan 9, 2026
c7af6e6
Merge branch 'feat/mapper-refactor' of github.com:ecmwf/anemoi-core i…
anaprietonem Jan 9, 2026
4688701
Merge branch 'feat/mapper-refactor' into feature/multi-dataset-integr…
JPXKQX Jan 9, 2026
b8e1d27
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2026
7fd8615
pre-commit
JPXKQX Jan 9, 2026
44da8ee
fix
JPXKQX Jan 10, 2026
789612c
fix: diffusion
JPXKQX Jan 12, 2026
95b4bbb
fix: model x_data_latent
JPXKQX Jan 12, 2026
63dff8a
fix: specify dataset grid_shard_slice
JPXKQX Jan 12, 2026
f7e6a7e
update migration
JPXKQX Jan 12, 2026
cc01462
Merge branch 'feat/mapper-refactor' into feature/multi-dataset-integr…
JPXKQX Jan 12, 2026
549bcc6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 12, 2026
f281ac7
fix
JPXKQX Jan 13, 2026
f863d02
fix: mapper refactor update gnn checkpoint (#793)
anaprietonem Jan 13, 2026
40cf199
Merge branch 'feat/mapper-refactor' into feature/multi-dataset-integr…
JPXKQX Jan 13, 2026
5b38d82
Merge branch 'feature/multi-dataset-integration-mapper-refactor' of h…
JPXKQX Jan 13, 2026
e7e3195
fixes for plotting
ssmmnn11 Jan 13, 2026
6c66fc2
Merge remote-tracking branch 'origin/feat/mapper-refactor' into feat/…
ssmmnn11 Jan 13, 2026
f239a1f
Merge branch 'main' into feat/mapper-refactor
japols Jan 13, 2026
c9bbd8c
use balanced_partition for head splits
japols Jan 13, 2026
9676037
fix: rename graph_providers.py to graph_provider.py
JPXKQX Jan 13, 2026
4f6fb3d
Merge branch 'feat/mapper-refactor' into feature/multi-dataset-integr…
JPXKQX Jan 13, 2026
07778f0
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX Jan 13, 2026
5294e8f
Merge branch 'feature/multi-dataset-integration-mapper-refactor' into…
JPXKQX Jan 13, 2026
da1c1e0
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX Jan 13, 2026
8c9fa2e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 13, 2026
3b871f2
pre-commit
JPXKQX Jan 13, 2026
7c111f9
Merge branch 'feature/multi-dataset-integration' of https://github.co…
JPXKQX Jan 13, 2026
b03a006
fix
JPXKQX Jan 13, 2026
e35e1d7
fix: integration tests
JPXKQX Jan 14, 2026
a855685
fix
JPXKQX Jan 14, 2026
8885656
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX Jan 14, 2026
2efd37a
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX Jan 14, 2026
16c2eb8
fix: lam with existing graph
JPXKQX Jan 14, 2026
65c270d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2026
529bdfd
update configs
JPXKQX Jan 14, 2026
eb0ab83
adding integration test for 2 datasets
JPXKQX Jan 14, 2026
498821f
fix: new integration tests
JPXKQX Jan 14, 2026
b6d6056
change dataset_b to CERRA
JPXKQX Jan 14, 2026
2f3c896
undo
JPXKQX Jan 14, 2026
afa90c5
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX Jan 14, 2026
a9962a3
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX Jan 15, 2026
3d67883
fix: test_training_cycle_lam_with_existing_graph
JPXKQX Jan 15, 2026
2a81ff6
fix: move to utils/config_utils.py to anemoi-utils
JPXKQX Jan 15, 2026
77f6a69
fix: profiler
JPXKQX Jan 15, 2026
061d0df
fix: configs
JPXKQX Jan 15, 2026
01593ce
fix: move config utility to anemoi/models
JPXKQX Jan 15, 2026
cceaea3
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX Jan 15, 2026
9f9a057
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2026
ae66995
delete extra configs
JPXKQX Jan 16, 2026
822a5af
fix: grid_shard_slice
JPXKQX Jan 16, 2026
c7c3117
fix: import
JPXKQX Jan 16, 2026
09e9287
add comment
JPXKQX Jan 16, 2026
7db8df1
fixes
ssmmnn11 Jan 16, 2026
6c2d6e7
fixes
ssmmnn11 Jan 16, 2026
c3d7f28
fix
ssmmnn11 Jan 16, 2026
1900ce2
multi dataset test fix
ssmmnn11 Jan 16, 2026
f45224f
fix: mlflow integration test
JPXKQX Jan 19, 2026
6b79046
plotting overrides multi integration test
VeraChristina Jan 19, 2026
0fe5980
fix: update stretched config
JPXKQX Jan 19, 2026
5993d7d
update ADR
JPXKQX Jan 19, 2026
1f337f4
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX Jan 19, 2026
3c62e6d
update plotting config
JPXKQX Jan 19, 2026
3c8ede1
Merge branch 'feature/multi-dataset-integration' of https://github.co…
JPXKQX Jan 19, 2026
2cc2901
fix: plotting node attributes
JPXKQX Jan 19, 2026
889be6e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2026
c39407c
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX Jan 19, 2026
300e2fb
get task type from task instead of config
VeraChristina Jan 20, 2026
b31747d
add draft section on multi datasets in docs
VeraChristina Jan 20, 2026
3c0ea34
update docs
VeraChristina Jan 20, 2026
01500ce
update ADR
JPXKQX Jan 20, 2026
d484bee
docstring
JPXKQX Jan 20, 2026
01a72c2
add supporting arrays from grid_indices (LAM)
JPXKQX Jan 20, 2026
e89ef7c
fix: supporting arrays
JPXKQX Jan 20, 2026
5644541
add multi dataset diagrams to docs
VeraChristina Jan 20, 2026
3b80ad6
update docs
VeraChristina Jan 20, 2026
35cce38
update config snippets in training docs
VeraChristina Jan 20, 2026
f6ddf4e
remove outdated config snippet in docs
VeraChristina Jan 20, 2026
5e80301
update docs/modules
JPXKQX Jan 20, 2026
7a00eda
update docs/user-guide
JPXKQX Jan 20, 2026
7107190
fix: autogenerated docstring
JPXKQX Jan 20, 2026
8036f20
update docs
VeraChristina Jan 20, 2026
0890231
bring back rollout & multi_step into interpolator task
JPXKQX Jan 21, 2026
25d4308
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX Jan 21, 2026
75e0a7e
refactor: streamline batch and ensemble size validation across datasets
JPXKQX Jan 21, 2026
b1bb7f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 21, 2026
b876baf
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX Jan 21, 2026
7b73af2
Merge branch 'feature/multi-dataset-integration' into multids/downsca…
JPXKQX Jan 21, 2026
a131cf0
fix for sharded model and loss that does not support sharding
ssmmnn11 Jan 21, 2026
d2c1d51
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 21, 2026
8b8485f
configs
JPXKQX Jan 21, 2026
d3811af
Merge branch 'multids/downscaling' of https://github.com/ecmwf/anemoi…
JPXKQX Jan 21, 2026
87088e5
fix: update schema
JPXKQX Jan 21, 2026
a2c595d
Merge branch 'feature/multi-dataset-integration' into multids/downsca…
JPXKQX Jan 21, 2026
d400ccb
Merge branch 'multids/downscaling' of https://github.com/ecmwf/anemoi…
JPXKQX Jan 21, 2026
c71ed9f
undo base changes
JPXKQX Jan 21, 2026
dd4b49e
update
JPXKQX Jan 21, 2026
5f80ad2
undo base.py
JPXKQX Jan 21, 2026
c20bf80
update
JPXKQX Jan 21, 2026
9f50066
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 21, 2026
db4da13
Update training/src/anemoi/training/train/tasks/downscaler.py
JPXKQX Jan 21, 2026
7b8a5bd
update
JPXKQX Jan 21, 2026
57ae94b
fix
JPXKQX Jan 21, 2026
f163e7a
Merge branch 'main' into multids/downscaling
JPXKQX Jan 22, 2026
2915351
fix
JPXKQX Jan 22, 2026
69e50e8
undo forecaster changes
JPXKQX Jan 22, 2026
1c53881
Update training/src/anemoi/training/train/tasks/forecaster.py
JPXKQX Jan 22, 2026
b4dfc68
update
JPXKQX Jan 22, 2026
e94906b
Merge branch 'multids/downscaling' of https://github.com/ecmwf/anemoi…
JPXKQX Jan 22, 2026
a093446
Merge branch 'main' into multids/downscaling
JPXKQX Jan 26, 2026
08f08fd
pre-commit
JPXKQX Jan 26, 2026
6a71d35
specify target dataset names
JPXKQX Jan 26, 2026
54b1da9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 26, 2026
0ad1f2e
Update training/src/anemoi/training/train/tasks/base.py
JPXKQX Jan 27, 2026
1edfad2
integration tests
JPXKQX Jan 27, 2026
e18445f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 27, 2026
1b76bcb
update
JPXKQX Jan 27, 2026
b2514cb
Merge branch 'main' into multids/downscaling
JPXKQX Jan 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions models/src/anemoi/models/data_indices/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@
class IndexCollection:
"""Collection of data and model indices."""

def __init__(self, config, name_to_index) -> None:
self.config = OmegaConf.to_container(config, resolve=True)
def __init__(self, data_config, name_to_index) -> None:
self.config = OmegaConf.to_container(data_config, resolve=True)
self.name_to_index = dict(sorted(name_to_index.items(), key=operator.itemgetter(1)))
self.forcing = [] if config.data.forcing is None else OmegaConf.to_container(config.data.forcing, resolve=True)
self.forcing = [] if data_config.forcing is None else OmegaConf.to_container(data_config.forcing, resolve=True)
self.diagnostic = (
[] if config.data.diagnostic is None else OmegaConf.to_container(config.data.diagnostic, resolve=True)
[] if data_config.diagnostic is None else OmegaConf.to_container(data_config.diagnostic, resolve=True)
)
self.target = (
[] if config.data.get("target", None) is None else OmegaConf.to_container(config.data.target, resolve=True)
[] if data_config.get("target", None) is None else OmegaConf.to_container(data_config.target, resolve=True)
)
defined_variables = set.union(set(self.forcing), set(self.diagnostic), set(self.target))
self.prognostic = [v for v in self.name_to_index.keys() if v not in defined_variables]
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(self, config, name_to_index) -> None:
)

def __repr__(self) -> str:
return f"IndexCollection(config={self.config}, name_to_index={self.name_to_index})"
return f"IndexCollection(data_config={self.config}, name_to_index={self.name_to_index})"

def __eq__(self, other):
if not isinstance(other, IndexCollection):
Expand Down
1 change: 1 addition & 0 deletions models/src/anemoi/models/data_indices/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def todict(self):
"forcing": self.forcing,
"target": self.target,
"full": self.full,
"name_to_index": self.name_to_index,
}

@staticmethod
Expand Down
97 changes: 77 additions & 20 deletions models/src/anemoi/models/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch_geometric.data import HeteroData

from anemoi.models.preprocessing import Processors
from anemoi.models.utils.config import get_multiple_datasets_config
from anemoi.utils.config import DotDict


Expand Down Expand Up @@ -75,28 +76,77 @@ def __init__(
self.supporting_arrays = supporting_arrays if supporting_arrays is not None else {}
self.data_indices = data_indices
self._build_model()
self._update_metadata()

def _build_model(self) -> None:
"""Builds the model and pre- and post-processors."""
# Instantiate processors
def _build_processors_for_dataset(
self,
processors_configs: dict,
statistics: dict,
data_indices: dict,
statistics_tendencies: dict = None,
):
"""Build processors for a single dataset.

Parameters
----------
processors_configs : dict
Configuration for the processors
statistics : dict
Statistics for the dataset
data_indices : dict
Data indices for the dataset
statistics_tendencies : dict, optional
Tendencies statistics for the dataset

Returns
-------
tuple
(pre_processors, post_processors, pre_processors_tendencies, post_processors_tendencies)
"""
# Build processors for the dataset
processors = [
[name, instantiate(processor, data_indices=self.data_indices, statistics=self.statistics)]
for name, processor in self.config.data.processors.items()
[name, instantiate(processor, data_indices=data_indices, statistics=statistics)]
for name, processor in processors_configs.items()
]

# Assign the processor list pre- and post-processors
self.pre_processors = Processors(processors)
self.post_processors = Processors(processors, inverse=True)
pre_processors = Processors(processors)
post_processors = Processors(processors, inverse=True)

# If tendencies statistics are provided, instantiate the tendencies processors
if self.statistics_tendencies is not None:
processors = [
[name, instantiate(processor, data_indices=self.data_indices, statistics=self.statistics_tendencies)]
for name, processor in self.config.data.processors.items()
# Build tendencies processors if provided
pre_processors_tendencies = None
post_processors_tendencies = None
if statistics_tendencies is not None:
processors_tendencies = [
[name, instantiate(processor, data_indices=data_indices, statistics=statistics_tendencies)]
for name, processor in processors_configs.items()
]
# Assign the processor list pre- and post-processors
self.pre_processors_tendencies = Processors(processors)
self.post_processors_tendencies = Processors(processors, inverse=True)
pre_processors_tendencies = Processors(processors_tendencies)
post_processors_tendencies = Processors(processors_tendencies, inverse=True)

return pre_processors, post_processors, pre_processors_tendencies, post_processors_tendencies

def _build_model(self) -> None:
"""Builds the model and pre- and post-processors."""
# Multi-dataset mode: create processors for each dataset
self.pre_processors = torch.nn.ModuleDict()
self.post_processors = torch.nn.ModuleDict()
self.pre_processors_tendencies = torch.nn.ModuleDict()
self.post_processors_tendencies = torch.nn.ModuleDict()

data_config = get_multiple_datasets_config(self.config.data)
for dataset_name in self.statistics.keys():
# Build processors for each dataset
pre, post, pre_tend, post_tend = self._build_processors_for_dataset(
data_config[dataset_name].processors,
self.statistics[dataset_name],
self.data_indices[dataset_name],
self.statistics_tendencies[dataset_name] if self.statistics_tendencies is not None else None,
)
self.pre_processors[dataset_name] = pre
self.post_processors[dataset_name] = post
if pre_tend is not None:
self.pre_processors_tendencies[dataset_name] = pre_tend
self.post_processors_tendencies[dataset_name] = post_tend

# Instantiate the model
# Only pass _target_ and _convert_ from model config to avoid passing diffusion as kwarg
Expand All @@ -117,13 +167,17 @@ def _build_model(self) -> None:
self.forward = self.model.forward

def predict_step(
self, batch: torch.Tensor, model_comm_group: Optional[ProcessGroup] = None, gather_out: bool = True, **kwargs
) -> torch.Tensor:
self,
batch: dict[str, torch.Tensor],
model_comm_group: Optional[ProcessGroup] = None,
gather_out: bool = True,
**kwargs,
) -> dict[str, torch.Tensor]:
"""Prediction step for the model.

Parameters
----------
batch : torch.Tensor
batch : dict[str, torch.Tensor]
Input batched data.
model_comm_group : Optional[ProcessGroup], optional
model communication group, specifies which GPUs work together
Expand All @@ -132,7 +186,7 @@ def predict_step(

Returns
-------
torch.Tensor
dict[str, torch.Tensor]
Predicted data.
"""
# Prepare kwargs for model's predict_step
Expand All @@ -152,3 +206,6 @@ def predict_step(

# Delegate to the model's predict_step implementation with processors
return self.model.predict_step(**predict_kwargs, **kwargs)

def _update_metadata(self) -> None:
self.model.fill_metadata(self.metadata)
39 changes: 38 additions & 1 deletion models/src/anemoi/models/layers/bounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


def build_boundings(
def _build_dataset_boundings(
model_config: Any,
data_indices: Any,
statistics: dict | None,
Expand Down Expand Up @@ -355,3 +355,40 @@ def build_boundings(
for cfg in bounding_cfgs
]
)


def build_boundings(
model_config: Any,
data_indices: Any,
statistics: dict | None,
) -> nn.ModuleDict:
"""Build the model-output bounding modules from configuration.

This is a thin factory that creates a ``nn.ModuleDict`` of bounding
modules by invoking ``_build_dataset_boundings`` for each dataset
specified in ``data_indices``.

Parameters
----------
model_config : Any
Object with a ``model`` attribute containing an iterable ``bounding``
(e.g. a list of Hydra configs). If absent or empty, an empty
``nn.ModuleDict`` is returned.
data_indices : Any
Dictionary mapping dataset names to data indices objects. Each
data indices object must provide the mappings:
``data_indices.model.output.name_to_index`` and
``data_indices.data.input.name_to_index``. These are forwarded to each
instantiated bounding module as ``name_to_index`` and
``name_to_index_stats`` respectively.
statistics : dict | None
Dictionary mapping dataset names to optional dataset/model statistics
passed to each bounding module. Use ``None`` if not required by the
configured classes.
"""
bounding_modules = nn.ModuleDict()
for dataset_name, dataset_indices in data_indices.items():
bounding_modules[dataset_name] = _build_dataset_boundings(
model_config, dataset_indices, statistics[dataset_name]
)
return bounding_modules
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# (C) Copyright 2025 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from anemoi.models.migrations import CkptType
from anemoi.models.migrations import MigrationMetadata

# DO NOT CHANGE -->
metadata = MigrationMetadata(
versions={
"migration": "1.0.0",
"anemoi-models": "%NEXT_ANEMOI_MODELS_VERSION%",
},
)
# <-- END DO NOT CHANGE


def migrate(ckpt: CkptType) -> CkptType:
"""Migrate the checkpoint.

Parameters
----------
ckpt : CkptType
The checkpoint dict.

Returns
-------
CkptType
The migrated checkpoint dict.
"""
dummy_dataset_name = "data"

updates = {}
for key in list(ckpt["state_dict"].keys()):
# Update pre-processors
if key.startswith("model.pre_processors."):
new_key = key.replace("model.pre_processors.", f"model.pre_processors.{dummy_dataset_name}.")
updates[new_key] = ckpt["state_dict"][key]
del ckpt["state_dict"][key]

# Update post-processors
if key.startswith("model.post_processors."):
new_key = key.replace("model.post_processors.", f"model.post_processors.{dummy_dataset_name}.")
updates[new_key] = ckpt["state_dict"][key]
del ckpt["state_dict"][key]

# Update node attributes
if key.startswith("model.model.node_attributes."):
new_key = key.replace("model.model.node_attributes.", f"model.model.node_attributes.{dummy_dataset_name}.")
updates[new_key] = ckpt["state_dict"][key]
del ckpt["state_dict"][key]

# Adjust model components
for model_component in ["encoder", "encoder_graph_provider", "decoder", "decoder_graph_provider"]:
prefix = f"model.model.{model_component}."

if key.startswith(prefix):
new_key = key.replace(prefix, f"{prefix}{dummy_dataset_name}.")
updates[new_key] = ckpt["state_dict"][key]
del ckpt["state_dict"][key]

ckpt["state_dict"].update(updates)

ckpt["hyper_parameters"]["data_indices"] = {dummy_dataset_name: ckpt["hyper_parameters"].pop("data_indices")}
ckpt["hyper_parameters"]["statistics"] = {dummy_dataset_name: ckpt["hyper_parameters"].pop("statistics")}
ckpt["hyper_parameters"]["statistics_tendencies"] = {
dummy_dataset_name: ckpt["hyper_parameters"].pop("statistics_tendencies")
}
ckpt["hyper_parameters"]["supporting_arrays"] = {
dummy_dataset_name: ckpt["hyper_parameters"].pop("supporting_arrays")
}
return ckpt


def rollback(ckpt: CkptType) -> CkptType:
"""Rollback the checkpoint.

Parameters
----------
ckpt : CkptType
The checkpoint dict.

Returns
-------
CkptType
The rollbacked checkpoint dict.
"""
return ckpt
Loading
Loading