Skip to content

Commit 4f55da7

Browse files
Merge pull request #1614 from OceanParcels/code_cleanups
Code cleanups
2 parents 2910326 + a8cb61f commit 4f55da7

File tree

4 files changed

+29
-27
lines changed

4 files changed

+29
-27
lines changed

parcels/application_kernels/advection.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -251,21 +251,21 @@ def compute_ds(F0, F1, r, direction, tol):
251251
s_min = min(abs(ds_x), abs(ds_y), abs(ds_z), abs(ds_t / (dxdy * dz)))
252252

253253
# calculate end position in time s_min
254-
def compute_rs(ds, r, B, delta, s_min):
254+
def compute_rs(r, B, delta, s_min):
255255
if abs(B) < tol:
256256
return -delta * s_min + r
257257
else:
258258
return (r + delta / B) * math.exp(-B * s_min) - delta / B
259259

260-
rs_x = compute_rs(ds_x, xsi, B_x, delta_x, s_min)
261-
rs_y = compute_rs(ds_y, eta, B_y, delta_y, s_min)
260+
rs_x = compute_rs(xsi, B_x, delta_x, s_min)
261+
rs_y = compute_rs(eta, B_y, delta_y, s_min)
262262

263-
particle_dlon = (1.-rs_x)*(1.-rs_y) * px[0] + rs_x * (1.-rs_y) * px[1] + rs_x * rs_y * px[2] + (1.-rs_x)*rs_y * px[3] - particle.lon # noqa
264-
particle_dlat = (1.-rs_x)*(1.-rs_y) * py[0] + rs_x * (1.-rs_y) * py[1] + rs_x * rs_y * py[2] + (1.-rs_x)*rs_y * py[3] - particle.lat # noqa
263+
particle_dlon += (1.-rs_x)*(1.-rs_y) * px[0] + rs_x * (1.-rs_y) * px[1] + rs_x * rs_y * px[2] + (1.-rs_x)*rs_y * px[3] - particle.lon # noqa
264+
particle_dlat += (1.-rs_x)*(1.-rs_y) * py[0] + rs_x * (1.-rs_y) * py[1] + rs_x * rs_y * py[2] + (1.-rs_x)*rs_y * py[3] - particle.lat # noqa
265265

266266
if withW:
267-
rs_z = compute_rs(ds_z, zeta, B_z, delta_z, s_min)
268-
particle.depth = (1.-rs_z) * pz[0] + rs_z * pz[1]
267+
rs_z = compute_rs(zeta, B_z, delta_z, s_min)
268+
particle_ddepth += (1.-rs_z) * pz[0] + rs_z * pz[1] - particle.depth # noqa
269269

270270
if particle.dt > 0:
271271
particle.dt = max(direction * s_min * (dxdy * dz), 1e-7)

parcels/compilation/codegenerator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,8 @@ def visit_Subscript(self, node):
657657
if isinstance(node.value, FieldNode) or isinstance(node.value, VectorFieldNode):
658658
node.ccode = node.value.__getitem__(node.slice.ccode).ccode
659659
elif isinstance(node.value, ParticleXiYiZiTiAttributeNode):
660-
node.ccode = f"{node.value.obj}->{node.value.attr}[pnum, {node.slice.ccode}]"
660+
ngrid = str(self.fieldset.gridset.size if self.fieldset is not None else 1)
661+
node.ccode = f"{node.value.obj}->{node.value.attr}[pnum*{ngrid}+{node.slice.ccode}]"
661662
elif isinstance(node.value, IntrinsicNode):
662663
raise NotImplementedError(f"Subscript not implemented for object type {type(node.value).__name__}")
663664
else:

