Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
317501e
Replace cf.rank==0 with utils.distributed.is_root
Jul 16, 2025
77de417
replace cf.rank==0 with weathergen.utils.distributed.is_root
Jul 16, 2025
6439618
Merge branch 'ecmwf:develop' into develop
csjfwang Jul 22, 2025
8993875
Merge branch 'ecmwf:develop' into develop
csjfwang Jul 25, 2025
f4a9d85
Merge branch 'ecmwf:develop' into develop
csjfwang Jul 28, 2025
f8fdef4
Merge branch 'ecmwf:develop' into develop
csjfwang Jul 29, 2025
ca89e7b
Merge branch 'ecmwf:develop' into develop
csjfwang Jul 30, 2025
49d7a4d
Merge branch 'ecmwf:develop' into develop
csjfwang Jul 31, 2025
f39f094
Merge branch 'ecmwf:develop' into develop
csjfwang Jul 31, 2025
ebb03ea
Merge branch 'ecmwf:develop' into develop
csjfwang Aug 25, 2025
f40737d
Merge branch 'ecmwf:develop' into develop
csjfwang Aug 28, 2025
87fa078
Merge branch 'ecmwf:develop' into develop
csjfwang Sep 10, 2025
5dfe275
Merge branch 'ecmwf:develop' into develop
csjfwang Sep 19, 2025
b7244d9
Merge branch 'ecmwf:develop' into develop
csjfwang Sep 22, 2025
5be41f5
Merge branch 'ecmwf:develop' into develop
csjfwang Sep 22, 2025
39d3965
Merge branch 'ecmwf:develop' into develop
csjfwang Sep 23, 2025
015ec88
Merge branch 'ecmwf:develop' into develop
csjfwang Sep 24, 2025
cb1b7cc
Merge branch 'ecmwf:develop' into develop
csjfwang Oct 1, 2025
90da4cf
Merge branch 'ecmwf:develop' into develop
csjfwang Oct 20, 2025
f04891b
Merge branch 'ecmwf:develop' into develop
csjfwang Oct 21, 2025
105d992
Merge branch 'ecmwf:develop' into develop
csjfwang Oct 24, 2025
5f56073
Merge branch 'ecmwf:develop' into develop
csjfwang Oct 26, 2025
95ee18a
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 3, 2025
3c702d3
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 10, 2025
6f14a30
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 13, 2025
5e87881
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 14, 2025
0c7d305
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 24, 2025
e43ac94
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 25, 2025
5f63bcc
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 26, 2025
c51eb94
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 26, 2025
dd5acc2
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 26, 2025
f03672d
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 27, 2025
49c52e1
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 28, 2025
c6356a2
Merge branch 'ecmwf:develop' into develop
csjfwang Dec 1, 2025
36c709a
Merge branch 'ecmwf:develop' into develop
csjfwang Dec 1, 2025
765276a
Merge branch 'ecmwf:develop' into develop
csjfwang Dec 9, 2025
734a96b
add 2d rope to develop
Dec 9, 2025
e4be7c3
simplify assimilate global, forecast mode config
Dec 9, 2025
b193655
add 2d rope to forecast eigine only once
Dec 9, 2025
77d95ed
only keep global & forecast engine add 2d rope
Dec 10, 2025
3928e59
simplify the code
Dec 10, 2025
b542067
fix lint
Dec 10, 2025
8538291
small fix
Dec 10, 2025
81a52b5
fix annotation
Dec 10, 2025
a8afd35
fix lint
Dec 10, 2025
5a48898
add annotation
Dec 10, 2025
f3eb78a
Merge branch 'ecmwf:develop' into develop
csjfwang Dec 10, 2025
805293f
Merge branch 'develop' into develop_2d_rope_final_gf
Dec 10, 2025
dc914ea
default config
Dec 10, 2025
d9e0504
fix default use_reentrant
Dec 10, 2025
6e7f1ed
use_2d_rope false as defaut
Dec 11, 2025
ba3f579
Add copyright notice for RoPE functions and update naming
Dec 12, 2025
7e0aff2
fix lint
Dec 12, 2025
2bc76dd
fix lint
Dec 12, 2025
11761bd
add 2d rope to all forecast steps
Dec 14, 2025
0d35319
merge confs
kctezcan Dec 30, 2025
c6938fb
more confs
kctezcan Dec 30, 2025
cca2c23
add missing enumerate
kctezcan Dec 30, 2025
34378c6
def forecast config
kctezcan Dec 30, 2025
e39f56b
aux_info=None in Forecast Eng forward
kctezcan Dec 30, 2025
3abf188
lint
kctezcan Dec 30, 2025
d3cebcf
add rope to global engine, which was moved to encoder
Jan 23, 2026
131de8a
1)init attention module with_2d_rope and rope_learnable_freq
Jan 23, 2026
542f23e
Merge branch 'ecmwf:develop' into develop
csjfwang Jan 23, 2026
5992541
solve some reviews
Jan 23, 2026
44ff5e4
Merge branch 'develop_student_teacher' into ktezcan-csjfwang-2drope
Jan 23, 2026
e7ccc23
fix lint
Jan 23, 2026
c4604f3
fix 2 bugs: remove rope in QueryAggregation, and change bs in model.py
Jan 23, 2026
1a64cd4
temporally remove learnable rope
Jan 30, 2026
d14f61f
Merge branch 'ecmwf:develop' into develop
csjfwang Jan 30, 2026
8fbca29
Merge branch 'develop_student_teacher' into ktezcan-csjfwang-2drope
Jan 30, 2026
00a7fdd
add rope for register and class tokens; fix lint
Jan 30, 2026
d1634f9
rename aux_info in queryaggregation
Jan 30, 2026
b60224e
remove position_ids and change raise valueError to assert
Jan 31, 2026
7efef08
batch size get from get_batch_size_from_config()
Jan 31, 2026
ecf75b2
revert to default config without batchsize
Jan 31, 2026
cdd6088
use self.batch_size
Jan 31, 2026
b76ceb1
fix lint
Jan 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ forecast_att_dense_rate: 1.0

