@@ -56,7 +56,7 @@ def rename_symmetric_atoms(pred_coors, true_coors, seq_list, cloud_mask, pred_fe
56
56
amb_idxs = np .array (pairs ["indexs" ]).flatten ().tolist ()
57
57
idxs = torch .tensor ([
58
58
k for k ,s in enumerate (seq ) if s == aa and \
59
- idx in set ( torch .nonzero (aux_cloud_mask [i , :, amb_idxs ].sum (dim = - 1 )).tolist () )
59
+ k in set ( torch .nonzero (aux_cloud_mask [i , :, amb_idxs ].sum (dim = - 1 )).tolist ()[ 0 ] )
60
60
]).long ()
61
61
# check if any AAs matching
62
62
if idxs .shape [0 ] == 0 :
@@ -115,7 +115,7 @@ def fape_torch(pred_coords, true_coords, max_val=10., l_func=None,
115
115
Outputs: (B, N_atoms)
116
116
"""
117
117
fape_store = []
118
- if l_func is not None :
118
+ if l_func is None :
119
119
l_func = lambda x ,y ,eps = 1e-7 ,sup = max_val : (((x - y )** 2 ).sum (dim = - 1 ) + eps ).sqrt ()
120
120
# for chain
121
121
for s in range (pred_coords .shape [0 ]):
@@ -144,8 +144,8 @@ def fape_torch(pred_coords, true_coords, max_val=10., l_func=None,
144
144
145
145
# measure errors - for residue
146
146
for i ,rot_mat in enumerate (rot_mats ):
147
- fape_store [s ] += l1 ( pred_center [s ][mask_center [s ]] @ rot_mat ,
148
- true_center [s ][mask_center [s ]]
147
+ fape_store [s ] += l_func ( pred_center [s ][mask_center [s ]] @ rot_mat ,
148
+ true_center [s ][mask_center [s ]]
149
149
).clamp (0 , max_val )
150
150
fape_store [s ] /= rot_mats .shape [0 ]
151
151
0 commit comments