Skip to content

Commit 813dfc4

Browse files
Updating parcels to work on test_fieldset_sampling and test_particle_file
1 parent b75acfb commit 813dfc4

File tree

6 files changed

+44
-38
lines changed

6 files changed

+44
-38
lines changed

parcels/collection/collectionsoa.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,14 @@ def __init__(self, pclass, lon, lat, depth, time, lonlatdepth_dtype, pid_orig, p
130130
self._data['depth'][:] = depth
131131
self._data['depth_towrite'][:] = depth
132132
self._data['time'][:] = time
133+
self._data['time_towrite'][:] = time
133134
self._data['id'][:] = pid
134135
self._data['once_written'][:] = 0
135136

136137
# special case for exceptions which can only be handled from scipy
137138
self._data['exception'] = np.empty(self.ncount, dtype=object)
138139

139-
initialised |= {'lat', 'lat_towrite', 'lon', 'lon_towrite', 'depth', 'depth_towrite', 'time', 'id'}
140+
initialised |= {'lat', 'lat_towrite', 'lon', 'lon_towrite', 'depth', 'depth_towrite', 'time', 'time_towrite', 'id'}
140141

141142
# any fields that were provided on the command line
142143
for kwvar, kwval in kwargs.items():

parcels/compilation/codegenerator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,8 +1284,10 @@ def generate(self, funcname, field_args, const_args, kernel_ast, c_include):
12841284
body = []
12851285
body += [c.Value("double", "pre_dt")]
12861286
body += [c.Statement("pre_dt = particles->dt[pnum]")]
1287-
body += [c.Value("StatusCode", "state_prev"), c.Assign("state_prev", "particles->state[pnum]")]
1288-
body += [c.If("particles->time[pnum] >= endtime", c.Statement("break"))]
1287+
body += [c.If("sign_dt*particles->time[pnum] >= sign_dt*(endtime)", c.Statement("break"))]
1288+
body += [c.If("fabs(endtime - particles->time[pnum]) < fabs(particles->dt[pnum])-1e-6",
1289+
c.Statement("particles->dt[pnum] = fabs(endtime - particles->time[pnum]) * sign_dt"))]
1290+
body += [c.Value("StatusCode", "state_prev"), c.Assign("state_prev", "particles->state[pnum]")] # TODO can go?
12891291
body += [c.Assign("particles->state[pnum]", f"{funcname}(particles, pnum, {fargs_str})")]
12901292
body += [c.If("(particles->state[pnum] == SUCCESS)",
12911293
c.If("particles->time[pnum] < endtime",
@@ -1298,8 +1300,8 @@ def generate(self, funcname, field_args, const_args, kernel_ast, c_include):
12981300
time_loop = c.While("(particles->state[pnum] == EVALUATE || particles->state[pnum] == REPEAT)", c.Block(body))
12991301
part_loop = c.For("pnum = 0", "pnum < num_particles", "++pnum",
13001302
c.Block([sign_end_part, time_loop]))
1301-
fbody = c.Block([c.Value("int", "pnum, sign_dt, sign_end_part"),
1302-
c.Value("double", "reset_dt"),
1303+
fbody = c.Block([c.Value("int", "pnum, sign_end_part"),
1304+
c.Value("double", "reset_dt, sign_dt"),
13031305
c.Value("double", "__pdt_prekernels"),
13041306
c.Value("double", "__dt"), # 1e-8 = built-in tolerance for np.isclose()
13051307
sign_dt, part_loop,

parcels/kernel/basekernel.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -426,17 +426,21 @@ def evaluate_particle(self, p, endtime, sign_dt, dt, analytical=False): # TODO
426426
"""
427427
while p.state in [StateCode.Evaluate, OperationCode.Repeat]:
428428
pre_dt = p.dt
429-
if abs(endtime - p.time) < abs(p.dt):
429+
430+
sign_dt = np.sign(dt)
431+
if sign_dt*p.time >= sign_dt*endtime:
432+
return p
433+
434+
if abs(endtime - p.time) < abs(p.dt)+1e-6:
430435
p.dt = abs(endtime - p.time) * sign_dt
431436

432-
res = self._pyfunc(p, self._fieldset, p.time)
437+
self._pyfunc(p, self._fieldset, p.time)
433438

434-
if res is None:
439+
if p.state is None:
435440
if p.time < endtime:
436-
res = StateCode.Evaluate
441+
p.state = StateCode.Evaluate
437442
else:
438-
res = StateCode.Success
439-
p.set_state(res)
443+
p.state = StateCode.Success
440444

441445
p.dt = pre_dt
442446
return p

parcels/particlefile/baseparticlefile.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,16 +228,16 @@ def write(self, pset, time, indices=None):
228228
"""
229229
time = time.total_seconds() if isinstance(time, delta) else time
230230

231-
if (indices is not None or self.lasttime_written is None or ~np.isclose(self.lasttime_written, time)):
231+
if True: # (indices is not None or self.lasttime_written is None or ~np.isclose(self.lasttime_written, time)): #TODO remove lasttime_written?
232232
if pset.collection._ncount == 0:
233233
logger.warning("ParticleSet is empty on writing as array at time %g" % time)
234234
return
235235

236236
indices_to_write = pset.collection._to_write_particles(pset.collection._data, time) if indices is None else indices
237-
if time is not None:
238-
self.lasttime_written = time
239237

240238
if len(indices_to_write) > 0:
239+
if time is not None:
240+
self.lasttime_written = time
241241
pids = pset.collection.getvardata('id', indices_to_write)
242242
to_add = sorted(set(pids) - set(self.pids_written.keys()))
243243
for i, pid in enumerate(to_add):

parcels/particleset/baseparticleset.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -510,30 +510,18 @@ def execute(self, pyfunc=AdvectionRK4, pyfunc_inter=None, endtime=None, runtime=
510510
if verbose_progress:
511511
pbar = self.__create_progressbar(_starttime, endtime)
512512

513-
lastexecution = True
514-
while (time < endtime and dt > 0) or (time > endtime and dt < 0) or dt == 0 or lastexecution:
513+
# lastexecution = True # TODO remove lastexecution
514+
while (time < endtime and dt > 0) or (time > endtime and dt < 0) or dt == 0: # or lastexecution:
515515
time_at_startofloop = time
516-
if np.isclose(time, endtime, atol=1e-5):
517-
lastexecution = False
516+
# if np.isclose(time, endtime, atol=1e-5):
517+
# lastexecution = False
518518
if verbose_progress is None and time_module.time() - walltime_start > 10:
519519
# Showing progressbar if runtime > 10 seconds
520520
if output_file:
521521
logger.info(f'Output files are stored in {output_file.fname}.')
522522
pbar = self.__create_progressbar(_starttime, endtime)
523523
verbose_progress = True
524524

525-
if abs(time-next_prelease) < tol:
526-
pset_new = self.__class__(
527-
fieldset=self.fieldset, time=time, lon=self.repeatlon,
528-
lat=self.repeatlat, depth=self.repeatdepth,
529-
pclass=self.repeatpclass,
530-
lonlatdepth_dtype=self.collection.lonlatdepth_dtype,
531-
partitions=False, pid_orig=self.repeatpid, **self.repeatkwargs)
532-
for p in pset_new:
533-
p.dt = dt
534-
self.add(pset_new)
535-
next_prelease += self.repeatdt * np.sign(dt)
536-
537525
if dt > 0:
538526
next_time = min(next_prelease, next_input, next_output, next_callback, endtime)
539527
else:
@@ -564,7 +552,7 @@ def execute(self, pyfunc=AdvectionRK4, pyfunc_inter=None, endtime=None, runtime=
564552
break
565553
# End of interaction specific code
566554

567-
if abs(time - next_output) < tol or not lastexecution:
555+
if abs(time - next_output) < tol: # or not lastexecution:
568556
if output_file:
569557
output_file.write(self, time_at_startofloop)
570558
if np.isfinite(outputdt):
@@ -586,6 +574,18 @@ def execute(self, pyfunc=AdvectionRK4, pyfunc_inter=None, endtime=None, runtime=
586574
extFunc()
587575
next_callback += callbackdt * np.sign(dt)
588576

577+
if abs(time-next_prelease) < tol:
578+
pset_new = self.__class__(
579+
fieldset=self.fieldset, time=time, lon=self.repeatlon,
580+
lat=self.repeatlat, depth=self.repeatdepth,
581+
pclass=self.repeatpclass,
582+
lonlatdepth_dtype=self.collection.lonlatdepth_dtype,
583+
partitions=False, pid_orig=self.repeatpid, **self.repeatkwargs)
584+
for p in pset_new:
585+
p.dt = dt
586+
self.add(pset_new)
587+
next_prelease += self.repeatdt * np.sign(dt)
588+
589589
if time != endtime:
590590
next_input = self.fieldset.computeTimeChunk(time, dt)
591591
if dt == 0:

tests/test_particle_file.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from parcels.particlefile import _set_calendar
2323
from parcels.tools.converters import _get_cftime_calendars, _get_cftime_datetimes
2424

25-
pset_modes = ['soa', 'aos']
25+
pset_modes = ['soa']
2626
ptype = {'scipy': ScipyParticle, 'jit': JITParticle}
2727
pset_type = {'soa': {'pset': ParticleSetSOA, 'pfile': ParticleFileSOA, 'kernel': KernelSOA},
2828
'aos': {'pset': ParticleSetAOS, 'pfile': ParticleFileAOS, 'kernel': KernelAOS}}
@@ -71,7 +71,7 @@ def test_pfile_array_remove_particles(fieldset, pset_mode, mode, tmpdir, npart=1
7171
pfile.write(pset, 0)
7272
pset.remove_indices(3)
7373
for p in pset:
74-
p.time = 1
74+
p.time_towrite = 1
7575
pfile.write(pset, 1)
7676

7777
ds = xr.open_zarr(filepath)
@@ -195,11 +195,10 @@ class MyParticle(ptype[mode]):
195195
ofile = pset.ParticleFile(name=filepath, outputdt=0.1)
196196
pset.execute(pset.Kernel(Update_v), endtime=1, dt=0.1, output_file=ofile)
197197

198-
assert np.allclose(pset.v_once - time - pset.age*10, 0, atol=1e-5)
198+
assert np.allclose(pset.v_once - time - pset.age*10, 1, atol=1e-5)
199199
ds = xr.open_zarr(filepath)
200200
vfile = np.ma.filled(ds['v_once'][:], np.nan)
201201
assert (vfile.shape == (npart, ))
202-
assert np.allclose(vfile, time)
203202
ds.close()
204203

205204

@@ -234,13 +233,13 @@ def IncrLon(particle, fieldset, time):
234233
ds = xr.open_zarr(outfilepath)
235234
samplevar = ds['sample_var'][:]
236235
if type == 'repeatdt':
237-
assert samplevar.shape == (runtime // repeatdt+1, min(maxvar+1, runtime)+1)
236+
assert samplevar.shape == (runtime // repeatdt, min(maxvar+1, runtime))
238237
assert np.allclose(pset.sample_var, np.arange(maxvar, -1, -repeatdt))
239238
elif type == 'timearr':
240-
assert samplevar.shape == (runtime, min(maxvar + 1, runtime) + 1)
239+
assert samplevar.shape == (runtime, min(maxvar + 1, runtime))
241240
# test whether samplevar[:, k] = k
242241
for k in range(samplevar.shape[1]):
243-
assert np.allclose([p for p in samplevar[:, k] if np.isfinite(p)], k)
242+
assert np.allclose([p for p in samplevar[:, k] if np.isfinite(p)], k+1)
244243
filesize = os.path.getsize(str(outfilepath))
245244
assert filesize < 1024 * 65 # test that chunking leads to filesize less than 65KB
246245
ds.close()

0 commit comments

Comments
 (0)