Skip to content

Commit c8c61d2

Browse files
More unit tests fixed now that decode_cf works
1 parent 676c188 commit c8c61d2

File tree

1 file changed

+8
-14
lines changed

1 file changed

+8
-14
lines changed

tests/test_particlefile.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,8 @@ def test_pfile_array_remove_all_particles(fieldset, chunks_obs, tmp_zarrfile):
107107
pfile.write(pset, fieldset.time_interval.left + np.timedelta64(1, "D"))
108108
pfile.write(pset, fieldset.time_interval.left + np.timedelta64(2, "D"))
109109

110-
ds = xr.open_zarr(tmp_zarrfile, decode_cf=False).load()
111-
pytest.skip(
112-
"TODO v4: Set decode_cf=True, which will mean that missing values get decoded to NaT rather than fill value"
113-
)
114-
assert np.allclose(ds["time"][:, 0], np.timedelta64(0, "s"), atol=np.timedelta64(1, "ms"))
110+
ds = xr.open_zarr(tmp_zarrfile)
111+
np.testing.assert_allclose(ds["time"][:, 0] - fieldset.time_interval.left, np.timedelta64(0, "s"))
115112
if chunks_obs is not None:
116113
assert ds["time"][:].shape == chunks
117114
else:
@@ -200,7 +197,7 @@ def IncrLon(particles, fieldset): # pragma: no cover
200197
for _ in range(npart):
201198
pset.execute(IncrLon, dt=dt, runtime=np.timedelta64(1, "s"), output_file=pfile)
202199

203-
ds = xr.open_zarr(tmp_zarrfile, decode_cf=False)
200+
ds = xr.open_zarr(tmp_zarrfile)
204201
samplevar = ds["sample_var"][:]
205202
assert samplevar.shape == (npart, min(maxvar, npart + 1))
206203
# test whether samplevar[:, k] = k
@@ -216,12 +213,10 @@ def test_write_timebackward(fieldset, tmp_zarrfile):
216213
pfile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s"))
217214
pset.execute(DoNothing, runtime=np.timedelta64(3, "s"), dt=-np.timedelta64(1, "s"), output_file=pfile)
218215

219-
# TODO v4: Fix decode_cf and remove the following lines
220-
ds = xr.open_zarr(tmp_zarrfile, decode_cf=False)
216+
ds = xr.open_zarr(tmp_zarrfile)
221217
trajs = ds["trajectory"][:]
222218

223219
output_time = ds["time"][:].values
224-
output_time = np.where(output_time > 100.0, np.nan, output_time) # ignore fill values
225220

226221
assert trajs.values.dtype == "int64"
227222
assert np.all(np.diff(trajs.values) < 0) # all particles written in order of release
@@ -298,13 +293,12 @@ def IncreaseAge(particles, fieldset): # pragma: no cover
298293

299294
pset.execute(IncreaseAge, runtime=np.timedelta64(npart * 2, "s"), dt=np.timedelta64(1, "s"), output_file=ofile)
300295

301-
# TODO v4: Fix metadata and re-enable decode_cf
302-
ds = xr.open_zarr(tmp_zarrfile, decode_cf=False)
296+
ds = xr.open_zarr(tmp_zarrfile)
303297
age = ds["age"][:].values
304-
ds_time = np.zeros_like(age)
298+
ds_timediff = np.zeros_like(age)
305299
for i in range(npart):
306-
ds_time[i, :] = ds.time.values[i, :] - (time[i] - fieldset.time_interval.left) / np.timedelta64(1, "s")
307-
assert np.allclose(age[~np.isnan(age)], ds_time[~np.isnan(age)])
300+
ds_timediff[i, :] = (ds.time.values[i, :] - time[i]) / np.timedelta64(1, "s")
301+
np.testing.assert_equal(age, ds_timediff)
308302

309303

310304
def test_reset_dt(fieldset, tmp_zarrfile):

0 commit comments

Comments
 (0)