Skip to content

Commit 465fc12

Browse files
committed
patching large time numbers for trajectory writers to not overflow the low precision they use
1 parent fc58376 commit 465fc12

File tree

2 files changed

+78
-10
lines changed

2 files changed

+78
-10
lines changed

moleculekit/writers.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,37 @@
3131
# fmt: on
3232

3333

34+
def _format_large_time(mol, ext):
35+
step = mol.step
36+
time = mol.time
37+
nframes = mol.numFrames
38+
39+
if len(step) == 1:
40+
trajfreq = step[0]
41+
else:
42+
trajfreq = step[1] - step[0]
43+
44+
if step[-1] != step[-1].astype(np.uint32):
45+
logger.warning(
46+
f"Molecule.step contains values too large to be written to a {ext} file. They will be renumbered starting from 1."
47+
)
48+
step = np.arange(1, nframes + 1, dtype=np.uint32)
49+
else:
50+
step = step.astype(np.uint32)
51+
52+
time = time / 1e3 # convert from fs to ps
53+
if time[-1] != time[-1].astype(np.float32):
54+
logger.warning(
55+
f"Molecule.time contains values too large to be written to a {ext} file. They will be renumbered starting from 0."
56+
)
57+
if trajfreq == 0:
58+
raise AssertionError("The trajectory step should not be 0")
59+
timestep = (mol.fstep / trajfreq) / 1e-6
60+
time = (mol.time - ((mol.step[0] - trajfreq) * timestep)) / 1e3
61+
62+
return step, time
63+
64+
3465
def _format_pdb_name(name, resname, element=None):
3566
name = name[:4]
3667
first_col = f"{name:<4}"
@@ -263,9 +294,10 @@ def XTCwrite(mol, filename):
263294
if os.path.isfile(filename):
264295
os.unlink(filename)
265296

297+
step, time = _format_large_time(mol, "XTC")
298+
266299
box = box.astype(np.float32) * 0.1
267-
step = step.astype(np.int32)
268-
time = time.astype(np.float32) / 1e3 # Convert from fs to ps
300+
time = time.astype(np.float32)
269301
coords = coords.astype(np.float32) * 0.1 # Convert from A to nm
270302
if not box.flags["C_CONTIGUOUS"]:
271303
box = np.ascontiguousarray(box)
@@ -671,6 +703,16 @@ def DCDwrite(mol, filename):
671703
nsavc = int(mol.step[1] - mol.step[0])
672704
fstep = mol.fstep * 1000 # ns to ps
673705
delta = fstep / nsavc / 0.04888821 # Conversion factor found in OpenMM
706+
if mol.step[0] != mol.step[0].astype(np.int32):
707+
logger.warning(
708+
"Molecule.step contains values too large to be written to DCD file. They will be renumbered starting from 1."
709+
)
710+
istart //= nsavc
711+
delta *= nsavc
712+
nsavc = 1
713+
# If it's still too large just start from 1
714+
if istart != np.array(istart).astype(np.int32):
715+
istart = 1
674716
except Exception:
675717
istart = 0
676718
nsavc = 1
@@ -699,8 +741,7 @@ def TRRwrite(mol, filename):
699741
from moleculekit.trr import TRRTrajectoryFile
700742

701743
xyz = np.transpose(mol.coords, (2, 0, 1)) / 10 # Convert Angstrom to nm
702-
time = mol.time / 1000 # Convert fs to ps
703-
step = mol.step
744+
step, time = _format_large_time(mol, "TRR")
704745
boxvectors = np.transpose(mol.boxvectors, (2, 0, 1)) / 10 # Angstrom to nm
705746
with TRRTrajectoryFile(filename, "w") as fh:
706747
fh.write(xyz, time=time, step=step, box=boxvectors, lambd=None)
@@ -728,8 +769,10 @@ def NETCDFwrite(mol, filename):
728769
)
729770
n_frames, n_atoms = coordinates.shape[0], coordinates.shape[1]
730771

