-
Notifications
You must be signed in to change notification settings - Fork 39
Fixes loss distillation + manual KL grads #418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
fast_llm/models/gpt/model.py
Outdated
| for begin, end in loss_masking_spans: | ||
| loss_mask[sample_index, begin:end] = False | ||
| if self._config.output_layer.distillation_model is not None: | ||
| if self._config.decoder.block.distillation_model is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still need this for logits distillation. (output_layer has been renamed to head)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
true, didn't see that head also has distillation model attribute
| for document in documents: | ||
| for begin, end in document.ranges: | ||
| ranges.extend((begin + sample_size, end + sample_size)) | ||
| ranges.append((begin + sample_size, end + sample_size)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any idea why this wasn't caught in the tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we never test a forward pass with loss masks and when testing loss masking spans (so preprocess batch is never called with loss masking pans) we do not test sampling from SampledIndexedDataset correctly. I will try to add a test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added llama_with_loss_masking model config that should catch these bugs
fast_llm/models/gpt/model.py
Outdated
| for begin, end in loss_masking_spans: | ||
| loss_mask[sample_index, begin:end] = False | ||
| if self._config.output_layer.distillation_model is not None: | ||
| if self._config.head.distillation_model is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checking for the decoder distillation model was a good idea, please make it check for both.
| target_format=target_format, | ||
| ) | ||
| _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref.chunk(world_size, 1)[rank], 1e-4) | ||
| _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref.chunk(world_size, 1)[rank], 1e-4, 1e-6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be needed. The gradients are O(1e-6), so using 1e-6 as an absolute threshold amounts to dropping the check entirely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
tests/utils/dataset.py
Outdated
|
|
||
| config = ( | ||
| {"type": "file", "path": config_paths[0]} | ||
| {"type": "file", "path": config_paths[0]} # TODO: shouldn't this be {"training": {...}}? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No this is a dataset config. The "training" part is for a Data config, see for example https://github.com/ServiceNow/Fast-LLM/blob/main/tests/data/test_blending.py#L151
tests/utils/dataset.py
Outdated
| ) | ||
|
|
||
|
|
||
| def get_dataset_with_loss_masking_spans( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to go with get_model_test_dataset and match the parametrization (vocab size). I recommend just adding spans to get_model_test_dataset instead since it does nothing unless used.
| }, | ||
| ) | ||
|
|
||
| _update_and_add_testing_config( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recommend adding to one or more of the existing distillation configs instead, it will cover more code and avoid the extra config.
|
The distributed gradients still don't seem to match the, here is an example with low thresholds (for demonstration) showing an error above 10%: As a comparison, cross-entropy has an error of ~1e-5 ( The single-gpu version seems to be working though, also within ~1e-5 |
| with torch.enable_grad(): | ||
| logits_ = logits.float().detach().requires_grad_(grad_output is not None) | ||
| student_log_probs = distributed_log_softmax(logits_, group=group) | ||
| # logits_ = logits.float()#.detach().requires_grad_(grad_output is not None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove?
| torch.manual_seed(0) | ||
| world_size = torch.distributed.get_world_size(group) | ||
| logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) | ||
| logits, target, loss_mask = _get_cross_entropy_inputs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just use @requires_cuda?
| return _get_test_dataset( | ||
| DATASET_CACHE / "model_dataset", | ||
| seed=1234, | ||
| max_loss_masking_spans=5 if use_loss_masking else 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work, the dataset is cached so the argument needs to be constant. max_loss_masking_spans=5 should work
✨ Description
🔍 Type of change
Select all that apply:
📝 Changes
List the key changes introduced in this PR:
✅ Checklist
Make sure the following tasks are completed before submitting the PR:
General
Dependencies and Configuration
Testing
Performance Impact
📊 Performance Impact Details
If there is any impact on performance, describe it and provide benchmark results, if applicable:
🗒️ Additional Notes
Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.