Skip to content

Commit 0ce0650

Browse files
Merge pull request #1641 from OceanParcels/croco_3D_velocities
Support for CROCO 3D velocities
2 parents 6e72612 + e73b275 commit 0ce0650

File tree

16 files changed

+666
-45
lines changed

16 files changed

+666
-45
lines changed

docs/documentation/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Parcels has several documentation and tutorial Jupyter notebooks and scripts whi
2222
../examples/documentation_indexing.ipynb
2323
../examples/tutorial_nemo_curvilinear.ipynb
2424
../examples/tutorial_nemo_3D.ipynb
25+
../examples/tutorial_croco_3D.ipynb
2526
../examples/tutorial_NestedFields.ipynb
2627
../examples/tutorial_timevaryingdepthdimensions.ipynb
2728
../examples/tutorial_periodic_boundaries.ipynb

docs/examples/tutorial_croco_3D.ipynb

Lines changed: 332 additions & 0 deletions
Large diffs are not rendered by default.

parcels/_typing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ class ParcelsAST(ast.AST):
3535
) # corresponds with `interp_method` (which can also be dict mapping field names to method)
3636
PathLike = str | os.PathLike
3737
Mesh = Literal["spherical", "flat"] # corresponds with `mesh`
38-
VectorType = Literal["3D", "2D"] | None # corresponds with `vector_type`
38+
VectorType = Literal["3D", "3DSigma", "2D"] | None # corresponds with `vector_type`
3939
ChunkMode = Literal["auto", "specific", "failsafe"] # corresponds with `chunk_mode`
40-
GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo"] # corresponds with `gridindexingtype`
40+
GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo", "croco"] # corresponds with `gridindexingtype`
4141
UpdateStatus = Literal["not_updated", "first_updated", "updated"] # corresponds with `_update_status`
4242
TimePeriodic = float | datetime.timedelta | Literal[False] # corresponds with `time_periodic`
4343
NetcdfEngine = Literal["netcdf4", "xarray", "scipy"]

parcels/application_kernels/advection.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@
44

55
from parcels.tools.statuscodes import StatusCode
66

7-
__all__ = ["AdvectionRK4", "AdvectionEE", "AdvectionRK45", "AdvectionRK4_3D", "AdvectionAnalytical"]
7+
__all__ = [
8+
"AdvectionRK4",
9+
"AdvectionEE",
10+
"AdvectionRK45",
11+
"AdvectionRK4_3D",
12+
"AdvectionAnalytical",
13+
"AdvectionRK4_3D_CROCO",
14+
]
815

916

1017
def AdvectionRK4(particle, fieldset, time):
@@ -40,6 +47,51 @@ def AdvectionRK4_3D(particle, fieldset, time):
4047
particle_ddepth += (w1 + 2 * w2 + 2 * w3 + w4) / 6 * particle.dt # noqa
4148

4249

