Commit 9600183
Enable per stream masking config override (#1951)
* Add collapse monitoring
* Fix bug
* Fix SVD computation failing
* Reduce variables logged
* Fix EMA beta value computation
* Refactor get_current_beta to ema.py
* Sensible default for ema in jepa
* Allow collapse monitoring for forecasting
* Fix no collapse monitoring for forecasting
* Try to fix forecasting
* Fix teacher rank collapse when rope_2D is enabled
Two issues caused the EMA teacher's effective rank to drop to ~8-10
(multi-GPU) or ~40 (single-GPU) at training start when rope_2D=True,
while the student appeared unaffected:
1. pe_global zeroed with rope_2D: When rope_2D was enabled,
pe_global was cleared to zero under the assumption that RoPE
replaces it. However, RoPE only provides relative position in
Q/K attention -- it does not affect V. pe_global is the sole source
of per-cell token identity for masked cells (which have no content
from local assimilation). Without it, all masked cells are identical,
collapsing the teacher representation. The student metric was
artificially inflated by dropout noise hiding the same underlying
low-rank issue. Fix: always initialize pe_global -- it and RoPE
serve complementary roles.
2. EMA reset ignores DDP key prefix: EMAModel.reset() loads the
student state_dict directly via load_state_dict, but DDP wrapping
adds a module. prefix to all keys. With strict=False, every key
silently fails to match, leaving the teacher with uninitialized
weights from to_empty(). The update() method already handled this
mismatch but reset() did not. Combined with q_cells being skipped
in EMA updates, the teacher q_cells was permanently corrupted on
multi-GPU runs. Fix: strip the module. prefix before loading.
Co-Authored-By: Claude Opus 4.6 <[email protected]>
* Try adding 2d rope to Query engine
* Fix shape mismatch
* Run linter
* Adding support for dropping of streams
* enable healpix masking at the level of the data
* enable per stream masking strategy config override
* per stream masking override test
* move perstream masking to masker
* fix moving per stream config in masker
* lint
* tidy up
* better naming and docs of per stream override, and msds rename stage cfg to stream cfg
* addressed comments but now broken with tokens_all.scatter_(0, scatter_idxs, torch.cat(x_embeds) + pe_embed[pe_idxs]) expected non-empty list of tensors. Also more scaffolding needed to make this work for masking, since we build the targets first and the source is just the ~target_mask, and there was stuff in the code not to drop target streams, only to drop as sources
* Revert "addressed comments but now broken with tokens_all.scatter_(0, scatter_idxs, torch.cat(x_embeds) + pe_embed[pe_idxs]) expected non-empty list of tensors. Also more scaffolding needed to make this work for masking, since we build the targets first and the source is just the ~target_mask, and there was stuff in the code not to drop target streams, only to drop as sources"
This reverts commit 70ce173.
* move per stream config overrides to masker, and now randomly drop has its own scaffold when building source inputs
* address reviewer comments on static method and consolidated config
* update merge_masking_config docstring to reflect the randomly_drop is dop independently per source sample
* drop decision for stream applies to all source strategies. Dropping streams only applies during training
* update the test
---------
Co-authored-by: Sophie Xhonneux <[email protected]>
Co-authored-by: sophiex <[email protected]>
Co-authored-by: Claude Opus 4.6 <[email protected]>
Co-authored-by: Christian Lessig <[email protected]>1 parent e8b1f8e commit 9600183
File tree
4 files changed
+370
-19
lines changed- src/weathergen/datasets
- tests
4 files changed
+370
-19
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
| 13 | + | |
13 | 14 | | |
14 | 15 | | |
15 | 16 | | |
| |||
111 | 112 | | |
112 | 113 | | |
113 | 114 | | |
114 | | - | |
| 115 | + | |
115 | 116 | | |
116 | 117 | | |
117 | 118 | | |
| |||
123 | 124 | | |
124 | 125 | | |
125 | 126 | | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
126 | 133 | | |
127 | 134 | | |
128 | 135 | | |
129 | 136 | | |
130 | 137 | | |
131 | 138 | | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
132 | 210 | | |
133 | 211 | | |
134 | 212 | | |
| |||
257 | 335 | | |
258 | 336 | | |
259 | 337 | | |
260 | | - | |
261 | | - | |
| 338 | + | |
262 | 339 | | |
263 | 340 | | |
264 | 341 | | |
265 | 342 | | |
266 | 343 | | |
267 | 344 | | |
| 345 | + | |
| 346 | + | |
268 | 347 | | |
269 | | - | |
270 | | - | |
| 348 | + | |
| 349 | + | |
271 | 350 | | |
272 | 351 | | |
273 | 352 | | |
274 | 353 | | |
275 | 354 | | |
276 | | - | |
| 355 | + | |
277 | 356 | | |
278 | 357 | | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
279 | 365 | | |
280 | 366 | | |
281 | 367 | | |
| |||
285 | 371 | | |
286 | 372 | | |
287 | 373 | | |
288 | | - | |
| 374 | + | |
289 | 375 | | |
290 | 376 | | |
| 377 | + | |
291 | 378 | | |
292 | 379 | | |
293 | 380 | | |
| |||
312 | 399 | | |
313 | 400 | | |
314 | 401 | | |
| 402 | + | |
315 | 403 | | |
316 | 404 | | |
317 | 405 | | |
| |||
336 | 424 | | |
337 | 425 | | |
338 | 426 | | |
339 | | - | |
340 | | - | |
| 427 | + | |
| 428 | + | |
341 | 429 | | |
342 | 430 | | |
343 | 431 | | |
| |||
427 | 515 | | |
428 | 516 | | |
429 | 517 | | |
430 | | - | |
| 518 | + | |
| 519 | + | |
| 520 | + | |
| 521 | + | |
431 | 522 | | |
432 | 523 | | |
433 | 524 | | |
| |||
692 | 783 | | |
693 | 784 | | |
694 | 785 | | |
695 | | - | |
696 | | - | |
| 786 | + | |
| 787 | + | |
697 | 788 | | |
698 | 789 | | |
699 | 790 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
243 | 243 | | |
244 | 244 | | |
245 | 245 | | |
246 | | - | |
| 246 | + | |
| 247 | + | |
247 | 248 | | |
248 | 249 | | |
249 | 250 | | |
| |||
575 | 576 | | |
576 | 577 | | |
577 | 578 | | |
578 | | - | |
| 579 | + | |
579 | 580 | | |
580 | | - | |
581 | 581 | | |
582 | 582 | | |
583 | 583 | | |
584 | 584 | | |
585 | 585 | | |
586 | 586 | | |
587 | | - | |
588 | 587 | | |
589 | 588 | | |
590 | 589 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
81 | 81 | | |
82 | 82 | | |
83 | 83 | | |
84 | | - | |
85 | | - | |
| 84 | + | |
86 | 85 | | |
87 | 86 | | |
88 | 87 | | |
89 | 88 | | |
90 | | - | |
| 89 | + | |
91 | 90 | | |
92 | 91 | | |
93 | 92 | | |
| |||
0 commit comments