Skip to content

Commit 760cfe4

Browse files
author
Lee, Kin Long Kelvin
committed
refactor: updating shape checking for MaceEnergyForceTask as well
1 parent 875f630 commit 760cfe4

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

matsciml/models/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1584,7 +1584,16 @@ def _compute_losses(
15841584
else:
15851585
coefficient = self.loss_coeff[key]
15861586

1587-
losses[key] = self.loss_func(predictions[key], target_val) * (
1587+
model_outputs = predictions[key]
1588+
# attempt to reshape model outputs for correct broadcasting
1589+
if model_outputs.shape != target_val.shape:
1590+
try:
1591+
model_outputs = model_outputs.reshape(target_val.shape)
1592+
except RuntimeError as e:
1593+
raise RuntimeError(
1594+
f"Unable to reconile shapes for preds/labels; preds: {model_outputs.shape}, labels: {target_val.shape}."
1595+
) from e
1596+
losses[key] = self.loss_func(model_outputs, target_val) * (
15881597
coefficient / predictions[key].numel()
15891598
)
15901599
total_loss: torch.Tensor = sum(losses.values())

0 commit comments

Comments
 (0)