Skip to content

Add support for masking by dropping of streams#1948

Closed
clessig wants to merge 19 commits intodevelopfrom
clessig/develop/feature_drop_streams_1947
Closed

Add support for masking by dropping of streams#1948
clessig wants to merge 19 commits intodevelopfrom
clessig/develop/feature_drop_streams_1947

Conversation

@clessig
Copy link
Collaborator

@clessig clessig commented Feb 27, 2026

Description

Add support for masking by dropping of streams

Issue Number

Closes #1947

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

Sophie Xhonneux and others added 19 commits February 4, 2026 19:53
Two issues caused the EMA teacher's effective rank to drop to ~8-10
(multi-GPU) or ~40 (single-GPU) at training start when rope_2D=True,
while the student appeared unaffected:

1. pe_global zeroed with rope_2D: When rope_2D was enabled,
   pe_global was cleared to zero under the assumption that RoPE
   replaces it. However, RoPE only provides relative position in
   Q/K attention -- it does not affect V. pe_global is the sole source
   of per-cell token identity for masked cells (which have no content
   from local assimilation). Without it, all masked cells are identical,
   collapsing the teacher representation. The student metric was
   artificially inflated by dropout noise hiding the same underlying
   low-rank issue. Fix: always initialize pe_global -- it and RoPE
   serve complementary roles.

2. EMA reset ignores DDP key prefix: EMAModel.reset() loads the
   student state_dict directly via load_state_dict, but DDP wrapping
   adds a module. prefix to all keys. With strict=False, every key
   silently fails to match, leaving the teacher with uninitialized
   weights from to_empty(). The update() method already handled this
   mismatch but reset() did not. Combined with q_cells being skipped
   in EMA updates, the teacher q_cells was permanently corrupted on
   multi-GPU runs. Fix: strip the module. prefix before loading.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
@github-actions github-actions bot added the model Related to model training or definition (not generic infra) label Feb 27, 2026
@clessig
Copy link
Collaborator Author

clessig commented Mar 2, 2026

Will be merged with #1951

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Related to model training or definition (not generic infra)

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

Allow for dropping of streams in masking

2 participants