Skip to content

Commit 95f34b0

Browse files
authored
Merge pull request #29 from pyiron/add_calc_minimize
Add calc minimize
2 parents dc8d3ac + aeaa151 commit 95f34b0

File tree

2 files changed

+102
-11
lines changed

2 files changed

+102
-11
lines changed

sphinx_parser/jobs.py

Lines changed: 83 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,30 @@
33
from sphinx_parser.potential import get_paw_from_structure
44

55

6-
def calc_static(
7-
structure,
8-
eCut=25,
9-
xc=1,
10-
spinPolarized=False,
11-
maxSteps=30,
12-
ekt=0.2,
13-
k_point_coords=[0.5, 0.5, 0.5],
6+
def set_base_parameters(
7+
structure: "ase.Atoms",
8+
eCut: float = 25,
9+
xc: int = 1,
10+
maxSteps: int = 30,
11+
ekt: float = 0.2,
12+
k_point_coords: list = [0.5, 0.5, 0.5],
1413
):
14+
"""
15+
Set the base parameters for the sphinx input file
16+
17+
Args:
18+
structure (ase.Atoms): ASE Atoms object
19+
eCut (float, optional): Energy cutoff. Defaults to 25.
20+
xc (int, optional): Exchange-correlation functional. Defaults to 1.
21+
maxSteps (int, optional): Maximum number of steps. Defaults to 30.
22+
ekt (float, optional): Temperature. Defaults to 0.2.
23+
k_point_coords (list, optional): K-point coordinates. Defaults to [0.5, 0.5, 0.5].
24+
25+
Returns:
26+
dict: Sphinx input dictionary
27+
"""
1528
struct_group, spin_lst = get_structure_group(structure)
29+
spinPolarized = spin_lst is not None
1630
main_group = sphinx.main.create(
1731
scfDiag=sphinx.main.scfDiag.create(
1832
maxSteps=maxSteps, blockCCG=sphinx.main.scfDiag.blockCCG.create()
@@ -40,3 +54,64 @@ def calc_static(
4054
initialGuess=initial_guess_group,
4155
)
4256
return input_sx
57+
58+
59+
def apply_minimization(sphinx_input, mode="linQN", dEnergy=1.0e-6, maxSteps=50):
60+
"""
61+
Apply minimization to the sphinx input file
62+
63+
Args:
64+
sphinx_input (dict): Sphinx input dictionary
65+
mode (str, optional): Minimization mode. Defaults to "linQN".
66+
dEnergy (float, optional): Energy tolerance. Defaults to 1.0e-6.
67+
maxSteps (int, optional): Maximum number of steps. Defaults to 50.
68+
69+
Returns:
70+
dict: Sphinx input dictionary
71+
"""
72+
input_sx = sphinx_input.copy()
73+
if "main" not in input_sx or "scfDiag" not in input_sx["main"]:
74+
raise ValueError("main group not found - run set_base_parameters first")
75+
if mode == "linQN":
76+
input_sx["main"] = sphinx.main.create(
77+
linQN=sphinx.main.linQN.create(
78+
dEnergy=dEnergy,
79+
maxSteps=maxSteps,
80+
bornOppenheimer=sphinx.main.linQN.bornOppenheimer.create(
81+
scfDiag=input_sx["main"]["scfDiag"]
82+
),
83+
)
84+
)
85+
elif mode == "QN":
86+
input_sx["main"] = sphinx.main.create(
87+
QN=sphinx.main.QN.create(
88+
dEnergy=dEnergy,
89+
maxSteps=maxSteps,
90+
bornOppenheimer=sphinx.main.QN.bornOppenheimer.create(
91+
scfDiag=input_sx["main"]["scfDiag"]
92+
),
93+
)
94+
)
95+
elif mode == "ricQN":
96+
input_sx["main"] = sphinx.main.create(
97+
ricQN=sphinx.main.ricQN.create(
98+
dEnergy=dEnergy,
99+
maxSteps=maxSteps,
100+
bornOppenheimer=sphinx.main.ricQN.bornOppenheimer.create(
101+
scfDiag=input_sx["main"]["scfDiag"]
102+
),
103+
)
104+
)
105+
elif mode == "ricTS":
106+
input_sx["main"] = sphinx.main.create(
107+
ricTS=sphinx.main.ricTS.create(
108+
dEnergy=dEnergy,
109+
maxSteps=maxSteps,
110+
bornOppenheimer=sphinx.main.ricTS.bornOppenheimer.create(
111+
scfDiag=input_sx["main"]["scfDiag"]
112+
),
113+
)
114+
)
115+
else:
116+
raise ValueError("mode not recognized")
117+
return input_sx

tests/unit/test_jobs.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,31 @@
11
import unittest
22
from ase.build import bulk
3-
from sphinx_parser.jobs import calc_static
3+
from sphinx_parser.jobs import set_base_parameters, apply_minimization
44
from sphinx_parser.toolkit import to_sphinx
55

66

77
class TestJobs(unittest.TestCase):
88
def test_magnetic_bulk(self):
99
structure = bulk("Fe", cubic=True)
10-
self.assertTrue("atomicSpin" in to_sphinx(calc_static(structure)))
10+
self.assertTrue("atomicSpin" in to_sphinx(set_base_parameters(structure)))
1111
structure = bulk("Al", cubic=True)
12-
self.assertFalse("atomicSpin" in to_sphinx(calc_static(structure)))
12+
self.assertFalse("atomicSpin" in to_sphinx(set_base_parameters(structure)))
13+
14+
def test_calc_minimize(self):
15+
structure = bulk("Fe", cubic=True)
16+
input_sx = set_base_parameters(structure)
17+
default_input_sx = to_sphinx(apply_minimization(input_sx))
18+
self.assertTrue("linQN" in default_input_sx)
19+
for term in ["ricQN", "QN", "ricTS"]:
20+
self.assertEqual(
21+
to_sphinx(apply_minimization(input_sx, mode=term)),
22+
default_input_sx.replace("linQN", term),
23+
)
24+
with self.assertRaises(ValueError):
25+
apply_minimization(input_sx, mode="not_a_valid_mode")
26+
input_sx.pop("main")
27+
with self.assertRaises(ValueError):
28+
apply_minimization(input_sx)
1329

1430

1531
if __name__ == "__main__":

0 commit comments

Comments
 (0)