Skip to content

Commit da7cfe6

Browse files
authored
Merge pull request #102 from mahdip72/main
Enable AMP Compatibility for GCP Model to Reduce VRAM Usage
2 parents 340885e + 4b4f3d0 commit da7cfe6

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
* Add support for handling backward OOMs gracefully [#83](https://github.com/a-r-j/ProteinWorkshop/pull/83)
1717
* Update GCPNet paper link [#85](https://github.com/a-r-j/ProteinWorkshop/pull/85)
1818
* Add ability for `BenchmarkModel` to have its decoder disabled [#101](https://github.com/a-r-j/ProteinWorkshop/pull/101)
19+
* Fix dtype mismatch in `gcp.py` that broke Automatic Mixed Precision (AMP) training [#102](https://github.com/a-r-j/ProteinWorkshop/pull/102)
1920

2021
### Framework
2122

proteinworkshop/models/graph_encoders/layers/gcp.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -265,14 +265,20 @@ def scalarize(
265265

266266
if node_mask is not None:
267267
edge_mask = node_mask[row] & node_mask[col]
268+
# Initialize destination tensor
268269
local_scalar_rep_i = torch.zeros(
269270
(edge_index.shape[1], 3, 3), device=edge_index.device
270271
)
271-
local_scalar_rep_i[edge_mask] = torch.matmul(
272+
# Calculate the source value (result of matmul, likely Half under AMP)
273+
matmul_result = torch.matmul(
272274
frames[edge_mask], vector_rep_i[edge_mask]
273275
)
276+
# Explicitly cast the source value to the destination's dtype before assignment
277+
local_scalar_rep_i[edge_mask] = matmul_result.to(local_scalar_rep_i.dtype)
278+
274279
local_scalar_rep_i = local_scalar_rep_i.transpose(-1, -2)
275280
else:
281+
# This path might need similar treatment if it causes issues
276282
local_scalar_rep_i = torch.matmul(frames, vector_rep_i).transpose(-1, -2)
277283

278284
# potentially enable E(3)-equivariance and, thereby, chirality-invariance

0 commit comments

Comments
 (0)