Skip to content

Commit a39a72c

Browse files
Removing distionary from method to kernels
Instead directly using Kernel functions in test_advection
1 parent c107a5d commit a39a72c

File tree

1 file changed

+42
-54
lines changed

1 file changed

+42
-54
lines changed

tests/test_advection.py

Lines changed: 42 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,6 @@
2626
)
2727
from tests.utils import round_and_hash_float_array
2828

29-
kernel = {
30-
"EE": AdvectionEE,
31-
"RK2": AdvectionRK2,
32-
"RK2_3D": AdvectionRK2_3D,
33-
"RK4": AdvectionRK4,
34-
"RK4_3D": AdvectionRK4_3D,
35-
"RK45": AdvectionRK45,
36-
# "AA": AdvectionAnalytical,
37-
"AdvDiffEM": AdvectionDiffusionEM,
38-
"AdvDiffM1": AdvectionDiffusionM1,
39-
}
40-
4129

4230
@pytest.mark.parametrize("mesh", ["spherical", "flat"])
4331
def test_advection_zonal(mesh, npart=10):
@@ -260,32 +248,32 @@ def test_radialrotation(npart=10):
260248

261249

262250
@pytest.mark.parametrize(
263-
"method, rtol",
251+
"kernel, rtol",
264252
[
265-
("EE", 1e-2),
266-
("AdvDiffEM", 1e-2),
267-
("AdvDiffM1", 1e-2),
268-
("RK2", 6e-5),
269-
("RK2_3D", 6e-5),
270-
("RK4", 1e-5),
271-
("RK4_3D", 1e-5),
272-
("RK45", 1e-4),
253+
(AdvectionEE, 1e-2),
254+
(AdvectionDiffusionEM, 1e-2),
255+
(AdvectionDiffusionM1, 1e-2),
256+
(AdvectionRK2, 6e-5),
257+
(AdvectionRK2_3D, 6e-5),
258+
(AdvectionRK4, 1e-5),
259+
(AdvectionRK4_3D, 1e-5),
260+
(AdvectionRK45, 1e-4),
273261
],
274262
)
275-
def test_moving_eddy(method, rtol):
263+
def test_moving_eddy(kernel, rtol):
276264
ds = moving_eddy_dataset()
277265
grid = XGrid.from_dataset(ds)
278266
U = Field("U", ds["U"], grid, interp_method=XLinear)
279267
V = Field("V", ds["V"], grid, interp_method=XLinear)
280-
if method in ["RK2_3D", "RK4_3D"]:
268+
if kernel in [AdvectionRK2_3D, AdvectionRK4_3D]:
281269
# Using W to test 3D advection (assuming same velocity as V)
282270
W = Field("W", ds["V"], grid, interp_method=XLinear)
283271
UVW = VectorField("UVW", U, V, W)
284272
fieldset = FieldSet([U, V, W, UVW])
285273
else:
286274
UV = VectorField("UV", U, V)
287275
fieldset = FieldSet([U, V, UV])
288-
if method in ["AdvDiffEM", "AdvDiffM1"]:
276+
if kernel in [AdvectionDiffusionEM, AdvectionDiffusionM1]:
289277
# Add zero diffusivity field for diffusion kernels
290278
ds["Kh"] = (["time", "depth", "YG", "XG"], np.full(ds["U"].shape, 0))
291279
fieldset.add_field(Field("Kh", ds["Kh"], grid, interp_method=XLinear), "Kh_zonal")
@@ -295,11 +283,11 @@ def test_moving_eddy(method, rtol):
295283
start_lon, start_lat, start_z = 12000, 12500, 12500
296284
dt = np.timedelta64(30, "m")
297285

298-
if method == "RK45":
286+
if kernel == AdvectionRK45:
299287
fieldset.add_constant("RK45_tol", rtol)
300288

301289
pset = ParticleSet(fieldset, lon=start_lon, lat=start_lat, z=start_z, time=np.timedelta64(0, "s"))
302-
pset.execute(kernel[method], dt=dt, endtime=np.timedelta64(1, "h"))
290+
pset.execute(kernel, dt=dt, endtime=np.timedelta64(1, "h"))
303291

