Skip to content

Commit ecd7cd1

Browse files
authored
Merge pull request #9 from So-Takamoto/batch_abc
Bugfix: batch calculation with abc=True
2 parents 34785c0 + ee62187 commit ecd7cd1

File tree

7 files changed

+136
-33
lines changed

7 files changed

+136
-33
lines changed

.flexci/config.pbtxt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ configs {
88
disk: 10
99
gpu: 1
1010
}
11+
time_limit {
12+
seconds: 1800
13+
}
1114
command:
1215
"bash -x .flexci/pytest_script.sh"
1316
}

tests/test_torch_dftd3_calculator.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,15 @@ def _assert_energy_force_stress_equal(calc1, calc2, atoms: Atoms):
7272

7373

7474
def _test_calc_energy_force_stress(
75-
damping, xc, old, atoms, device="cpu", dtype=torch.float64, abc=False, cnthr=15.0
75+
damping,
76+
xc,
77+
old,
78+
atoms,
79+
device="cpu",
80+
dtype=torch.float64,
81+
bidirectional=True,
82+
abc=False,
83+
cnthr=15.0,
7684
):
7785
cutoff = 22.0 # Make test faster
7886
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -95,6 +103,7 @@ def _test_calc_energy_force_stress(
95103
cutoff=cutoff,
96104
cnthr=cnthr,
97105
abc=abc,
106+
bidirectional=bidirectional,
98107
)
99108
_assert_energy_force_stress_equal(dftd3_calc, torch_dftd3_calc, atoms)
100109

@@ -201,13 +210,39 @@ def test_calc_energy_force_stress_with_dft():
201210
@pytest.mark.parametrize("atoms", _create_atoms())
202211
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
203212
@pytest.mark.parametrize("dtype", [torch.float64])
213+
@pytest.mark.parametrize("bidirectional", [True, False])
204214
@pytest.mark.parametrize("abc", [True])
205-
def test_calc_energy_force_stress_device_abc(damping, old, atoms, device, dtype, abc):
215+
def test_calc_energy_force_stress_device_abc(
216+
damping, old, atoms, device, dtype, bidirectional, abc
217+
):
206218
"""Test: check tri-partite calc with device, dtype dependency."""
207219
xc = "pbe"
208-
_test_calc_energy_force_stress(
209-
damping, xc, old, atoms, device=device, dtype=dtype, abc=abc, cnthr=7.0
210-
)
220+
if np.all(atoms.pbc) and bidirectional == False:
221+
# TODO: bidirectional=False is not implemented for pbc now.
222+
with pytest.raises(NotImplementedError):
223+
_test_calc_energy_force_stress(
224+
damping,
225+
xc,
226+
old,
227+
atoms,
228+
device=device,
229+
dtype=dtype,
230+
bidirectional=bidirectional,
231+
abc=abc,
232+
cnthr=7.0,
233+
)
234+
else:
235+
_test_calc_energy_force_stress(
236+
damping,
237+
xc,
238+
old,
239+
atoms,
240+
device=device,
241+
dtype=dtype,
242+
bidirectional=bidirectional,
243+
abc=abc,
244+
cnthr=7.0,
245+
)
211246

212247

213248
if __name__ == "__main__":

tests/test_torch_dftd3_calculator_batch.py

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,18 @@
1313
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
1414

1515

16-
def _create_atoms() -> List[Atoms]:
16+
def _create_atoms() -> List[List[Atoms]]:
1717
"""Initialization"""
1818
atoms = molecule("CH3CH2OCH3")
1919

2020
slab = fcc111("Au", size=(2, 1, 3), vacuum=80.0)
2121
slab.pbc = np.array([True, True, True])
22-
return [atoms, slab]
22+
23+
slab_wo_pbc = slab.copy()
24+
slab_wo_pbc.pbc = np.array([False, False, False])
25+
26+
null = Atoms()
27+
return [[atoms, slab], [atoms, slab_wo_pbc], [null]]
2328

2429

2530
def _assert_energy_equal_batch(calc1, atoms_list: List[Atoms]):
@@ -60,7 +65,15 @@ def _assert_energy_force_stress_equal_batch(calc1, atoms_list: List[Atoms]):
6065

6166

