Skip to content

Commit 99c6f81

Browse files
amoreheada-r-j
andauthored
Fix orientation batching (#58)
* Fix orientation batching * Vectorize orientations feature computation * Comment out debugging code * Simplify code * update changelog --------- Co-authored-by: Arian Jamasb <[email protected]>
1 parent 14e33c8 commit 99c6f81

File tree

5 files changed

+42
-9
lines changed

5 files changed

+42
-9
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
### Features
1111

1212
* Improves positional encoding performance by adding a `seq_pos` attribute on `Data/Protein` objects in the base dataset getter. [#53](https://github.com/a-r-j/ProteinWorkshop/pull/53/)
13+
* Ensure correct batched computation of orientation features. [#58](https://github.com/a-r-j/ProteinWorkshop/pull/58/)
1314

1415
### Models
1516

proteinworkshop/config/visualise.yaml

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# === 1. Set config parameters ===
44
name: "" # default name for the experiment, "" means logger (eg. wandb) will generate a unique name
5-
seed: 52 # seed for random number generators in pytorch, numpy and python.random
5+
seed: 52 # seed for random number generators in pytorch, numpy and python.random (as well as in UMAP)
66
num_workers: 16 # number of subprocesses to use for data loading.
77

88
# === 2. Specify defaults here. Defaults will be overwritten by equivalently named options in this file ===
@@ -29,7 +29,6 @@ compile: True
2929
# simply provide checkpoint path and plot filepath to embed dataset and plot its UMAP embeddings
3030
ckpt_path: null # path to checkpoint to load
3131
plot_filepath: null # path to which to save embeddings plot
32-
seed: 42 # random seed to be used by the UMAP algorithm
3332
use_cuda_device: True # if True, use an available CUDA device for embedding generation
3433
cuda_device_index: 0 # if CUDA devices are targeted and available, which available CUDA device to use for embedding generation
3534

proteinworkshop/datasets/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,9 @@ def create_example_batch(n: int = 4) -> ProteinBatch:
137137
batch.pos = batch.coords[:, 1, :]
138138
batch.x = F.one_hot(batch.residue_type, num_classes=23).float()
139139

140+
batch.x_vector_attr = orientations(batch.pos, batch._slice_dict["coords"])
140141
batch.graph_y = torch.randint(0, 2, (n, 1))
141142

142-
batch.x_vector_attr = orientations(batch.pos)
143143
batch.edge_attr = pos_emb(batch.edge_index, 9)
144144
batch.edge_vector_attr = _normalize(
145145
batch.pos[batch.edge_index[0]] - batch.pos[batch.edge_index[1]]

proteinworkshop/features/node_features.py

+39-5
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def compute_vector_node_features(
103103
vector_node_features = []
104104
for feature in vector_features:
105105
if feature == "orientation":
106-
vector_node_features.append(orientations(x.coords))
106+
vector_node_features.append(orientations(x.coords, x._slice_dict["coords"]))
107107
elif feature == "virtual_cb_vector":
108108
raise NotImplementedError("Virtual CB vector not implemented yet.")
109109
else:
@@ -149,12 +149,46 @@ def compute_surface_feat(
149149

150150
@jaxtyped(typechecker=typechecker)
151151
def orientations(
152-
X: Union[CoordTensor, AtomTensor], ca_idx: int = 1
152+
X: Union[CoordTensor, AtomTensor], coords_slice_index: torch.Tensor, ca_idx: int = 1
153153
) -> OrientationTensor:
154154
if X.ndim == 3:
155155
X = X[:, ca_idx, :]
156-
forward = _normalize(X[1:] - X[:-1])
157-
backward = _normalize(X[:-1] - X[1:])
156+
157+
# NOTE: the first item in the coordinates slice index is always 0,
158+
# and the last item is always the node count of the batch
159+
batch_num_nodes = X.shape[0]
160+
slice_index = coords_slice_index[1:] - 1
161+
last_node_index = slice_index[:-1]
162+
first_node_index = slice_index[:-1] + 1
163+
slice_mask = torch.zeros(batch_num_nodes - 1, dtype=torch.bool)
164+
last_node_forward_slice_mask = slice_mask.clone()
165+
first_node_backward_slice_mask = slice_mask.clone()
166+
167+
# NOTE: all of the last (first) nodes in a subgraph have their
168+
# forward (backward) vectors set to a padding value (i.e., 0.0)
169+
# to mimic feature construction behavior with single input graphs
170+
forward_slice = X[1:] - X[:-1]
171+
backward_slice = X[:-1] - X[1:]
172+
last_node_forward_slice_mask[last_node_index] = True
173+
first_node_backward_slice_mask[first_node_index - 1] = True # NOTE: for the backward slices, our indexing defaults to node index `1`
174+
forward_slice[last_node_forward_slice_mask] = 0.0 # NOTE: this handles all but the last node in the last subgraph
175+
backward_slice[first_node_backward_slice_mask] = 0.0 # NOTE: this handles all but the first node in the first subgraph
176+
177+
# NOTE: padding first and last nodes with zero vectors does not impact feature normalization
178+
forward = _normalize(forward_slice)
179+
backward = _normalize(backward_slice)
158180
forward = F.pad(forward, [0, 0, 0, 1])
159181
backward = F.pad(backward, [0, 0, 1, 0])
160-
return torch.cat((forward.unsqueeze(-2), backward.unsqueeze(-2)), dim=-2)
182+
orientations = torch.cat((forward.unsqueeze(-2), backward.unsqueeze(-2)), dim=-2)
183+
184+
# optionally debug/verify the orientations
185+
# last_node_indices = torch.cat((last_node_index, torch.tensor([batch_num_nodes - 1])), dim=0)
186+
# first_node_indices = torch.cat((torch.tensor([0]), first_node_index), dim=0)
187+
# intermediate_node_indices_mask = torch.ones(batch_num_nodes, device=X.device, dtype=torch.bool)
188+
# intermediate_node_indices_mask[last_node_indices] = False
189+
# intermediate_node_indices_mask[first_node_indices] = False
190+
# assert not orientations[last_node_indices][:, 0].any() and orientations[last_node_indices][:, 1].any()
191+
# assert orientations[first_node_indices][:, 0].any() and not orientations[first_node_indices][:, 1].any()
192+
# assert orientations[intermediate_node_indices_mask][:, 0].any() and orientations[intermediate_node_indices_mask][:, 1].any()
193+
194+
return orientations

proteinworkshop/train.py

-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import torch.nn as nn
1515
import torch_geometric
1616
from graphein.protein.tensor.dataloader import ProteinDataLoader
17-
from graphein.ml.datasets.foldcomp_dataset import FoldCompLightningDataModule
1817
from lightning.pytorch.callbacks import Callback
1918
from lightning.pytorch.loggers import Logger
2019
from loguru import logger as log

0 commit comments

Comments
 (0)