Skip to content

Conversation

@oleksost
Copy link
Contributor

@oleksost oleksost commented Dec 11, 2025

✨ Description

🔍 Type of change

Select all that apply:

  • 🐛 Bug fix (non-breaking change that addresses a specific issue)
  • 🚀 New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
  • 📝 Documentation change (updates documentation, including new content or typo fixes)
  • 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)

📝 Changes

List the key changes introduced in this PR:

  1. Change A
  2. Change B

✅ Checklist

Make sure the following tasks are completed before submitting the PR:

General

  • 📜 I have read and followed the contributing guidelines.
  • 🏷️ I am using a clear and descriptive PR title that summarizes the key change or feature introduced.
  • 🎉 The functionality is complete, and I have tested the changes.
  • 📝 I have updated the documentation if needed.
  • ⚠️ The change does not introduce any new issues (e.g., runtime warnings, type checker errors, linting problems, unhandled edge cases).
  • 🧩 I have commented my code, especially in hard-to-understand areas.

Dependencies and Configuration

  • 🐋 I have updated the Docker configuration or dependencies, if applicable.
  • 🔄 I have ensured compatibility with the existing setup after dependency changes.

Testing

  • 🧪 I have added or updated tests to cover my changes.
  • ✔️ New and existing tests pass locally with my changes.
  • 🚦 I have tested these changes on GPUs and verified training stability.
  • 🏋️ I have tested the changes on realistic training workloads, if applicable.

Performance Impact

  • 📊 I have run benchmarks where applicable to evaluate the performance impact.
  • ✅ The benchmarks show no performance regression.
  • 🚀 The benchmarks indicate a potential performance improvement.
  • ⚠️ The benchmarks indicate a potential performance degradation.
  • 📈 I have provided benchmark results and detailed any performance impact below, if applicable.

📊 Performance Impact Details

If there is any impact on performance, describe it and provide benchmark results, if applicable:


🗒️ Additional Notes

Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.

for begin, end in loss_masking_spans:
loss_mask[sample_index, begin:end] = False
if self._config.output_layer.distillation_model is not None:
if self._config.decoder.block.distillation_model is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need this for logits distillation. (output_layer has been renamed to head)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true, didn't see that head also has distillation model attribute

for document in documents:
for begin, end in document.ranges:
ranges.extend((begin + sample_size, end + sample_size))
ranges.append((begin + sample_size, end + sample_size))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any idea why this wasn't caught in the tests?

Copy link
Contributor Author

@oleksost oleksost Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we never test a forward pass with loss masks and when testing loss masking spans (so preprocess batch is never called with loss masking pans) we do not test sampling from SampledIndexedDataset correctly. I will try to add a test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added llama_with_loss_masking model config that should catch these bugs

@oleksost oleksost changed the title Fixes masked loss distillation Fixes loss distillation + manual KL grads Dec 12, 2025
for begin, end in loss_masking_spans:
loss_mask[sample_index, begin:end] = False
if self._config.output_layer.distillation_model is not None:
if self._config.head.distillation_model is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking for the decoder distillation model was a good idea, please make it check for both.

target_format=target_format,
)
_compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref.chunk(world_size, 1)[rank], 1e-4)
_compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref.chunk(world_size, 1)[rank], 1e-4, 1e-6)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be needed. The gradients are O(1e-6), so using 1e-6 as an absolute threshold amounts to dropping the check entirely.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


config = (
{"type": "file", "path": config_paths[0]}
{"type": "file", "path": config_paths[0]} # TODO: shouldn't this be {"training": {...}}?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this is a dataset config. The "training" part is for a Data config, see for example https://github.com/ServiceNow/Fast-LLM/blob/main/tests/data/test_blending.py#L151

)


def get_dataset_with_loss_masking_spans(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to go with get_model_test_dataset and match the parametrization (vocab size). I recommend just adding spans to get_model_test_dataset instead since it does nothing unless used.

},
)

_update_and_add_testing_config(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend adding to one or more of the existing distillation configs instead, it will cover more code and avoid the extra config.

@jlamypoirier
Copy link
Collaborator

The distributed gradients still don't seem to match the, here is an example with low thresholds (for demonstration) showing an error above 10%:

 >>>>>> Failed reverse_kl_forward_backward, target_format, use_mask=False
Traceback (most recent call last):
  File "/app/tests/functional/test_cross_entropy.py", line 200, in compare_parallel_cross_entropy
    _compare_parallel_cross_entropy(rank, group, target_format, function, loss_masking)
  File "/app/tests/functional/test_cross_entropy.py", line 191, in _compare_parallel_cross_entropy
    _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref.chunk(world_size, 1)[rank], 1e-4, 1e-6)
  File "/app/tests/functional/test_cross_entropy.py", line 47, in _compare_cross_entropy_outputs
    Assert.rms_close_relative(grad, ref_grad, 1e-8, 1e-12)
  File "/app/fast_llm/utils.py", line 164, in rms_close_relative
    rms <= threshold
