Skip to content

Commit 3d66b45

Browse files
author
Kelvin Lee
authored
Merge pull request #347 from laserkelvin/loss-broadcasting-fix
Loss broadcasting fix
2 parents 32d690c + 760cfe4 commit 3d66b45

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

matsciml/models/base.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,9 +1103,18 @@ def _compute_losses(
11031103
loss_func_signature = signature(loss_func.forward).parameters
11041104
# TODO refactor this once outputs are homogenized
11051105
if isinstance(predictions, dict):
1106-
kwargs = {"input": predictions[key], "target": target_val}
1106+
model_outputs = predictions[key]
11071107
else:
1108-
kwargs = {"input": getattr(predictions, key), "target": target_val}
1108+
model_outputs = getattr(predictions, key)
1109+
# to ensure broadcasting works correctly
1110+
if model_outputs.shape != target_val.shape:
1111+
try:
1112+
model_outputs = model_outputs.reshape(target_val.shape)
1113+
except RuntimeError as e:
1114+
raise RuntimeError(
1115+
f"Unable to reconcile prediction/label shapes; preds: {model_outputs.shape}, labels: {target_val.shape}"
1116+
) from e
1117+
kwargs = {"input": model_outputs, "target": target_val}
11091118
if not isinstance(kwargs["input"], torch.Tensor):
11101119
raise KeyError(f"Expected model to produce output with key {key}.")
11111120
# pack atoms per graph information too
@@ -1575,7 +1584,16 @@ def _compute_losses(
15751584
else:
15761585
coefficient = self.loss_coeff[key]
15771586

1578-
losses[key] = self.loss_func(predictions[key], target_val) * (
1587+
model_outputs = predictions[key]
1588+
# attempt to reshape model outputs for correct broadcasting
1589+
if model_outputs.shape != target_val.shape:
1590+
try:
1591+
model_outputs = model_outputs.reshape(target_val.shape)
1592+
except RuntimeError as e:
1593+
raise RuntimeError(
1594+
f"Unable to reconile shapes for preds/labels; preds: {model_outputs.shape}, labels: {target_val.shape}."
1595+
) from e
1596+
losses[key] = self.loss_func(model_outputs, target_val) * (
15791597
coefficient / predictions[key].numel()
15801598
)
15811599
total_loss: torch.Tensor = sum(losses.values())

0 commit comments

Comments
 (0)