Skip to content

Loss broadcasting fix #347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 18, 2025
Merged

Conversation

laserkelvin
Copy link

This PR attempts to correct broadcasting issues due to shape mismatches in loss calculations.

This was brought on by the realization that broadcasting works a little differently (from some time ago) when tensor shapes are mismatched. In particular, labels come out of the pipeline with an extra dimension (e.g. [N, 1]) compared to graph readouts. The resulting behavior is actually very different from the intention:

>>> y
tensor([0.0551, 0.2665, 0.4638, 0.3288, 0.1201, 0.1515, 0.9187, 0.4527])
>>> x
tensor([[0.0961],
        [0.6194],
        [0.3628],
        [0.5289],
        [0.2046],
        [0.0989],
        [0.6621],
        [0.8717]])
>>> from torch.nn import MSELoss
>>> MSELoss()(y, x)
/python3.12/site-packages/torch/nn/modules/loss.py:610: UserWarning: Using a target size (torch.Size([8, 1])) that is different to the input size (torch.Size([8])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
tensor(0.1454)
>>> MSELoss()(y.view(-1, 1), x)
tensor(0.0535)

The code changes make it so that _compute_losses methods will check if model output and label shapes are mismatched, and if they are, attempt to reshape the model outputs to match the labels' shape before computing the loss.

@laserkelvin laserkelvin added the bug Something isn't working label Mar 17, 2025
@laserkelvin laserkelvin merged commit 3d66b45 into IntelLabs:main Mar 18, 2025
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants