Skip to content

Commit 3c44175

Browse files
lint
1 parent 3372bb5 commit 3c44175

File tree

2 files changed

+74
-47
lines changed

2 files changed

+74
-47
lines changed

emmet-core/emmet/core/neb.py

+54-38
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,16 @@ class NebTaskDoc(BaseModel, extra="allow"):
4343
None,
4444
description="The initial and final configurations (reactants and products) of the barrier.",
4545
)
46-
endpoint_energies : Optional[Sequence[float]] = Field(
47-
None,
48-
description="Energies of the endpoint structures."
46+
endpoint_energies: Optional[Sequence[float]] = Field(
47+
None, description="Energies of the endpoint structures."
4948
)
50-
endpoint_calculations : Optional[list[Calculation]] = Field(
51-
None,
52-
description = "Calculation information for the endpoint structures"
49+
endpoint_calculations: Optional[list[Calculation]] = Field(
50+
None, description="Calculation information for the endpoint structures"
5351
)
54-
endpoint_objects : Optional[list[dict]] = Field(
52+
endpoint_objects: Optional[list[dict]] = Field(
5553
None, description="VASP objects for each endpoint calculation."
5654
)
57-
endpoint_directories : Optional[list[str]] = Field(
55+
endpoint_directories: Optional[list[str]] = Field(
5856
None, description="List of the directories for the endpoint calculations."
5957
)
6058

@@ -140,25 +138,33 @@ def set_barriers(self) -> Self:
140138
def num_images(self) -> int:
141139
"""Return the number of VASP calculations / number of images performed."""
142140
return len(self.image_directories)
143-
141+
144142
@property
145143
def energies(self) -> list[float]:
146144
"""Return the endpoint (optional) and image energies."""
147145
if self.endpoint_energies is not None:
148-
return [self.endpoint_energies[0], *self.image_energies, self.endpoint_energies[1]]
146+
return [
147+
self.endpoint_energies[0],
148+
*self.image_energies,
149+
self.endpoint_energies[1],
150+
]
149151
return self.image_energies
150152

151153
@property
152154
def structures(self) -> list[Structure]:
153155
"""Return the endpoint and image structures."""
154-
return [self.endpoint_structures[0], *self.image_structures, self.endpoint_structures[1]]
155-
156+
return [
157+
self.endpoint_structures[0],
158+
*self.image_structures,
159+
self.endpoint_structures[1],
160+
]
161+
156162
@classmethod
157163
def from_directory(
158164
cls,
159165
dir_name: Union[Path, str],
160166
volumetric_files: Tuple[str, ...] = _VOLUMETRIC_FILES,
161-
store_calculations : bool = True,
167+
store_calculations: bool = True,
162168
**neb_task_doc_kwargs,
163169
) -> Self:
164170
"""
@@ -172,7 +178,7 @@ def from_directory(
172178

173179
neb_directories = sorted(dir_name.glob("[0-9][0-9]"))
174180

175-
if (ep_calcs := neb_task_doc_kwargs.pop("endpoint_calculations", None) ) is None:
181+
if (ep_calcs := neb_task_doc_kwargs.pop("endpoint_calculations", None)) is None:
176182
endpoint_directories = [neb_directories[0], neb_directories[-1]]
177183
endpoint_structures = [
178184
Structure.from_file(zpath(f"{endpoint_dir}/POSCAR"))
@@ -181,12 +187,8 @@ def from_directory(
181187
endpoint_energies = None
182188
else:
183189
endpoint_directories = neb_task_doc_kwargs.pop("endpoint_directories")
184-
endpoint_structures = [
185-
ep_calc.output.structure for ep_calc in ep_calcs
186-
]
187-
endpoint_energies = [
188-
ep_calc.output.energy for ep_calc in ep_calcs
189-
]
190+
endpoint_structures = [ep_calc.output.structure for ep_calc in ep_calcs]
191+
endpoint_energies = [ep_calc.output.energy for ep_calc in ep_calcs]
190192

191193
image_directories = neb_directories[1:-1]
192194

@@ -216,8 +218,7 @@ def from_directory(
216218
task_state = (
217219
TaskState.SUCCESS
218220
if all(
219-
calc.has_vasp_completed == TaskState.SUCCESS
220-
for calc in calcs_to_check
221+
calc.has_vasp_completed == TaskState.SUCCESS for calc in calcs_to_check
221222
)
222223
else TaskState.FAILED
223224
)
@@ -247,11 +248,11 @@ def from_directory(
247248

248249
return cls(
249250
endpoint_structures=endpoint_structures,
250-
endpoint_energies = endpoint_energies,
251-
endpoint_directories = [str(ep_dir) for ep_dir in endpoint_directories],
252-
endpoint_calculations = ep_calcs if store_calculations else None,
251+
endpoint_energies=endpoint_energies,
252+
endpoint_directories=[str(ep_dir) for ep_dir in endpoint_directories],
253+
endpoint_calculations=ep_calcs if store_calculations else None,
253254
image_calculations=image_calculations if store_calculations else None,
254-
image_structures = image_structures,
255+
image_structures=image_structures,
255256
dir_name=str(dir_name),
256257
image_directories=[str(img_dir) for img_dir in image_directories],
257258
orig_inputs=inputs["orig_inputs"],
@@ -271,7 +272,7 @@ def from_directories(
271272
endpoint_directories: list[str | Path],
272273
neb_directory: str | Path,
273274
volumetric_files: Tuple[str, ...] = _VOLUMETRIC_FILES,
274-
**neb_task_doc_kwargs
275+
**neb_task_doc_kwargs,
275276
) -> Self:
276277
"""
277278
Return an NebTaskDoc from endpoint and NEB calculation directories.
@@ -282,12 +283,26 @@ def from_directories(
282283
endpoint_calculations = [None for _ in range(2)]
283284
endpoint_objects = [None for _ in range(2)]
284285
for idx, endpoint_dir in enumerate(endpoint_directories):
285-
vasp_files = _find_vasp_files(endpoint_dir, volumetric_files=volumetric_files)
286-
ep_key = "standard" if vasp_files.get("standard") else "relax" + str(max(
287-
int(k.split("relax")[-1]) for k in vasp_files if k.startswith("relax")
288-
))
286+
vasp_files = _find_vasp_files(
287+
endpoint_dir, volumetric_files=volumetric_files
288+
)
289+
ep_key = (
290+
"standard"
291+
if vasp_files.get("standard")
292+
else "relax"
293+
+ str(
294+
max(
295+
int(k.split("relax")[-1])
296+
for k in vasp_files
297+
if k.startswith("relax")
298+
)
299+
)
300+
)
289301

290-
endpoint_calculations[idx], endpoint_objects[idx] = Calculation.from_vasp_files(
302+
(
303+
endpoint_calculations[idx],
304+
endpoint_objects[idx],
305+
) = Calculation.from_vasp_files(
291306
dir_name=endpoint_dir,
292307
task_name=f"NEB endpoint {idx + 1}",
293308
vasprun_file=vasp_files[ep_key]["vasprun_file"],
@@ -299,16 +314,17 @@ def from_directories(
299314
"parse_potcar_file": False,
300315
},
301316
)
302-
317+
303318
return cls.from_directory(
304319
neb_directory,
305320
volumetric_files=volumetric_files,
306-
endpoint_calculations = endpoint_calculations,
307-
endpoint_objects = endpoint_objects,
308-
endpoint_directories = endpoint_directories,
309-
**neb_task_doc_kwargs
321+
endpoint_calculations=endpoint_calculations,
322+
endpoint_objects=endpoint_objects,
323+
endpoint_directories=endpoint_directories,
324+
**neb_task_doc_kwargs,
310325
)
311-
326+
327+
312328
def neb_barrier_spline_fit(
313329
energies: Sequence[float],
314330
spline_kwargs: dict | None = None,

emmet-core/tests/test_neb.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def test_neb_doc(test_dir, from_dir: bool):
3737
assert neb_doc.num_images == 3
3838
assert len(neb_doc.image_structures) == neb_doc.num_images
3939
assert len(neb_doc.energies) == neb_doc.num_images
40-
assert len(neb_doc.structures) == neb_doc.num_images + 2 # always includes endpoints
40+
assert (
41+
len(neb_doc.structures) == neb_doc.num_images + 2
42+
) # always includes endpoints
4143
assert isinstance(neb_doc.orig_inputs, OrigInputs)
4244

4345
# test that NEB image calculations are all VASP Calculation objects
@@ -77,8 +79,8 @@ def test_neb_doc(test_dir, from_dir: bool):
7779
)
7880
assert len(neb_doc.image_energies) == neb_doc.num_images
7981

80-
def test_from_directories(test_dir):
8182

83+
def test_from_directories(test_dir):
8284
with TemporaryDirectory() as tmpdir:
8385
tmpdir = Path(tmpdir)
8486
shutil.unpack_archive(test_dir / "neb_sample_calc.zip", tmpdir, "zip")
@@ -87,23 +89,32 @@ def test_from_directories(test_dir):
8789
tmpdir / "neb",
8890
)
8991

90-
assert all(isinstance(ep_calc,Calculation) for ep_calc in neb_doc.endpoint_calculations)
91-
9292
assert all(
93-
"relax_endpoint_" in ep_dir for ep_dir in neb_doc.endpoint_directories
93+
isinstance(ep_calc, Calculation) for ep_calc in neb_doc.endpoint_calculations
9494
)
9595

96+
assert all("relax_endpoint_" in ep_dir for ep_dir in neb_doc.endpoint_directories)
97+
9698
assert len(neb_doc.energies) == neb_doc.num_images + 2
9799
assert len(neb_doc.structures) == neb_doc.num_images + 2
98-
assert isinstance(neb_doc.barrier_analysis,dict)
100+
assert isinstance(neb_doc.barrier_analysis, dict)
99101

100102
assert all(
101103
neb_doc.barrier_analysis.get(k) is not None
102-
for k in ("energies","frame_index","cubic_spline_pars","ts_frame_index","ts_energy","ts_in_frames","forward_barrier","reverse_barrier")
104+
for k in (
105+
"energies",
106+
"frame_index",
107+
"cubic_spline_pars",
108+
"ts_frame_index",
109+
"ts_energy",
110+
"ts_in_frames",
111+
"forward_barrier",
112+
"reverse_barrier",
113+
)
103114
)
104115

105116
assert all(
106-
getattr(neb_doc,f"{direction}_barrier") == neb_doc.barrier_analysis[f"{direction}_barrier"]
117+
getattr(neb_doc, f"{direction}_barrier")
118+
== neb_doc.barrier_analysis[f"{direction}_barrier"]
107119
for direction in ("forward", "reverse")
108120
)
109-

0 commit comments

Comments
 (0)