We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents a5b9fae + 51d5c32 commit c690dd6Copy full SHA for c690dd6
torch_dftd/functions/dftd3.py
@@ -109,14 +109,15 @@ def _getc6_impl(
109
Zi: Tensor, Zj: Tensor, nci: Tensor, ncj: Tensor, c6ab: Tensor, k3: float = d3_k3
110
) -> Tensor:
111
# gather the relevant entries from the table
112
- # c6ab (95, 95, 5, 5, 3) --> c6ab_ (n_edges, 5, 5, 3)
113
- c6ab_ = c6ab[Zi, Zj].type(nci.dtype)
114
- # calculate c6 coefficients
115
-
116
- # cn0, cn1, cn2 (n_edges, 5, 5)
117
- cn0 = c6ab_[:, :, :, 0]
118
- cn1 = c6ab_[:, :, :, 1]
119
- cn2 = c6ab_[:, :, :, 2]
+ # c6ab (95, 95, 5, 5, 3) --> cni (9025, 5, 5, 1)
+ cn0, cn1, cn2 = c6ab.reshape(-1, 5, 5, 3).split(1, dim=3)
+ index = Zi * c6ab.size(1) + Zj
+
+ # cni (9025, 5, 5, 1) --> cni (n_edges, 5, 5)
+ cn0 = cn0.squeeze(dim=3)[index].type(nci.dtype)
+ cn1 = cn1.squeeze(dim=3)[index].type(nci.dtype)
+ cn2 = cn2.squeeze(dim=3)[index].type(nci.dtype)
120
121
r = (cn1 - nci[:, None, None]) ** 2 + (cn2 - ncj[:, None, None]) ** 2
122
123
n_edges = r.shape[0]
0 commit comments