File tree 2 files changed +8
-1
lines changed
proteinworkshop/models/graph_encoders/layers
2 files changed +8
-1
lines changed Original file line number Diff line number Diff line change 16
16
* Add support for handling backward OOMs gracefully [ #83 ] ( https://github.com/a-r-j/ProteinWorkshop/pull/83 )
17
17
* Update GCPNet paper link [ #85 ] ( https://github.com/a-r-j/ProteinWorkshop/pull/85 )
18
18
* Add ability for ` BenchmarkModel ` to have its decoder disabled [ #101 ] ( https://github.com/a-r-j/ProteinWorkshop/pull/101 )
19
+ * Fix dtype mismatch in ` gcp.py ` that broke Automatic Mixed Precision (AMP) training [ #102 ] ( https://github.com/a-r-j/ProteinWorkshop/pull/102 )
19
20
20
21
### Framework
21
22
Original file line number Diff line number Diff line change @@ -265,14 +265,20 @@ def scalarize(
265
265
266
266
if node_mask is not None :
267
267
edge_mask = node_mask [row ] & node_mask [col ]
268
+ # Initialize destination tensor
268
269
local_scalar_rep_i = torch .zeros (
269
270
(edge_index .shape [1 ], 3 , 3 ), device = edge_index .device
270
271
)
271
- local_scalar_rep_i [edge_mask ] = torch .matmul (
272
+ # Calculate the source value (result of matmul, likely Half under AMP)
273
+ matmul_result = torch .matmul (
272
274
frames [edge_mask ], vector_rep_i [edge_mask ]
273
275
)
276
+ # Explicitly cast the source value to the destination's dtype before assignment
277
+ local_scalar_rep_i [edge_mask ] = matmul_result .to (local_scalar_rep_i .dtype )
278
+
274
279
local_scalar_rep_i = local_scalar_rep_i .transpose (- 1 , - 2 )
275
280
else :
281
+ # This path might need similar treatment if it causes issues
276
282
local_scalar_rep_i = torch .matmul (frames , vector_rep_i ).transpose (- 1 , - 2 )
277
283
278
284
# potentially enable E(3)-equivariance and, thereby, chirality-invariance
You can’t perform that action at this time.
0 commit comments