Skip to content

Commit

Permalink
Updating parcels to work on test_fieldset_sampling and test_particle_…
Browse files Browse the repository at this point in the history
…file
  • Loading branch information
erikvansebille committed Aug 2, 2023
1 parent b75acfb commit 813dfc4
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 38 deletions.
3 changes: 2 additions & 1 deletion parcels/collection/collectionsoa.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,14 @@ def __init__(self, pclass, lon, lat, depth, time, lonlatdepth_dtype, pid_orig, p
self._data['depth'][:] = depth
self._data['depth_towrite'][:] = depth
self._data['time'][:] = time
self._data['time_towrite'][:] = time
self._data['id'][:] = pid
self._data['once_written'][:] = 0

# special case for exceptions which can only be handled from scipy
self._data['exception'] = np.empty(self.ncount, dtype=object)

initialised |= {'lat', 'lat_towrite', 'lon', 'lon_towrite', 'depth', 'depth_towrite', 'time', 'id'}
initialised |= {'lat', 'lat_towrite', 'lon', 'lon_towrite', 'depth', 'depth_towrite', 'time', 'time_towrite', 'id'}

# any fields that were provided on the command line
for kwvar, kwval in kwargs.items():
Expand Down
10 changes: 6 additions & 4 deletions parcels/compilation/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,8 +1284,10 @@ def generate(self, funcname, field_args, const_args, kernel_ast, c_include):
body = []
body += [c.Value("double", "pre_dt")]
body += [c.Statement("pre_dt = particles->dt[pnum]")]
body += [c.Value("StatusCode", "state_prev"), c.Assign("state_prev", "particles->state[pnum]")]
body += [c.If("particles->time[pnum] >= endtime", c.Statement("break"))]
body += [c.If("sign_dt*particles->time[pnum] >= sign_dt*(endtime)", c.Statement("break"))]
body += [c.If("fabs(endtime - particles->time[pnum]) < fabs(particles->dt[pnum])-1e-6",
c.Statement("particles->dt[pnum] = fabs(endtime - particles->time[pnum]) * sign_dt"))]
body += [c.Value("StatusCode", "state_prev"), c.Assign("state_prev", "particles->state[pnum]")] # TODO can go?
body += [c.Assign("particles->state[pnum]", f"{funcname}(particles, pnum, {fargs_str})")]
body += [c.If("(particles->state[pnum] == SUCCESS)",
c.If("particles->time[pnum] < endtime",
Expand All @@ -1298,8 +1300,8 @@ def generate(self, funcname, field_args, const_args, kernel_ast, c_include):
time_loop = c.While("(particles->state[pnum] == EVALUATE || particles->state[pnum] == REPEAT)", c.Block(body))
part_loop = c.For("pnum = 0", "pnum < num_particles", "++pnum",
c.Block([sign_end_part, time_loop]))
fbody = c.Block([c.Value("int", "pnum, sign_dt, sign_end_part"),
c.Value("double", "reset_dt"),
fbody = c.Block([c.Value("int", "pnum, sign_end_part"),
c.Value("double", "reset_dt, sign_dt"),
c.Value("double", "__pdt_prekernels"),
c.Value("double", "__dt"), # 1e-8 = built-in tolerance for np.isclose()
sign_dt, part_loop,
Expand Down
16 changes: 10 additions & 6 deletions parcels/kernel/basekernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,17 +426,21 @@ def evaluate_particle(self, p, endtime, sign_dt, dt, analytical=False): # TODO
"""
while p.state in [StateCode.Evaluate, OperationCode.Repeat]:
pre_dt = p.dt
if abs(endtime - p.time) < abs(p.dt):

sign_dt = np.sign(dt)
if sign_dt*p.time >= sign_dt*endtime:
return p

if abs(endtime - p.time) < abs(p.dt)+1e-6:
p.dt = abs(endtime - p.time) * sign_dt

res = self._pyfunc(p, self._fieldset, p.time)
self._pyfunc(p, self._fieldset, p.time)

if res is None:
if p.state is None:
if p.time < endtime:
res = StateCode.Evaluate
p.state = StateCode.Evaluate
else:
res = StateCode.Success
p.set_state(res)
p.state = StateCode.Success

p.dt = pre_dt
return p
Expand Down
6 changes: 3 additions & 3 deletions parcels/particlefile/baseparticlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,16 +228,16 @@ def write(self, pset, time, indices=None):
"""
time = time.total_seconds() if isinstance(time, delta) else time

if (indices is not None or self.lasttime_written is None or ~np.isclose(self.lasttime_written, time)):
if True: # (indices is not None or self.lasttime_written is None or ~np.isclose(self.lasttime_written, time)): #TODO remove lasttime_written?
if pset.collection._ncount == 0:
logger.warning("ParticleSet is empty on writing as array at time %g" % time)
return

indices_to_write = pset.collection._to_write_particles(pset.collection._data, time) if indices is None else indices
if time is not None:
self.lasttime_written = time

if len(indices_to_write) > 0:
if time is not None:
self.lasttime_written = time
pids = pset.collection.getvardata('id', indices_to_write)
to_add = sorted(set(pids) - set(self.pids_written.keys()))
for i, pid in enumerate(to_add):
Expand Down
34 changes: 17 additions & 17 deletions parcels/particleset/baseparticleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,30 +510,18 @@ def execute(self, pyfunc=AdvectionRK4, pyfunc_inter=None, endtime=None, runtime=
if verbose_progress:
pbar = self.__create_progressbar(_starttime, endtime)

lastexecution = True
while (time < endtime and dt > 0) or (time > endtime and dt < 0) or dt == 0 or lastexecution:
# lastexecution = True # TODO remove lastexecution
while (time < endtime and dt > 0) or (time > endtime and dt < 0) or dt == 0: # or lastexecution:
time_at_startofloop = time
if np.isclose(time, endtime, atol=1e-5):
lastexecution = False
# if np.isclose(time, endtime, atol=1e-5):
# lastexecution = False
if verbose_progress is None and time_module.time() - walltime_start > 10:
# Showing progressbar if runtime > 10 seconds
if output_file:
logger.info(f'Output files are stored in {output_file.fname}.')
pbar = self.__create_progressbar(_starttime, endtime)
verbose_progress = True

if abs(time-next_prelease) < tol:
pset_new = self.__class__(
fieldset=self.fieldset, time=time, lon=self.repeatlon,
lat=self.repeatlat, depth=self.repeatdepth,
pclass=self.repeatpclass,
lonlatdepth_dtype=self.collection.lonlatdepth_dtype,
partitions=False, pid_orig=self.repeatpid, **self.repeatkwargs)
for p in pset_new:
p.dt = dt
self.add(pset_new)
next_prelease += self.repeatdt * np.sign(dt)

if dt > 0:
next_time = min(next_prelease, next_input, next_output, next_callback, endtime)
else:
Expand Down Expand Up @@ -564,7 +552,7 @@ def execute(self, pyfunc=AdvectionRK4, pyfunc_inter=None, endtime=None, runtime=
break
# End of interaction specific code

if abs(time - next_output) < tol or not lastexecution:
if abs(time - next_output) < tol: # or not lastexecution:
if output_file:
output_file.write(self, time_at_startofloop)
if np.isfinite(outputdt):
Expand All @@ -586,6 +574,18 @@ def execute(self, pyfunc=AdvectionRK4, pyfunc_inter=None, endtime=None, runtime=
extFunc()
next_callback += callbackdt * np.sign(dt)

if abs(time-next_prelease) < tol:
pset_new = self.__class__(
fieldset=self.fieldset, time=time, lon=self.repeatlon,
lat=self.repeatlat, depth=self.repeatdepth,
pclass=self.repeatpclass,
lonlatdepth_dtype=self.collection.lonlatdepth_dtype,
partitions=False, pid_orig=self.repeatpid, **self.repeatkwargs)
for p in pset_new:
p.dt = dt
self.add(pset_new)
next_prelease += self.repeatdt * np.sign(dt)

if time != endtime:
next_input = self.fieldset.computeTimeChunk(time, dt)
if dt == 0:
Expand Down
13 changes: 6 additions & 7 deletions tests/test_particle_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from parcels.particlefile import _set_calendar
from parcels.tools.converters import _get_cftime_calendars, _get_cftime_datetimes

pset_modes = ['soa', 'aos']
pset_modes = ['soa']
ptype = {'scipy': ScipyParticle, 'jit': JITParticle}
pset_type = {'soa': {'pset': ParticleSetSOA, 'pfile': ParticleFileSOA, 'kernel': KernelSOA},
'aos': {'pset': ParticleSetAOS, 'pfile': ParticleFileAOS, 'kernel': KernelAOS}}
Expand Down Expand Up @@ -71,7 +71,7 @@ def test_pfile_array_remove_particles(fieldset, pset_mode, mode, tmpdir, npart=1
pfile.write(pset, 0)
pset.remove_indices(3)
for p in pset:
p.time = 1
p.time_towrite = 1
pfile.write(pset, 1)

ds = xr.open_zarr(filepath)
Expand Down Expand Up @@ -195,11 +195,10 @@ class MyParticle(ptype[mode]):
ofile = pset.ParticleFile(name=filepath, outputdt=0.1)
pset.execute(pset.Kernel(Update_v), endtime=1, dt=0.1, output_file=ofile)

assert np.allclose(pset.v_once - time - pset.age*10, 0, atol=1e-5)
assert np.allclose(pset.v_once - time - pset.age*10, 1, atol=1e-5)
ds = xr.open_zarr(filepath)
vfile = np.ma.filled(ds['v_once'][:], np.nan)
assert (vfile.shape == (npart, ))
assert np.allclose(vfile, time)
ds.close()


Expand Down Expand Up @@ -234,13 +233,13 @@ def IncrLon(particle, fieldset, time):
ds = xr.open_zarr(outfilepath)
samplevar = ds['sample_var'][:]
if type == 'repeatdt':
assert samplevar.shape == (runtime // repeatdt+1, min(maxvar+1, runtime)+1)
assert samplevar.shape == (runtime // repeatdt, min(maxvar+1, runtime))
assert np.allclose(pset.sample_var, np.arange(maxvar, -1, -repeatdt))
elif type == 'timearr':
assert samplevar.shape == (runtime, min(maxvar + 1, runtime) + 1)
assert samplevar.shape == (runtime, min(maxvar + 1, runtime))
# test whether samplevar[:, k] = k
for k in range(samplevar.shape[1]):
assert np.allclose([p for p in samplevar[:, k] if np.isfinite(p)], k)
assert np.allclose([p for p in samplevar[:, k] if np.isfinite(p)], k+1)
filesize = os.path.getsize(str(outfilepath))
assert filesize < 1024 * 65 # test that chunking leads to filesize less than 65KB
ds.close()
Expand Down

0 comments on commit 813dfc4

Please sign in to comment.