File tree Expand file tree Collapse file tree 1 file changed +10
-1
lines changed Expand file tree Collapse file tree 1 file changed +10
-1
lines changed Original file line number Diff line number Diff line change @@ -1584,7 +1584,16 @@ def _compute_losses(
1584
1584
else :
1585
1585
coefficient = self .loss_coeff [key ]
1586
1586
1587
- losses [key ] = self .loss_func (predictions [key ], target_val ) * (
1587
+ model_outputs = predictions [key ]
1588
+ # attempt to reshape model outputs for correct broadcasting
1589
+ if model_outputs .shape != target_val .shape :
1590
+ try :
1591
+ model_outputs = model_outputs .reshape (target_val .shape )
1592
+ except RuntimeError as e :
1593
+ raise RuntimeError (
1594
+ f"Unable to reconile shapes for preds/labels; preds: { model_outputs .shape } , labels: { target_val .shape } ."
1595
+ ) from e
1596
+ losses [key ] = self .loss_func (model_outputs , target_val ) * (
1588
1597
coefficient / predictions [key ].numel ()
1589
1598
)
1590
1599
total_loss : torch .Tensor = sum (losses .values ())
You can’t perform that action at this time.
0 commit comments