healpix_level: 5

# Use 2D RoPE instead of traditional global positional encoding
# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon)
# When False: uses traditional pe_global positional encoding
rope_2D: False

with_mixed_precision: True
with_flash_attention: True
compile_model: False
Expand Down
2 changes: 2 additions & 0 deletions config/default_forecast_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ fe_impute_latent_noise_std: 1e-4 # 1e-4

healpix_level: 5

rope_2D: False

with_mixed_precision: True
with_flash_attention: True
compile_model: False
Expand Down
26 changes: 24 additions & 2 deletions src/weathergen/model/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

from weathergen.model.norms import AdaLayerNorm, RMSNorm
from weathergen.model.positional_encoding import rotary_pos_emb_2d

"""
Attention blocks used by WeatherGenerator.

Some blocks optionally apply 2D RoPE. When enabled, the caller must provide per-token 2D
coordinates aligned with the token order (lat, lon in radians).
"""


class MultiSelfAttentionHeadVarlen(torch.nn.Module):
Expand Down Expand Up @@ -197,13 +205,15 @@ def __init__(
dim_aux=None,
norm_eps=1e-5,
attention_dtype=torch.bfloat16,
with_2d_rope=False,
):
super(MultiSelfAttentionHeadLocal, self).__init__()

self.num_heads = num_heads
self.with_flash = with_flash
self.softcap = softcap
self.with_residual = with_residual
self.with_2d_rope = with_2d_rope

assert dim_embed % num_heads == 0
self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj
Expand Down Expand Up @@ -242,7 +252,7 @@ def mask_block_local(batch, head, idx_q, idx_kv):
# compile for efficiency
self.flex_attention = torch.compile(flex_attention, dynamic=False)

def forward(self, x, ada_ln_aux=None):
def forward(self, x, coords=None, ada_ln_aux=None):
if self.with_residual:
x_in = x
x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux)
Expand All @@ -253,6 +263,11 @@ def forward(self, x, ada_ln_aux=None):
ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype).permute([0, 2, 1, 3])
vs = self.proj_heads_v(x).reshape(s).permute([0, 2, 1, 3])

if self.with_2d_rope:
if coords is None:
raise ValueError("coords must be provided when with_2d_rope=True")
qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=1)

outs = self.flex_attention(qs, ks, vs, block_mask=self.block_mask).transpose(1, 2)

out = self.proj_out(self.dropout(outs.flatten(-2, -1)))
Expand Down Expand Up @@ -487,6 +502,7 @@ def __init__(
dim_aux=None,
norm_eps=1e-5,
attention_dtype=torch.bfloat16,
with_2d_rope=False,
):
super(MultiSelfAttentionHead, self).__init__()

