-
Notifications
You must be signed in to change notification settings - Fork 78
feat(training): downscaling #811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
JPXKQX
wants to merge
337
commits into
main
Choose a base branch
from
multids/downscaling
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 250 commits
Commits
Show all changes
337 commits
Select commit
Hold shift + click to select a range
17678bf
remove unused code
JPXKQX 34f1a0e
add test
JPXKQX c743ccb
start updating diffusion
JPXKQX 10dfc38
diffusion
JPXKQX 33bbcda
Merge remote-tracking branch 'origin/main' into feat/mapper-refactor
ssmmnn11 82305d3
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX 6471510
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] d1cf2c5
Merge branch 'feature/multi-dataset-integration' of https://github.co…
JPXKQX 45d6ff1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 9524934
update after rebase (WIP)
JPXKQX 3401990
Merge branch 'feature/multi-dataset-integration' of https://github.co…
JPXKQX 79059f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 7715fcf
debug_multi_eracerra working
JPXKQX caaa131
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 437be12
ensemble_crps working
JPXKQX 599a9f0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 5a067b8
update predict_step()
JPXKQX 885dcfa
Merge branch 'feature/multi-dataset-integration' of https://github.co…
JPXKQX 74c13df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a4ffdfa
move relative_date_indices inside MultiDataset
JPXKQX 7b17ce7
Merge branch 'feature/multi-dataset-integration' of https://github.co…
JPXKQX 368054e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1f9d3b0
fix plotting after merge
VeraChristina 7a4f34f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ac8d393
working again
JPXKQX 11038c8
Merge branch 'feature/multi-dataset-integration' of https://github.co…
JPXKQX de1f45f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 6142158
update
JPXKQX 57d4d65
Merge branch 'feature/multi-dataset-integration' of https://github.co…
JPXKQX 7109c50
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 89300b7
merged main
ssmmnn11 81fcecd
modifications to make sparse graph provider work
ssmmnn11 ab83ef5
revisions
ssmmnn11 41ea6c9
fix
ssmmnn11 366442e
fix for creation from graph
ssmmnn11 dc9e990
remove transpose option
ssmmnn11 d3e2f95
config fix
ssmmnn11 d968e2c
Merge branch 'main' into feat/mapper-refactor
ssmmnn11 8761862
fix
JPXKQX fc5dc44
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX b22041f
feat: add ensemble plots for multi datasets (#747)
VeraChristina 0e3acda
Merge branch 'main' into feat/mapper-refactor
ssmmnn11 19fdb1b
remove predict_step
JPXKQX 4a2798a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1b6794c
new update
JPXKQX 437c162
Fix predict_step
gmertes 31fd027
revert unnecessary bool
gmertes 14eb3aa
Update type hint
gmertes 478a86d
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX d6a8eba
fix ensemble plots tests
VeraChristina 5708d8c
Merge branch 'feature/multi-dataset-integration' of github.com:ecmwf/…
VeraChristina 0cb28c4
feat: metadata for multi datasets (#762)
VeraChristina 1894a7c
Merge remote-tracking branch 'origin/main' into feat/mapper-refactor
ssmmnn11 31b7610
adapt multiscale loss to graph provider
ssmmnn11 c994178
fix migration script
ssmmnn11 a801f22
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX 76d83af
fix
JPXKQX c955060
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 6c42f3b
update crps
JPXKQX 961605b
fix test_variable_order
VeraChristina eba0162
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX 565161d
diffusion is working!
JPXKQX ad02aef
clean
JPXKQX b1ad392
time interpolation working
JPXKQX a80e061
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX d755ab1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 4f6635e
style
JPXKQX 72a6f2f
pre-commit passing
JPXKQX 154561e
unit tests passing
JPXKQX 81a95ce
configs
JPXKQX d33ca46
prepare integration tests
JPXKQX e3f4a20
prepare schemas
JPXKQX 6b58f4f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 2bdee19
.
JPXKQX ac762b3
configs
JPXKQX 84351bd
update integration test configs
JPXKQX 0393a82
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] d64fbfd
update integration benchmark configs
JPXKQX 3271327
truncation matrices handling in integration tests
JPXKQX f822370
fix: ens schemas
JPXKQX 1965535
fix: lam schemas
JPXKQX c5d2ad9
fix: stretched schema
JPXKQX d6e5e5b
fix: interpolation schemas
JPXKQX 307693c
remove empty space
JPXKQX 2724a91
fix: aicon integration configs
JPXKQX df7f88f
fix
JPXKQX 39850b1
fix: data indices tests
JPXKQX d162f55
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] dbcc4da
fix: tend diffusion
JPXKQX 9e0cee3
fix: interpolator & aicon integrations tests
JPXKQX 61be08c
fix: integration tests
JPXKQX e44f969
fix: device
JPXKQX 115f200
fix: hierarchical
JPXKQX c2e4865
fix: diffusion integration test
JPXKQX a6b2673
fix: interpolation integration test
JPXKQX 047e840
fix: multiscale
JPXKQX 99f608a
remove
JPXKQX a34b7f1
fix: compute_metrics
JPXKQX 219c959
pre-commit
JPXKQX 566bd04
Merge branch 'main' into feat/mapper-refactor
ssmmnn11 ad54463
add migration script
JPXKQX 95e16c0
diffusion predict()
JPXKQX 94d0333
diffusion predict
JPXKQX 4b839e8
update unit test:
JPXKQX fc47f71
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] dccbba2
update unit test
JPXKQX 900880d
diffusion: predict_step()
JPXKQX 54e78fe
update diffusion sampler
JPXKQX 2815f20
pre-commit
JPXKQX a968c6b
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX ca2fdfd
fix: Heun sampler test
JPXKQX 217e40a
fix: DPMpp2M sampler tests
JPXKQX fb10385
fix: samplers comparison
JPXKQX 7a3b9bc
pre-commit
JPXKQX 4516f7f
Merge branch 'main' into feat/mapper-refactor
ssmmnn11 e4fcc0f
Merge branch 'main' into feat/mapper-refactor
anaprietonem 3404fe2
fix multiscale loss tests
ssmmnn11 8132b10
Merge remote-tracking branch 'origin/feat/mapper-refactor' into feat/…
ssmmnn11 a3a1cdc
update documentation
ssmmnn11 96e5838
add edge sharding and checkpointing in graph provider
ssmmnn11 ee65fc8
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX 2509061
fix for splitting uneven seq
ssmmnn11 300ee5f
fix sharding for uneven head dimensions
ssmmnn11 f200956
fix for validation loss computation
ssmmnn11 5d1088b
fix hanging callback
ssmmnn11 c7ff82a
clear optimiser states in migration script
ssmmnn11 1d88dd6
remove checkpoint opt state reset in migration script
ssmmnn11 df36dc6
refactor of edge sharding
ssmmnn11 3a52cae
bug fix: do not persist edge_inc buffer
ssmmnn11 d2ce519
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX 664c2c8
Merge branch 'main' into feat/mapper-refactor
JPXKQX 6f81a43
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] fea4055
update new checkpoint for gnn
anaprietonem c7af6e6
Merge branch 'feat/mapper-refactor' of github.com:ecmwf/anemoi-core i…
anaprietonem 4688701
Merge branch 'feat/mapper-refactor' into feature/multi-dataset-integr…
JPXKQX b8e1d27
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 7fd8615
pre-commit
JPXKQX 44da8ee
fix
JPXKQX 789612c
fix: diffusion
JPXKQX 95b4bbb
fix: model x_data_latent
JPXKQX 63dff8a
fix: specify dataset grid_shard_slice
JPXKQX f7e6a7e
update migration
JPXKQX cc01462
Merge branch 'feat/mapper-refactor' into feature/multi-dataset-integr…
JPXKQX 549bcc6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f281ac7
fix
JPXKQX f863d02
fix: mapper refactor update gnn checkpoint (#793)
anaprietonem 40cf199
Merge branch 'feat/mapper-refactor' into feature/multi-dataset-integr…
JPXKQX 5b38d82
Merge branch 'feature/multi-dataset-integration-mapper-refactor' of h…
JPXKQX e7e3195
fixes for plotting
ssmmnn11 6c66fc2
Merge remote-tracking branch 'origin/feat/mapper-refactor' into feat/…
ssmmnn11 f239a1f
Merge branch 'main' into feat/mapper-refactor
japols c9bbd8c
use balanced_partition for head splits
japols 9676037
fix: rename graph_providers.py to graph_provider.py
JPXKQX 4f6fb3d
Merge branch 'feat/mapper-refactor' into feature/multi-dataset-integr…
JPXKQX 07778f0
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX 5294e8f
Merge branch 'feature/multi-dataset-integration-mapper-refactor' into…
JPXKQX da1c1e0
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX 8c9fa2e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 3b871f2
pre-commit
JPXKQX 7c111f9
Merge branch 'feature/multi-dataset-integration' of https://github.co…
JPXKQX b03a006
fix
JPXKQX e35e1d7
fix: integration tests
JPXKQX a855685
fix
JPXKQX 8885656
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX 2efd37a
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX 16c2eb8
fix: lam with existing graph
JPXKQX 65c270d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 529bdfd
update configs
JPXKQX eb0ab83
adding integration test for 2 datasets
JPXKQX 498821f
fix: new integration tests
JPXKQX b6d6056
change dataset_b to CERRA
JPXKQX 2f3c896
undo
JPXKQX afa90c5
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX a9962a3
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX 3d67883
fix: test_training_cycle_lam_with_existing_graph
JPXKQX 2a81ff6
fix: move to utils/config_utils.py to anemoi-utils
JPXKQX 77f6a69
fix: profiler
JPXKQX 061d0df
fix: configs
JPXKQX 01593ce
fix: move config utility to anemoi/models
JPXKQX cceaea3
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX 9f9a057
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ae66995
delete extra configs
JPXKQX 822a5af
fix: grid_shard_slice
JPXKQX c7c3117
fix: import
JPXKQX 09e9287
add comment
JPXKQX 7db8df1
fixes
ssmmnn11 6c2d6e7
fixes
ssmmnn11 c3d7f28
fix
ssmmnn11 1900ce2
multi dataset test fix
ssmmnn11 f45224f
fix: mlflow integration test
JPXKQX 6b79046
plotting overrides multi integration test
VeraChristina 0fe5980
fix: update stretched config
JPXKQX 5993d7d
update ADR
JPXKQX 1f337f4
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX 3c62e6d
update plotting config
JPXKQX 3c8ede1
Merge branch 'feature/multi-dataset-integration' of https://github.co…
JPXKQX 2cc2901
fix: plotting node attributes
JPXKQX 889be6e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] c39407c
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX 300e2fb
get task type from task instead of config
VeraChristina b31747d
add draft section on multi datasets in docs
VeraChristina 3c0ea34
update docs
VeraChristina 01500ce
update ADR
JPXKQX d484bee
docstring
JPXKQX 01a72c2
add supporting arrays from grid_indices (LAM)
JPXKQX e89ef7c
fix: supporting arrays
JPXKQX 5644541
add multi dataset diagrams to docs
VeraChristina 3b80ad6
update docs
VeraChristina 35cce38
update config snippets in training docs
VeraChristina f6ddf4e
remove outdated config snippet in docs
VeraChristina 5e80301
update docs/modules
JPXKQX 7a00eda
update docs/user-guide
JPXKQX 7107190
fix: autogenerated docstring
JPXKQX 8036f20
update docs
VeraChristina 0890231
bring back rollout & multi_step into interpolator task
JPXKQX 25d4308
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX 75e0a7e
refactor: streamline batch and ensemble size validation across datasets
JPXKQX b1bb7f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b876baf
Merge branch 'main' into feature/multi-dataset-integration
JPXKQX 7b73af2
Merge branch 'feature/multi-dataset-integration' into multids/downsca…
JPXKQX a131cf0
fix for sharded model and loss that does not support sharding
ssmmnn11 d2c1d51
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8b8485f
configs
JPXKQX d3811af
Merge branch 'multids/downscaling' of https://github.com/ecmwf/anemoi…
JPXKQX 87088e5
fix: update schema
JPXKQX a2c595d
Merge branch 'feature/multi-dataset-integration' into multids/downsca…
JPXKQX d400ccb
Merge branch 'multids/downscaling' of https://github.com/ecmwf/anemoi…
JPXKQX c71ed9f
undo base changes
JPXKQX dd4b49e
update
JPXKQX 5f80ad2
undo base.py
JPXKQX c20bf80
update
JPXKQX 9f50066
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] db4da13
Update training/src/anemoi/training/train/tasks/downscaler.py
JPXKQX 7b8a5bd
update
JPXKQX 57ae94b
fix
JPXKQX f163e7a
Merge branch 'main' into multids/downscaling
JPXKQX 2915351
fix
JPXKQX 69e50e8
undo forecaster changes
JPXKQX 1c53881
Update training/src/anemoi/training/train/tasks/forecaster.py
JPXKQX b4dfc68
update
JPXKQX e94906b
Merge branch 'multids/downscaling' of https://github.com/ecmwf/anemoi…
JPXKQX a093446
Merge branch 'main' into multids/downscaling
JPXKQX 08f08fd
pre-commit
JPXKQX 6a71d35
specify target dataset names
JPXKQX 54b1da9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 0ad1f2e
Update training/src/anemoi/training/train/tasks/base.py
JPXKQX 1edfad2
integration tests
JPXKQX e18445f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1b76bcb
update
JPXKQX b2514cb
Merge branch 'main' into multids/downscaling
JPXKQX File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
93 changes: 93 additions & 0 deletions
93
models/src/anemoi/models/migrations/scripts/1767108147_move_to_multiple_datasets.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| # (C) Copyright 2025 Anemoi contributors. | ||
| # | ||
| # This software is licensed under the terms of the Apache Licence Version 2.0 | ||
| # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
| # | ||
| # In applying this licence, ECMWF does not waive the privileges and immunities | ||
| # granted to it by virtue of its status as an intergovernmental organisation | ||
| # nor does it submit to any jurisdiction. | ||
|
|
||
| from anemoi.models.migrations import CkptType | ||
| from anemoi.models.migrations import MigrationMetadata | ||
|
|
||
| # DO NOT CHANGE --> | ||
| metadata = MigrationMetadata( | ||
| versions={ | ||
| "migration": "1.0.0", | ||
| "anemoi-models": "%NEXT_ANEMOI_MODELS_VERSION%", | ||
| }, | ||
| ) | ||
| # <-- END DO NOT CHANGE | ||
|
|
||
|
|
||
| def migrate(ckpt: CkptType) -> CkptType: | ||
| """Migrate the checkpoint. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| ckpt : CkptType | ||
| The checkpoint dict. | ||
|
|
||
| Returns | ||
| ------- | ||
| CkptType | ||
| The migrated checkpoint dict. | ||
| """ | ||
| dummy_dataset_name = "data" | ||
|
|
||
| updates = {} | ||
| for key in list(ckpt["state_dict"].keys()): | ||
| # Update pre-processors | ||
| if key.startswith("model.pre_processors."): | ||
| new_key = key.replace("model.pre_processors.", f"model.pre_processors.{dummy_dataset_name}.") | ||
| updates[new_key] = ckpt["state_dict"][key] | ||
| del ckpt["state_dict"][key] | ||
|
|
||
| # Update post-processors | ||
| if key.startswith("model.post_processors."): | ||
| new_key = key.replace("model.post_processors.", f"model.post_processors.{dummy_dataset_name}.") | ||
| updates[new_key] = ckpt["state_dict"][key] | ||
| del ckpt["state_dict"][key] | ||
|
|
||
| # Update node attributes | ||
| if key.startswith("model.model.node_attributes."): | ||
| new_key = key.replace("model.model.node_attributes.", f"model.model.node_attributes.{dummy_dataset_name}.") | ||
| updates[new_key] = ckpt["state_dict"][key] | ||
| del ckpt["state_dict"][key] | ||
|
|
||
| # Adjust model components | ||
| for model_component in ["encoder", "encoder_graph_provider", "decoder", "decoder_graph_provider"]: | ||
| prefix = f"model.model.{model_component}." | ||
|
|
||
| if key.startswith(prefix): | ||
| new_key = key.replace(prefix, f"{prefix}{dummy_dataset_name}.") | ||
| updates[new_key] = ckpt["state_dict"][key] | ||
| del ckpt["state_dict"][key] | ||
|
|
||
| ckpt["state_dict"].update(updates) | ||
|
|
||
| ckpt["hyper_parameters"]["data_indices"] = {dummy_dataset_name: ckpt["hyper_parameters"].pop("data_indices")} | ||
| ckpt["hyper_parameters"]["statistics"] = {dummy_dataset_name: ckpt["hyper_parameters"].pop("statistics")} | ||
| ckpt["hyper_parameters"]["statistics_tendencies"] = { | ||
| dummy_dataset_name: ckpt["hyper_parameters"].pop("statistics_tendencies") | ||
| } | ||
| ckpt["hyper_parameters"]["supporting_arrays"] = { | ||
| dummy_dataset_name: ckpt["hyper_parameters"].pop("supporting_arrays") | ||
| } | ||
| return ckpt | ||
|
|
||
|
|
||
| def rollback(ckpt: CkptType) -> CkptType: | ||
| """Rollback the checkpoint. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| ckpt : CkptType | ||
| The checkpoint dict. | ||
|
|
||
| Returns | ||
| ------- | ||
| CkptType | ||
| The rollbacked checkpoint dict. | ||
| """ | ||
| return ckpt |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.