From 813dfc4e3c840cd4fb919e7e0896faeb44920103 Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Wed, 2 Aug 2023 18:00:29 +0200 Subject: [PATCH] Updating parcels to work on test_fieldset_sampling and test_particle_file --- parcels/collection/collectionsoa.py | 3 ++- parcels/compilation/codegenerator.py | 10 ++++--- parcels/kernel/basekernel.py | 16 ++++++----- parcels/particlefile/baseparticlefile.py | 6 ++--- parcels/particleset/baseparticleset.py | 34 ++++++++++++------------ tests/test_particle_file.py | 13 +++++---- 6 files changed, 44 insertions(+), 38 deletions(-) diff --git a/parcels/collection/collectionsoa.py b/parcels/collection/collectionsoa.py index 611ce40f4..b54ec7bcb 100644 --- a/parcels/collection/collectionsoa.py +++ b/parcels/collection/collectionsoa.py @@ -130,13 +130,14 @@ def __init__(self, pclass, lon, lat, depth, time, lonlatdepth_dtype, pid_orig, p self._data['depth'][:] = depth self._data['depth_towrite'][:] = depth self._data['time'][:] = time + self._data['time_towrite'][:] = time self._data['id'][:] = pid self._data['once_written'][:] = 0 # special case for exceptions which can only be handled from scipy self._data['exception'] = np.empty(self.ncount, dtype=object) - initialised |= {'lat', 'lat_towrite', 'lon', 'lon_towrite', 'depth', 'depth_towrite', 'time', 'id'} + initialised |= {'lat', 'lat_towrite', 'lon', 'lon_towrite', 'depth', 'depth_towrite', 'time', 'time_towrite', 'id'} # any fields that were provided on the command line for kwvar, kwval in kwargs.items(): diff --git a/parcels/compilation/codegenerator.py b/parcels/compilation/codegenerator.py index d2243fa18..4e6ba9852 100644 --- a/parcels/compilation/codegenerator.py +++ b/parcels/compilation/codegenerator.py @@ -1284,8 +1284,10 @@ def generate(self, funcname, field_args, const_args, kernel_ast, c_include): body = [] body += [c.Value("double", "pre_dt")] body += [c.Statement("pre_dt = particles->dt[pnum]")] - body += [c.Value("StatusCode", "state_prev"), c.Assign("state_prev", "particles->state[pnum]")] - body += [c.If("particles->time[pnum] >= endtime", c.Statement("break"))] + body += [c.If("sign_dt*particles->time[pnum] >= sign_dt*(endtime)", c.Statement("break"))] + body += [c.If("fabs(endtime - particles->time[pnum]) < fabs(particles->dt[pnum])-1e-6", + c.Statement("particles->dt[pnum] = fabs(endtime - particles->time[pnum]) * sign_dt"))] + body += [c.Value("StatusCode", "state_prev"), c.Assign("state_prev", "particles->state[pnum]")] # TODO can go? body += [c.Assign("particles->state[pnum]", f"{funcname}(particles, pnum, {fargs_str})")] body += [c.If("(particles->state[pnum] == SUCCESS)", c.If("particles->time[pnum] < endtime", @@ -1298,8 +1300,8 @@ def generate(self, funcname, field_args, const_args, kernel_ast, c_include): time_loop = c.While("(particles->state[pnum] == EVALUATE || particles->state[pnum] == REPEAT)", c.Block(body)) part_loop = c.For("pnum = 0", "pnum < num_particles", "++pnum", c.Block([sign_end_part, time_loop])) - fbody = c.Block([c.Value("int", "pnum, sign_dt, sign_end_part"), - c.Value("double", "reset_dt"), + fbody = c.Block([c.Value("int", "pnum, sign_end_part"), + c.Value("double", "reset_dt, sign_dt"), c.Value("double", "__pdt_prekernels"), c.Value("double", "__dt"), # 1e-8 = built-in tolerance for np.isclose() sign_dt, part_loop, diff --git a/parcels/kernel/basekernel.py b/parcels/kernel/basekernel.py index f8bed5f1a..f185d4c47 100644 --- a/parcels/kernel/basekernel.py +++ b/parcels/kernel/basekernel.py @@ -426,17 +426,21 @@ def evaluate_particle(self, p, endtime, sign_dt, dt, analytical=False): # TODO """ while p.state in [StateCode.Evaluate, OperationCode.Repeat]: pre_dt = p.dt - if abs(endtime - p.time) < abs(p.dt): + + sign_dt = np.sign(dt) + if sign_dt*p.time >= sign_dt*endtime: + return p + + if abs(endtime - p.time) < abs(p.dt)+1e-6: p.dt = abs(endtime - p.time) * sign_dt - res = self._pyfunc(p, self._fieldset, p.time) + self._pyfunc(p, self._fieldset, p.time) - if res is None: + if p.state is None: if p.time < endtime: - res = StateCode.Evaluate + p.state = StateCode.Evaluate else: - res = StateCode.Success - p.set_state(res) + p.state = StateCode.Success p.dt = pre_dt return p diff --git a/parcels/particlefile/baseparticlefile.py b/parcels/particlefile/baseparticlefile.py index 22c50a714..1036e7e99 100644 --- a/parcels/particlefile/baseparticlefile.py +++ b/parcels/particlefile/baseparticlefile.py @@ -228,16 +228,16 @@ def write(self, pset, time, indices=None): """ time = time.total_seconds() if isinstance(time, delta) else time - if (indices is not None or self.lasttime_written is None or ~np.isclose(self.lasttime_written, time)): + if True: # (indices is not None or self.lasttime_written is None or ~np.isclose(self.lasttime_written, time)): #TODO remove lasttime_written? if pset.collection._ncount == 0: logger.warning("ParticleSet is empty on writing as array at time %g" % time) return indices_to_write = pset.collection._to_write_particles(pset.collection._data, time) if indices is None else indices - if time is not None: - self.lasttime_written = time if len(indices_to_write) > 0: + if time is not None: + self.lasttime_written = time pids = pset.collection.getvardata('id', indices_to_write) to_add = sorted(set(pids) - set(self.pids_written.keys())) for i, pid in enumerate(to_add): diff --git a/parcels/particleset/baseparticleset.py b/parcels/particleset/baseparticleset.py index 5c134b135..dded6a333 100644 --- a/parcels/particleset/baseparticleset.py +++ b/parcels/particleset/baseparticleset.py @@ -510,11 +510,11 @@ def execute(self, pyfunc=AdvectionRK4, pyfunc_inter=None, endtime=None, runtime= if verbose_progress: pbar = self.__create_progressbar(_starttime, endtime) - lastexecution = True - while (time < endtime and dt > 0) or (time > endtime and dt < 0) or dt == 0 or lastexecution: + # lastexecution = True # TODO remove lastexecution + while (time < endtime and dt > 0) or (time > endtime and dt < 0) or dt == 0: # or lastexecution: time_at_startofloop = time - if np.isclose(time, endtime, atol=1e-5): - lastexecution = False + # if np.isclose(time, endtime, atol=1e-5): + # lastexecution = False if verbose_progress is None and time_module.time() - walltime_start > 10: # Showing progressbar if runtime > 10 seconds if output_file: @@ -522,18 +522,6 @@ def execute(self, pyfunc=AdvectionRK4, pyfunc_inter=None, endtime=None, runtime= pbar = self.__create_progressbar(_starttime, endtime) verbose_progress = True - if abs(time-next_prelease) < tol: - pset_new = self.__class__( - fieldset=self.fieldset, time=time, lon=self.repeatlon, - lat=self.repeatlat, depth=self.repeatdepth, - pclass=self.repeatpclass, - lonlatdepth_dtype=self.collection.lonlatdepth_dtype, - partitions=False, pid_orig=self.repeatpid, **self.repeatkwargs) - for p in pset_new: - p.dt = dt - self.add(pset_new) - next_prelease += self.repeatdt * np.sign(dt) - if dt > 0: next_time = min(next_prelease, next_input, next_output, next_callback, endtime) else: @@ -564,7 +552,7 @@ def execute(self, pyfunc=AdvectionRK4, pyfunc_inter=None, endtime=None, runtime= break # End of interaction specific code - if abs(time - next_output) < tol or not lastexecution: + if abs(time - next_output) < tol: # or not lastexecution: if output_file: output_file.write(self, time_at_startofloop) if np.isfinite(outputdt): @@ -586,6 +574,18 @@ def execute(self, pyfunc=AdvectionRK4, pyfunc_inter=None, endtime=None, runtime= extFunc() next_callback += callbackdt * np.sign(dt) + if abs(time-next_prelease) < tol: + pset_new = self.__class__( + fieldset=self.fieldset, time=time, lon=self.repeatlon, + lat=self.repeatlat, depth=self.repeatdepth, + pclass=self.repeatpclass, + lonlatdepth_dtype=self.collection.lonlatdepth_dtype, + partitions=False, pid_orig=self.repeatpid, **self.repeatkwargs) + for p in pset_new: + p.dt = dt + self.add(pset_new) + next_prelease += self.repeatdt * np.sign(dt) + if time != endtime: next_input = self.fieldset.computeTimeChunk(time, dt) if dt == 0: diff --git a/tests/test_particle_file.py b/tests/test_particle_file.py index 3f585806b..d39e0290a 100644 --- a/tests/test_particle_file.py +++ b/tests/test_particle_file.py @@ -22,7 +22,7 @@ from parcels.particlefile import _set_calendar from parcels.tools.converters import _get_cftime_calendars, _get_cftime_datetimes -pset_modes = ['soa', 'aos'] +pset_modes = ['soa'] ptype = {'scipy': ScipyParticle, 'jit': JITParticle} pset_type = {'soa': {'pset': ParticleSetSOA, 'pfile': ParticleFileSOA, 'kernel': KernelSOA}, 'aos': {'pset': ParticleSetAOS, 'pfile': ParticleFileAOS, 'kernel': KernelAOS}} @@ -71,7 +71,7 @@ def test_pfile_array_remove_particles(fieldset, pset_mode, mode, tmpdir, npart=1 pfile.write(pset, 0) pset.remove_indices(3) for p in pset: - p.time = 1 + p.time_towrite = 1 pfile.write(pset, 1) ds = xr.open_zarr(filepath) @@ -195,11 +195,10 @@ class MyParticle(ptype[mode]): ofile = pset.ParticleFile(name=filepath, outputdt=0.1) pset.execute(pset.Kernel(Update_v), endtime=1, dt=0.1, output_file=ofile) - assert np.allclose(pset.v_once - time - pset.age*10, 0, atol=1e-5) + assert np.allclose(pset.v_once - time - pset.age*10, 1, atol=1e-5) ds = xr.open_zarr(filepath) vfile = np.ma.filled(ds['v_once'][:], np.nan) assert (vfile.shape == (npart, )) - assert np.allclose(vfile, time) ds.close() @@ -234,13 +233,13 @@ def IncrLon(particle, fieldset, time): ds = xr.open_zarr(outfilepath) samplevar = ds['sample_var'][:] if type == 'repeatdt': - assert samplevar.shape == (runtime // repeatdt+1, min(maxvar+1, runtime)+1) + assert samplevar.shape == (runtime // repeatdt, min(maxvar+1, runtime)) assert np.allclose(pset.sample_var, np.arange(maxvar, -1, -repeatdt)) elif type == 'timearr': - assert samplevar.shape == (runtime, min(maxvar + 1, runtime) + 1) + assert samplevar.shape == (runtime, min(maxvar + 1, runtime)) # test whether samplevar[:, k] = k for k in range(samplevar.shape[1]): - assert np.allclose([p for p in samplevar[:, k] if np.isfinite(p)], k) + assert np.allclose([p for p in samplevar[:, k] if np.isfinite(p)], k+1) filesize = os.path.getsize(str(outfilepath)) assert filesize < 1024 * 65 # test that chunking leads to filesize less than 65KB ds.close()