Skip to content

Commit dd8644e

Browse files
authored
Merge pull request #13 from So-Takamoto/cell_grad_rewrite
use shift for gradient calculation instead of cell
2 parents 4a0d47d + eca0300 commit dd8644e

16 files changed

+241
-161
lines changed

.flexci/config.pbtxt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ configs {
33
key: "torch-dftd.pytest"
44
value {
55
requirement {
6-
cpu: 4
7-
memory: 24
6+
cpu: 6
7+
memory: 36
88
disk: 10
99
gpu: 1
1010
}

.flexci/pytest_script.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ main() {
2020
docker run --runtime=nvidia --rm --volume="$(pwd)":/workspace -w /workspace \
2121
${IMAGE} \
2222
bash -x -c "pip install flake8 pytest pytest-cov pytest-xdist pytest-benchmark && \
23-
pip install cupy-cuda102 pytorch-pfn-extras && \
23+
pip install cupy-cuda102 pytorch-pfn-extras!=0.5.0 && \
2424
pip install -e .[develop] && \
2525
pysen run lint && \
2626
pytest --cov=torch_dftd -n $(nproc) -m 'not slow' tests &&

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[tool:pytest]
2+
markers =
3+
slow: mark test as slow.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup_requires: List[str] = []
77
install_requires: List[str] = [
88
"ase>=3.18, <4.0.0", # Note that we require ase==3.21.1 for pytest.
9-
"pymatgen",
9+
"pymatgen>=2020.1.28",
1010
]
1111
extras_require: Dict[str, List[str]] = {
1212
"develop": ["pysen[lint]==0.9.1"],

tests/functions_tests/test_triplets.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ def test_calc_triplets():
1313
dtype=torch.long,
1414
device=device,
1515
)
16-
shift = torch.zeros((edge_index.shape[1], 3), dtype=torch.float32, device=device)
17-
shift[:, 0] = torch.tensor(
16+
shift_pos = torch.zeros((edge_index.shape[1], 3), dtype=torch.float32, device=device)
17+
shift_pos[:, 0] = torch.tensor(
1818
[1, 2, 3, 4, 5, 6, -1, -2, -3, -4, -5, -6], dtype=torch.float32, device=device
1919
)
2020
# print("shift", shift.shape)
21-
triplet_node_index, multiplicity, triplet_shift, batch_triplets = calc_triplets(
22-
edge_index, shift
21+
triplet_node_index, multiplicity, edge_jk, batch_triplets = calc_triplets(
22+
edge_index, shift_pos
2323
)
2424
# print("triplet_node_index", triplet_node_index.shape, triplet_node_index)
2525
# print("multiplicity", multiplicity.shape, multiplicity)
@@ -38,6 +38,20 @@ def test_calc_triplets():
3838
)
3939
assert multiplicity.shape == (n_triplets,)
4040
assert torch.all(multiplicity.cpu() == torch.ones((n_triplets,), dtype=torch.float32))
41+
42+
assert torch.allclose(
43+
edge_jk.cpu(),
44+
torch.tensor([[7, 6], [8, 6], [8, 7], [9, 10], [9, 11], [11, 10]], dtype=torch.long),
45+
)
46+
# shift for edge `i->j`, `i->k`, `j->k`.
47+
triplet_shift = torch.stack(
48+
[
49+
-shift_pos[edge_jk[:, 0]],
50+
-shift_pos[edge_jk[:, 1]],
51+
shift_pos[edge_jk[:, 0]] - shift_pos[edge_jk[:, 1]],
52+
],
53+
dim=1,
54+
)
4155
assert torch.allclose(
4256
triplet_shift.cpu()[:, :, 0],
4357
torch.tensor(
@@ -61,7 +75,7 @@ def test_calc_triplets_noshift():
6175
edge_index = torch.tensor(
6276
[[0, 1, 1, 3, 1, 2, 3, 0], [1, 2, 3, 0, 0, 1, 1, 3]], dtype=torch.long, device=device
6377
)
64-
triplet_node_index, multiplicity, triplet_shift, batch_triplets = calc_triplets(
78+
triplet_node_index, multiplicity, edge_jk, batch_triplets = calc_triplets(
6579
edge_index, dtype=torch.float64
6680
)
6781
# print("triplet_node_index", triplet_node_index.shape, triplet_node_index)
@@ -78,13 +92,7 @@ def test_calc_triplets_noshift():
7892
assert multiplicity.shape == (n_triplets,)
7993
assert multiplicity.dtype == torch.float64
8094
assert torch.all(multiplicity.cpu() == torch.ones((n_triplets,), dtype=torch.float64))
81-
assert torch.all(
82-
triplet_shift.cpu()
83-
== torch.zeros(
84-
(n_triplets, 3, 3),
85-
dtype=torch.float32,
86-
)
87-
)
95+
assert torch.all(edge_jk.cpu() == torch.tensor([[1, 0], [2, 3]], dtype=torch.long))
8896
assert torch.all(batch_triplets.cpu() == torch.zeros((n_triplets,), dtype=torch.long))
8997

9098

@@ -95,7 +103,7 @@ def test_calc_triplets_noshift():
95103
def test_calc_triplets_no_triplets(edge_index):
96104
# edge_index = edge_index.to("cuda:0")
97105
# No triplet exist in this graph. Case1: No edge, Case 2 No triplets in this edge.
98-
triplet_node_index, multiplicity, triplet_shift, batch_triplets = calc_triplets(edge_index)
106+
triplet_node_index, multiplicity, edge_jk, batch_triplets = calc_triplets(edge_index)
99107
# print("triplet_node_index", triplet_node_index.shape, triplet_node_index)
100108
# print("multiplicity", multiplicity.shape, multiplicity)
101109
# print("triplet_shift", triplet_shift.shape, triplet_shift)
@@ -104,7 +112,7 @@ def test_calc_triplets_no_triplets(edge_index):
104112
# 0 triplets exist.
105113
assert triplet_node_index.shape == (0, 3)
106114
assert multiplicity.shape == (0,)
107-
assert triplet_shift.shape == (0, 3, 3)
115+
assert edge_jk.shape == (0, 2)
108116
assert batch_triplets.shape == (0,)
109117

110118

tests/test_torch_dftd3_calculator.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,35 @@
88
import pytest
99
import torch
1010
from ase import Atoms
11-
from ase.build import fcc111, molecule
11+
from ase.build import bulk, fcc111, molecule
1212
from ase.calculators.dftd3 import DFTD3
1313
from ase.calculators.emt import EMT
1414
from torch_dftd.testing.damping import damping_method_list, damping_xc_combination_list
1515
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
1616

1717

18-
def _create_atoms() -> List[Atoms]:
18+
@pytest.fixture(
19+
params=[
20+
pytest.param("mol", id="mol"),
21+
pytest.param("slab", id="slab"),
22+
pytest.param("large", marks=[pytest.mark.slow], id="large"),
23+
]
24+
)
25+
def atoms(request) -> Atoms:
1926
"""Initialization"""
20-
atoms = molecule("CH3CH2OCH3")
27+
mol = molecule("CH3CH2OCH3")
2128

2229
slab = fcc111("Au", size=(2, 1, 3), vacuum=80.0)
30+
slab.set_cell(
31+
slab.get_cell().array @ np.array([[1.0, 0.1, 0.2], [0.0, 1.0, 0.3], [0.0, 0.0, 1.0]])
32+
)
2333
slab.pbc = np.array([True, True, True])
24-
return [atoms, slab]
34+
35+
large_bulk = bulk("Pt", "fcc") * (4, 4, 4)
36+
37+
atoms_dict = {"mol": mol, "slab": slab, "large": large_bulk}
38+
39+
return atoms_dict[request.param]
2540

2641

2742
def _assert_energy_equal(calc1, calc2, atoms: Atoms):
@@ -53,20 +68,21 @@ def _test_calc_energy(damping, xc, old, atoms, device="cpu", dtype=torch.float64
5368
_assert_energy_equal(dftd3_calc, torch_dftd3_calc, atoms)
5469

5570

56-
def _assert_energy_force_stress_equal(calc1, calc2, atoms: Atoms):
71+
def _assert_energy_force_stress_equal(calc1, calc2, atoms: Atoms, force_tol: float = 1e-5):
5772
calc1.reset()
5873
atoms.calc = calc1
5974
f1 = atoms.get_forces()
6075
e1 = atoms.get_potential_energy()
76+
if np.all(atoms.pbc == np.array([True, True, True])):
77+
s1 = atoms.get_stress()
6178

6279
calc2.reset()
6380
atoms.calc = calc2
6481
f2 = atoms.get_forces()
6582
e2 = atoms.get_potential_energy()
6683
assert np.allclose(e1, e2, atol=1e-4, rtol=1e-4)
67-
assert np.allclose(f1, f2, atol=1e-5, rtol=1e-5)
84+
assert np.allclose(f1, f2, atol=force_tol, rtol=force_tol)
6885
if np.all(atoms.pbc == np.array([True, True, True])):
69-
s1 = atoms.get_stress()
7086
s2 = atoms.get_stress()
7187
assert np.allclose(s1, s2, atol=1e-5, rtol=1e-5)
7288

@@ -83,6 +99,9 @@ def _test_calc_energy_force_stress(
8399
cnthr=15.0,
84100
):
85101
cutoff = 22.0 # Make test faster
102+
force_tol = 1e-5
103+
if dtype == torch.float32:
104+
force_tol = 1.0e-4
86105
with tempfile.TemporaryDirectory() as tmpdirname:
87106
dftd3_calc = DFTD3(
88107
damping=damping,
@@ -105,25 +124,22 @@ def _test_calc_energy_force_stress(
105124
abc=abc,
106125
bidirectional=bidirectional,
107126
)
108-
_assert_energy_force_stress_equal(dftd3_calc, torch_dftd3_calc, atoms)
127+
_assert_energy_force_stress_equal(dftd3_calc, torch_dftd3_calc, atoms, force_tol=force_tol)
109128

110129

111130
@pytest.mark.parametrize("damping,xc,old", damping_xc_combination_list)
112-
@pytest.mark.parametrize("atoms", _create_atoms())
113131
def test_calc_energy(damping, xc, old, atoms):
114132
"""Test1-1: check damping,xc,old combination works for energy"""
115133
_test_calc_energy(damping, xc, old, atoms, device="cpu")
116134

117135

118136
@pytest.mark.parametrize("damping,xc,old", damping_xc_combination_list)
119-
@pytest.mark.parametrize("atoms", _create_atoms())
120137
def test_calc_energy_force_stress(damping, xc, old, atoms):
121138
"""Test1-2: check damping,xc,old combination works for energy, force & stress"""
122139
_test_calc_energy_force_stress(damping, xc, old, atoms, device="cpu")
123140

124141

125142
@pytest.mark.parametrize("damping,old", damping_method_list)
126-
@pytest.mark.parametrize("atoms", _create_atoms())
127143
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
128144
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
129145
def test_calc_energy_device(damping, old, atoms, device, dtype):
@@ -133,7 +149,6 @@ def test_calc_energy_device(damping, old, atoms, device, dtype):
133149

134150

135151
@pytest.mark.parametrize("damping,old", damping_method_list)
136-
@pytest.mark.parametrize("atoms", _create_atoms())
137152
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
138153
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
139154
def test_calc_energy_force_stress_device(damping, old, atoms, device, dtype):
@@ -142,7 +157,6 @@ def test_calc_energy_force_stress_device(damping, old, atoms, device, dtype):
142157
_test_calc_energy_force_stress(damping, xc, old, atoms, device=device, dtype=dtype)
143158

144159

145-
@pytest.mark.parametrize("atoms", _create_atoms())
146160
@pytest.mark.parametrize("damping,old", damping_method_list)
147161
def test_calc_energy_force_stress_bidirectional(atoms, damping, old):
148162
"""Test with bidirectional=False"""
@@ -161,7 +175,6 @@ def test_calc_energy_force_stress_bidirectional(atoms, damping, old):
161175
_assert_energy_force_stress_equal(dftd3_calc, torch_dftd3_calc, atoms)
162176

163177

164-
@pytest.mark.parametrize("atoms", _create_atoms())
165178
@pytest.mark.parametrize("damping,old", damping_method_list)
166179
def test_calc_energy_force_stress_cutoff_smoothing(atoms, damping, old):
167180
"""Test wit cutoff_smoothing."""
@@ -207,7 +220,6 @@ def test_calc_energy_force_stress_with_dft():
207220

208221

209222
@pytest.mark.parametrize("damping,old", damping_method_list)
210-
@pytest.mark.parametrize("atoms", _create_atoms())
211223
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
212224
@pytest.mark.parametrize("dtype", [torch.float64])
213225
@pytest.mark.parametrize("bidirectional", [True, False])

tests/test_torch_dftd3_calculator_batch.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,22 @@
88
import pytest
99
import torch
1010
from ase import Atoms
11-
from ase.build import fcc111, molecule
11+
from ase.build import bulk, fcc111, molecule
1212
from torch_dftd.testing.damping import damping_method_list
1313
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
1414

1515

16-
def _create_atoms() -> List[List[Atoms]]:
16+
@pytest.fixture(
17+
params=[
18+
pytest.param("case1", id="mol+slab"),
19+
pytest.param("case2", id="mol+slab(wo_pbc)"),
20+
pytest.param("case3", id="null"),
21+
pytest.param("case4", marks=[pytest.mark.slow], id="large"),
22+
]
23+
)
24+
def atoms_list(request) -> List[Atoms]:
1725
"""Initialization"""
18-
atoms = molecule("CH3CH2OCH3")
26+
mol = molecule("CH3CH2OCH3")
1927

2028
slab = fcc111("Au", size=(2, 1, 3), vacuum=80.0)
2129
slab.pbc = np.array([True, True, True])
@@ -24,7 +32,17 @@ def _create_atoms() -> List[List[Atoms]]:
2432
slab_wo_pbc.pbc = np.array([False, False, False])
2533

2634
null = Atoms()
27-
return [[atoms, slab], [atoms, slab_wo_pbc], [null]]
35+
36+
large_bulk = bulk("Pt", "fcc") * (8, 8, 8)
37+
38+
atoms_dict = {
39+
"case1": [mol, slab],
40+
"case2": [mol, slab_wo_pbc],
41+
"case3": [null],
42+
"case4": [large_bulk],
43+
}
44+
45+
return atoms_dict[request.param]
2846

2947

3048
def _assert_energy_equal_batch(calc1, atoms_list: List[Atoms]):
@@ -91,7 +109,6 @@ def _test_calc_energy_force_stress(
91109

92110

93111
@pytest.mark.parametrize("damping,old", damping_method_list)
94-
@pytest.mark.parametrize("atoms_list", _create_atoms())
95112
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
96113
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
97114
def test_calc_energy_device_batch(damping, old, atoms_list, device, dtype):
@@ -101,7 +118,6 @@ def test_calc_energy_device_batch(damping, old, atoms_list, device, dtype):
101118

102119

103120
@pytest.mark.parametrize("damping,old", damping_method_list)
104-
@pytest.mark.parametrize("atoms_list", _create_atoms())
105121
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
106122
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
107123
def test_calc_energy_force_stress_device_batch(damping, old, atoms_list, device, dtype):
@@ -111,7 +127,6 @@ def test_calc_energy_force_stress_device_batch(damping, old, atoms_list, device,
111127

112128

113129
@pytest.mark.parametrize("damping,old", damping_method_list)
114-
@pytest.mark.parametrize("atoms_list", _create_atoms())
115130
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
116131
@pytest.mark.parametrize("bidirectional", [True, False])
117132
@pytest.mark.parametrize("dtype", [torch.float64])

tests/test_torch_dftd3_calculator_benchmark.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,30 @@
44
import numpy as np
55
import pytest
66
from ase import Atoms
7-
from ase.build import fcc111, molecule
7+
from ase.build import bulk, fcc111, molecule
88
from ase.calculators.dftd3 import DFTD3
99
from ase.units import Bohr
1010
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
1111

1212

13-
def _create_atoms() -> List[Atoms]:
13+
@pytest.fixture(
14+
params=[
15+
pytest.param("mol", id="mol"),
16+
pytest.param("slab", id="slab"),
17+
pytest.param("large", marks=[pytest.mark.slow], id="large"),
18+
]
19+
)
20+
def atoms(request) -> Atoms:
1421
"""Initialization"""
15-
atoms = molecule("CH3CH2OCH3")
22+
mol = molecule("CH3CH2OCH3")
1623

1724
slab = fcc111("Au", size=(2, 1, 3), vacuum=80.0)
1825
slab.pbc = np.array([True, True, True])
19-
return [atoms, slab]
26+
27+
large_bulk = bulk("Pt", "fcc") * (4, 4, 4)
28+
29+
atoms_dict = {"mol": mol, "slab": slab, "large": large_bulk}
30+
return atoms_dict[request.param]
2031

2132

2233
def calc_energy(calculator, atoms):
@@ -35,7 +46,6 @@ def calc_force_stress(calculator, atoms):
3546
return True
3647

3748

38-
@pytest.mark.parametrize("atoms", _create_atoms())
3949
def test_dftd3_calculator_benchmark(atoms, benchmark):
4050
damping = "bj"
4151
xc = "pbe"
@@ -53,7 +63,6 @@ def test_dftd3_calculator_benchmark(atoms, benchmark):
5363
)
5464

5565

56-
@pytest.mark.parametrize("atoms", _create_atoms())
5766
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
5867
def test_torch_dftd3_calculator_benchmark(atoms, device, benchmark):
5968
damping = "bj"

0 commit comments

Comments
 (0)