CompositeLossMetrics now performs a weighted sum of losses.#1251
Merged
CompositeLossMetrics now performs a weighted sum of losses.#1251
CompositeLossMetrics now performs a weighted sum of losses.#1251Conversation
Contributor
Author
|
@markblee Could you take a look? From 1399 |
markblee
reviewed
Jun 10, 2025
Contributor
markblee
left a comment
There was a problem hiding this comment.
(Will approve after the internal review completes.)
00d1611 to
29f13f7
Compare
Currently, `CompositeLossMetrics` sums the losses without considering their weights (i.e., the number of live targets). To make this a weighted sum, downstream code has been implementing `CompositeLossWeights` to inject the number of live targets into `loss_weights`. This is essentially patching a surprising logic (initail loss sum) with complex logic (CompositeLossWeights) into a straightforward one (weighted sum). Therefore, we’re changing the default loss aggregation logic to be straightforward from the beginning. From now on, our standarized loss aggregation logic is ``` loss = sum(each_loss_weight * each_loss * num_each_samples) / sum(each_loss_weight * num_each_samples) ``` Historically, the complex logic was introduced because the weights of losses returned by child metrics were unknown. But now that child metrics return losses as `WeightedScalar`, we can adopt a simpler, cleaner aggregation logic. Note: alternative formulation could be ``` loss = sum(each_loss_weight * each_loss * num_each_samples) / sum(num_each_samples) ``` However, when num_each_samples is large and each_loss_weight is small, the denominator can become disproportionately large. So we discard this option.
29f13f7 to
1bb0551
Compare
Contributor
Author
|
@markblee could you take a look again?
All reviewers approved internally at 23540 |
markblee
approved these changes
Jul 15, 2025
loofahcus
pushed a commit
to loofahcus/axlearn
that referenced
this pull request
Oct 11, 2025
…apple#1251)" (#1573) This reverts commit 343102a.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Currently,
CompositeLossMetricssums the losses without considering their weights (i.e., the number of live targets). To make this a weighted sum, downstream code has been implementingCompositeLossWeightsto inject the number of live targets intoloss_weights. This is essentially patching a surprising logic (initail loss sum) with complex logic (CompositeLossWeights) into a straightforward one (weighted sum).Therefore, we’re changing the default loss aggregation logic to be straightforward from the beginning.
From now on, our standarized loss aggregation logic is
Historically, the complex logic was introduced because the weights of losses returned by child metrics were unknown. But now that child metrics return losses as
WeightedScalar, we can adopt a simpler, cleaner aggregation logic.Note: alternative formulation could be
However, when num_each_samples is large and each_loss_weight is small, the denominator can become disproportionately large. So we discard this option.