304292
def truth_moving(x_0, y_0, t):
305293
t /= np.timedelta64(1, "s")
@@ -310,20 +298,20 @@ def truth_moving(x_0, y_0, t):
310298
exp_lon, exp_lat = truth_moving(start_lon, start_lat, pset.time[0])
311299
np.testing.assert_allclose(pset.lon, exp_lon, rtol=rtol)
312300
np.testing.assert_allclose(pset.lat, exp_lat, rtol=rtol)
313-
if method == "RK4_3D":
301+
if kernel == AdvectionRK4_3D:
314302
np.testing.assert_allclose(pset.z, exp_lat, rtol=rtol)
315303

316304

317305
@pytest.mark.parametrize(
318-
"method, rtol",
306+
"kernel, rtol",
319307
[
320-
("EE", 1e-1),
321-
("RK2", 3e-3),
322-
("RK4", 1e-5),
323-
("RK45", 1e-4),
308+
(AdvectionEE, 1e-1),
309+
(AdvectionRK2, 3e-3),
310+
(AdvectionRK4, 1e-5),
311+
(AdvectionRK45, 1e-4),
324312
],
325313
)
326-
def test_decaying_moving_eddy(method, rtol):
314+
def test_decaying_moving_eddy(kernel, rtol):
327315
ds = decaying_moving_eddy_dataset()
328316
grid = XGrid.from_dataset(ds)
329317
U = Field("U", ds["U"], grid, interp_method=XLinear)
@@ -334,12 +322,12 @@ def test_decaying_moving_eddy(method, rtol):
334322
start_lon, start_lat = 10000, 10000
335323
dt = np.timedelta64(60, "m")
336324

337-
if method == "RK45":
325+
if kernel == AdvectionRK45:
338326
fieldset.add_constant("RK45_tol", rtol)
339327
fieldset.add_constant("RK45_min_dt", 10 * 60)
340328

341329
pset = ParticleSet(fieldset, lon=start_lon, lat=start_lat, time=np.timedelta64(0, "s"))
342-
pset.execute(kernel[method], dt=dt, endtime=np.timedelta64(1, "D"))
330+
pset.execute(kernel, dt=dt, endtime=np.timedelta64(1, "D"))
343331

344332
def truth_moving(x_0, y_0, t):
345333
t /= np.timedelta64(1, "s")
@@ -361,15 +349,15 @@ def truth_moving(x_0, y_0, t):
361349

362350

363351
@pytest.mark.parametrize(
364-
"method, rtol",
352+
"kernel, rtol",
365353
[
366-
("RK2", 0.1),
367-
("RK4", 0.1),
368-
("RK45", 0.1),
354+
(AdvectionRK2, 0.1),
355+
(AdvectionRK4, 0.1),
356+
(AdvectionRK45, 0.1),
369357
],
370358
)
371359
@pytest.mark.parametrize("grid_type", ["A", "C"])
372-
def test_stommelgyre_fieldset(method, rtol, grid_type):
360+
def test_stommelgyre_fieldset(kernel, rtol, grid_type):
373361
npart = 2
374362
ds = stommel_gyre_dataset(grid_type=grid_type)
375363
grid = XGrid.from_dataset(ds)
@@ -385,7 +373,7 @@ def test_stommelgyre_fieldset(method, rtol, grid_type):
385373
start_lon = np.linspace(10e3, 100e3, npart)
386374
start_lat = np.ones_like(start_lon) * 5000e3
387375

388-
if method == "RK45":
376+
if kernel == AdvectionRK45:
389377
fieldset.add_constant("RK45_tol", rtol)
390378

