Skip to content

Commit 7972458

Browse files
precommit
1 parent 9fc6e94 commit 7972458

File tree

1 file changed

+47
-21
lines changed

1 file changed

+47
-21
lines changed

emmet-core/emmet/core/neb.py

+47-21
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from emmet.core.vasp.calculation import Calculation, VaspObject
2626
from emmet.core.vasp.task_valid import TaskState
2727

28+
2829
class NebMethod(ValueEnum):
2930
"""Common methods for NEB calculations.
3031
@@ -35,6 +36,7 @@ class NebMethod(ValueEnum):
3536
CLIMBING_IMAGE = "climbing_image"
3637
APPROX = "approxNEB"
3738

39+
3840
class HopFailureReason(ValueEnum):
3941
"""Define failure modes for ApproxNEB calculations."""
4042

@@ -46,20 +48,34 @@ class HopFailureReason(ValueEnum):
4648
class BarrierAnalysis(BaseModel):
4749
"""Define analysis schema for barrier calculations."""
4850

49-
energies : list[float] = Field(description="The energies of each frame along the reaction coordinate.")
50-
frame_index : list[float] | None = Field(None, description="The fractional index along the reaction coordinate, between 0 and 1.")
51-
cubic_spline_pars : list[list[float]] | None = Field(None, description="Parameters of the cubic spline used to fit the energies.")
52-
ts_frame_index : float | None = Field(None, description="The fractional index of the reaction coordinate.")
53-
ts_energy : float | None = Field(None,description="The energy at the transition state.")
54-
ts_in_frames : bool | None = Field(None, description="Whether the transition state is one of the computed snapshots.")
55-
forward_barrier : float | None = Field(None,description="The forwards barrier.")
56-
reverse_barrier : float | None = Field(None,description="The reverse barrier.")
51+
energies: list[float] = Field(
52+
description="The energies of each frame along the reaction coordinate."
53+
)
54+
frame_index: list[float] | None = Field(
55+
None,
56+
description="The fractional index along the reaction coordinate, between 0 and 1.",
57+
)
58+
cubic_spline_pars: list[list[float]] | None = Field(
59+
None, description="Parameters of the cubic spline used to fit the energies."
60+
)
61+
ts_frame_index: float | None = Field(
62+
None, description="The fractional index of the reaction coordinate."
63+
)
64+
ts_energy: float | None = Field(
65+
None, description="The energy at the transition state."
66+
)
67+
ts_in_frames: bool | None = Field(
68+
None,
69+
description="Whether the transition state is one of the computed snapshots.",
70+
)
71+
forward_barrier: float | None = Field(None, description="The forwards barrier.")
72+
reverse_barrier: float | None = Field(None, description="The reverse barrier.")
5773

5874
@classmethod
5975
def from_energies(
6076
cls,
6177
energies: Sequence[float],
62-
spline_kwargs: dict[str,Any] | None = None,
78+
spline_kwargs: dict[str, Any] | None = None,
6379
frame_match_tol: float = 1.0e-6,
6480
) -> Self:
6581
"""
@@ -91,9 +107,9 @@ def from_energies(
91107
analysis["ts_frame_index"] = -1
92108
analysis["ts_energy"] = -np.inf
93109
for crit_point in crit_points:
94-
if (energy := spline_fit(crit_point)) > analysis["ts_energy"] and spline_fit(
95-
crit_point, 2
96-
) <= 0.0:
110+
if (energy := spline_fit(crit_point)) > analysis[
111+
"ts_energy"
112+
] and spline_fit(crit_point, 2) <= 0.0:
97113
analysis["ts_frame_index"] = crit_point
98114
analysis["ts_energy"] = float(energy)
99115

@@ -107,9 +123,10 @@ def from_energies(
107123

108124
return cls(**analysis)
109125

126+
110127
class NebResult(BaseModel):
111128
"""Container class to store high-level NEB calculation info.
112-
129+
113130
This is intended to be code-agnostic, whereas NebTaskDoc
114131
is VASP-specific.
115132
"""
@@ -170,9 +187,13 @@ def set_barriers(self) -> Self:
170187
):
171188
self.barrier_analysis = BarrierAnalysis.from_energies(self.energies)
172189
for k in ("forward", "reverse"):
173-
setattr(self, f"{k}_barrier", getattr(self.barrier_analysis,f"{k}_barrier",None))
190+
setattr(
191+
self,
192+
f"{k}_barrier",
193+
getattr(self.barrier_analysis, f"{k}_barrier", None),
194+
)
174195
return self
175-
196+
176197

177198
class NebTaskDoc(NebResult):
178199
"""Define schema for VASP NEB tasks."""
@@ -187,7 +208,7 @@ class NebTaskDoc(NebResult):
187208
endpoint_calculations: list[Calculation] | None = Field(
188209
None, description="Calculation information for the endpoint structures"
189210
)
190-
endpoint_objects: list[dict[VaspObject,Any]] | None = Field(
211+
endpoint_objects: list[dict[VaspObject, Any]] | None = Field(
191212
None, description="VASP objects for each endpoint calculation."
192213
)
193214
endpoint_directories: list[str] | None = Field(
@@ -235,13 +256,13 @@ class NebTaskDoc(NebResult):
235256
None, description="Timestamp for when this task was completed"
236257
)
237258

238-
task_label : str | None = Field(
239-
None, description = "Label for the NEB calculation(s)."
259+
task_label: str | None = Field(
260+
None, description="Label for the NEB calculation(s)."
240261
)
241262

242-
def model_post_init(self, __context : Any) -> None:
263+
def model_post_init(self, __context: Any) -> None:
243264
"""Ensure base model fields are populated for analysis."""
244-
265+
245266
if self.energies is None:
246267
if self.endpoint_energies is not None:
247268
self.energies = [ # type: ignore[misc]
@@ -272,7 +293,11 @@ def model_post_init(self, __context : Any) -> None:
272293
calc.input.structure for calc in self.image_calculations
273294
]
274295

275-
self.initial_images = [ep_structures[0], *intermed_structs, ep_structures[1]]
296+
self.initial_images = [
297+
ep_structures[0],
298+
*intermed_structs,
299+
ep_structures[1],
300+
]
276301

277302
@classmethod
278303
def from_directory(
@@ -439,6 +464,7 @@ def from_directories(
439464
**neb_task_doc_kwargs,
440465
)
441466

467+
442468
class NebPathwayResult(BaseModel): # type: ignore[call-arg]
443469
"""Class for containing multiple NEB calculations, as along a reaction pathway."""
444470

0 commit comments

Comments
 (0)