50+
def AdvectionRK4_3D_CROCO(particle, fieldset, time):
51+
"""Advection of particles using fourth-order Runge-Kutta integration including vertical velocity.
52+
This kernel assumes the vertical velocity is the 'w' field from CROCO output and works on sigma-layers.
53+
"""
54+
sig_dep = particle.depth / fieldset.H[time, 0, particle.lat, particle.lon]
55+
56+
(u1, v1, w1) = fieldset.UVW[time, particle.depth, particle.lat, particle.lon, particle]
57+
w1 *= sig_dep / fieldset.H[time, 0, particle.lat, particle.lon]
58+
lon1 = particle.lon + u1 * 0.5 * particle.dt
59+
lat1 = particle.lat + v1 * 0.5 * particle.dt
60+
sig_dep1 = sig_dep + w1 * 0.5 * particle.dt
61+
dep1 = sig_dep1 * fieldset.H[time, 0, lat1, lon1]
62+
63+
(u2, v2, w2) = fieldset.UVW[time + 0.5 * particle.dt, dep1, lat1, lon1, particle]
64+
w2 *= sig_dep1 / fieldset.H[time, 0, lat1, lon1]
65+
lon2 = particle.lon + u2 * 0.5 * particle.dt
66+
lat2 = particle.lat + v2 * 0.5 * particle.dt
67+
sig_dep2 = sig_dep + w2 * 0.5 * particle.dt
68+
dep2 = sig_dep2 * fieldset.H[time, 0, lat2, lon2]
69+
70+
(u3, v3, w3) = fieldset.UVW[time + 0.5 * particle.dt, dep2, lat2, lon2, particle]
71+
w3 *= sig_dep2 / fieldset.H[time, 0, lat2, lon2]
72+
lon3 = particle.lon + u3 * particle.dt
73+
lat3 = particle.lat + v3 * particle.dt
74+
sig_dep3 = sig_dep + w3 * particle.dt
75+
dep3 = sig_dep3 * fieldset.H[time, 0, lat3, lon3]
76+
77+
(u4, v4, w4) = fieldset.UVW[time + particle.dt, dep3, lat3, lon3, particle]
78+
w4 *= sig_dep3 / fieldset.H[time, 0, lat3, lon3]
79+
lon4 = particle.lon + u4 * particle.dt
80+
lat4 = particle.lat + v4 * particle.dt
81+
sig_dep4 = sig_dep + w4 * particle.dt
82+
dep4 = sig_dep4 * fieldset.H[time, 0, lat4, lon4]
83+
84+
particle_dlon += (u1 + 2 * u2 + 2 * u3 + u4) / 6 * particle.dt # noqa
85+
particle_dlat += (v1 + 2 * v2 + 2 * v3 + v4) / 6 * particle.dt # noqa
86+
particle_ddepth += ( # noqa
87+
(dep1 - particle.depth) * 2
88+
+ 2 * (dep2 - particle.depth) * 2
89+
+ 2 * (dep3 - particle.depth)
90+
+ dep4
91+
- particle.depth
92+
) / 6
93+
94+
4395
def AdvectionEE(particle, fieldset, time):
4496
"""Advection of particles using Explicit Euler (aka Euler Forward) integration."""
4597
(u1, v1) = fieldset.UV[particle]

parcels/compilation/codegenerator.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,13 @@ def __init__(self, field):
8080

8181

8282
class VectorFieldEvalNode(IntrinsicNode):
83-
def __init__(self, field, args, var, var2, var3, convert=True):
83+
def __init__(self, field, args, var, var2, var3, var4, convert=True):
8484
self.field = field
8585
self.args = args
8686
self.var = var # the variable in which the interpolated field is written
8787
self.var2 = var2 # second variable for UV interpolation
8888
self.var3 = var3 # third variable for UVW interpolation
89+
self.var4 = var4 # extra variable for sigma-scaling for croco
8990
self.convert = convert # whether to convert the result (like field.applyConversion)
9091

9192

@@ -107,12 +108,13 @@ def __getitem__(self, attr):
107108

108109

109110
class NestedVectorFieldEvalNode(IntrinsicNode):
110-
def __init__(self, fields, args, var, var2, var3):
111+
def __init__(self, fields, args, var, var2, var3, var4):
111112
self.fields = fields
112113
self.args = args
113114
self.var = var # the variable in which the interpolated field is written
114115
self.var2 = var2 # second variable for UV interpolation
115116
self.var3 = var3 # third variable for UVW interpolation
117+
self.var4 = var4 # extra variable for sigma-scaling for croco
116118

117119

