88import pytest
99import torch
1010from ase import Atoms
11- from ase .build import fcc111 , molecule
11+ from ase .build import bulk , fcc111 , molecule
1212from ase .calculators .dftd3 import DFTD3
1313from ase .calculators .emt import EMT
1414from torch_dftd .testing .damping import damping_method_list , damping_xc_combination_list
1515from torch_dftd .torch_dftd3_calculator import TorchDFTD3Calculator
1616
1717
18- def _create_atoms () -> List [Atoms ]:
18+ @pytest .fixture (
19+ params = [
20+ pytest .param ("mol" , id = "mol" ),
21+ pytest .param ("slab" , id = "slab" ),
22+ pytest .param ("large" , marks = [pytest .mark .slow ], id = "large" ),
23+ ]
24+ )
25+ def atoms (request ) -> Atoms :
1926 """Initialization"""
20- atoms = molecule ("CH3CH2OCH3" )
27+ mol = molecule ("CH3CH2OCH3" )
2128
2229 slab = fcc111 ("Au" , size = (2 , 1 , 3 ), vacuum = 80.0 )
30+ slab .set_cell (
31+ slab .get_cell ().array @ np .array ([[1.0 , 0.1 , 0.2 ], [0.0 , 1.0 , 0.3 ], [0.0 , 0.0 , 1.0 ]])
32+ )
2333 slab .pbc = np .array ([True , True , True ])
24- return [atoms , slab ]
34+
35+ large_bulk = bulk ("Pt" , "fcc" ) * (4 , 4 , 4 )
36+
37+ atoms_dict = {"mol" : mol , "slab" : slab , "large" : large_bulk }
38+
39+ return atoms_dict [request .param ]
2540
2641
2742def _assert_energy_equal (calc1 , calc2 , atoms : Atoms ):
@@ -53,20 +68,21 @@ def _test_calc_energy(damping, xc, old, atoms, device="cpu", dtype=torch.float64
5368 _assert_energy_equal (dftd3_calc , torch_dftd3_calc , atoms )
5469
5570
56- def _assert_energy_force_stress_equal (calc1 , calc2 , atoms : Atoms ):
71+ def _assert_energy_force_stress_equal (calc1 , calc2 , atoms : Atoms , force_tol : float = 1e-5 ):
5772 calc1 .reset ()
5873 atoms .calc = calc1
5974 f1 = atoms .get_forces ()
6075 e1 = atoms .get_potential_energy ()
76+ if np .all (atoms .pbc == np .array ([True , True , True ])):
77+ s1 = atoms .get_stress ()
6178
6279 calc2 .reset ()
6380 atoms .calc = calc2
6481 f2 = atoms .get_forces ()
6582 e2 = atoms .get_potential_energy ()
6683 assert np .allclose (e1 , e2 , atol = 1e-4 , rtol = 1e-4 )
67- assert np .allclose (f1 , f2 , atol = 1e-5 , rtol = 1e-5 )
84+ assert np .allclose (f1 , f2 , atol = force_tol , rtol = force_tol )
6885 if np .all (atoms .pbc == np .array ([True , True , True ])):
69- s1 = atoms .get_stress ()
7086 s2 = atoms .get_stress ()
7187 assert np .allclose (s1 , s2 , atol = 1e-5 , rtol = 1e-5 )
7288
@@ -83,6 +99,9 @@ def _test_calc_energy_force_stress(
8399 cnthr = 15.0 ,
84100):
85101 cutoff = 22.0 # Make test faster
102+ force_tol = 1e-5
103+ if dtype == torch .float32 :
104+ force_tol = 1.0e-4
86105 with tempfile .TemporaryDirectory () as tmpdirname :
87106 dftd3_calc = DFTD3 (
88107 damping = damping ,
@@ -105,25 +124,22 @@ def _test_calc_energy_force_stress(
105124 abc = abc ,
106125 bidirectional = bidirectional ,
107126 )
108- _assert_energy_force_stress_equal (dftd3_calc , torch_dftd3_calc , atoms )
127+ _assert_energy_force_stress_equal (dftd3_calc , torch_dftd3_calc , atoms , force_tol = force_tol )
109128
110129
111130@pytest .mark .parametrize ("damping,xc,old" , damping_xc_combination_list )
112- @pytest .mark .parametrize ("atoms" , _create_atoms ())
113131def test_calc_energy (damping , xc , old , atoms ):
114132 """Test1-1: check damping,xc,old combination works for energy"""
115133 _test_calc_energy (damping , xc , old , atoms , device = "cpu" )
116134
117135
118136@pytest .mark .parametrize ("damping,xc,old" , damping_xc_combination_list )
119- @pytest .mark .parametrize ("atoms" , _create_atoms ())
120137def test_calc_energy_force_stress (damping , xc , old , atoms ):
121138 """Test1-2: check damping,xc,old combination works for energy, force & stress"""
122139 _test_calc_energy_force_stress (damping , xc , old , atoms , device = "cpu" )
123140
124141
125142@pytest .mark .parametrize ("damping,old" , damping_method_list )
126- @pytest .mark .parametrize ("atoms" , _create_atoms ())
127143@pytest .mark .parametrize ("device" , ["cpu" , "cuda:0" ])
128144@pytest .mark .parametrize ("dtype" , [torch .float32 , torch .float64 ])
129145def test_calc_energy_device (damping , old , atoms , device , dtype ):
@@ -133,7 +149,6 @@ def test_calc_energy_device(damping, old, atoms, device, dtype):
133149
134150
135151@pytest .mark .parametrize ("damping,old" , damping_method_list )
136- @pytest .mark .parametrize ("atoms" , _create_atoms ())
137152@pytest .mark .parametrize ("device" , ["cpu" , "cuda:0" ])
138153@pytest .mark .parametrize ("dtype" , [torch .float32 , torch .float64 ])
139154def test_calc_energy_force_stress_device (damping , old , atoms , device , dtype ):
@@ -142,7 +157,6 @@ def test_calc_energy_force_stress_device(damping, old, atoms, device, dtype):
142157 _test_calc_energy_force_stress (damping , xc , old , atoms , device = device , dtype = dtype )
143158
144159
145- @pytest .mark .parametrize ("atoms" , _create_atoms ())
146160@pytest .mark .parametrize ("damping,old" , damping_method_list )
147161def test_calc_energy_force_stress_bidirectional (atoms , damping , old ):
148162 """Test with bidirectional=False"""
@@ -161,7 +175,6 @@ def test_calc_energy_force_stress_bidirectional(atoms, damping, old):
161175 _assert_energy_force_stress_equal (dftd3_calc , torch_dftd3_calc , atoms )
162176
163177
164- @pytest .mark .parametrize ("atoms" , _create_atoms ())
165178@pytest .mark .parametrize ("damping,old" , damping_method_list )
166179def test_calc_energy_force_stress_cutoff_smoothing (atoms , damping , old ):
167180 """Test wit cutoff_smoothing."""
@@ -207,7 +220,6 @@ def test_calc_energy_force_stress_with_dft():
207220
208221
209222@pytest .mark .parametrize ("damping,old" , damping_method_list )
210- @pytest .mark .parametrize ("atoms" , _create_atoms ())
211223@pytest .mark .parametrize ("device" , ["cpu" , "cuda:0" ])
212224@pytest .mark .parametrize ("dtype" , [torch .float64 ])
213225@pytest .mark .parametrize ("bidirectional" , [True , False ])
0 commit comments