Skip to content

Commit 0efeb6d

Browse files
Fix info (#323)
* Test density updated correctly * Update density during MD correctly
1 parent 2a08145 commit 0efeb6d

File tree

2 files changed

+43
-14
lines changed

2 files changed

+43
-14
lines changed

janus_core/calculations/md.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -509,12 +509,21 @@ def __init__(
509509

510510
self._parse_correlations()
511511

512-
def _set_time_step(self):
513-
"""Set time in fs and current dynamics step to info."""
512+
def _set_info(self):
513+
"""Set time in fs, current dynamics step, and density to info."""
514514
time = (self.offset * self.timestep + self.dyn.get_time()) / units.fs
515515
step = self.offset + self.dyn.nsteps
516516
self.dyn.atoms.info["time_fs"] = time
517517
self.dyn.atoms.info["step"] = step
518+
try:
519+
density = (
520+
np.sum(self.dyn.atoms.get_masses())
521+
/ self.dyn.atoms.get_volume()
522+
* DENS_FACT
523+
)
524+
self.dyn.atoms.info["density"] = density
525+
except ValueError:
526+
self.dyn.atoms.info["density"] = 0.0
518527

519528
def _prepare_restart(self) -> None:
520529
"""Prepare restart files, structure and offset."""
@@ -726,19 +735,13 @@ def get_stats(self) -> dict[str, float]:
726735
e_kin = self.dyn.atoms.get_kinetic_energy() / self.n_atoms
727736
current_temp = e_kin / (1.5 * units.kB)
728737

729-
self._set_time_step()
738+
self._set_info()
730739

731740
time_now = datetime.datetime.now()
732741
real_time = time_now - self.dyn.atoms.info["real_time"]
733742
self.dyn.atoms.info["real_time"] = time_now
734743

735744
try:
736-
density = (
737-
np.sum(self.dyn.atoms.get_masses())
738-
/ self.dyn.atoms.get_volume()
739-
* DENS_FACT
740-
)
741-
self.dyn.atoms.info["density"] = density
742745
volume = self.dyn.atoms.get_volume()
743746
pressure = (
744747
-np.trace(
@@ -754,7 +757,6 @@ def get_stats(self) -> dict[str, float]:
754757
except ValueError:
755758
volume = 0.0
756759
pressure = 0.0
757-
density = 0.0
758760
pressure_tensor = np.zeros(6)
759761

760762
return {
@@ -765,7 +767,7 @@ def get_stats(self) -> dict[str, float]:
765767
"EKin/N": e_kin,
766768
"T": current_temp,
767769
"ETot/N": e_pot + e_kin,
768-
"Density": density,
770+
"Density": self.dyn.atoms.info["density"],
769771
"Volume": volume,
770772
"P": pressure,
771773
"Pxx": pressure_tensor[0],
@@ -874,7 +876,7 @@ def _write_traj(self) -> None:
874876
self.dyn.nsteps > self.traj_start + self.traj_start % self.traj_every
875877
)
876878

877-
self._set_time_step()
879+
self._set_info()
878880
write_kwargs = self.write_kwargs
879881
write_kwargs["filename"] = self.traj_file
880882
write_kwargs["append"] = append
@@ -895,7 +897,7 @@ def _write_final_state(self) -> None:
895897
# Append if final file has been created
896898
append = self.created_final_file
897899

898-
self._set_time_step()
900+
self._set_info()
899901
write_kwargs = self.write_kwargs
900902
write_kwargs["filename"] = self.final_file
901903
write_kwargs["append"] = append
@@ -998,7 +1000,7 @@ def _write_restart(self) -> None:
9981000
if step > 0:
9991001
write_kwargs = self.write_kwargs
10001002
write_kwargs["filename"] = self._restart_file
1001-
self._set_time_step()
1003+
self._set_info()
10021004

10031005
output_structs(
10041006
images=self.struct,

tests/test_md.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,3 +1058,30 @@ def test_auto_restart_restart_stem(tmp_path):
10581058

10591059
final_traj = read(traj_path, index=":")
10601060
assert len(final_traj) == 9
1061+
1062+
1063+
def test_set_info(tmp_path):
1064+
"""Test info is set at correct frequency."""
1065+
file_prefix = tmp_path / "npt"
1066+
traj_path = tmp_path / "npt-traj.extxyz"
1067+
1068+
single_point = SinglePoint(
1069+
struct_path=DATA_PATH / "NaCl.cif",
1070+
arch="mace",
1071+
calc_kwargs={"model": MODEL_PATH},
1072+
)
1073+
1074+
npt = NPT(
1075+
struct=single_point.struct,
1076+
steps=10,
1077+
temp=1000,
1078+
stats_every=7,
1079+
file_prefix=file_prefix,
1080+
seed=2024,
1081+
traj_every=10,
1082+
)
1083+
1084+
npt.run()
1085+
final_struct = read(traj_path, index="-1")
1086+
assert npt.struct.info["density"] == pytest.approx(2.120952627887493)
1087+
assert final_struct.info["density"] == pytest.approx(2.120952627887493)

0 commit comments

Comments
 (0)