1313from torch_dftd .torch_dftd3_calculator import TorchDFTD3Calculator
1414
1515
16- def _create_atoms () -> List [Atoms ]:
16+ def _create_atoms () -> List [List [ Atoms ] ]:
1717 """Initialization"""
1818 atoms = molecule ("CH3CH2OCH3" )
1919
2020 slab = fcc111 ("Au" , size = (2 , 1 , 3 ), vacuum = 80.0 )
2121 slab .pbc = np .array ([True , True , True ])
22- return [atoms , slab ]
22+
23+ slab_wo_pbc = slab .copy ()
24+ slab_wo_pbc .pbc = np .array ([False , False , False ])
25+
26+ null = Atoms ()
27+ return [[atoms , slab ], [atoms , slab_wo_pbc ], [null ]]
2328
2429
2530def _assert_energy_equal_batch (calc1 , atoms_list : List [Atoms ]):
@@ -60,7 +65,15 @@ def _assert_energy_force_stress_equal_batch(calc1, atoms_list: List[Atoms]):
6065
6166
6267def _test_calc_energy_force_stress (
63- damping , xc , old , atoms_list , device = "cpu" , dtype = torch .float64 , abc = False , cnthr = 15.0
68+ damping ,
69+ xc ,
70+ old ,
71+ atoms_list ,
72+ device = "cpu" ,
73+ dtype = torch .float64 ,
74+ bidirectional = True ,
75+ abc = False ,
76+ cnthr = 15.0 ,
6477):
6578 cutoff = 22.0 # Make test faster
6679 torch_dftd3_calc = TorchDFTD3Calculator (
@@ -72,41 +85,68 @@ def _test_calc_energy_force_stress(
7285 cutoff = cutoff ,
7386 cnthr = cnthr ,
7487 abc = abc ,
88+ bidirectional = bidirectional ,
7589 )
7690 _assert_energy_force_stress_equal_batch (torch_dftd3_calc , atoms_list )
7791
7892
7993@pytest .mark .parametrize ("damping,old" , damping_method_list )
94+ @pytest .mark .parametrize ("atoms_list" , _create_atoms ())
8095@pytest .mark .parametrize ("device" , ["cpu" , "cuda:0" ])
8196@pytest .mark .parametrize ("dtype" , [torch .float32 , torch .float64 ])
82- def test_calc_energy_device_batch (damping , old , device , dtype ):
97+ def test_calc_energy_device_batch (damping , old , atoms_list , device , dtype ):
8398 """Test2-1: check device, dtype dependency. with only various damping method."""
8499 xc = "pbe"
85- atoms_list = _create_atoms ()
86100 _test_calc_energy (damping , xc , old , atoms_list , device = device , dtype = dtype )
87101
88102
89103@pytest .mark .parametrize ("damping,old" , damping_method_list )
104+ @pytest .mark .parametrize ("atoms_list" , _create_atoms ())
90105@pytest .mark .parametrize ("device" , ["cpu" , "cuda:0" ])
91106@pytest .mark .parametrize ("dtype" , [torch .float32 , torch .float64 ])
92- def test_calc_energy_force_stress_device_batch (damping , old , device , dtype ):
107+ def test_calc_energy_force_stress_device_batch (damping , old , atoms_list , device , dtype ):
93108 """Test2-2: check device, dtype dependency. with only various damping method."""
94109 xc = "pbe"
95- atoms_list = _create_atoms ()
96110 _test_calc_energy_force_stress (damping , xc , old , atoms_list , device = device , dtype = dtype )
97111
98112
99113@pytest .mark .parametrize ("damping,old" , damping_method_list )
114+ @pytest .mark .parametrize ("atoms_list" , _create_atoms ())
100115@pytest .mark .parametrize ("device" , ["cpu" , "cuda:0" ])
116+ @pytest .mark .parametrize ("bidirectional" , [True , False ])
101117@pytest .mark .parametrize ("dtype" , [torch .float64 ])
102- def test_calc_energy_force_stress_device_batch_abc (damping , old , device , dtype ):
103- """Test2-2: check device, dtype dependency. with only various damping method."""
118+ def test_calc_energy_force_stress_device_batch_abc (
119+ damping , old , atoms_list , device , bidirectional , dtype
120+ ):
121+ """Test2-3: check device, dtype dependency. with only various damping method."""
104122 xc = "pbe"
105123 abc = True
106- atoms_list = _create_atoms ()
107- _test_calc_energy_force_stress (
108- damping , xc , old , atoms_list , device = device , dtype = dtype , cnthr = 7.0
109- )
124+ if any ([np .all (atoms .pbc ) for atoms in atoms_list ]) and bidirectional == False :
125+ # TODO: bidirectional=False is not implemented for pbc now.
126+ with pytest .raises (NotImplementedError ):
127+ _test_calc_energy_force_stress (
128+ damping ,
129+ xc ,
130+ old ,
131+ atoms_list ,
132+ device = device ,
133+ dtype = dtype ,
134+ bidirectional = bidirectional ,
135+ abc = abc ,
136+ cnthr = 7.0 ,
137+ )
138+ else :
139+ _test_calc_energy_force_stress (
140+ damping ,
141+ xc ,
142+ old ,
143+ atoms_list ,
144+ device = device ,
145+ dtype = dtype ,
146+ bidirectional = bidirectional ,
147+ abc = abc ,
148+ cnthr = 7.0 ,
149+ )
110150
111151
112152if __name__ == "__main__" :
0 commit comments