Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*.jpeg
*.dot
*.png
*.svg
~*
*~
.project
Expand Down
26 changes: 23 additions & 3 deletions janus_core/calculations/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class EoS(BaseCalculation):
write_kwargs : Optional[OutputKwargs],
Keyword arguments to pass to ase.io.write to save generated structures.
Default is {}.
plot_to_file : bool
Whether to save plot equation of state to svg. Default is False.
plot_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to EquationOfState.plot. Default is {}.
file_prefix : Optional[PathLike]
Prefix for output filenames. Default is inferred from structure name, or
chemical formula of the structure.
Expand Down Expand Up @@ -119,6 +123,8 @@ def __init__(
write_results: bool = True,
write_structures: bool = False,
write_kwargs: Optional[OutputKwargs] = None,
plot_to_file: bool = False,
plot_kwargs: Optional[dict[str, Any]] = None,
file_prefix: Optional[PathLike] = None,
) -> None:
"""
Expand Down Expand Up @@ -174,12 +180,16 @@ def __init__(
write_kwargs : Optional[OutputKwargs],
Keyword arguments to pass to ase.io.write to save generated structures.
Default is {}.
plot_to_file : bool
Whether to save plot equation of state to svg. Default is False.
plot_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to EquationOfState.plot. Default is {}.
file_prefix : Optional[PathLike]
Prefix for output filenames. Default is inferred from structure name, or
chemical formula of the structure.
"""
(read_kwargs, minimize_kwargs, write_kwargs) = none_to_dict(
(read_kwargs, minimize_kwargs, write_kwargs)
(read_kwargs, minimize_kwargs, write_kwargs, plot_kwargs) = none_to_dict(
(read_kwargs, minimize_kwargs, write_kwargs, plot_kwargs)
)

self.min_volume = min_volume
Expand All @@ -192,6 +202,8 @@ def __init__(
self.write_results = write_results
self.write_structures = write_structures
self.write_kwargs = write_kwargs
self.plot_to_file = plot_to_file
self.plot_kwargs = plot_kwargs

if (
(self.minimize or self.minimize_all)
Expand Down Expand Up @@ -243,12 +255,17 @@ def __init__(
"filemode": "a",
}

# Set output file
# Set output files
self.write_kwargs.setdefault("filename", None)
self.write_kwargs["filename"] = self._build_filename(
"generated.extxyz", filename=self.write_kwargs["filename"]
).absolute()

self.plot_kwargs.setdefault("filename", None)
self.plot_kwargs["filename"] = self._build_filename(
"eos-plot.svg", filename=self.plot_kwargs["filename"]
).absolute()

self.results = {}
self.volumes = []
self.energies = []
Expand Down Expand Up @@ -320,6 +337,9 @@ def run(self) -> EoSResults:
"v_0": v_0,
}

if self.plot_to_file:
eos.plot(**self.plot_kwargs)

return self.results

def _calc_volumes_energies(self) -> None:
Expand Down
7 changes: 7 additions & 0 deletions janus_core/cli/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def eos(
Option(help="Whether to write out all genereated structures."),
] = False,
write_kwargs: WriteKwargs = None,
plot_to_file: Annotated[
bool,
Option(help="Whether to plot equation of state."),
] = False,
arch: Architecture = "mace_mp",
device: Device = "cpu",
model_path: ModelPath = None,
Expand Down Expand Up @@ -112,6 +116,8 @@ def eos(
write_kwargs : Optional[dict[str, Any]],
Keyword arguments to pass to ase.io.write to save generated structures.
Default is {}.
plot_to_file : bool
Whether to save plot equation of state to svg. Default is False.
arch : Optional[str]
MLIP architecture to use for geometry optimization.
Default is "mace_mp".
Expand Down Expand Up @@ -176,6 +182,7 @@ def eos(
"minimize_kwargs": minimize_kwargs,
"write_structures": write_structures,
"write_kwargs": write_kwargs,
"plot_to_file": plot_to_file,
"file_prefix": file_prefix,
}

Expand Down
18 changes: 18 additions & 0 deletions tests/test_eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,21 @@ def test_logging(tmp_path):

assert log_file.exists()
assert single_point.struct.info["emissions"] > 0


def test_plot(tmp_path):
"""Test plotting equation of state."""
plot_file = tmp_path / "plot.svg"

eos = EoS(
struct_path=DATA_PATH / "NaCl.cif",
arch="mace_mp",
calc_kwargs={"model": MODEL_PATH},
plot_to_file=True,
plot_kwargs={"filename": plot_file},
file_prefix=tmp_path / "NaCl",
)

results = eos.run()
assert all(key in results for key in ("eos", "bulk_modulus", "e_0", "v_0"))
assert plot_file.exists()
22 changes: 22 additions & 0 deletions tests/test_eos_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,25 @@ def test_invalid_traj_input(tmp_path):
)
assert result.exit_code == 1
assert isinstance(result.exception, ValueError)


def test_plot(tmp_path):
"""Test plotting equation of state."""
file_prefix = tmp_path / "NaCl"
plot_path = tmp_path / "NaCl-eos-plot.svg"

result = runner.invoke(
app,
[
"eos",
"--struct",
DATA_PATH / "NaCl.cif",
"--n-volumes",
4,
"--plot-to-file",
"--file-prefix",
file_prefix,
],
)
assert result.exit_code == 0
assert plot_path.exists()