6267
def _test_calc_energy_force_stress(
63-
damping, xc, old, atoms_list, device="cpu", dtype=torch.float64, abc=False, cnthr=15.0
68+
damping,
69+
xc,
70+
old,
71+
atoms_list,
72+
device="cpu",
73+
dtype=torch.float64,
74+
bidirectional=True,
75+
abc=False,
76+
cnthr=15.0,
6477
):
6578
cutoff = 22.0 # Make test faster
6679
torch_dftd3_calc = TorchDFTD3Calculator(
@@ -72,41 +85,68 @@ def _test_calc_energy_force_stress(
7285
cutoff=cutoff,
7386
cnthr=cnthr,
7487
abc=abc,
88+
bidirectional=bidirectional,
7589
)
7690
_assert_energy_force_stress_equal_batch(torch_dftd3_calc, atoms_list)
7791

7892

7993
@pytest.mark.parametrize("damping,old", damping_method_list)
94+
@pytest.mark.parametrize("atoms_list", _create_atoms())
8095
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
8196
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
82-
def test_calc_energy_device_batch(damping, old, device, dtype):
97+
def test_calc_energy_device_batch(damping, old, atoms_list, device, dtype):
8398
"""Test2-1: check device, dtype dependency. with only various damping method."""
8499
xc = "pbe"
85-
atoms_list = _create_atoms()
86100
_test_calc_energy(damping, xc, old, atoms_list, device=device, dtype=dtype)
87101

88102

89103
@pytest.mark.parametrize("damping,old", damping_method_list)
104+
@pytest.mark.parametrize("atoms_list", _create_atoms())
90105
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
91106
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
92-
def test_calc_energy_force_stress_device_batch(damping, old, device, dtype):
107+
def test_calc_energy_force_stress_device_batch(damping, old, atoms_list, device, dtype):
93108
"""Test2-2: check device, dtype dependency. with only various damping method."""
94109
xc = "pbe"
95-
atoms_list = _create_atoms()
96110
_test_calc_energy_force_stress(damping, xc, old, atoms_list, device=device, dtype=dtype)
97111

98112

99113
@pytest.mark.parametrize("damping,old", damping_method_list)
114+
@pytest.mark.parametrize("atoms_list", _create_atoms())
100115
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
116+
@pytest.mark.parametrize("bidirectional", [True, False])
101117
@pytest.mark.parametrize("dtype", [torch.float64])
102-
def test_calc_energy_force_stress_device_batch_abc(damping, old, device, dtype):
103-
"""Test2-2: check device, dtype dependency. with only various damping method."""
118+
def test_calc_energy_force_stress_device_batch_abc(
119+
damping, old, atoms_list, device, bidirectional, dtype
120+
):
121+
"""Test2-3: check device, dtype dependency. with only various damping method."""
104122
xc = "pbe"
105123
abc = True
106-
atoms_list = _create_atoms()
107-
_test_calc_energy_force_stress(
108-
damping, xc, old, atoms_list, device=device, dtype=dtype, cnthr=7.0
109-
)
124+
if any([np.all(atoms.pbc) for atoms in atoms_list]) and bidirectional == False:
125+
# TODO: bidirectional=False is not implemented for pbc now.
126+
with pytest.raises(NotImplementedError):
127+
_test_calc_energy_force_stress(
128+
damping,
129+
xc,
130+
old,
131+
atoms_list,
132+
device=device,
133+
dtype=dtype,
134+
bidirectional=bidirectional,
135+
abc=abc,
136+
cnthr=7.0,
137+
)
138+
else:
139+
_test_calc_energy_force_stress(
140+
damping,
141+
xc,
142+
old,
143+
atoms_list,
144+
device=device,
145+
dtype=dtype,
146+
bidirectional=bidirectional,
147+
abc=abc,
148+
cnthr=7.0,
149+
)
110150

111151

112152
if __name__ == "__main__":

torch_dftd/functions/dftd2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ def edisp_d2(
7272
g = e6.sum()[None]
7373
else:
7474
# (n_graphs,)
75-
n_graphs = int(batch[-1]) + 1
75+
if batch.size()[0] == 0:
76+
n_graphs = 1
77+
else:
78+
n_graphs = int(batch[-1]) + 1
7679
g = e6.new_zeros((n_graphs,))
7780
g.scatter_add_(0, batch_edge, e6)
7881

torch_dftd/functions/dftd3.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,10 @@ def edisp(
250250
g = e68.sum()[None]
251251
else:
252252
# (n_graphs,)
253-
n_graphs = int(batch[-1]) + 1
253+
if batch.size()[0] == 0:
254+
n_graphs = 1
255+
else:
256+
n_graphs = int(batch[-1]) + 1
254257
g = e68.new_zeros((n_graphs,))
255258
g.scatter_add_(0, batch_edge, e68)
256259

@@ -262,29 +265,40 @@ def edisp(
262265
# r_abc = r[within_cutoff]
263266
# r2_abc = r2[within_cutoff]
264267
edge_index_abc = edge_index[:, within_cutoff]
268+
batch_edge_abc = None if batch_edge is None else batch_edge[within_cutoff]
265269
# c6_abc = c6[within_cutoff]
266270
shift_abc = None if shift is None else shift[within_cutoff]
267271

268272
n_atoms = Z.shape[0]
269273
if not bidirectional:
270274
# (2, n_edges) -> (2, n_edges * 2)
271-
edge_index_abc = torch.cat([edge_index_abc, edge_index_abc.flip(dims=[1])], dim=1)
275+
edge_index_abc = torch.cat([edge_index_abc, edge_index_abc.flip(dims=[0])], dim=1)
276+
# (n_edges, ) -> (n_edges * 2, )
277+
batch_edge_abc = (
278+
None
279+
if batch_edge_abc is None
280+
else torch.cat([batch_edge_abc, batch_edge_abc], dim=0)
281+
)
282+
# (n_edges, ) -> (n_edges * 2, )
283+
shift_abc = None if shift_abc is None else torch.cat([shift_abc, -shift_abc], dim=0)
272284
with torch.no_grad():
273285
# triplet_node_index, triplet_edge_index = calc_triplets_cycle(edge_index_abc, n_atoms, shift=shift_abc)
274286
triplet_node_index, multiplicity, triplet_shift, batch_triplets = calc_triplets(
275-
edge_index_abc, shift=shift_abc, dtype=pos.dtype, batch_edge=batch_edge
287+
edge_index_abc, shift=shift_abc, dtype=pos.dtype, batch_edge=batch_edge_abc
276288
)
289+
batch_triplets = None if batch_edge is None else batch_triplets
277290

278291
# Apply `cnthr` cutoff threshold for r_kj
279292
idx_j, idx_k = triplet_node_index[:, 1], triplet_node_index[:, 2]
280293
ts2 = triplet_shift[:, 2]
281-
r_jk = calc_distances(pos, torch.stack([idx_j, idx_k], dim=0), cell, ts2)
294+
r_jk = calc_distances(pos, torch.stack([idx_j, idx_k], dim=0), cell, ts2, batch_triplets)
282295
kj_within_cutoff = r_jk <= cnthr
283296

284297
triplet_node_index = triplet_node_index[kj_within_cutoff]
285-
multiplicity, triplet_shift = (
298+
multiplicity, triplet_shift, batch_triplets = (
286299
multiplicity[kj_within_cutoff],
287300
triplet_shift[kj_within_cutoff],
301+
None if batch_triplets is None else batch_triplets[kj_within_cutoff],
288302
)
289303

290304
idx_i, idx_j, idx_k = (
@@ -294,8 +308,8 @@ def edisp(
294308
)
295309
ts0, ts1, ts2 = triplet_shift[:, 0], triplet_shift[:, 1], triplet_shift[:, 2]
296310

297-
r_ij = calc_distances(pos, torch.stack([idx_i, idx_j], dim=0), cell, ts0)
298-
r_ik = calc_distances(pos, torch.stack([idx_i, idx_k], dim=0), cell, ts1)
311+
r_ij = calc_distances(pos, torch.stack([idx_i, idx_j], dim=0), cell, ts0, batch_triplets)
312+
r_ik = calc_distances(pos, torch.stack([idx_i, idx_k], dim=0), cell, ts1, batch_triplets)
299313
r_jk = r_jk[kj_within_cutoff]
300314

301315
Zti, Ztj, Ztk = Z[idx_i], Z[idx_j], Z[idx_k]
@@ -328,9 +342,9 @@ def edisp(
328342
if batch_edge is None:
329343
e6abc = e3.sum()
330344
g += e6abc
331-
print("g", g)
332-
print("e6abc", e6abc)
345+
# print("g", g)
346+
# print("e6abc", e6abc)
333347
else:
334348
g.scatter_add_(0, batch_triplets, e3)
335-
print("g", g)
349+
# print("g", g)
336350
return g # (n_graphs,)

torch_dftd/nn/base_dftd_module.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ def calc_energy(
7878
if batch is None:
7979
return [{"energy": E_disp.item()}]
8080
else:
81-
n_graphs = int(batch[-1]) + 1
81+
if batch.size()[0] == 0:
82+
n_graphs = 1
83+
else:
84+
n_graphs = int(batch[-1]) + 1
8285
return [{"energy": E_disp[i].item()} for i in range(n_graphs)]
8386

8487
def calc_energy_and_forces(
@@ -133,7 +136,10 @@ def calc_energy_and_forces(
133136
if batch is None:
134137
results_list = [{"energy": E_disp.item(), "forces": forces.cpu().numpy()}]
135138
else:
136-
n_graphs = int(batch[-1]) + 1
139+
if batch.size()[0] == 0:
140+
n_graphs = 1
141+
else:
142+
n_graphs = int(batch[-1]) + 1
137143
results_list = [{"energy": E_disp[i].item()} for i in range(n_graphs)]
138144
for i in range(n_graphs):
139145
results_list[i]["forces"] = forces[batch == i].cpu().numpy()

torch_dftd/torch_dftd3_calculator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,10 @@ def batch_calculate(self, atoms_list=None, properties=["energy"], system_changes
166166
)
167167
batch_dicts = dict(
168168
Z=torch.cat([d["Z"] for d in input_dicts_list], dim=0), # (n_nodes,)
169-
pos=torch.cat([d["pos"] for d in input_dicts_list], dim=0), # (n_nodes,)
170-
cell=cell_batch, # (bs, 3, 3)
169+
pos=torch.cat([d["pos"] for d in input_dicts_list], dim=0).requires_grad_(
170+
True
171+
), # (n_nodes,)
172+
cell=cell_batch.requires_grad_(True), # (bs, 3, 3)
171173
pbc=torch.stack([d["pbc"] for d in input_dicts_list]), # (bs, 3)
172174
shift=torch.cat([d["shift"] for d in input_dicts_list], dim=0), # (n_nodes,)
173175
)

0 commit comments

Comments
 (0)