Enable per stream masking config override#1951
Conversation
Two issues caused the EMA teacher's effective rank to drop to ~8-10 (multi-GPU) or ~40 (single-GPU) at training start when rope_2D=True, while the student appeared unaffected: 1. pe_global zeroed with rope_2D: When rope_2D was enabled, pe_global was cleared to zero under the assumption that RoPE replaces it. However, RoPE only provides relative position in Q/K attention -- it does not affect V. pe_global is the sole source of per-cell token identity for masked cells (which have no content from local assimilation). Without it, all masked cells are identical, collapsing the teacher representation. The student metric was artificially inflated by dropout noise hiding the same underlying low-rank issue. Fix: always initialize pe_global -- it and RoPE serve complementary roles. 2. EMA reset ignores DDP key prefix: EMAModel.reset() loads the student state_dict directly via load_state_dict, but DDP wrapping adds a module. prefix to all keys. With strict=False, every key silently fails to match, leaving the teacher with uninitialized weights from to_empty(). The update() method already handled this mismatch but reset() did not. Combined with q_cells being skipped in EMA updates, the teacher q_cells was permanently corrupted on multi-GPU runs. Fix: strip the module. prefix before loading. Co-Authored-By: Claude Opus 4.6 <[email protected]>
| self.perms = None | ||
| self.perms_num_forecast_steps = None | ||
|
|
||
| def _build_effective_masking_cfgs(self) -> dict[StreamName, Config]: |
There was a problem hiding this comment.
Shouldn't this be in Masker?
| stay consistent across streams for batch assembly. | ||
| """ | ||
| cfgs: dict[StreamName, Config] = {} | ||
| for stream_info in self.streams: |
There was a problem hiding this comment.
I think the logic here should be:
for each stream:
stream_merged_masking_config = config.merge( masking_config, stream_config.masking_override)
| Generate source and target masks for all streams | ||
| Generate source and target masks for all streams. | ||
|
|
||
| Each stream uses its own effective masking config (which may include |
There was a problem hiding this comment.
I don't think this should be here (would make sense if the masking_config would be passed as an argument, which is a sensible possibility). Can be in line 631.
…eams_1947' into shmh40/dev/1950-per-stream-masking
|
Thanks for the comments @clessig. Agreed it should be handled in masking and I have tried to do this with only minor changes to msds. Hopefully better now. We can think if we want it to be more general but I am ok with this for now. I directly merged your PR in too. I haven't come up with a neater solution to the source/target distinction but can revisit it next week if helpful. |
clessig
left a comment
There was a problem hiding this comment.
Left some more comments. We need an example how the overrides look like but for documentation and to ensure things are handled correctly. I also made some more suggestions to further improve encapsulation.
src/weathergen/datasets/masking.py
Outdated
| if override is None: | ||
| return mode_cfg | ||
|
|
||
| effective = copy.deepcopy(mode_cfg) |
There was a problem hiding this comment.
Can we use a better variable name than effective, e.g. stream_cfg_masking seem appropriate
src/weathergen/datasets/masking.py
Outdated
| num_cells=num_cells, | ||
| strategy=target_cfg.get("masking_strategy"), | ||
| masking_strategy_config=target_cfg.get("masking_strategy_config", {}), | ||
| stream_cfg=stream_cfg_target, |
There was a problem hiding this comment.
The masking_strategy_config that is passed here should be the consolidated one. Then stream_cfg should not be needed here. See also the comment above.
There was a problem hiding this comment.
Yes I think so too. Just checking -- this is done currently because we need to pass through the randomly_drop_as_source_rate? So we should put the randomly_drop in build_samples_from_stream, and then we can remove stream_cfg and stream_cfg_target deepcopy?
There was a problem hiding this comment.
Where do we have build_samples_from_stream? randomly_drop_as_source_rate should be in the consolidated masking config.
src/weathergen/datasets/masking.py
Outdated
| num_cells=num_cells, | ||
| strategy=source_cfg.get("masking_strategy"), | ||
| masking_strategy_config=masking_config, | ||
| stream_cfg=stream_cfg, |
There was a problem hiding this comment.
See above. Do we really need it?
| for stream_info in self.streams: | ||
| # Each stream uses its own effective masking config (which may include | ||
| # per-stream ``masking_override`` merged on top of the global config). | ||
| stage_cfg = self._effective_masking_cfgs[stream_info["name"]] |
There was a problem hiding this comment.
stage_cfg -> stream_cfg -- nothing related to stage here.
| self.masker = Masker(cf.healpix_level, stage) | ||
| self.tokenizer = TokenizerMasking(cf.healpix_level, self.masker) | ||
|
|
||
| self._effective_masking_cfgs = self.masker.build_effective_masking_cfgs( |
There was a problem hiding this comment.
I am wondering if _effective_masking_cfgs cannot be kept in the Masker and constructed in the constructor?
There was a problem hiding this comment.
This is possible, I don't think it makes too much difference though?
There was a problem hiding this comment.
It helps encapsulation. MultiStreamDataSampler is complex enough.
…cfg to stream cfg
…_idxs, torch.cat(x_embeds) + pe_embed[pe_idxs]) expected non-empty list of tensors. Also more scaffolding needed to make this work for masking, since we build the targets first and the source is just the ~target_mask, and there was stuff in the code not to drop target streams, only to drop as sources
… scatter_idxs, torch.cat(x_embeds) + pe_embed[pe_idxs]) expected non-empty list of tensors. Also more scaffolding needed to make this work for masking, since we build the targets first and the source is just the ~target_mask, and there was stuff in the code not to drop target streams, only to drop as sources" This reverts commit 70ce173.
… its own scaffold when building source inputs
clessig
left a comment
There was a problem hiding this comment.
Minor comments but please address before merging and let me know if we should disuss anything.
src/weathergen/datasets/masking.py
Outdated
| # determine if diagnostic dataset => mask is empty | ||
| if is_stream_diagnostic(stream_cfg, self.stage): | ||
| source_mask, mask_params = torch.zeros(num_cells, dtype=torch.bool), {} | ||
| elif randomly_drop_rate > 0.0 and self.rng.uniform() < randomly_drop_rate: |
There was a problem hiding this comment.
The two branches are identical so they should be merged. To make the code more readable, one can compute the condition first, e.g.
is_stream_dropped = randomly_drop_rate > 0.0 and self.rng.uniform() < randomly_drop_rate
if is_stream_diagnostic(stream_cfg, self.stage) or is_stream_dropped:
source_mask, mask_params = torch.zeros(num_cells, dtype=torch.bool), {}
|
|
||
| def _generate_cell_mask( | ||
| self, num_cells: int, strategy: str, masking_strategy_config: dict | ||
| self, |
There was a problem hiding this comment.
This should fit again in one line (linters are stupid and they only add lines, never remove :))
There was a problem hiding this comment.
Linter didn't want one line! :(
| for stream_info in self.streams: | ||
| # Each stream uses its own effective masking config (which may include | ||
| # per-stream ``masking_override`` merged on top of the global config). | ||
| stream_cfg = self.masker._effective_masking_cfgs[stream_info["name"]] |
There was a problem hiding this comment.
Why do we still need this here? Is it not directly available in the masker where it is needed?
Also stream_cfg is a bit misleading and I would call it stream_masking_cfg
src/weathergen/datasets/masking.py
Outdated
| self.rng = rng | ||
|
|
||
| @staticmethod | ||
| def merge_masking_config(mode_cfg, override): |
There was a problem hiding this comment.
I don't think this should be static.
src/weathergen/datasets/masking.py
Outdated
|
|
||
| return stream_cfg_masking | ||
|
|
||
| @staticmethod |
There was a problem hiding this comment.
I don't think this should be static.
|
To activate randomly_drop_as_source_rate, or to override the masking_strategy_config for a particular stream, an example snippet to be included in the stream config is below: Example usage: |
… dop independently per source sample
…treams only applies during training
clessig
left a comment
There was a problem hiding this comment.
Much cleaner now. Thanks for cleaning up.
Description
Enables per-stream masking config override.
Issue Number
Closes #1950
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60