Skip to content

Enable per stream masking config override#1951

Merged
clessig merged 39 commits intodevelopfrom
shmh40/dev/1950-per-stream-masking
Mar 12, 2026
Merged

Enable per stream masking config override#1951
clessig merged 39 commits intodevelopfrom
shmh40/dev/1950-per-stream-masking

Conversation

@shmh40
Copy link
Contributor

@shmh40 shmh40 commented Feb 27, 2026

Description

Enables per-stream masking config override.

Issue Number

Closes #1950

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

Sophie Xhonneux and others added 22 commits February 4, 2026 19:53
Two issues caused the EMA teacher's effective rank to drop to ~8-10
(multi-GPU) or ~40 (single-GPU) at training start when rope_2D=True,
while the student appeared unaffected:

1. pe_global zeroed with rope_2D: When rope_2D was enabled,
   pe_global was cleared to zero under the assumption that RoPE
   replaces it. However, RoPE only provides relative position in
   Q/K attention -- it does not affect V. pe_global is the sole source
   of per-cell token identity for masked cells (which have no content
   from local assimilation). Without it, all masked cells are identical,
   collapsing the teacher representation. The student metric was
   artificially inflated by dropout noise hiding the same underlying
   low-rank issue. Fix: always initialize pe_global -- it and RoPE
   serve complementary roles.

2. EMA reset ignores DDP key prefix: EMAModel.reset() loads the
   student state_dict directly via load_state_dict, but DDP wrapping
   adds a module. prefix to all keys. With strict=False, every key
   silently fails to match, leaving the teacher with uninitialized
   weights from to_empty(). The update() method already handled this
   mismatch but reset() did not. Combined with q_cells being skipped
   in EMA updates, the teacher q_cells was permanently corrupted on
   multi-GPU runs. Fix: strip the module. prefix before loading.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
@shmh40 shmh40 self-assigned this Feb 27, 2026
@github-actions github-actions bot added data Anything related to the datasets used in the project data:reading Everything related to data reading model Related to model training or definition (not generic infra) labels Feb 27, 2026
self.perms = None
self.perms_num_forecast_steps = None

def _build_effective_masking_cfgs(self) -> dict[StreamName, Config]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be in Masker?