Expand All @@ -495,6 +511,7 @@ def __init__(
self.softcap = softcap
self.dropout_rate = dropout_rate
self.with_residual = with_residual
self.with_2d_rope = with_2d_rope

assert dim_embed % num_heads == 0
self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj
Expand Down Expand Up @@ -527,7 +544,7 @@ def __init__(
self.att = self.attention
self.softmax = torch.nn.Softmax(dim=-1)

def forward(self, x, ada_ln_aux=None):
def forward(self, x, coords=None, ada_ln_aux=None):
if self.with_residual:
x_in = x
x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux)
Expand All @@ -539,6 +556,11 @@ def forward(self, x, ada_ln_aux=None):
ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype)
vs = self.proj_heads_v(x).reshape(s).to(self.dtype)

if self.with_2d_rope:
if coords is None:
raise ValueError("coords must be provided when with_2d_rope=True")
qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=2)

# set dropout rate according to training/eval mode as required by flash_attn
dropout_rate = self.dropout_rate if self.training else 0.0

Expand Down
7 changes: 6 additions & 1 deletion src/weathergen/model/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,12 @@ def forward(self, model_params, batch):
self.assimilate_local, model_params, stream_cell_tokens, batch, use_reentrant=False
)

tokens_global = checkpoint(self.ae_global_engine, tokens_global, use_reentrant=False)
tokens_global = checkpoint(
self.ae_global_engine,
tokens_global,
coords=model_params.rope_coords,
use_reentrant=False,
)

return tokens_global, posteriors

Expand Down
18 changes: 12 additions & 6 deletions src/weathergen/model/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,12 +304,13 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None:
)
)

def forward(self, tokens, batch_lens, use_reentrant):
def forward(self, tokens, batch_lens, use_reentrant, coords=None):
for block in self.ae_aggregation_blocks:
aux_info = None
if isinstance(block, MultiSelfAttentionHeadVarlen):
tokens = block(tokens, x_lens=batch_lens)
else:
tokens = block(tokens)
tokens = block(tokens, coords, aux_info)
return tokens


Expand Down Expand Up @@ -345,6 +346,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None:
norm_type=self.cf.norm_type,
norm_eps=self.cf.norm_eps,
attention_dtype=get_dtype(self.cf.attention_dtype),
with_2d_rope=self.cf.rope_2D,
)
)
else:
Expand All @@ -360,6 +362,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None:
norm_type=self.cf.norm_type,
norm_eps=self.cf.norm_eps,
attention_dtype=get_dtype(self.cf.attention_dtype),
with_2d_rope=self.cf.rope_2D,
)
)
# MLP block
Expand All @@ -379,9 +382,10 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None:
torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False)
)

def forward(self, tokens):
def forward(self, tokens, coords=None):
aux_info = None
for block in self.ae_global_blocks:
tokens = block(tokens)
tokens = block(tokens, coords, aux_info)
return tokens


Expand Down Expand Up @@ -416,6 +420,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int =
dim_aux=dim_aux,
norm_eps=self.cf.norm_eps,
attention_dtype=get_dtype(self.cf.attention_dtype),
with_2d_rope=self.cf.rope_2D,
)
)
else:
Expand All @@ -432,6 +437,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int =
dim_aux=dim_aux,
norm_eps=self.cf.norm_eps,
attention_dtype=get_dtype(self.cf.attention_dtype),
with_2d_rope=self.cf.rope_2D,
)
)
# Add MLP block
Expand Down Expand Up @@ -461,7 +467,7 @@ def init_weights_final(m):
for block in self.fe_blocks:
block.apply(init_weights_final)

def forward(self, tokens, fstep):
def forward(self, tokens, fstep, coords=None):
if self.training:
# Impute noise to the latent state
noise_std = self.cf.get("fe_impute_latent_noise_std", 0.0)
Expand All @@ -473,7 +479,7 @@ def forward(self, tokens, fstep):
if isinstance(block, torch.nn.modules.normalization.LayerNorm):
tokens = block(tokens)
else:
tokens = block(tokens, aux_info)
tokens = block(tokens, coords, aux_info)
return tokens


Expand Down
93 changes: 72 additions & 21 deletions src/weathergen/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from weathergen.common.config import Config
from weathergen.datasets.batch import ModelBatch
from weathergen.datasets.utils import healpix_verts_rots, r3tos2
from weathergen.model.encoder import EncoderModule
from weathergen.model.engines import (
BilinearDecoder,
Expand All @@ -35,6 +36,7 @@
)
from weathergen.model.layers import MLP, NamedLinear
from weathergen.model.utils import get_num_parameters
from weathergen.train.utils import get_batch_size_from_config
from weathergen.utils.distributed import is_root
from weathergen.utils.utils import get_dtype

