Skip to content

Commit c690dd6

Browse files
authored
Merge pull request #20 from takagi/remove-onnx-gather
Reduce Gather op when ONNX-exported
2 parents a5b9fae + 51d5c32 commit c690dd6

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

torch_dftd/functions/dftd3.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,15 @@ def _getc6_impl(
109109
Zi: Tensor, Zj: Tensor, nci: Tensor, ncj: Tensor, c6ab: Tensor, k3: float = d3_k3
110110
) -> Tensor:
111111
# gather the relevant entries from the table
112-
# c6ab (95, 95, 5, 5, 3) --> c6ab_ (n_edges, 5, 5, 3)
113-
c6ab_ = c6ab[Zi, Zj].type(nci.dtype)
114-
# calculate c6 coefficients
115-
116-
# cn0, cn1, cn2 (n_edges, 5, 5)
117-
cn0 = c6ab_[:, :, :, 0]
118-
cn1 = c6ab_[:, :, :, 1]
119-
cn2 = c6ab_[:, :, :, 2]
112+
# c6ab (95, 95, 5, 5, 3) --> cni (9025, 5, 5, 1)
113+
cn0, cn1, cn2 = c6ab.reshape(-1, 5, 5, 3).split(1, dim=3)
114+
index = Zi * c6ab.size(1) + Zj
115+
116+
# cni (9025, 5, 5, 1) --> cni (n_edges, 5, 5)
117+
cn0 = cn0.squeeze(dim=3)[index].type(nci.dtype)
118+
cn1 = cn1.squeeze(dim=3)[index].type(nci.dtype)
119+
cn2 = cn2.squeeze(dim=3)[index].type(nci.dtype)
120+
120121
r = (cn1 - nci[:, None, None]) ** 2 + (cn2 - ncj[:, None, None]) ** 2
121122

122123
n_edges = r.shape[0]

0 commit comments

Comments
 (0)