Skip to content

Commit 58c9fc6

Browse files
committed
make sure tests pass
1 parent 2ce9300 commit 58c9fc6

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

mp_nerf/ml_utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def rename_symmetric_atoms(pred_coors, true_coors, seq_list, cloud_mask, pred_fe
5656
amb_idxs = np.array(pairs["indexs"]).flatten().tolist()
5757
idxs = torch.tensor([
5858
k for k,s in enumerate(seq) if s==aa and \
59-
idx in set( torch.nonzero(aux_cloud_mask[i, :, amb_idxs].sum(dim=-1)).tolist() )
59+
k in set( torch.nonzero(aux_cloud_mask[i, :, amb_idxs].sum(dim=-1)).tolist()[0] )
6060
]).long()
6161
# check if any AAs matching
6262
if idxs.shape[0] == 0:
@@ -115,7 +115,7 @@ def fape_torch(pred_coords, true_coords, max_val=10., l_func=None,
115115
Outputs: (B, N_atoms)
116116
"""
117117
fape_store = []
118-
if l_func is not None:
118+
if l_func is None:
119119
l_func = lambda x,y,eps=1e-7,sup=max_val: (((x-y)**2).sum(dim=-1) + eps).sqrt()
120120
# for chain
121121
for s in range(pred_coords.shape[0]):
@@ -144,8 +144,8 @@ def fape_torch(pred_coords, true_coords, max_val=10., l_func=None,
144144

145145
# measure errors - for residue
146146
for i,rot_mat in enumerate(rot_mats):
147-
fape_store[s] += l1( pred_center[s][mask_center[s]] @ rot_mat,
148-
true_center[s][mask_center[s]]
147+
fape_store[s] += l_func( pred_center[s][mask_center[s]] @ rot_mat,
148+
true_center[s][mask_center[s]]
149149
).clamp(0, max_val)
150150
fape_store[s] /= rot_mats.shape[0]
151151

tests/test_ml_utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@ def test_rename_symmetric_atoms():
2626
seq_list = ["AGHHKLHRTVNMSTIL"]
2727
pred_coors = torch.randn(1, 16, 14, 3)
2828
true_coors = torch.randn(1, 16, 14, 3)
29-
cloud_mask = scn_cloud_mask(seq_list[0]).unsqueeze(-1)
29+
cloud_mask = scn_cloud_mask(seq_list[0]).unsqueeze(0)
3030
pred_feats = torch.randn(1, 16, 14, 16)
3131

32+
print(cloud_mask.shape)
33+
3234
renamed = rename_symmetric_atoms(pred_coors, true_coors, seq_list, cloud_mask, pred_feats=pred_feats)
3335
assert renamed[0].shape == pred_coors.shape and renamed[1].shape == pred_feats.shape, "Shapes don't match"
3436

0 commit comments

Comments
 (0)