stay consistent across streams for batch assembly.
"""
cfgs: dict[StreamName, Config] = {}
for stream_info in self.streams:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the logic here should be:

for each stream:
stream_merged_masking_config = config.merge( masking_config, stream_config.masking_override)

Generate source and target masks for all streams
Generate source and target masks for all streams.

Each stream uses its own effective masking config (which may include
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be here (would make sense if the masking_config would be passed as an argument, which is a sensible possibility). Can be in line 631.

@shmh40
Copy link
Contributor Author

shmh40 commented Feb 27, 2026

Thanks for the comments @clessig. Agreed it should be handled in masking and I have tried to do this with only minor changes to msds. Hopefully better now. We can think if we want it to be more general but I am ok with this for now.

I directly merged your PR in too. I haven't come up with a neater solution to the source/target distinction but can revisit it next week if helpful.

Copy link
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some more comments. We need an example how the overrides look like but for documentation and to ensure things are handled correctly. I also made some more suggestions to further improve encapsulation.

if override is None:
return mode_cfg

effective = copy.deepcopy(mode_cfg)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use a better variable name than effective, e.g. stream_cfg_masking seem appropriate

num_cells=num_cells,
strategy=target_cfg.get("masking_strategy"),
masking_strategy_config=target_cfg.get("masking_strategy_config", {}),
stream_cfg=stream_cfg_target,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The masking_strategy_config that is passed here should be the consolidated one. Then stream_cfg should not be needed here. See also the comment above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think so too. Just checking -- this is done currently because we need to pass through the randomly_drop_as_source_rate? So we should put the randomly_drop in build_samples_from_stream, and then we can remove stream_cfg and stream_cfg_target deepcopy?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do we have build_samples_from_stream? randomly_drop_as_source_rate should be in the consolidated masking config.

num_cells=num_cells,
strategy=source_cfg.get("masking_strategy"),
masking_strategy_config=masking_config,
stream_cfg=stream_cfg,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above. Do we really need it?

for stream_info in self.streams:
# Each stream uses its own effective masking config (which may include
# per-stream ``masking_override`` merged on top of the global config).
stage_cfg = self._effective_masking_cfgs[stream_info["name"]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stage_cfg -> stream_cfg -- nothing related to stage here.

self.masker = Masker(cf.healpix_level, stage)
self.tokenizer = TokenizerMasking(cf.healpix_level, self.masker)

self._effective_masking_cfgs = self.masker.build_effective_masking_cfgs(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if _effective_masking_cfgs cannot be kept in the Masker and constructed in the constructor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is possible, I don't think it makes too much difference though?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It helps encapsulation. MultiStreamDataSampler is complex enough.

shmh40 added 6 commits March 3, 2026 18:37
…_idxs, torch.cat(x_embeds) + pe_embed[pe_idxs]) expected non-empty list of tensors. Also more scaffolding needed to make this work for masking, since we build the targets first and the source is just the ~target_mask, and there was stuff in the code not to drop target streams, only to drop as sources
… scatter_idxs, torch.cat(x_embeds) + pe_embed[pe_idxs]) expected non-empty list of tensors. Also more scaffolding needed to make this work for masking, since we build the targets first and the source is just the ~target_mask, and there was stuff in the code not to drop target streams, only to drop as sources"

This reverts commit 70ce173.
… its own scaffold when building source inputs
Copy link
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments but please address before merging and let me know if we should disuss anything.

# determine if diagnostic dataset => mask is empty
if is_stream_diagnostic(stream_cfg, self.stage):
source_mask, mask_params = torch.zeros(num_cells, dtype=torch.bool), {}
elif randomly_drop_rate > 0.0 and self.rng.uniform() < randomly_drop_rate:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two branches are identical so they should be merged. To make the code more readable, one can compute the condition first, e.g.

is_stream_dropped = randomly_drop_rate > 0.0 and self.rng.uniform() < randomly_drop_rate

if is_stream_diagnostic(stream_cfg, self.stage) or is_stream_dropped:
    source_mask, mask_params = torch.zeros(num_cells, dtype=torch.bool), {}


def _generate_cell_mask(
self, num_cells: int, strategy: str, masking_strategy_config: dict
self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should fit again in one line (linters are stupid and they only add lines, never remove :))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Linter didn't want one line! :(

for stream_info in self.streams:
# Each stream uses its own effective masking config (which may include
# per-stream ``masking_override`` merged on top of the global config).
stream_cfg = self.masker._effective_masking_cfgs[stream_info["name"]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we still need this here? Is it not directly available in the masker where it is needed?

Also stream_cfg is a bit misleading and I would call it stream_masking_cfg

self.rng = rng

@staticmethod
def merge_masking_config(mode_cfg, override):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be static.


return stream_cfg_masking

@staticmethod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be static.

@shmh40
Copy link
Contributor Author

shmh40 commented Mar 11, 2026

To activate randomly_drop_as_source_rate, or to override the masking_strategy_config for a particular stream, an example snippet to be included in the stream config is below:

Example usage:

ERA5_smooth :
  type : anemoi
  filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr']
  stream_id : 0
  source : ['2d', '2t', 'msl', 'skt', 'sp']
  target : ['2d', '2t', 'msl', 'skt', 'sp']
  ...
  masking_override :
    randomly_drop_as_source_rate: 0.9
    target_input :
      masking_strategy_config :
        hl_mask : 3
  embed:
     ...

Copy link
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much cleaner now. Thanks for cleaning up.

@clessig clessig merged commit 9600183 into develop Mar 12, 2026
3 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

data:reading Everything related to data reading data Anything related to the datasets used in the project model:pretrain model Related to model training or definition (not generic infra)

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

Enable per-stream masking

3 participants