Skip to content

Commit a5b9fae

Browse files
authored
Merge pull request #17 from takagi/codegen-dftd3
Split c6 computation to reduce peak memory usage
2 parents 9095c4d + 0fe6778 commit a5b9fae

File tree

2 files changed

+41
-9
lines changed

2 files changed

+41
-9
lines changed

torch_dftd/functions/dftd3.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ def _ncoord(
4444
"""
4545
if cutoff is not None:
4646
# Calculate _ncoord only for r < cutoff
47-
within_cutoff = r <= cutoff
48-
r = r[within_cutoff]
49-
# Zi = Zi[within_cutoff]
50-
# Zj = Zj[within_cutoff]
51-
idx_i = idx_i[within_cutoff]
52-
idx_j = idx_j[within_cutoff]
47+
indices = torch.nonzero(r <= cutoff).reshape(-1)
48+
r = r[indices]
49+
# Zi = Zi[indices]
50+
# Zj = Zj[indices]
51+
idx_i = idx_i[indices]
52+
idx_j = idx_j[indices]
5353
Zi = Z[idx_i]
5454
Zj = Z[idx_j]
5555
rco = rcov[Zi] + rcov[Zj] # (n_edges,)
@@ -67,7 +67,13 @@ def _ncoord(
6767

6868

6969
def _getc6(
70-
Zi: Tensor, Zj: Tensor, nci: Tensor, ncj: Tensor, c6ab: Tensor, k3: float = d3_k3
70+
Zi: Tensor,
71+
Zj: Tensor,
72+
nci: Tensor,
73+
ncj: Tensor,
74+
c6ab: Tensor,
75+
k3: float = d3_k3,
76+
n_chunks: Optional[int] = None,
7177
) -> Tensor:
7278
"""interpolate c6
7379
@@ -78,10 +84,30 @@ def _getc6(
7884
ncj: (n_edges,)
7985
c6ab:
8086
k3:
87+
n_chunks:
8188
8289
Returns:
8390
c6 (Tensor): (n_edges,)
8491
"""
92+
if n_chunks is None:
93+
return _getc6_impl(Zi, Zj, nci, ncj, c6ab, k3=k3)
94+
95+
# TODO(takagi) More balanced split like torch.tensor_split as, for example,
96+
# trying to split 13 elements into 6 chunks currently gives 5 chunks:
97+
# ([0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12])
98+
n_chunks_t = torch.tensor(n_chunks)
99+
chunk_size = torch.ceil(Zi.shape[0] / n_chunks_t).to(torch.int64)
100+
c6s = []
101+
for i in range(0, n_chunks):
102+
chunk_start = i * chunk_size
103+
slc = slice(chunk_start, chunk_start + chunk_size)
104+
c6s.append(_getc6_impl(Zi[slc], Zj[slc], nci[slc], ncj[slc], c6ab, k3=k3))
105+
return torch.cat(c6s, 0)
106+
107+
108+
def _getc6_impl(
109+
Zi: Tensor, Zj: Tensor, nci: Tensor, ncj: Tensor, c6ab: Tensor, k3: float = d3_k3
110+
) -> Tensor:
85111
# gather the relevant entries from the table
86112
# c6ab (95, 95, 5, 5, 3) --> c6ab_ (n_edges, 5, 5, 3)
87113
c6ab_ = c6ab[Zi, Zj].type(nci.dtype)
@@ -98,7 +124,7 @@ def _getc6(
98124
if cn0.size(0) == 0:
99125
k3_rnc = (k3 * r).view(n_edges, n_c6ab)
100126
else:
101-
k3_rnc = torch.where(cn0 > 0.0, k3 * r, -1.0e20 * torch.ones_like(r)).view(n_edges, n_c6ab)
127+
k3_rnc = torch.where(cn0 > 0.0, k3 * r, -1.0e20).view(n_edges, n_c6ab)
102128
r_ratio = torch.softmax(k3_rnc, dim=1)
103129
c6 = (r_ratio * cn0.view(n_edges, n_c6ab)).sum(dim=1)
104130
return c6
@@ -130,6 +156,7 @@ def edisp(
130156
damping: str = "zero",
131157
bidirectional: bool = False,
132158
abc: bool = False,
159+
n_chunks: Optional[int] = None,
133160
):
134161
"""compute d3 dispersion energy in Hartree
135162
@@ -159,6 +186,7 @@ def edisp(
159186
damping (str): damping method, only "zero" is supported.
160187
bidirectional (bool): calculated `edge_index` is bidirectional or not.
161188
abc (bool): ATM 3-body interaction
189+
n_chunks (int or None): number of times to split c6 computation to reduce peak memory
162190
163191
Returns:
164192
energy: (n_graphs,) Energy in Hartree unit.
@@ -190,7 +218,7 @@ def edisp(
190218

191219
nci = nc[idx_i]
192220
ncj = nc[idx_j]
193-
c6 = _getc6(Zi, Zj, nci, ncj, c6ab=c6ab, k3=k3) # c6 coefficients
221+
c6 = _getc6(Zi, Zj, nci, ncj, c6ab=c6ab, k3=k3, n_chunks=n_chunks) # c6 coefficients
194222

195223
c8 = 3 * c6 * r2r4[Zi].type(c6.dtype) * r2r4[Zj].type(c6.dtype) # c8 coefficient
196224

torch_dftd/nn/dftd3_module.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class DFTD3Module(BaseDFTDModule):
2222
abc (bool): ATM 3-body interaction
2323
dtype (dtype): internal calculation is done in this precision.
2424
bidirectional (bool): calculated `edge_index` is bidirectional or not.
25+
n_chunks (int): number of times to split c6 computation to reduce peak memory.
2526
"""
2627

2728
def __init__(
@@ -33,6 +34,7 @@ def __init__(
3334
dtype=torch.float32,
3435
bidirectional: bool = False,
3536
cutoff_smoothing: str = "none",
37+
n_chunks: Optional[int] = None,
3638
):
3739
super(DFTD3Module, self).__init__()
3840

@@ -62,6 +64,7 @@ def __init__(
6264
self.dtype = dtype
6365
self.bidirectional = bidirectional
6466
self.cutoff_smoothing = cutoff_smoothing
67+
self.n_chunks = n_chunks
6568

6669
def calc_energy_batch(
6770
self,
@@ -105,5 +108,6 @@ def calc_energy_batch(
105108
abc=self.abc,
106109
pos=pos_bohr,
107110
cell=cell_bohr,
111+
n_chunks=self.n_chunks,
108112
)
109113
return E_disp

0 commit comments

Comments
 (0)