Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
126 commits
Select commit Hold shift + click to select a range
ae96af7
fresh start for multiple output steps
dietervdb-meteo Oct 30, 2025
4e884b5
minor fixes
dietervdb-meteo Oct 30, 2025
692a67c
add example configs
dietervdb-meteo Oct 31, 2025
38f95e8
Add observation-informed interpolator
OpheliaMiralles Oct 31, 2025
0e257c9
Precommit
OpheliaMiralles Oct 31, 2025
69e4aef
Add training part
OpheliaMiralles Oct 31, 2025
2c8c57c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2025
7dc7921
Precommit
OpheliaMiralles Oct 31, 2025
bbcf274
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2025
41776ce
Fix bad alignment
OpheliaMiralles Oct 31, 2025
6638392
Remove the underscores
OpheliaMiralles Oct 31, 2025
9411fa4
Fix
OpheliaMiralles Oct 31, 2025
5fd2c2a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2025
9071eb2
Merge branch 'main' into obsinterpolator
OpheliaMiralles Nov 3, 2025
27ca361
Merge branch 'main' into feat/multi-output-steps
dietervdb-meteo Nov 4, 2025
28c9e8b
rename time step scaler
dietervdb-meteo Nov 4, 2025
ccffab3
enable boundary rollout
dietervdb-meteo Nov 4, 2025
5027007
update training/config
dietervdb-meteo Nov 5, 2025
a231168
first try to fix test
dietervdb-meteo Nov 5, 2025
6b8b26b
Revert "first try to fix test"
dietervdb-meteo Nov 5, 2025
2a3dd47
update more configs
dietervdb-meteo Nov 5, 2025
65af1ca
fix aicon test
dietervdb-meteo Nov 5, 2025
2cf14cf
fix loss function tests
dietervdb-meteo Nov 5, 2025
2dc0f9f
Merge branch 'main' into feat/multi-output-steps
dietervdb-meteo Nov 5, 2025
5a7382c
Merge branch 'main' into obsinterpolator
OpheliaMiralles Nov 11, 2025
fe555fb
Merge branch 'main' into obsinterpolator
OpheliaMiralles Nov 12, 2025
8be3cec
Merge branch 'main' into feat/multi-output-steps
OpheliaMiralles Nov 12, 2025
19d1d48
Merge branch 'main' into feat/multi-output-steps
OpheliaMiralles Nov 13, 2025
81decd8
Merge branch 'main' into obsinterpolator
OpheliaMiralles Nov 17, 2025
3e431bc
change time dimension reduction in loss
dietervdb-meteo Nov 19, 2025
46458a1
Add lead time decay weights
OpheliaMiralles Nov 19, 2025
b8740a3
fix: basemodel.predict_step (#672)
japols Nov 13, 2025
99e4a2a
fix(models): assert no dropout (#638)
JPXKQX Nov 14, 2025
b8a2b0a
feat(training)!: remove support for EDA (#651)
JPXKQX Nov 14, 2025
e3cb040
fix: bug for mlflow offline logging (#675)
anaprietonem Nov 17, 2025
94b5362
fix(graphs,normalisation): add assert when dividing by 0 (#676)
JPXKQX Nov 17, 2025
a017784
feat(graphs): add LimitedAreaMask for stretched hidden nodes (#671)
JPXKQX Nov 17, 2025
e48b281
chore: Release main (#619)
DeployDuck Nov 18, 2025
a78aee2
feat(models): multibackend all_to_all wrapper (#95)
cathalobrien Nov 19, 2025
41184bf
fix(models): processor chunking (#629)
japols Nov 19, 2025
4a74379
fix: small pytorch boxcox inefficiency (#683)
elkir Nov 19, 2025
0d9dec0
fix!: cond layer norm (#658)
ssmmnn11 Nov 19, 2025
da0e358
fresh start for multiple output steps
dietervdb-meteo Oct 30, 2025
aae9092
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2025
dede8eb
Rebase on right commit
OpheliaMiralles Nov 19, 2025
557dab6
Merge branch 'main' into feat/multi-output-steps
OpheliaMiralles Nov 19, 2025
cf462c0
Remove output scalers
OpheliaMiralles Nov 19, 2025
2a05162
Leftover output step
OpheliaMiralles Nov 20, 2025
6e1d70d
Update training/src/anemoi/training/losses/scalers/time_step.py
OpheliaMiralles Nov 20, 2025
3ae2412
Update training/src/anemoi/training/losses/scalers/time_step.py
OpheliaMiralles Nov 20, 2025
1b793d2
Add schema for LeadTimeDecayScaler
OpheliaMiralles Nov 20, 2025
275e937
fix schema
dietervdb-meteo Nov 20, 2025
444308a
Merge branch 'main' into feat/multi-output-steps
dietervdb-meteo Nov 20, 2025
d98806b
Merge branch 'main' into feat/multi-output-steps
dietervdb-meteo Nov 21, 2025
8e5c611
Merge branch 'main' into feat/multi-output-steps
OpheliaMiralles Nov 25, 2025
3dbf705
Merge branch 'main' into obsinterpolator
OpheliaMiralles Nov 25, 2025
9269c27
Merge branch 'main' into obsinterpolator
OpheliaMiralles Dec 12, 2025
748a2f3
fix: added safeguard for contiguous memory after expand()
MicheleCattaneo Dec 15, 2025
762c0ca
Trial for the multi-output interpolator
Dec 15, 2025
0282f35
Merge remote-tracking branch 'origin/obsinterpolator' into feat/multi…
Dec 19, 2025
8bbbbd2
Merge remote-tracking branch 'origin/main' into feat/multi-output-steps
Dec 19, 2025
59b38f5
Add test, obs-informed interpolator and fix tests
Dec 22, 2025
d8248b8
Merge remote-tracking branch 'origin/main' into feat/multi-output-steps
OpheliaMiralles Dec 22, 2025
7ba8209
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 22, 2025
97a2b83
Fix tests bis
OpheliaMiralles Dec 22, 2025
c7955ac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 22, 2025
7909d90
Remove hardcoding of schemas for model multi_out
OpheliaMiralles Dec 22, 2025
f9237cb
Same for datamodule
OpheliaMiralles Dec 22, 2025
81351fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 22, 2025
b6eac3f
Remove redundant losses tests
OpheliaMiralles Dec 22, 2025
4a3511d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 22, 2025
77c3958
Merge branch 'main' into feat/multi-output-steps
OpheliaMiralles Jan 2, 2026
46a8ed7
Merge branch 'main' into feat/multi-output-steps
dietervdb-meteo Jan 5, 2026
a626019
Merge branch 'main' into feat/multi-output-steps
dietervdb-meteo Jan 5, 2026
2e92082
Merge branch 'main' into feat/multi-output-steps
OpheliaMiralles Jan 5, 2026
e4877cb
temporary fix
dietervdb-meteo Jan 6, 2026
88c4437
Try to align dims
OpheliaMiralles Jan 6, 2026
b073398
Merge branch 'main' into feat/multi-output-steps
OpheliaMiralles Jan 6, 2026
884f21d
Try to fix plots
OpheliaMiralles Jan 6, 2026
3f819c6
One last trial before giving up
OpheliaMiralles Jan 6, 2026
edcb4d5
Merge remote-tracking branch 'origin/main' into feat/multi-output-steps
OpheliaMiralles Jan 6, 2026
e735338
More plot fixes
OpheliaMiralles Jan 7, 2026
05fa26f
Diffusion model adaptation
OpheliaMiralles Jan 7, 2026
742ac97
Trial to fix diffusion tests
OpheliaMiralles Jan 8, 2026
6a96403
Fix diffusion
OpheliaMiralles Jan 8, 2026
1e56454
Fix ensemble diffusion integration tests
OpheliaMiralles Jan 8, 2026
b10c448
Fix unit test
OpheliaMiralles Jan 8, 2026
bfb9722
refactor: delegate init of `AnemoiModelEncProcDecHierarchical` to par…
dietervdb-meteo Jan 9, 2026
2062b1d
Merge branch 'main' into feat/multi-output-steps
OpheliaMiralles Jan 9, 2026
85c8fdb
fix rollout base class refactor
dietervdb-meteo Jan 9, 2026
570f4ed
clean-up
dietervdb-meteo Jan 9, 2026
3e166de
further clean-up
dietervdb-meteo Jan 9, 2026
9e28961
add multi-out test
dietervdb-meteo Jan 9, 2026
c9470b7
fix unit tests
dietervdb-meteo Jan 9, 2026
9cfaf36
revert masks
dietervdb-meteo Jan 12, 2026
69fcfdd
fix lam plots
dietervdb-meteo Jan 13, 2026
b4f5a0f
adapt ens rollout to multi-out
dietervdb-meteo Jan 14, 2026
4a5d43b
update docstrings
dietervdb-meteo Jan 14, 2026
85512aa
update docstring
dietervdb-meteo Jan 14, 2026
fe963a5
fix plots
dietervdb-meteo Jan 14, 2026
667a1dc
add comment
dietervdb-meteo Jan 14, 2026
6a8027f
enable ens multi-out
dietervdb-meteo Jan 16, 2026
a0ed6aa
add integration test multi-out-ensemble
dietervdb-meteo Jan 16, 2026
59e8023
fix dimension ordering
dietervdb-meteo Jan 19, 2026
af24f94
support multi-out for diffusion
dietervdb-meteo Jan 19, 2026
46be8a3
enable multi-out tendency diffusion
dietervdb-meteo Jan 19, 2026
0fd6857
add multi-out diffusion tests
dietervdb-meteo Jan 19, 2026
594a3db
Merge remote-tracking branch origin/main into local/multi-merge-test
dietervdb-meteo Jan 20, 2026
55c183f
fix ensemble-plot-mixin test
dietervdb-meteo Jan 20, 2026
d7356f4
fix: add edge_dim to init
dietervdb-meteo Jan 20, 2026
d18f1fa
tmp fix: select last output step
dietervdb-meteo Jan 21, 2026
7c4dd9a
Merge branch 'main' into feat/multi-output-steps
dietervdb-meteo Jan 21, 2026
48df0cc
Merge branch 'main' into feat/multi-output-steps
dietervdb-meteo Jan 21, 2026
d194588
add warning for tendency diffusion
dietervdb-meteo Jan 21, 2026
4c83a1a
overwrite multistep config entries for interpolator
dietervdb-meteo Jan 21, 2026
6ba6a79
Merge remote-tracking branch 'origin/main' into feat/multi-output-steps
OpheliaMiralles Jan 22, 2026
096b2c5
diffusion tendency forecasting for multi output steps
ssmmnn11 Jan 22, 2026
af0f54e
test fixes
ssmmnn11 Jan 22, 2026
2a826a9
Tentative merge with multi dataset
OpheliaMiralles Jan 22, 2026
6d73628
Add timeincrement
OpheliaMiralles Jan 22, 2026
e4befb4
Adapt config to new structure
OpheliaMiralles Jan 22, 2026
de45d5e
scaler fix
ssmmnn11 Jan 22, 2026
7b0d192
Merge remote-tracking branch 'origin/feat/multi-output-steps-diffusio…
OpheliaMiralles Jan 22, 2026
60234d7
Manage stat tendencies
OpheliaMiralles Jan 22, 2026
67b77a1
Fix tests
OpheliaMiralles Jan 23, 2026
1162391
Merge branch 'main' into feat/obsinterpolator
OpheliaMiralles Jan 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ _dev/
_api/
./outputs
*tmp_data/
*/uv.lock

# Project specific
?
Expand Down
10 changes: 8 additions & 2 deletions models/src/anemoi/models/data_indices/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,15 @@ class IndexCollection:
def __init__(self, data_config, name_to_index) -> None:
self.config = OmegaConf.to_container(data_config, resolve=True)
self.name_to_index = dict(sorted(name_to_index.items(), key=operator.itemgetter(1)))
self.forcing = [] if data_config.forcing is None else OmegaConf.to_container(data_config.forcing, resolve=True)
self.forcing = (
[]
if data_config.get("forcing", None) is None
else OmegaConf.to_container(data_config.forcing, resolve=True)
)
self.diagnostic = (
[] if data_config.diagnostic is None else OmegaConf.to_container(data_config.diagnostic, resolve=True)
[]
if data_config.get("diagnostic", None) is None
else OmegaConf.to_container(data_config.diagnostic, resolve=True)
)
self.target = (
[] if data_config.get("target", None) is None else OmegaConf.to_container(data_config.target, resolve=True)
Expand Down
32 changes: 22 additions & 10 deletions models/src/anemoi/models/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch_geometric.data import HeteroData

from anemoi.models.preprocessing import Processors
from anemoi.models.preprocessing import StepwiseProcessors
from anemoi.models.utils.config import get_multiple_datasets_config
from anemoi.utils.config import DotDict

Expand Down Expand Up @@ -103,25 +104,36 @@ def _build_processors_for_dataset(
tuple
(pre_processors, post_processors, pre_processors_tendencies, post_processors_tendencies)
"""

# Build processors for the dataset
processors = [
[name, instantiate(processor, data_indices=data_indices, statistics=statistics)]
for name, processor in processors_configs.items()
]
def build_processors(statistics: dict) -> list:
return [
[name, instantiate(processor, data_indices=data_indices, statistics=statistics)]
for name, processor in processors_configs.items()
]

processors = build_processors(statistics)
pre_processors = Processors(processors)
post_processors = Processors(processors, inverse=True)

# Build tendencies processors if provided
pre_processors_tendencies = None
post_processors_tendencies = None
if statistics_tendencies is not None:
processors_tendencies = [
[name, instantiate(processor, data_indices=data_indices, statistics=statistics_tendencies)]
for name, processor in processors_configs.items()
]
pre_processors_tendencies = Processors(processors_tendencies)
post_processors_tendencies = Processors(processors_tendencies, inverse=True)
assert isinstance(statistics_tendencies, dict), "Tendency statistics must be a dict with per-step entries."
lead_times = statistics_tendencies.get("lead_times")
assert isinstance(lead_times, list), "Tendency statistics must include 'lead_times'."
assert all(
lead_time in statistics_tendencies for lead_time in lead_times
), "Missing tendency statistics for one or more output steps."
pre_processors_tendencies = StepwiseProcessors(lead_times)
post_processors_tendencies = StepwiseProcessors(lead_times)
for lead_time in lead_times:
step_stats = statistics_tendencies[lead_time]
if step_stats is not None:
step_processors = build_processors(step_stats)
pre_processors_tendencies.set(lead_time, Processors(step_processors))
post_processors_tendencies.set(lead_time, Processors(step_processors, inverse=True))

return pre_processors, post_processors, pre_processors_tendencies, post_processors_tendencies

Expand Down
6 changes: 6 additions & 0 deletions models/src/anemoi/models/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
model_config.graph.hidden
) # assumed to be all the same because this is how we construct the graphs
self.multi_step = model_config.training.multistep_input
self.multi_out = model_config.training.multistep_output
self.num_channels = model_config.model.num_channels

self.node_attributes = torch.nn.ModuleDict()
Expand Down Expand Up @@ -99,6 +100,7 @@ def _calculate_shapes_and_indices(self, data_indices: dict) -> None:
self._internal_output_idx = {}
self.input_dim = {}
self.input_dim_latent = {}
self.output_dim = {}

for dataset_name, dataset_indices in data_indices.items():
self.num_input_channels[dataset_name] = len(dataset_indices.model.input)
Expand All @@ -107,6 +109,7 @@ def _calculate_shapes_and_indices(self, data_indices: dict) -> None:
self._internal_input_idx[dataset_name] = dataset_indices.model.input.prognostic
self._internal_output_idx[dataset_name] = dataset_indices.model.output.prognostic
self.input_dim[dataset_name] = self._calculate_input_dim(dataset_name)
self.output_dim[dataset_name] = self._calculate_output_dim(dataset_name)
self.input_dim_latent[dataset_name] = self._calculate_input_dim_latent(dataset_name)

def _calculate_input_dim(self, dataset_name: str) -> int:
Expand Down Expand Up @@ -200,6 +203,9 @@ def _get_consistent_dim(self, x: dict[str, Tensor], dim: int) -> int:

return dim_sizes[0]

def _calculate_output_dim(self, dataset_name: str):
return self.multi_out * self.num_output_channels[dataset_name]

@abstractmethod
def _build_networks(self, model_config: DotDict) -> None:
"""Builds the networks for the model."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import warnings
from typing import Callable
from typing import Optional
from typing import Sequence
from typing import Union

import einops
Expand Down Expand Up @@ -128,13 +129,14 @@ def _build_networks(self, model_config: DotDict) -> None:
in_channels_src=self.num_channels,
in_channels_dst=self.input_dim[dataset_name],
hidden_dim=self.num_channels,
out_channels_dst=self.num_output_channels[dataset_name],
out_channels_dst=self.output_dim[dataset_name],
edge_dim=self.decoder_graph_provider[dataset_name].edge_dim,
)

def _calculate_input_dim(self, dataset_name: str) -> int:
input_dim = super()._calculate_input_dim(dataset_name)
input_dim += self.num_output_channels[dataset_name] # input + noised targets
output_dim = self._calculate_output_dim()
input_dim += output_dim[dataset_name] # input + noised targets
return input_dim

def _create_noise_conditioning_mlp(self) -> nn.Sequential:
Expand All @@ -159,7 +161,7 @@ def _assemble_input(self, x, y_noised, bse, grid_shard_shapes=None, model_comm_g
x_data_latent = torch.cat(
(
einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"),
einops.rearrange(y_noised, "batch ensemble grid vars -> (batch ensemble grid) vars"),
einops.rearrange(y_noised, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"),
node_attributes_data,
),
dim=-1, # feature dimension
Expand All @@ -173,18 +175,21 @@ def _assemble_input(self, x, y_noised, bse, grid_shard_shapes=None, model_comm_g
def _assemble_output(self, x_out, x_skip, batch_size, ensemble_size, dtype):
x_out = einops.rearrange(
x_out,
"(batch ensemble grid) vars -> batch ensemble grid vars",
"(batch ensemble grid) (time vars) -> batch time ensemble grid vars",
batch=batch_size,
ensemble=ensemble_size,
time=self.multi_out,
).to(dtype=dtype)

return x_out

def _make_noise_emb(self, noise_emb: torch.Tensor, repeat: int) -> torch.Tensor:
out = einops.repeat(
noise_emb, "batch ensemble noise_level vars -> batch ensemble (repeat noise_level) vars", repeat=repeat
noise_emb,
"batch time ensemble noise_level vars -> batch time ensemble (repeat noise_level) vars",
repeat=repeat,
)
out = einops.rearrange(out, "batch ensemble grid vars -> (batch ensemble grid) vars")
out = einops.rearrange(out, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)")
return out

def _generate_noise_conditioning(
Expand Down Expand Up @@ -648,7 +653,7 @@ def sample(

# Initialize output with noise
batch_size, ensemble_size, grid_size = x_data.shape[0], x_data.shape[2], x_data.shape[-2]
shape = (batch_size, ensemble_size, grid_size, self.num_output_channels)
shape = (batch_size, self.multi_out, ensemble_size, grid_size, self.num_output_channels)
y_init[dataset_name] = torch.randn(shape, device=x_data.device, dtype=sigmas.dtype) * sigmas[0]

# Build diffusion sampler config dict from all inference defaults
Expand Down Expand Up @@ -725,6 +730,12 @@ def __init__(
statistics=statistics,
graph_data=graph_data,
)
if self.multi_out > 1:
warnings.warn(
"The currently implemented normalization of the tendencies when the model has more than one output step is unconventional. Using"
"more than one output step with tendency diffusion models is currently highly experimental and results should be "
"cautiously interpreted."
)

def _calculate_input_dim(self, dataset_name: str) -> int:
input_dim = super()._calculate_input_dim(dataset_name)
Expand All @@ -746,6 +757,10 @@ def _assemble_input(
grid_shard_shapes = grid_shard_shapes[dataset_name] if grid_shard_shapes is not None else None

x_skip = self.residual[dataset_name](x, grid_shard_shapes, model_comm_group)
x_skip = x_skip.unsqueeze(1).expand(-1, self.multi_out, -1, -1, -1)
x_skip = einops.rearrange(x_skip, "batch time ensemble grid vars -> (batch ensemble) grid (time vars)")
# Get node attributes
node_attributes_data = self.node_attributes(self._graph_name_data, batch_size=bse)

# Shard node attributes if grid sharding is enabled
if grid_shard_shapes is not None:
Expand All @@ -758,7 +773,7 @@ def _assemble_input(
x_data_latent = torch.cat(
(
einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"),
einops.rearrange(y_noised, "batch ensemble grid vars -> (batch ensemble grid) vars"),
einops.rearrange(y_noised, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"),
node_attributes_data,
),
dim=-1, # feature dimension
Expand Down Expand Up @@ -953,7 +968,7 @@ def _after_sampling(
model_comm_group: Optional[ProcessGroup] = None,
grid_shard_shapes: Optional[list] = None,
gather_out: bool = True,
post_processors_tendencies: Optional[nn.Module] = None,
post_processors_tendencies: Optional[Sequence[Optional[nn.Module]]] = None,
**kwargs,
) -> torch.Tensor:
"""Process sampled tendency to get state prediction.
Expand All @@ -968,14 +983,23 @@ def _after_sampling(

# truncate x_t0 if needed
x_t0 = self.apply_reference_state_truncation(x_t0, grid_shard_shapes, model_comm_group)

# Convert tendency to state
out = self.add_tendency_to_state(
x_t0,
out,
post_processors,
post_processors_tendencies,
)
assert post_processors_tendencies is not None, "Per-step tendency processors must be provided."
assert (
len(post_processors_tendencies) == out.shape[1]
), "Per-step tendency processors must match the number of output steps."
states = []
for step, post_proc in enumerate(post_processors_tendencies):
out_step = out[:, step : step + 1]
x_t0_step = x_t0[:, step : step + 1]
state_step = self.add_tendency_to_state(
x_t0_step,
out_step,
post_processors,
post_proc,
)
states.append(state_step)
out = torch.cat(states, dim=1)

# Gather if needed
if gather_out and model_comm_group is not None:
Expand Down Expand Up @@ -1004,7 +1028,8 @@ def apply_reference_state_truncation(

for dataset_name, in_x in x.items():
x_skip = self.residual[dataset_name](in_x, grid_shard_shapes[dataset_name], model_comm_group)
# x_skip.shape: (bs, ens, latlon, nvar)
x_skip = x_skip.unsqueeze(1).expand(-1, self.multi_out, -1, -1, -1)
# x_skip.shape: (bs, time ens, latlon, nvar)
x_skips[dataset_name] = x_skip[..., self.data_indices[dataset_name].model.input.prognostic]

return x_skips
11 changes: 8 additions & 3 deletions models/src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _build_networks(self, model_config: DotDict) -> None:
in_channels_src=self.num_channels,
in_channels_dst=self.input_dim[dataset_name],
hidden_dim=self.num_channels,
out_channels_dst=self.num_output_channels[dataset_name],
out_channels_dst=self.output_dim[dataset_name],
edge_dim=self.decoder_graph_provider[dataset_name].edge_dim,
)

Expand Down Expand Up @@ -143,21 +143,26 @@ def _assemble_output(
x_out = (
einops.rearrange(
x_out,
"(batch ensemble grid) vars -> batch ensemble grid vars",
"(batch ensemble grid) (time vars) -> batch time ensemble grid vars",
batch=batch_size,
ensemble=ensemble_size,
time=self.multi_out,
)
.to(dtype=dtype)
.clone()
)

# residual connection (just for the prognostic variables)
assert dataset_name is not None, "dataset_name must be provided for multi-dataset case"
x_out[..., self._internal_output_idx[dataset_name]] += x_skip[..., self._internal_input_idx[dataset_name]]
x_out[..., self._internal_output_idx[dataset_name]] += (
x_skip[..., self._internal_input_idx[dataset_name]].unsqueeze(1).expand(-1, self.multi_out, -1, -1, -1)
)

for bounding in self.boundings[dataset_name]:
# bounding performed in the order specified in the config file
x_out = bounding(x_out)
# TODO(dieter): verify if this is needed or can be solved alternatively
x_out = x_out.contiguous() # necessary after expand()
return x_out

def _assert_valid_sharding(
Expand Down
12 changes: 10 additions & 2 deletions models/src/anemoi/models/models/ens_encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,24 @@ def _assemble_output(
):
ensemble_size = batch_ens_size // batch_size
x_out = (
einops.rearrange(x_out, "(bs e n) f -> bs e n f", bs=batch_size, e=ensemble_size).to(dtype=dtype).clone()
einops.rearrange(x_out, "(bs e) t n f -> bs t e n f", bs=batch_size, e=ensemble_size, t=self.multi_out)
.to(dtype=dtype)
.clone()
)

# residual connection (just for the prognostic variables)
assert dataset_name is not None, "dataset_name must be provided for multi-dataset case"
x_out[..., self._internal_output_idx[dataset_name]] += x_skip[..., self._internal_input_idx[dataset_name]]
x_out[..., self._internal_output_idx[dataset_name]] += einops.rearrange(
x_skip[..., self._internal_input_idx[dataset_name]].unsqueeze(1).expand(-1, self.multi_out, -1, -1),
"(batch ensemble) time grid var -> batch time ensemble grid var",
batch=batch_size,
).to(dtype=dtype)

for bounding in self.boundings[dataset_name]:
# bounding performed in the order specified in the config file
x_out = bounding(x_out)
# TODO(dieter): verify if this is needed or can be solved alternatively
x_out = x_out.contiguous() # necessary after expand()
return x_out

def forward(
Expand Down
8 changes: 5 additions & 3 deletions models/src/anemoi/models/models/interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _assemble_input(
x_data_latent = torch.cat(
(
einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"),
einops.rearrange(target_forcing, "batch ensemble grid vars -> (batch ensemble grid) (vars)"),
einops.rearrange(target_forcing, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"),
node_attributes_data,
),
dim=-1, # feature dimension
Expand All @@ -116,7 +116,7 @@ def _assemble_output(self, x_out, x_skip, batch_size, ensemble_size, dtype, data
x_out = (
einops.rearrange(
x_out,
"(batch ensemble grid) vars -> batch ensemble grid vars",
"(batch ensemble grid) (time vars) -> batch time ensemble grid vars",
batch=batch_size,
ensemble=ensemble_size,
)
Expand All @@ -127,7 +127,9 @@ def _assemble_output(self, x_out, x_skip, batch_size, ensemble_size, dtype, data
# residual connection (just for the prognostic variables)
if x_skip is not None:
# residual connection (just for the prognostic variables)
x_out[..., self._internal_output_idx[dataset_name]] += x_skip[..., self._internal_input_idx[dataset_name]]
x_out[..., self._internal_output_idx[dataset_name]] += (
x_skip[..., self._internal_input_idx[dataset_name]].unsqueeze(1).expand(-1, self.multi_out, -1, -1, -1)
)

for bounding in self.boundings[dataset_name]:
# bounding performed in the order specified in the config file
Expand Down
Loading
Loading