Skip to content

Commit

Permalink
Refactor code for centroid encoding and move layers classes to a dedi…
Browse files Browse the repository at this point in the history
…cated module
  • Loading branch information
althonos committed Nov 25, 2023
1 parent 43f97b6 commit cf378fb
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 109 deletions.
57 changes: 9 additions & 48 deletions mini3di/_unkerasify.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
from __future__ import annotations

import abc
import enum
import functools
import itertools
import struct
import typing
from typing import BinaryIO, Iterable
from typing import BinaryIO, Iterable, List

import numpy

from .utils import relu

if typing.TYPE_CHECKING:
from .utils import ArrayNxM
from .layers import Layer, DenseLayer


class LayerType(enum.IntEnum):
Expand All @@ -36,46 +31,6 @@ class ActivationType(enum.IntEnum):
HARD_SIGMOID = 6


class Layer(abc.ABC):
@abc.abstractmethod
def __call__(self, X: ArrayNxM[numpy.floating]) -> ArrayNxM[numpy.floating]:
raise NotImplementedError


class DenseLayer(Layer):
def __init__(self, weights, biases=None, activation=ActivationType.RELU):
self.activation = activation
self.weights = numpy.asarray(weights)
if biases is None:
self.biases = numpy.zeros(self.weights.shape[1])
else:
self.biases = numpy.asarray(biases)

def __call__(self, X: ArrayNxM[numpy.floating]) -> ArrayNxM[numpy.floating]:
_X = numpy.asarray(X)
out = _X @ self.weights
out += self.biases

if self.activation == ActivationType.RELU:
return relu(out, out=out)
else:
return out


class Model:

@classmethod
def load(cls, f: BinaryIO) -> Model:
parser = KerasifyParser(f)
return cls(parser)

def __init__(self, layers: Iterable[Layer] = ()):
self.layers = list(layers)

def __call__(self, X: ArrayNxM[numpy.floating]) -> ArrayNxM[numpy.floating]:
return functools.reduce(lambda x, f: f(x), self.layers, X)


