@@ -1103,9 +1103,18 @@ def _compute_losses(
1103
1103
loss_func_signature = signature (loss_func .forward ).parameters
1104
1104
# TODO refactor this once outputs are homogenized
1105
1105
if isinstance (predictions , dict ):
1106
- kwargs = { "input" : predictions [key ], "target" : target_val }
1106
+ model_outputs = predictions [key ]
1107
1107
else :
1108
- kwargs = {"input" : getattr (predictions , key ), "target" : target_val }
1108
+ model_outputs = getattr (predictions , key )
1109
+ # to ensure broadcasting works correctly
1110
+ if model_outputs .shape != target_val .shape :
1111
+ try :
1112
+ model_outputs = model_outputs .reshape (target_val .shape )
1113
+ except RuntimeError as e :
1114
+ raise RuntimeError (
1115
+ f"Unable to reconcile prediction/label shapes; preds: { model_outputs .shape } , labels: { target_val .shape } "
1116
+ ) from e
1117
+ kwargs = {"input" : model_outputs , "target" : target_val }
1109
1118
if not isinstance (kwargs ["input" ], torch .Tensor ):
1110
1119
raise KeyError (f"Expected model to produce output with key { key } ." )
1111
1120
# pack atoms per graph information too
@@ -1575,7 +1584,16 @@ def _compute_losses(
1575
1584
else :
1576
1585
coefficient = self .loss_coeff [key ]
1577
1586
1578
- losses [key ] = self .loss_func (predictions [key ], target_val ) * (
1587
+ model_outputs = predictions [key ]
1588
+ # attempt to reshape model outputs for correct broadcasting
1589
+ if model_outputs .shape != target_val .shape :
1590
+ try :
1591
+ model_outputs = model_outputs .reshape (target_val .shape )
1592
+ except RuntimeError as e :
1593
+ raise RuntimeError (
1594
+ f"Unable to reconile shapes for preds/labels; preds: { model_outputs .shape } , labels: { target_val .shape } ."
1595
+ ) from e
1596
+ losses [key ] = self .loss_func (model_outputs , target_val ) * (
1579
1597
coefficient / predictions [key ].numel ()
1580
1598
)
1581
1599
total_loss : torch .Tensor = sum (losses .values ())
0 commit comments