parcels/kernel.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ def __init__(self, fieldset, ptype, pyfunc=None, funcname=None, funccode=None, p
5555
self._ptype = ptype
5656
self._lib = None
5757
self.delete_cfiles = delete_cfiles
58-
self._cleanup_files = None
59-
self._cleanup_lib = None
6058
self._c_include = c_include
6159

6260
# Derive meta information from pyfunc, if not given

tests/test_particlefile.py

100644100755
Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def Update_lon(particle, fieldset, time):
284284
def test_write_xiyi(fieldset, mode, tmpdir):
285285
outfilepath = tmpdir.join("pfile_xiyi.zarr")
286286
fieldset.U.data[:] = 1 # set a non-zero zonal velocity
287-
fieldset.add_field(Field(name='P', data=np.zeros((2, 20)), lon=np.linspace(0, 1, 20), lat=[0, 2]))
287+
fieldset.add_field(Field(name='P', data=np.zeros((3, 20)), lon=np.linspace(0, 1, 20), lat=[-2, 0, 2]))
288288
dt = 3600
289289

290290
XiYiParticle = ptype[mode].add_variables([
@@ -304,26 +304,29 @@ def Get_XiYi(particle, fieldset, time):
304304

305305
def SampleP(particle, fieldset, time):
306306
if time > 5*3600:
307-
tmp = fieldset.P[particle] # noqa
307+
_ = fieldset.P[particle] # To trigger sampling of the P field
308308

309-
pset = ParticleSet(fieldset, pclass=XiYiParticle, lon=[0], lat=[0.2], lonlatdepth_dtype=np.float64)
309+
pset = ParticleSet(fieldset, pclass=XiYiParticle, lon=[0, 0.2], lat=[0.2, 1], lonlatdepth_dtype=np.float64)
310310
pfile = pset.ParticleFile(name=outfilepath, outputdt=dt)
311-
pset.execute([Get_XiYi, SampleP, AdvectionRK4], endtime=10*dt, dt=dt, output_file=pfile)
311+
pset.execute([SampleP, Get_XiYi, AdvectionRK4], endtime=10*dt, dt=dt, output_file=pfile)
312312

313313
ds = xr.open_zarr(outfilepath)
314-
pxi0 = ds['pxi0'][:].values[0].astype(np.int32)
315-
pxi1 = ds['pxi1'][:].values[0].astype(np.int32)
316-
lons = ds['lon'][:].values[0]
317-
pyi = ds['pyi'][:].values[0].astype(np.int32)
318-
lats = ds['lat'][:].values[0]
319-
320-
assert (pxi0[0] == 0) and (pxi0[-1] == 11) # check that particle has moved
321-
assert np.all(pxi1[:7] == 0) # check that particle has not been sampled on grid 1 until time 6
322-
assert np.all(pxi1[7:] > 0) # check that particle has not been sampled on grid 1 after time 6
323-
for xi, lon in zip(pxi0[1:], lons[1:]):
324-
assert fieldset.U.grid.lon[xi] <= lon < fieldset.U.grid.lon[xi+1]
325-
for yi, lat in zip(pyi[1:], lats[1:]):
326-
assert fieldset.U.grid.lat[yi] <= lat < fieldset.U.grid.lat[yi+1]
314+
pxi0 = ds['pxi0'][:].values.astype(np.int32)
315+
pxi1 = ds['pxi1'][:].values.astype(np.int32)
316+
lons = ds['lon'][:].values
317+
pyi = ds['pyi'][:].values.astype(np.int32)
318+
lats = ds['lat'][:].values
319+
320+
for p in range(pyi.shape[0]):
321+
assert (pxi0[p, 0] == 0) and (pxi0[p, -1] == pset[p].pxi0) # check that particle has moved
322+
assert np.all(pxi1[p, :6] == 0) # check that particle has not been sampled on grid 1 until time 6
323+
assert np.all(pxi1[p, 6:] > 0) # check that particle has not been sampled on grid 1 after time 6
324+
for xi, lon in zip(pxi0[p, 1:], lons[p, 1:]):
325+
assert fieldset.U.grid.lon[xi] <= lon < fieldset.U.grid.lon[xi+1]
326+
for xi, lon in zip(pxi1[p, 6:], lons[p, 6:]):
327+
assert fieldset.P.grid.lon[xi] <= lon < fieldset.P.grid.lon[xi+1]
328+
for yi, lat in zip(pyi[p, 1:], lats[p, 1:]):
329+
assert fieldset.U.grid.lat[yi] <= lat < fieldset.U.grid.lat[yi+1]
327330
ds.close()
328331

329332

0 commit comments

Comments
 (0)