AssertionError: Rms diff too big (1.92e-07 > 1.00e-12, scale = 2.29e-06) between tensors tensor([[-7.1526e-07, -2.2165e-07, -2.1607e-06,  ..., -6.4448e-07,
         -6.8173e-07, -1.0803e-06],
        [-5.3272e-07,  1.7360e-06,  3.9935e-06,  ...,  7.6294e-06,
          1.3411e-06, -1.6415e-08],
        [ 5.4762e-07,  9.0152e-07, -1.0757e-07,  ..., -8.1956e-08,
         -4.9919e-07, -5.8860e-07],
        ...,
        [ 1.5646e-07, -1.3188e-06, -1.9930e-07,  ..., -2.6776e-09,
          9.1642e-07, -1.5497e-06],
        [-1.8161e-07,  4.8429e-07, -2.2631e-07,  ..., -2.3097e-06,
         -4.8801e-07, -8.8289e-07],
        [-1.6671e-07,  1.5087e-07, -1.7807e-06,  ..., -4.3027e-07,
         -3.8650e-08,  2.7865e-06]], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>) and tensor([[-8.0839e-07, -2.7567e-07, -2.2650e-06,  ..., -8.0839e-07,
         -7.4506e-07, -1.2517e-06],
        [-6.0722e-07,  1.5721e-06,  3.5614e-06,  ...,  6.9141e-06,
          1.2368e-06, -2.5611e-08],
        [ 5.1782e-07,  8.1956e-07, -1.3597e-07,  ..., -9.2667e-08,
         -5.8860e-07, -6.0722e-07],
        ...,
        [ 1.2945e-07, -1.5423e-06, -3.3528e-07,  ..., -1.2922e-08,
          8.6799e-07, -1.6540e-06],
        [-2.0210e-07,  4.5821e-07, -2.3842e-07,  ..., -2.4587e-06,
         -5.2154e-07, -1.2890e-06],
        [-1.8720e-07,  1.1781e-07, -1.9222e-06,  ..., -4.7125e-07,
         -5.2620e-08,  2.6524e-06]], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<SplitBackward0>)

As a comparison, cross-entropy has an error of ~1e-5 (Rms diff too big (3.22e-11 > 1.00e-12, scale = 2.13e-06))

The single-gpu version seems to be working though, also within ~1e-5

       AssertionError: Rms diff too big (4.05e-11 > 1.00e-12, scale = 2.31e-06) between tensors tensor([[-9.9652e-08, -8.1211e-07, -1.4715e-07,  ...,  1.0133e-06,
E                 8.1025e-08, -2.5844e-08],
E               [ 8.9779e-07, -2.1048e-07,  7.2643e-07,  ..., -1.9968e-06,
E                 1.1250e-06,  6.2212e-07],
E               [ 9.6858e-08, -1.3225e-07, -1.6153e-09,  ..., -4.6566e-07,
E                -6.3330e-07,  2.8871e-07],
E               ...,
E               [-1.4529e-06, -4.0233e-07, -1.2144e-06,  ..., -1.9465e-07,
E                -1.0654e-06, -5.4389e-07],
E               [ 1.0058e-06,  1.0571e-07, -1.5423e-06,  ..., -6.7800e-07,
E                -9.6858e-08, -9.6043e-09],
E               [-2.4796e-08, -7.1898e-07, -1.6466e-06,  ...,  3.7812e-07,
E                -1.9837e-07, -8.7172e-07]], device='cuda:0', dtype=torch.bfloat16,
E              grad_fn=<ToCopyBackward0>) and tensor([[-9.9652e-08, -8.1211e-07, -1.4715e-07,  ...,  1.0133e-06,
E                 8.1025e-08, -2.5844e-08],
E               [ 8.9779e-07, -2.1048e-07,  7.2643e-07,  ..., -1.9968e-06,
E                 1.1250e-06,  6.2212e-07],
E               [ 9.6858e-08, -1.3225e-07, -1.6153e-09,  ..., -4.6566e-07,
E                -6.3330e-07,  2.8871e-07],
E               ...,
E               [-1.4529e-06, -4.0233e-07, -1.2144e-06,  ..., -1.9465e-07,
E                -1.0654e-06, -5.4389e-07],
E               [ 1.0058e-06,  1.0571e-07, -1.5423e-06,  ..., -6.7800e-07,
E                -9.6858e-08, -9.6043e-09],
E               [-2.4796e-08, -7.1898e-07, -1.6466e-06,  ...,  3.7812e-07,
E                -1.9837e-07, -8.7172e-07]], device='cuda:0', dtype=torch.bfloat16)

with torch.enable_grad():
logits_ = logits.float().detach().requires_grad_(grad_output is not None)
student_log_probs = distributed_log_softmax(logits_, group=group)
# logits_ = logits.float()#.detach().requires_grad_(grad_output is not None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove?

torch.manual_seed(0)
world_size = torch.distributed.get_world_size(group)
logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format)
logits, target, loss_mask = _get_cross_entropy_inputs(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just use @requires_cuda?

return _get_test_dataset(
DATASET_CACHE / "model_dataset",
seed=1234,
max_loss_masking_spans=5 if use_loss_masking else 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work, the dataset is cached so the argument needs to be constant. max_loss_masking_spans=5 should work

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[bug] Bug with maksed distillation with loss spans

3 participants