class KerasifyParser:
"""An incomplete parser for model files serialized with `kerasify`.
Expand Down Expand Up @@ -130,6 +85,12 @@ def read(self) -> typing.Optional[Layer]:
)
biases = numpy.frombuffer(self._read(f"={b0}f"), dtype="f4").copy()
activation = ActivationType(self._get("I")[0])
return DenseLayer(weights, biases, activation)
if activation not in (ActivationType.LINEAR, ActivationType.RELU):
raise NotImplementedError(f"Unsupported activation type: {activation!r}")
return DenseLayer(weights, biases, activation==ActivationType.RELU)
else:
raise NotImplementedError(f"Unsupported layer type: {layer_type!r}")


def load(fh: BinaryIO) -> List[Layer]:
return list(KerasifyParser(fh))
97 changes: 36 additions & 61 deletions mini3di/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import numpy
import numpy.ma

from . import _unkerasify
from .layers import Layer, CentroidLayer, Model
from .utils import normalize
from ._unkerasify import KerasifyParser, Layer, Model

try:
from importlib.resources import files as resource_files
Expand All @@ -25,60 +26,6 @@
DISTANCE_ALPHA_BETA = 1.5336
ALPHABET = numpy.array(list("ACDEFGHIKLMNPQRSTVWYX"))

class CentroidEncoder:
_CENTROIDS: ArrayNx2[numpy.float32] = numpy.array(
[
[-1.0729, -0.3600],
[-0.1356, -1.8914],
[0.4948, -0.4205],
[-0.9874, 0.8128],
[-1.6621, -0.4259],
[2.1394, 0.0486],
[1.5558, -0.1503],
[2.9179, 1.1437],
[-2.8814, 0.9956],
[-1.1400, -2.0068],
[3.2025, 1.7356],
[1.7769, -1.3037],
[0.6901, -1.2554],
[-1.1061, -1.3397],
[2.1495, -0.8030],
[2.3060, -1.4988],
[2.5522, 0.6046],
[0.7786, -2.1660],
[-2.3030, 0.3813],
[1.0290, 0.8772],
]
)

@classmethod
def load(cls, invalid_state: int = 2):
return cls(cls._CENTROIDS.copy(), invalid_state=invalid_state)

def __init__(
self,
centroids: ArrayNx2[numpy.float32],
invalid_state: int = 2
) -> None:
self.invalid_state = invalid_state
self.centroids = numpy.asarray(centroids)
self.r2 = numpy.sum(self.centroids**2, 1).reshape(-1, 1).T

def __call__(
self,
embeddings: ArrayNx2[numpy.floating],
mask: ArrayN[numpy.bool_]
) -> ArrayN[numpy.uint8]:
# compute pairwise squared distance matrix
r1 = numpy.sum(embeddings * embeddings, 1).reshape(-1, 1)
D = r1 - 2 * embeddings @ self.centroids.T + self.r2
# find closest centroid
states = numpy.empty(D.shape[0], dtype=numpy.uint8)
D.argmin(axis=1, out=states)
# use invalid state for masked residues
states[~mask] = self.invalid_state
return states


class _BaseEncoder(abc.ABC, typing.Generic[T]):
@abc.abstractmethod
Expand Down Expand Up @@ -299,11 +246,39 @@ def encode_atoms(


class Encoder(_BaseEncoder["ArrayN[numpy.uint8]"]):

_INVALID_STATE = 2
_CENTROIDS: ArrayNx2[numpy.float32] = numpy.array(
[
[-1.0729, -0.3600],
[-0.1356, -1.8914],
[0.4948, -0.4205],
[-0.9874, 0.8128],
[-1.6621, -0.4259],
[2.1394, 0.0486],
[1.5558, -0.1503],
[2.9179, 1.1437],
[-2.8814, 0.9956],
[-1.1400, -2.0068],
[3.2025, 1.7356],
[1.7769, -1.3037],
[0.6901, -1.2554],
[-1.1061, -1.3397],
[2.1495, -0.8030],
[2.3060, -1.4988],
[2.5522, 0.6046],
[0.7786, -2.1660],
[-2.3030, 0.3813],
[1.0290, 0.8772],
]
)

def __init__(self) -> None:
self.feature_encoder = FeatureEncoder()
with resource_files(__package__).joinpath("encoder_weights_3di.kerasify").open("rb") as f:
self.vae_encoder = Model.load(f)
self.centroid_encoder = CentroidEncoder.load()
layers = _unkerasify.load(f)
layers.append(CentroidLayer(self._CENTROIDS))
self.vae_encoder = Model(layers)

def encode_atoms(
self,
Expand All @@ -313,12 +288,12 @@ def encode_atoms(
c: ArrayNx3[numpy.floating],
) -> ArrayN[numpy.uint8]:
descriptors = self.feature_encoder.encode_atoms(ca, cb, n, c)
embeddings = self.vae_encoder(descriptors.data)
states = self.centroid_encoder(embeddings, ~descriptors.mask[:, 0])
states = self.vae_encoder(descriptors.data)
states[descriptors.mask[:, 0]] = self._INVALID_STATE
return numpy.ma.masked_array(
states,
mask=~descriptors.mask[:, 0],
fill_value=self.centroid_encoder.invalid_state
mask=descriptors.mask[:, 0],
fill_value=self._INVALID_STATE,
)

def build_sequence(self, states: ArrayN[numpy.uint8]) -> str:
Expand Down
67 changes: 67 additions & 0 deletions mini3di/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""A mini implementation of the neural network layers used in ``foldseek``.
"""
from __future__ import annotations

import abc
import functools
import typing
from typing import Iterable, Optional

import numpy

from .utils import relu

if typing.TYPE_CHECKING:
from .utils import ArrayNxM, ArrayN


class Layer(abc.ABC):
@abc.abstractmethod
def __call__(self, X: ArrayNxM[numpy.floating]) -> ArrayNxM[numpy.floating]:
raise NotImplementedError


class DenseLayer(Layer):
def __init__(
self,
weights: ArrayNxM[numpy.floating],
biases: Optional[ArrayN[numpy.floating]] = None,
activation: bool = True
):
self.activation = activation
self.weights = numpy.asarray(weights)
if biases is None:
self.biases = numpy.zeros(self.weights.shape[1])
else:
self.biases = numpy.asarray(biases)

def __call__(self, X: ArrayNxM[numpy.floating]) -> ArrayNxM[numpy.floating]:
_X = numpy.asarray(X)
out = _X @ self.weights
out += self.biases

if self.activation:
return relu(out, out=out)
else:
return out


class CentroidLayer:
def __init__(self, centroids: ArrayNxM[numpy.floating]) -> None:
self.centroids = centroids
self.r2 = numpy.sum(self.centroids**2, axis=1).reshape(-1, 1).T
def __call__(self, X: ArrayNxM[numpy.floating]) -> ArrayN[numpy.uint8]:
# compute pairwise squared distance matrix
r1 = numpy.sum(X * X, 1).reshape(-1, 1)
D = r1 - 2 * X @ self.centroids.T + self.r2
# find closest centroid
states = numpy.empty(D.shape[0], dtype=numpy.uint8)
D.argmin(axis=1, out=states)
return states


class Model:
def __init__(self, layers: Iterable[Layer] = ()):
self.layers = list(layers)
def __call__(self, X: ArrayNxM[numpy.floating]) -> ArrayNxM[numpy.floating]:
return functools.reduce(lambda x, f: f(x), self.layers, X)

0 comments on commit cf378fb

Please sign in to comment.