diff --git a/mini3di/__init__.py b/mini3di/__init__.py index 79966a7..f2d21d7 100644 --- a/mini3di/__init__.py +++ b/mini3di/__init__.py @@ -1,10 +1,10 @@ """A NumPy port of the ``foldseek`` code for encoding structures to 3di. """ -__all__ = ["Encoder", "FeatureEncoder"] +__all__ = ["Encoder", "FeatureEncoder", "PartnerIndexEncoder", "VirtualCenterEncoder"] __version__ = "0.1.0" __author__ = "Martin Larralde " __license__ = "GPLv3" __credits__ = "Martin Steinegger and his lab for ``foldseek``." -from .encoder import Encoder, FeatureEncoder +from .encoder import Encoder, FeatureEncoder, PartnerIndexEncoder, VirtualCenterEncoder diff --git a/mini3di/encoder.py b/mini3di/encoder.py index 68783ec..55d15c0 100644 --- a/mini3di/encoder.py +++ b/mini3di/encoder.py @@ -71,7 +71,7 @@ def encode_chain( if ca_residue: residues = [residue for residue in chain.get_residues() if "CA" in residue] else: - residues = chain.get_residues() + residues = list(chain.get_residues()) # extract atom coordinates r = len(residues) ca = numpy.array(numpy.nan, dtype=numpy.float32).repeat(3 * r).reshape(r, 3) @@ -79,7 +79,8 @@ def encode_chain( n = ca.copy() c = ca.copy() for i, residue in enumerate(residues): - ca[i, :] = residue["CA"].coord + if "CA" in residue: + ca[i, :] = residue["CA"].coord if "N" in residue: n[i, :] = residue["N"].coord if "C" in residue: @@ -250,12 +251,52 @@ def encode_atoms( ) +class PartnerIndexEncoder(_BaseEncoder["ArrayN[numpy.int64]"]): + """An encoder for converting a protein structure to partner indices. + + For each residue, the coordinates of the virtual center are computed + from the coordinates of the *Cα*, *Cβ* and *N* atoms. A pairwise + distance matrix is then created, and the index of the closest partner + residue is extracted for each position. + + """ + + def __init__(self) -> None: + self.vc_encoder = VirtualCenterEncoder() + + def _find_residue_partners( + self, + x: ArrayNx3[numpy.floating], + ) -> ArrayN[numpy.int64]: + # compute pairwise squared distance matrix + r = numpy.sum(x * x, axis=-1).reshape(-1, 1) + r[0] = r[-1] = numpy.nan + D = r - 2 * numpy.ma.dot(x, x.T) + r.T + # avoid selecting residue itself as the best + D[numpy.diag_indices_from(D)] = numpy.inf + # get the closest non-masked residue + return numpy.nan_to_num(D, copy=False, nan=numpy.inf).argmin(axis=1) + + def encode_atoms( + self, + ca: ArrayNx3[numpy.floating], + cb: ArrayNx3[numpy.floating], + n: ArrayNx3[numpy.floating], + c: ArrayNx3[numpy.floating], + ) -> ArrayN[numpy.int64]: + # encode backbone atoms to virtual center + vc = self.vc_encoder.encode_atoms(ca, cb, n, c) + # find closest neighbor for each residue + return self._find_residue_partners(vc) + + class FeatureEncoder(_BaseEncoder["ArrayN[numpy.float32]"]): """An encoder for converting a protein structure to structural descriptors. """ def __init__(self) -> None: - self.vc_encoder = VirtualCenterEncoder() + self.partner_index_encoder = PartnerIndexEncoder() + self.vc_encoder = self.partner_index_encoder.vc_encoder def _calc_conformation_descriptors( self, @@ -286,19 +327,6 @@ def _calc_conformation_descriptors( desc[I, 9] = numpy.copysign(numpy.log(numpy.abs(J - I) + 1), J - I) return desc - def _find_residue_partners( - self, - x: ArrayNx3[numpy.floating], - ) -> ArrayN[numpy.int64]: - # compute pairwise squared distance matrix - r = numpy.sum(x * x, axis=-1).reshape(-1, 1) - r[0] = r[-1] = numpy.nan - D = r - 2 * numpy.ma.dot(x, x.T) + r.T - # avoid selecting residue itself as the best - D[numpy.diag_indices_from(D)] = numpy.inf - # get the closest non-masked residue - return numpy.nan_to_num(D, copy=False, nan=numpy.inf).argmin(axis=1) - def _create_descriptor_mask( self, mask: ArrayN[numpy.bool_], @@ -320,10 +348,10 @@ def encode_atoms( n: ArrayNx3[numpy.floating], c: ArrayNx3[numpy.floating], ) -> ArrayN[numpy.uint8]: - # compute the virtual center form the backbone atoms + # encode backbone atoms to virtual center vc = self.vc_encoder.encode_atoms(ca, cb, n, c) # find closest neighbor for each residue - partner_index = self._find_residue_partners(vc) + partner_index = self.partner_index_encoder._find_residue_partners(vc) # build position features from residue angles descriptors = self._calc_conformation_descriptors(ca, partner_index) # create mask