Skip to content

Commit 9600183

Browse files
shmh40Sophie Xhonneuxsophie-xhonneuxclaudeclessig
authored
Enable per stream masking config override (#1951)
* Add collapse monitoring * Fix bug * Fix SVD computation failing * Reduce variables logged * Fix EMA beta value computation * Refactor get_current_beta to ema.py * Sensible default for ema in jepa * Allow collapse monitoring for forecasting * Fix no collapse monitoring for forecasting * Try to fix forecasting * Fix teacher rank collapse when rope_2D is enabled Two issues caused the EMA teacher's effective rank to drop to ~8-10 (multi-GPU) or ~40 (single-GPU) at training start when rope_2D=True, while the student appeared unaffected: 1. pe_global zeroed with rope_2D: When rope_2D was enabled, pe_global was cleared to zero under the assumption that RoPE replaces it. However, RoPE only provides relative position in Q/K attention -- it does not affect V. pe_global is the sole source of per-cell token identity for masked cells (which have no content from local assimilation). Without it, all masked cells are identical, collapsing the teacher representation. The student metric was artificially inflated by dropout noise hiding the same underlying low-rank issue. Fix: always initialize pe_global -- it and RoPE serve complementary roles. 2. EMA reset ignores DDP key prefix: EMAModel.reset() loads the student state_dict directly via load_state_dict, but DDP wrapping adds a module. prefix to all keys. With strict=False, every key silently fails to match, leaving the teacher with uninitialized weights from to_empty(). The update() method already handled this mismatch but reset() did not. Combined with q_cells being skipped in EMA updates, the teacher q_cells was permanently corrupted on multi-GPU runs. Fix: strip the module. prefix before loading. Co-Authored-By: Claude Opus 4.6 <[email protected]> * Try adding 2d rope to Query engine * Fix shape mismatch * Run linter * Adding support for dropping of streams * enable healpix masking at the level of the data * enable per stream masking strategy config override * per stream masking override test * move perstream masking to masker * fix moving per stream config in masker * lint * tidy up * better naming and docs of per stream override, and msds rename stage cfg to stream cfg * addressed comments but now broken with tokens_all.scatter_(0, scatter_idxs, torch.cat(x_embeds) + pe_embed[pe_idxs]) expected non-empty list of tensors. Also more scaffolding needed to make this work for masking, since we build the targets first and the source is just the ~target_mask, and there was stuff in the code not to drop target streams, only to drop as sources * Revert "addressed comments but now broken with tokens_all.scatter_(0, scatter_idxs, torch.cat(x_embeds) + pe_embed[pe_idxs]) expected non-empty list of tensors. Also more scaffolding needed to make this work for masking, since we build the targets first and the source is just the ~target_mask, and there was stuff in the code not to drop target streams, only to drop as sources" This reverts commit 70ce173. * move per stream config overrides to masker, and now randomly drop has its own scaffold when building source inputs * address reviewer comments on static method and consolidated config * update merge_masking_config docstring to reflect the randomly_drop is dop independently per source sample * drop decision for stream applies to all source strategies. Dropping streams only applies during training * update the test --------- Co-authored-by: Sophie Xhonneux <[email protected]> Co-authored-by: sophiex <[email protected]> Co-authored-by: Claude Opus 4.6 <[email protected]> Co-authored-by: Christian Lessig <[email protected]>
1 parent e8b1f8e commit 9600183

File tree

4 files changed

+370
-19
lines changed

4 files changed

+370
-19
lines changed

src/weathergen/datasets/masking.py

Lines changed: 103 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from weathergen.datasets.batch import SampleMetaData
1212
from weathergen.train.utils import Stage
13+
from weathergen.utils.distributed import is_root
1314
from weathergen.utils.utils import is_stream_diagnostic, is_stream_forcing
1415

1516
logger = logging.getLogger(__name__)
@@ -111,7 +112,7 @@ class Masker:
111112
specific to the masking strategy. See above.
112113
"""
113114

114-
def __init__(self, healpix_level: int, stage: Stage):
115+
def __init__(self, healpix_level: int, stage: Stage, streams=None, mode_cfg=None):
115116
self.rng = None
116117

117118
self.mask_value = 0.0
@@ -123,12 +124,89 @@ def __init__(self, healpix_level: int, stage: Stage):
123124

124125
self.stage = stage
125126

127+
# Build and store per-stream effective masking configs
128+
if streams is not None and mode_cfg is not None:
129+
self._effective_masking_cfgs = self.build_effective_masking_cfgs(streams, mode_cfg)
130+
else:
131+
self._effective_masking_cfgs = {}
132+
126133
def reset_rng(self, rng) -> None:
127134
"""
128135
Reset rng after mini_epoch to ensure proper randomization
129136
"""
130137
self.rng = rng
131138

139+
def merge_masking_config(self, mode_cfg, override):
140+
"""Merge a stream's masking override into the base mode config.
141+
142+
Only masking strategy fields are overridden. Structural keys like
143+
``num_samples`` and ``num_steps_input`` remain unchanged.
144+
145+
The override is flat per section (``model_input`` / ``target_input``),
146+
not per named strategy. If a section has multiple strategies (e.g.
147+
``"input_physical"`` and ``"input_jepa"``), masking strategy fields are
148+
broadcast to all of them. ``randomly_drop_as_source_rate`` is a
149+
per-stream rate; the drop decision is made once per call to
150+
``build_samples_for_stream`` and applies to all source strategies
151+
uniformly (training only).
152+
153+
Expected YAML in a stream config, e.g.:
154+
155+
STREAM_NAME:
156+
type: ...
157+
filenames: ...
158+
...
159+
masking_override:
160+
target_input:
161+
masking_strategy_config:
162+
hl_mask: 3
163+
...
164+
165+
This overrides only ``hl_mask`` within ``masking_strategy_config`` for
166+
every target strategy, inheriting rate, rate_sampling, etc. from the
167+
global config. ``masking_strategy`` itself can also be replaced.
168+
"""
169+
if override is None:
170+
return mode_cfg
171+
172+
stream_cfg_masking = copy.deepcopy(mode_cfg)
173+
174+
# Copy top-level masking keys from override
175+
if "randomly_drop_as_source_rate" in override:
176+
stream_cfg_masking["randomly_drop_as_source_rate"] = override[
177+
"randomly_drop_as_source_rate"
178+
]
179+
180+
for section_key in ("model_input", "target_input"):
181+
override_values = override.get(section_key, None)
182+
if override_values is None:
183+
continue
184+
section = stream_cfg_masking.get(section_key, None)
185+
if section is None:
186+
continue
187+
for strategy_cfg in section.values():
188+
if "masking_strategy" in override_values:
189+
strategy_cfg["masking_strategy"] = override_values["masking_strategy"]
190+
if "masking_strategy_config" in override_values:
191+
strategy_cfg["masking_strategy_config"] = omegaconf.OmegaConf.merge(
192+
strategy_cfg.get("masking_strategy_config", omegaconf.OmegaConf.create({})),
193+
override_values["masking_strategy_config"],
194+
)
195+
196+
return stream_cfg_masking
197+
198+
def build_effective_masking_cfgs(self, streams, mode_cfg):
199+
"""Build effective masking configs for all streams."""
200+
cfgs = {}
201+
for stream_info in streams:
202+
name = stream_info["name"]
203+
override = stream_info.get("masking_override", None)
204+
cfgs[name] = self.merge_masking_config(mode_cfg, override)
205+
if override is not None and is_root():
206+
logger.info(f"Stream '{name}' using masking override: {override}")
207+
208+
return cfgs
209+
132210
def _get_sampling_rate(self, cfg):
133211
"""
134212
Get the sampling, if requested by sampling it itself
@@ -257,25 +335,33 @@ def build_samples_for_stream(
257335
self,
258336
training_mode: str,
259337
num_cells: int,
260-
stage_cfg: dict,
261-
stream_cfg: dict,
338+
stream_info: dict,
262339
) -> tuple[np.typing.NDArray, list[np.typing.NDArray], list[SampleMetaData]]:
263340
"""
264341
Construct teacher/student keep masks for a stream.
265342
SampleMetaData is currently just a dict with the masking params used.
266343
"""
267344

345+
stream_masking_cfg = self._effective_masking_cfgs[stream_info["name"]]
346+
268347
# target and source configs
269-
target_cfgs = stage_cfg.get("target_input", [])
270-
source_cfgs = stage_cfg.get("model_input", [])
348+
target_cfgs = stream_masking_cfg.get("target_input", [])
349+
source_cfgs = stream_masking_cfg.get("model_input", [])
271350

272351
# target and source are assumed identical when target is not specified
273352
if len(target_cfgs) == 0:
274353
target_cfgs = copy.deepcopy(source_cfgs)
275354

276-
losses = stage_cfg.losses
355+
losses = stream_masking_cfg.losses
277356
corr_dict = self.parse_src_target_correspondence(losses, target_cfgs, source_cfgs)
278357

358+
# randomly_drop_as_source_rate from consolidated masking config (training only)
359+
randomly_drop_rate = (
360+
stream_masking_cfg.get("randomly_drop_as_source_rate", 0.0)
361+
if self.stage == "train"
362+
else 0.0
363+
)
364+
279365
target_masks = MaskData()
280366

281367
# iterate over all target samples
@@ -285,9 +371,10 @@ def build_samples_for_stream(
285371
# different samples/view per strategy
286372
for _ in range(target_cfg.get("num_samples", 1)):
287373
# determine if forcing dataset => mask is empty
288-
if is_stream_forcing(stream_cfg, self.stage):
374+
if is_stream_forcing(stream_info, self.stage):
289375
target_mask, mask_params = torch.zeros(num_cells, dtype=torch.bool), {}
290376
else:
377+
# targets are never randomly dropped
291378
target_mask, mask_params = self._get_mask(
292379
num_cells=num_cells,
293380
strategy=target_cfg.get("masking_strategy"),
@@ -312,6 +399,7 @@ def build_samples_for_stream(
312399
source_masks = MaskData()
313400
source_target_mapping = []
314401
target_num_samples = get_num_samples(target_cfgs)
402+
is_stream_dropped = randomly_drop_rate > 0.0 and self.rng.uniform() < randomly_drop_rate
315403
i_source = 0
316404
for i_src_cfg, (_, source_cfg) in enumerate(source_cfgs.items()):
317405
# skip items that do not appear in loss
@@ -336,8 +424,8 @@ def build_samples_for_stream(
336424
# target is specified)
337425
target_idx += i_sample % target_num_samples[target_cfg_idx].item()
338426

339-
# determine if forcing dataset => mask is empty
340-
if is_stream_diagnostic(stream_cfg, self.stage):
427+
# determine if diagnostic dataset or randomly dropped => mask is empty
428+
if is_stream_diagnostic(stream_info, self.stage) or is_stream_dropped:
341429
source_mask, mask_params = torch.zeros(num_cells, dtype=torch.bool), {}
342430
else:
343431
source_mask, mask_params = self._get_mask(
@@ -427,7 +515,10 @@ def _get_mask(
427515
return (mask, params)
428516

429517
def _generate_cell_mask(
430-
self, num_cells: int, strategy: str, masking_strategy_config: dict
518+
self,
519+
num_cells: int,
520+
strategy: str,
521+
masking_strategy_config: dict,
431522
) -> (np.typing.NDArray, dict):
432523
"""Generate a boolean keep mask at data healpix level (True = keep cell).
433524
@@ -692,8 +783,8 @@ def _prepare_healpix_based_masking(self, cfg, keep_rate):
692783

693784
hl_data = self.healpix_level_data
694785
hl_mask = cfg.get("hl_mask")
695-
assert hl_mask is not None and hl_mask < hl_data, (
696-
"For healpix keep mask generation, cfg['hl_mask'] must be set and < data level."
786+
assert hl_mask is not None and hl_mask <= hl_data, (
787+
"For healpix keep mask generation, cfg['hl_mask'] must be set and <= data level."
697788
)
698789
num_parent_cells = 12 * (4**hl_mask)
699790
level_diff = hl_data - hl_mask

src/weathergen/datasets/multi_stream_data_sampler.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,8 @@ def __init__(
243243
else cf.data_loading.rng_seed * 97
244244
)
245245

246-
self.tokenizer = TokenizerMasking(cf.healpix_level, Masker(cf.healpix_level, stage))
246+
self.masker = Masker(cf.healpix_level, stage, self.streams, self.mode_cfg)
247+
self.tokenizer = TokenizerMasking(cf.healpix_level, self.masker)
247248

248249
self.mini_epoch = 0
249250

@@ -575,16 +576,14 @@ def _get_data_windows(self, base_idx, num_forecast_steps, num_steps_input_max, s
575576

576577
def _get_source_target_masks(self, training_mode):
577578
"""
578-
Generate source and target masks for all streams
579+
Generate source and target masks for all streams.
579580
"""
580-
581581
masks = {}
582582
for stream_info in self.streams:
583583
# Build source and target sample masks
584584
masks[stream_info["name"]] = self.tokenizer.build_samples_for_stream(
585585
training_mode,
586586
self.num_healpix_cells,
587-
self.mode_cfg,
588587
stream_info,
589588
)
590589
# identical for all streams

src/weathergen/datasets/tokenizer_masking.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,12 @@ def build_samples_for_stream(
8181
self,
8282
training_mode: str,
8383
num_cells: int,
84-
stage_cfg: dict,
85-
stream_cfg: dict,
84+
stream_info: dict,
8685
) -> tuple[np.typing.NDArray, list[np.typing.NDArray], list[SampleMetaData]]:
8786
"""
8887
Create masks for samples
8988
"""
90-
return self.masker.build_samples_for_stream(training_mode, num_cells, stage_cfg, stream_cfg)
89+
return self.masker.build_samples_for_stream(training_mode, num_cells, stream_info)
9190

9291
def cell_to_token_mask(self, idxs_cells, idxs_cells_lens, mask):
9392
""" """

0 commit comments

Comments
 (0)