Skip to content

Commit

Permalink
Add a BaseEncoder to encode residues to partner indices
Browse files Browse the repository at this point in the history
  • Loading branch information
althonos committed May 10, 2024
1 parent 2a20286 commit 5053316
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 20 deletions.
4 changes: 2 additions & 2 deletions mini3di/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""A NumPy port of the ``foldseek`` code for encoding structures to 3di.
"""

__all__ = ["Encoder", "FeatureEncoder"]
__all__ = ["Encoder", "FeatureEncoder", "PartnerIndexEncoder", "VirtualCenterEncoder"]
__version__ = "0.1.0"
__author__ = "Martin Larralde <[email protected]>"
__license__ = "GPLv3"
__credits__ = "Martin Steinegger and his lab for ``foldseek``."

from .encoder import Encoder, FeatureEncoder
from .encoder import Encoder, FeatureEncoder, PartnerIndexEncoder, VirtualCenterEncoder
64 changes: 46 additions & 18 deletions mini3di/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,16 @@ def encode_chain(
if ca_residue:
residues = [residue for residue in chain.get_residues() if "CA" in residue]
else:
residues = chain.get_residues()
residues = list(chain.get_residues())
# extract atom coordinates
r = len(residues)
ca = numpy.array(numpy.nan, dtype=numpy.float32).repeat(3 * r).reshape(r, 3)
cb = ca.copy()
n = ca.copy()
c = ca.copy()
for i, residue in enumerate(residues):
ca[i, :] = residue["CA"].coord
if "CA" in residue:
ca[i, :] = residue["CA"].coord
if "N" in residue:
n[i, :] = residue["N"].coord
if "C" in residue:
Expand Down Expand Up @@ -250,12 +251,52 @@ def encode_atoms(
)


class PartnerIndexEncoder(_BaseEncoder["ArrayN[numpy.int64]"]):
"""An encoder for converting a protein structure to partner indices.
For each residue, the coordinates of the virtual center are computed
from the coordinates of the *Cα*, *Cβ* and *N* atoms. A pairwise
distance matrix is then created, and the index of the closest partner
residue is extracted for each position.
"""

def __init__(self) -> None:
self.vc_encoder = VirtualCenterEncoder()

def _find_residue_partners(
self,
x: ArrayNx3[numpy.floating],
) -> ArrayN[numpy.int64]:
# compute pairwise squared distance matrix
r = numpy.sum(x * x, axis=-1).reshape(-1, 1)
r[0] = r[-1] = numpy.nan
D = r - 2 * numpy.ma.dot(x, x.T) + r.T
# avoid selecting residue itself as the best
D[numpy.diag_indices_from(D)] = numpy.inf
# get the closest non-masked residue
return numpy.nan_to_num(D, copy=False, nan=numpy.inf).argmin(axis=1)

def encode_atoms(
self,
ca: ArrayNx3[numpy.floating],
cb: ArrayNx3[numpy.floating],
n: ArrayNx3[numpy.floating],
c: ArrayNx3[numpy.floating],
) -> ArrayN[numpy.int64]:
# encode backbone atoms to virtual center
vc = self.vc_encoder.encode_atoms(ca, cb, n, c)
# find closest neighbor for each residue
return self._find_residue_partners(vc)


class FeatureEncoder(_BaseEncoder["ArrayN[numpy.float32]"]):
"""An encoder for converting a protein structure to structural descriptors.
"""

def __init__(self) -> None:
self.vc_encoder = VirtualCenterEncoder()
self.partner_index_encoder = PartnerIndexEncoder()
self.vc_encoder = self.partner_index_encoder.vc_encoder

def _calc_conformation_descriptors(
self,
Expand Down Expand Up @@ -286,19 +327,6 @@ def _calc_conformation_descriptors(
desc[I, 9] = numpy.copysign(numpy.log(numpy.abs(J - I) + 1), J - I)
return desc

def _find_residue_partners(
self,
x: ArrayNx3[numpy.floating],
) -> ArrayN[numpy.int64]:
# compute pairwise squared distance matrix
r = numpy.sum(x * x, axis=-1).reshape(-1, 1)
r[0] = r[-1] = numpy.nan
D = r - 2 * numpy.ma.dot(x, x.T) + r.T
# avoid selecting residue itself as the best
D[numpy.diag_indices_from(D)] = numpy.inf
# get the closest non-masked residue
return numpy.nan_to_num(D, copy=False, nan=numpy.inf).argmin(axis=1)

def _create_descriptor_mask(
self,
mask: ArrayN[numpy.bool_],
Expand All @@ -320,10 +348,10 @@ def encode_atoms(
n: ArrayNx3[numpy.floating],
c: ArrayNx3[numpy.floating],
) -> ArrayN[numpy.uint8]:
# compute the virtual center form the backbone atoms
# encode backbone atoms to virtual center
vc = self.vc_encoder.encode_atoms(ca, cb, n, c)
# find closest neighbor for each residue
partner_index = self._find_residue_partners(vc)
partner_index = self.partner_index_encoder._find_residue_partners(vc)
# build position features from residue angles
descriptors = self._calc_conformation_descriptors(ca, partner_index)
# create mask
Expand Down

0 comments on commit 5053316

Please sign in to comment.