Expand Down Expand Up @@ -89,6 +91,7 @@ def __init__(self, cf) -> None:
self.healpix_level = cf.healpix_level
self.num_healpix_cells = 12 * 4**cf.healpix_level
self.dtype = get_dtype(cf.attention_dtype)
self.batch_size_per_gpu = get_batch_size_from_config(cf.training_config)

### POSITIONAL EMBEDDINGS ###
len_token_seq = 1024
Expand All @@ -104,6 +107,25 @@ def __init__(self, cf) -> None:
)
self.pe_global = torch.nn.Parameter(pe, requires_grad=False)

### ROPE COORDS ###
self.rope_2D = cf.get("rope_2D", False)
if self.rope_2D:
self.num_extra_tokens = cf.num_register_tokens + cf.num_class_tokens
total_tokens = (
self.num_healpix_cells + self.num_extra_tokens
) * cf.ae_local_num_queries
self.register_buffer(
"rope_coords",
torch.zeros(
self.batch_size_per_gpu,
total_tokens,
2,
dtype=self.dtype,
),
)
else:
self.rope_coords = None

### HEALPIX NEIGHBOURS ###
hlc = self.healpix_level
with warnings.catch_warnings(action="ignore"):
Expand Down Expand Up @@ -161,28 +183,57 @@ def reset_parameters(self, cf: Config) -> "ModelParams":
self.pe_embed.data[:, 1::2] = torch.cos(position * div[: self.pe_embed[:, 1::2].shape[1]])

dim_embed = cf.ae_global_dim_embed
self.pe_global.data.fill_(0.0)
xs = 2.0 * np.pi * torch.arange(0, dim_embed, 2, device=self.pe_global.device) / dim_embed
self.pe_global.data[..., 0::2] = 0.5 * torch.sin(
torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs)
)
self.pe_global.data[..., 0::2] += (
torch.sin(
torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs)

if self.rope_2D:
# Precompute per-cell center coordinates (lat, lon in radians) for 2D RoPE.
# Shape: (num_healpix_cells, ae_local_num_queries, 2)
verts, _ = healpix_verts_rots(self.healpix_level, 0.5, 0.5)
coords = r3tos2(verts.to(self.rope_coords.device)).to(self.rope_coords.dtype)
coords = coords.unsqueeze(1).repeat(1, cf.ae_local_num_queries, 1)
coords_flat = coords.flatten(0, 1).unsqueeze(0).repeat(self.batch_size_per_gpu, 1, 1)
offset = self.num_extra_tokens * cf.ae_local_num_queries
self.rope_coords.data.fill_(0.0)
self.rope_coords.data[:, offset : offset + coords_flat.shape[1], :].copy_(coords_flat)

# Clear pe_global when using 2D RoPE
self.pe_global.data.fill_(0.0)
else:
# Original pe_global initialization
self.pe_global.data.fill_(0.0)
xs = (
2.0
* np.pi
* torch.arange(0, dim_embed, 2, device=self.pe_global.device)
/ dim_embed
)
.unsqueeze(1)
.repeat((1, cf.ae_local_num_queries, 1))
)
self.pe_global.data[..., 1::2] = 0.5 * torch.cos(
torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs)
)
self.pe_global.data[..., 1::2] += (
torch.cos(
torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs)
self.pe_global.data[..., 0::2] = 0.5 * torch.sin(
torch.outer(
8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs
)
)
self.pe_global.data[..., 0::2] += (
torch.sin(
torch.outer(
torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs
)
)
.unsqueeze(1)
.repeat((1, cf.ae_local_num_queries, 1))
)
self.pe_global.data[..., 1::2] = 0.5 * torch.cos(
torch.outer(
8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs
)
)
self.pe_global.data[..., 1::2] += (
torch.cos(
torch.outer(
torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs
)
)
.unsqueeze(1)
.repeat((1, cf.ae_local_num_queries, 1))
)
.unsqueeze(1)
.repeat((1, cf.ae_local_num_queries, 1))
)

# healpix neighborhood structure

Expand Down Expand Up @@ -585,7 +636,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput:
for step in batch.get_output_idxs():
# apply forecasting engine (if present)
if self.forecast_engine:
tokens = self.forecast_engine(tokens, step)
tokens = self.forecast_engine(tokens, step, coords=model_params.rope_coords)

# decoder predictions
output = self.predict_decoders(model_params, step, tokens, batch, output)
Expand Down
Loading
Loading