118120
class GridNode(IntrinsicNode):
@@ -285,9 +287,10 @@ def visit_Subscript(self, node):
285287
elif isinstance(node.value, VectorFieldNode):
286288
tmp = self.get_tmp()
287289
tmp2 = self.get_tmp()
288-
tmp3 = self.get_tmp() if node.value.obj.vector_type == "3D" else None
290+
tmp3 = self.get_tmp() if "3D" in node.value.obj.vector_type else None
291+
tmp4 = self.get_tmp() if "3DSigma" in node.value.obj.vector_type else None
289292
# Insert placeholder node for field eval ...
290-
self.stmt_stack += [VectorFieldEvalNode(node.value, node.slice, tmp, tmp2, tmp3)]
293+
self.stmt_stack += [VectorFieldEvalNode(node.value, node.slice, tmp, tmp2, tmp3, tmp4)]
291294
# .. and return the name of the temporary that will be populated
292295
if tmp3:
293296
return ast.Tuple([ast.Name(id=tmp), ast.Name(id=tmp2), ast.Name(id=tmp3)], ast.Load())
@@ -300,8 +303,9 @@ def visit_Subscript(self, node):
300303
elif isinstance(node.value, NestedVectorFieldNode):
301304
tmp = self.get_tmp()
302305
tmp2 = self.get_tmp()
303-
tmp3 = self.get_tmp() if list.__getitem__(node.value.obj, 0).vector_type == "3D" else None
304-
self.stmt_stack += [NestedVectorFieldEvalNode(node.value, node.slice, tmp, tmp2, tmp3)]
306+
tmp3 = self.get_tmp() if "3D" in list.__getitem__(node.value.obj, 0).vector_type else None
307+
tmp4 = self.get_tmp() if "3DSigma" in list.__getitem__(node.value.obj, 0).vector_type else None
308+
self.stmt_stack += [NestedVectorFieldEvalNode(node.value, node.slice, tmp, tmp2, tmp3, tmp4)]
305309
if tmp3:
306310
return ast.Tuple([ast.Name(id=tmp), ast.Name(id=tmp2), ast.Name(id=tmp3)], ast.Load())
307311
else:
@@ -371,7 +375,8 @@ def visit_Call(self, node):
371375
# get a temporary value to assign result to
372376
tmp1 = self.get_tmp()
373377
tmp2 = self.get_tmp()
374-
tmp3 = self.get_tmp() if node.func.field.obj.vector_type == "3D" else None
378+
tmp3 = self.get_tmp() if "3D" in node.func.field.obj.vector_type else None
379+
tmp4 = self.get_tmp() if "3DSigma" in node.func.field.obj.vector_type else None
375380
# whether to convert
376381
convert = True
377382
if "applyConversion" in node.keywords:
@@ -382,7 +387,7 @@ def visit_Call(self, node):
382387
# convert args to Index(Tuple(*args))
383388
args = ast.Index(value=ast.Tuple(node.args, ast.Load()))
384389

385-
self.stmt_stack += [VectorFieldEvalNode(node.func.field, args, tmp1, tmp2, tmp3, convert)]
390+
self.stmt_stack += [VectorFieldEvalNode(node.func.field, args, tmp1, tmp2, tmp3, tmp4, convert)]
386391
if tmp3:
387392
return ast.Tuple([ast.Name(id=tmp1), ast.Name(id=tmp2), ast.Name(id=tmp3)], ast.Load())
388393
else:
@@ -421,6 +426,8 @@ def __init__(self, fieldset=None, ptype=JITParticle):
421426
self.fieldset = fieldset
422427
self.ptype = ptype
423428
self.field_args = collections.OrderedDict()
429+
if isinstance(fieldset.U, Field) and fieldset.U.gridindexingtype == "croco" and hasattr(fieldset, "H"):
430+
self.field_args["H"] = fieldset.H # CROCO requires H field
424431
self.vector_field_args = collections.OrderedDict()
425432
self.const_args = collections.OrderedDict()
426433