772+
step, time = _format_large_time(mol, "NETCDF")
773+
731774
time = ensure_type(
732-
mol.time / 1000, # Convert from fs to ps
775+
time, # In ps
733776
np.float32,
734777
1,
735778
"time",
@@ -739,7 +782,7 @@ def NETCDFwrite(mol, filename):
739782
add_newaxis_on_deficient_ndim=True,
740783
)
741784
step = ensure_type(
742-
mol.step,
785+
step,
743786
np.int32,
744787
1,
745788
"step",
@@ -984,7 +1027,7 @@ def MDTRAJwrite(mol, filename):
9841027
traj = Trajectory(
9851028
xyz=np.transpose(mol.coords, (2, 0, 1)) / 10, # Ang to nm
9861029
topology=traj.topology,
987-
time=time,
1030+
time=time.astype(np.float32),
9881031
unitcell_lengths=box,
9891032
unitcell_angles=boxangles,
9901033
)

tests/test_writers.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,18 +185,41 @@ def _test_cif_mol2_atom_renaming():
185185
assert filelines == reflines, f"Failed comparison of {reffile2} {tmpfile}"
186186

187187

188-
@pytest.mark.parametrize("ext", ("netcdf", "trr", "binpos", "dcd", "xyz", "xyz.gz"))
189-
def _test_traj_writers(ext):
188+
@pytest.mark.parametrize(
189+
"ext", ("xtc", "netcdf", "trr", "binpos", "dcd", "xyz", "xyz.gz")
190+
)
191+
@pytest.mark.parametrize("maxtime", [1e9, 1e15])
192+
def _test_traj_writers(ext, maxtime):
190193
from moleculekit.molecule import Molecule
191194
import tempfile
192195

193196
mol = Molecule(os.path.join(curr_dir, "test_readers", "1N09", "structure.prmtop"))
194197
mol.read(os.path.join(curr_dir, "test_readers", "1N09", "output.dcd"))
198+
# 1e9 fs = 1us. Test if the trajectories can write steps of 100ps over 1us trajectories
199+
trajfreq = 25000
200+
timestep = 4
201+
timefreq = trajfreq * timestep
202+
mol.time[:] = np.arange(
203+
maxtime,
204+
maxtime + mol.numFrames * timefreq,
205+
timefreq,
206+
dtype=Molecule._dtypes["time"],
207+
)
208+
mol.step[:] = np.arange(
209+
maxtime / timestep,
210+
maxtime / timestep + (mol.numFrames * trajfreq),
211+
trajfreq,
212+
dtype=Molecule._dtypes["step"],
213+
)
195214

196215
with tempfile.TemporaryDirectory() as tmpdir:
197216
mol.write(os.path.join(tmpdir, f"output.{ext}"))
198217
molc = Molecule(os.path.join(tmpdir, f"output.{ext}"))
199218

219+
if maxtime > 1e10:
220+
mol.time -= (mol.step[0] - trajfreq) * timestep
221+
mol.step[:] = range(1, mol.numFrames + 1)
222+
200223
if ext == "binpos":
201224
assert np.allclose(mol.coords, molc.coords, atol=1e-6)
202225
elif ext in ("xyz", "xyz.gz"):
@@ -207,12 +230,14 @@ def _test_traj_writers(ext):
207230
fieldPrecision={"coords": 2e-5},
208231
)
209232
else:
233+
coor_prec = 3e-6 if ext != "xtc" else 1e-2
234+
assert abs(mol.fstep - molc.fstep) < 1e-2
210235
assert mol_equal(
211236
mol,
212237
molc,
213238
checkFields=Molecule._traj_fields,
214239
exceptFields=("fileloc"),
215-
fieldPrecision={"coords": 3e-6, "box": 3e-6},
240+
fieldPrecision={"coords": coor_prec, "box": 3e-6, "time": 1e-3},
216241
)
217242

218243

0 commit comments

Comments
 (0)