391379
SampleParticle = Particle.add_variable(
@@ -397,20 +385,20 @@ def UpdateP(particles, fieldset): # pragma: no cover
397385
particles.p_start = np.where(particles.time == 0, particles.p, particles.p_start)
398386

399387
pset = ParticleSet(fieldset, pclass=SampleParticle, lon=start_lon, lat=start_lat, time=np.timedelta64(0, "s"))
400-
pset.execute([kernel[method], UpdateP], dt=dt, runtime=runtime)
388+
pset.execute([kernel, UpdateP], dt=dt, runtime=runtime)
401389
np.testing.assert_allclose(pset.p, pset.p_start, rtol=rtol)
402390

403391

404392
@pytest.mark.parametrize(
405-
"method, rtol",
393+
"kernel, rtol",
406394
[
407-
("RK2", 2e-2),
408-
("RK4", 5e-3),
409-
("RK45", 1e-4),
395+
(AdvectionRK2, 2e-2),
396+
(AdvectionRK4, 5e-3),
397+
(AdvectionRK45, 1e-4),
410398
],
411399
)
412400
@pytest.mark.parametrize("grid_type", ["A"]) # TODO also implement C-grid once available
413-
def test_peninsula_fieldset(method, rtol, grid_type):
401+
def test_peninsula_fieldset(kernel, rtol, grid_type):
414402
npart = 2
415403
ds = peninsula_dataset(grid_type=grid_type)
416404
grid = XGrid.from_dataset(ds)
@@ -425,7 +413,7 @@ def test_peninsula_fieldset(method, rtol, grid_type):
425413
start_lat = np.linspace(3e3, 47e3, npart)
426414
start_lon = 3e3 * np.ones_like(start_lat)
427415

428-
if method == "RK45":
416+
if kernel == AdvectionRK45:
429417
fieldset.add_constant("RK45_tol", rtol)
430418

431419
SampleParticle = Particle.add_variable(
@@ -437,7 +425,7 @@ def UpdateP(particles, fieldset): # pragma: no cover
437425
particles.p_start = np.where(particles.time == 0, particles.p, particles.p_start)
438426

439427
pset = ParticleSet(fieldset, pclass=SampleParticle, lon=start_lon, lat=start_lat, time=np.timedelta64(0, "s"))
440-
pset.execute([kernel[method], UpdateP], dt=dt, runtime=runtime)
428+
pset.execute([kernel, UpdateP], dt=dt, runtime=runtime)
441429
np.testing.assert_allclose(pset.p, pset.p_start, rtol=rtol)
442430

443431

@@ -483,8 +471,8 @@ def periodicBC(particles, fieldset): # pragma: no cover
483471
np.testing.assert_allclose(pset.lat, latp, atol=1e-1)
484472

485473

486-
@pytest.mark.parametrize("method", ["RK4", "RK4_3D"])
487-
def test_nemo_3D_curvilinear_fieldset(method):
474+
@pytest.mark.parametrize("kernel", [AdvectionRK4, AdvectionRK4_3D])
475+
def test_nemo_3D_curvilinear_fieldset(kernel):
488476
download_dir = parcels.download_example_dataset("NemoNorthSeaORCA025-N006_data")
489477
ufiles = download_dir.glob("*U.nc")
490478
dsu = xr.open_mfdataset(ufiles, decode_times=False, drop_variables=["nav_lat", "nav_lon"])
@@ -562,10 +550,10 @@ def test_nemo_3D_curvilinear_fieldset(method):
562550
lats = np.linspace(52.5, 51.6, npart)
563551
pset = parcels.ParticleSet(fieldset, lon=lons, lat=lats, z=np.ones_like(lons))
564552

565-
pset.execute(kernel[method], runtime=np.timedelta64(4, "D"), dt=np.timedelta64(6, "h"))
553+
pset.execute(kernel, runtime=np.timedelta64(4, "D"), dt=np.timedelta64(6, "h"))
566554

567-
if method == "RK4":
555+
if kernel == AdvectionRK4:
568556
np.testing.assert_equal(round_and_hash_float_array([p.lon for p in pset], decimals=5), 29977383852960156017546)
569-
elif method == "RK4_3D":
557+
elif kernel == AdvectionRK4_3D:
570558
# TODO check why decimals needs to be so low in RK4_3D (compare to v3)
571559
np.testing.assert_equal(round_and_hash_float_array([p.z for p in pset], decimals=1), 29747210774230389239432)

0 commit comments

Comments
 (0)