@@ -456,7 +463,7 @@ def generate(self, py_ast, funcvars: list[str]):
456463
for kvar in self.kernel_vars + self.array_vars:
457464
if kvar in funcvars:
458465
funcvars.remove(kvar)
459-
self.ccode.body.insert(0, c.Value("int", "parcels_interp_state"))
466+
self.ccode.body.insert(0, c.Statement("int parcels_interp_state = 0"))
460467
if len(funcvars) > 0:
461468
for f in funcvars:
462469
self.ccode.body.insert(0, c.Statement(f"type_coord {f} = 0"))
@@ -819,6 +826,16 @@ def visit_FieldEvalNode(self, node):
819826
self.visit(node.field)
820827
self.visit(node.args)
821828
args = self._check_FieldSamplingArguments(node.args.ccode)
829+
statements_croco = []
830+
if "croco" in node.field.obj.gridindexingtype and node.field.obj.name != "H":
831+
statements_croco.append(
832+
c.Assign(
833+
"parcels_interp_state",
834+
f"temporal_interpolation({args[3]}, {args[2]}, 0, time, H, &particles->xi[pnum*ngrid], &particles->yi[pnum*ngrid], &particles->zi[pnum*ngrid], &particles->ti[pnum*ngrid], &{node.var}, LINEAR, {node.field.obj.gridindexingtype.upper()})",
835+
)
836+
)
837+
statements_croco.append(c.Statement(f"{node.var} = {args[1]}/{node.var}"))
838+
args = (args[0], node.var, args[2], args[3])
822839
ccode_eval = node.field.obj._ccode_eval(node.var, *args)
823840
stmts = [
824841
c.Assign("parcels_interp_state", ccode_eval),
@@ -830,12 +847,22 @@ def visit_FieldEvalNode(self, node):
830847
conv_stat = c.Statement(f"{node.var} *= {ccode_conv}")
831848
stmts += [conv_stat]
832849

833-
node.ccode = c.Block(stmts + [c.Statement("CHECKSTATUS_KERNELLOOP(parcels_interp_state)")])
850+
node.ccode = c.Block(statements_croco + stmts + [c.Statement("CHECKSTATUS_KERNELLOOP(parcels_interp_state)")])
834851

835852
def visit_VectorFieldEvalNode(self, node):
836853
self.visit(node.field)
837854
self.visit(node.args)
838855
args = self._check_FieldSamplingArguments(node.args.ccode)
856+
statements_croco = []
857+
if "3DSigma" in node.field.obj.vector_type:
858+
statements_croco.append(
859+
c.Assign(
860+
"parcels_interp_state",
861+
f"temporal_interpolation({args[3]}, {args[2]}, 0, time, H, &particles->xi[pnum*ngrid], &particles->yi[pnum*ngrid], &particles->zi[pnum*ngrid], &particles->ti[pnum*ngrid], &{node.var}, LINEAR, {node.field.obj.U.gridindexingtype.upper()})",
862+
)
863+
)
864+
statements_croco.append(c.Statement(f"{node.var4} = {args[1]}/{node.var}"))
865+
args = (args[0], node.var4, args[2], args[3])
839866
ccode_eval = node.field.obj._ccode_eval(
840867
node.var, node.var2, node.var3, node.field.obj.U, node.field.obj.V, node.field.obj.W, *args
841868
)
@@ -845,12 +872,13 @@ def visit_VectorFieldEvalNode(self, node):
845872
statements = [c.Statement(f"{node.var} *= {ccode_conv1}"), c.Statement(f"{node.var2} *= {ccode_conv2}")]
846873
else:
847874
statements = []
848-
if node.convert and node.field.obj.vector_type == "3D":
875+
if node.convert and "3D" in node.field.obj.vector_type:
849876
ccode_conv3 = node.field.obj.W._ccode_convert(*args)
850877
statements.append(c.Statement(f"{node.var3} *= {ccode_conv3}"))
851878
conv_stat = c.Block(statements)
852879
node.ccode = c.Block(
853880
[
881+
c.Block(statements_croco),
854882
c.Assign("parcels_interp_state", ccode_eval),
855883
c.Assign("particles->state[pnum]", "max(particles->state[pnum], parcels_interp_state)"),
856884
conv_stat,
@@ -891,7 +919,7 @@ def visit_NestedVectorFieldEvalNode(self, node):
891919
statements = [c.Statement(f"{node.var} *= {ccode_conv1}"), c.Statement(f"{node.var2} *= {ccode_conv2}")]
892920
else:
893921
statements = []
894-
if fld.vector_type == "3D":
922+
if "3D" in fld.vector_type:
895923
ccode_conv3 = fld.W._ccode_convert(*args)
896924
statements.append(c.Statement(f"{node.var3} *= {ccode_conv3}"))
897925
cstat += [

0 commit comments

Comments
 (0)