Skip to content

Commit 1d2bda0

Browse files
committed
new utils for scn
1 parent 58c9fc6 commit 1d2bda0

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

mp_nerf/proteins.py

+20
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,26 @@ def build_scaffolds_from_scn_angles(seq, angles=None, coords=None, device="auto"
162162
#############################
163163

164164

165+
def modify_angles_mask_with_torsions(seq, angles_mask, torsions):
166+
""" Modifies a torsion mask to include variable torsions.
167+
Inputs:
168+
* seq: (L,) str. FASTA sequence
169+
* angles_mask: (2, L, 14) float tensor of (angles, torsions)
170+
* torsions: (L, 4) float tensor (or (L, 5) if it includes torsion for cb)
171+
Outputs: (2, L, 14) a new angles mask
172+
"""
173+
c_beta = torsions.shape[-1] == 5 # whether c_beta torsion is passed as well
174+
start = 4 if c_beta else 5
175+
# get mask of to-fill values
176+
torsion_mask = torch.tensor([SUPREME_INFO[aa]["torsion_mask"] for aa in seq]).to(torsions.device) # (L, 14)
177+
torsion_mask = torsion_mask != torsion_mask # values that are nan need replace
178+
# undesired outside of margins
179+
torsion_mask[:, :start] = torsion_mask[:, start+torsions.shape[-1]:] = False
180+
181+
angles_mask[1, torsion_mask] = torsions[ torsion_mask[:, start:start+torsions.shape[-1]] ]
182+
return angles_mask
183+
184+
165185
def modify_scaffolds_with_coords(scaffolds, coords):
166186
""" Gets scaffolds and fills in the right data.
167187
Inputs:

tests/test_main.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,14 @@ def test_nerf_and_dihedral():
3232
# doesnt work because the scn angle was not measured correctly
3333
# so the method corrects that incorrection
3434
assert (mp_nerf_torch(a, b, c, l, theta, chi - np.pi) - torch.tensor([1,0,6])).sum().abs() < 0.1
35-
assert get_dihedral(a, b, c, d).item() == chi
35+
assert get_dihedral(a, b, c, d).item() == chi
36+
37+
38+
def test_modify_angles_mask_with_torsions():
39+
# create inputs
40+
seq = "AGHHKLHRTVNMSTIL"
41+
angles_mask = torch.randn(2, 16, 14)
42+
torsions = torch.ones(16, 4)
43+
# ensure shape
44+
assert modify_angles_mask_with_torsions(seq, angles_mask, torsions).shape == angles_mask.shape, \
45+
"Shapes don't match"

0 commit comments

Comments
 (0)