@@ -103,7 +103,7 @@ def compute_vector_node_features(
103
103
vector_node_features = []
104
104
for feature in vector_features :
105
105
if feature == "orientation" :
106
- vector_node_features .append (orientations (x .coords ))
106
+ vector_node_features .append (orientations (x .coords , x . _slice_dict [ "coords" ] ))
107
107
elif feature == "virtual_cb_vector" :
108
108
raise NotImplementedError ("Virtual CB vector not implemented yet." )
109
109
else :
@@ -149,12 +149,46 @@ def compute_surface_feat(
149
149
150
150
@jaxtyped (typechecker = typechecker )
151
151
def orientations (
152
- X : Union [CoordTensor , AtomTensor ], ca_idx : int = 1
152
+ X : Union [CoordTensor , AtomTensor ], coords_slice_index : torch . Tensor , ca_idx : int = 1
153
153
) -> OrientationTensor :
154
154
if X .ndim == 3 :
155
155
X = X [:, ca_idx , :]
156
- forward = _normalize (X [1 :] - X [:- 1 ])
157
- backward = _normalize (X [:- 1 ] - X [1 :])
156
+
157
+ # NOTE: the first item in the coordinates slice index is always 0,
158
+ # and the last item is always the node count of the batch
159
+ batch_num_nodes = X .shape [0 ]
160
+ slice_index = coords_slice_index [1 :] - 1
161
+ last_node_index = slice_index [:- 1 ]
162
+ first_node_index = slice_index [:- 1 ] + 1
163
+ slice_mask = torch .zeros (batch_num_nodes - 1 , dtype = torch .bool )
164
+ last_node_forward_slice_mask = slice_mask .clone ()
165
+ first_node_backward_slice_mask = slice_mask .clone ()
166
+
167
+ # NOTE: all of the last (first) nodes in a subgraph have their
168
+ # forward (backward) vectors set to a padding value (i.e., 0.0)
169
+ # to mimic feature construction behavior with single input graphs
170
+ forward_slice = X [1 :] - X [:- 1 ]
171
+ backward_slice = X [:- 1 ] - X [1 :]
172
+ last_node_forward_slice_mask [last_node_index ] = True
173
+ first_node_backward_slice_mask [first_node_index - 1 ] = True # NOTE: for the backward slices, our indexing defaults to node index `1`
174
+ forward_slice [last_node_forward_slice_mask ] = 0.0 # NOTE: this handles all but the last node in the last subgraph
175
+ backward_slice [first_node_backward_slice_mask ] = 0.0 # NOTE: this handles all but the first node in the first subgraph
176
+
177
+ # NOTE: padding first and last nodes with zero vectors does not impact feature normalization
178
+ forward = _normalize (forward_slice )
179
+ backward = _normalize (backward_slice )
158
180
forward = F .pad (forward , [0 , 0 , 0 , 1 ])
159
181
backward = F .pad (backward , [0 , 0 , 1 , 0 ])
160
- return torch .cat ((forward .unsqueeze (- 2 ), backward .unsqueeze (- 2 )), dim = - 2 )
182
+ orientations = torch .cat ((forward .unsqueeze (- 2 ), backward .unsqueeze (- 2 )), dim = - 2 )
183
+
184
+ # optionally debug/verify the orientations
185
+ # last_node_indices = torch.cat((last_node_index, torch.tensor([batch_num_nodes - 1])), dim=0)
186
+ # first_node_indices = torch.cat((torch.tensor([0]), first_node_index), dim=0)
187
+ # intermediate_node_indices_mask = torch.ones(batch_num_nodes, device=X.device, dtype=torch.bool)
188
+ # intermediate_node_indices_mask[last_node_indices] = False
189
+ # intermediate_node_indices_mask[first_node_indices] = False
190
+ # assert not orientations[last_node_indices][:, 0].any() and orientations[last_node_indices][:, 1].any()
191
+ # assert orientations[first_node_indices][:, 0].any() and not orientations[first_node_indices][:, 1].any()
192
+ # assert orientations[intermediate_node_indices_mask][:, 0].any() and orientations[intermediate_node_indices_mask][:, 1].any()
193
+
194
+ return orientations
0 commit comments