@@ -80,12 +80,13 @@ def __init__(self, field):
8080
8181
8282class VectorFieldEvalNode (IntrinsicNode ):
83- def __init__ (self , field , args , var , var2 , var3 , convert = True ):
83+ def __init__ (self , field , args , var , var2 , var3 , var4 , convert = True ):
8484 self .field = field
8585 self .args = args
8686 self .var = var # the variable in which the interpolated field is written
8787 self .var2 = var2 # second variable for UV interpolation
8888 self .var3 = var3 # third variable for UVW interpolation
89+ self .var4 = var4 # extra variable for sigma-scaling for croco
8990 self .convert = convert # whether to convert the result (like field.applyConversion)
9091
9192
@@ -107,12 +108,13 @@ def __getitem__(self, attr):
107108
108109
109110class NestedVectorFieldEvalNode (IntrinsicNode ):
110- def __init__ (self , fields , args , var , var2 , var3 ):
111+ def __init__ (self , fields , args , var , var2 , var3 , var4 ):
111112 self .fields = fields
112113 self .args = args
113114 self .var = var # the variable in which the interpolated field is written
114115 self .var2 = var2 # second variable for UV interpolation
115116 self .var3 = var3 # third variable for UVW interpolation
117+ self .var4 = var4 # extra variable for sigma-scaling for croco
116118
117119
118120class GridNode (IntrinsicNode ):
@@ -285,9 +287,10 @@ def visit_Subscript(self, node):
285287 elif isinstance (node .value , VectorFieldNode ):
286288 tmp = self .get_tmp ()
287289 tmp2 = self .get_tmp ()
288- tmp3 = self .get_tmp () if node .value .obj .vector_type == "3D" else None
290+ tmp3 = self .get_tmp () if "3D" in node .value .obj .vector_type else None
291+ tmp4 = self .get_tmp () if "3DSigma" in node .value .obj .vector_type else None
289292 # Insert placeholder node for field eval ...
290- self .stmt_stack += [VectorFieldEvalNode (node .value , node .slice , tmp , tmp2 , tmp3 )]
293+ self .stmt_stack += [VectorFieldEvalNode (node .value , node .slice , tmp , tmp2 , tmp3 , tmp4 )]
291294 # .. and return the name of the temporary that will be populated
292295 if tmp3 :
293296 return ast .Tuple ([ast .Name (id = tmp ), ast .Name (id = tmp2 ), ast .Name (id = tmp3 )], ast .Load ())
@@ -300,8 +303,9 @@ def visit_Subscript(self, node):
300303 elif isinstance (node .value , NestedVectorFieldNode ):
301304 tmp = self .get_tmp ()
302305 tmp2 = self .get_tmp ()
303- tmp3 = self .get_tmp () if list .__getitem__ (node .value .obj , 0 ).vector_type == "3D" else None
304- self .stmt_stack += [NestedVectorFieldEvalNode (node .value , node .slice , tmp , tmp2 , tmp3 )]
306+ tmp3 = self .get_tmp () if "3D" in list .__getitem__ (node .value .obj , 0 ).vector_type else None
307+ tmp4 = self .get_tmp () if "3DSigma" in list .__getitem__ (node .value .obj , 0 ).vector_type else None
308+ self .stmt_stack += [NestedVectorFieldEvalNode (node .value , node .slice , tmp , tmp2 , tmp3 , tmp4 )]
305309 if tmp3 :
306310 return ast .Tuple ([ast .Name (id = tmp ), ast .Name (id = tmp2 ), ast .Name (id = tmp3 )], ast .Load ())
307311 else :
@@ -371,7 +375,8 @@ def visit_Call(self, node):
371375 # get a temporary value to assign result to
372376 tmp1 = self .get_tmp ()
373377 tmp2 = self .get_tmp ()
374- tmp3 = self .get_tmp () if node .func .field .obj .vector_type == "3D" else None
378+ tmp3 = self .get_tmp () if "3D" in node .func .field .obj .vector_type else None
379+ tmp4 = self .get_tmp () if "3DSigma" in node .func .field .obj .vector_type else None
375380 # whether to convert
376381 convert = True
377382 if "applyConversion" in node .keywords :
@@ -382,7 +387,7 @@ def visit_Call(self, node):
382387 # convert args to Index(Tuple(*args))
383388 args = ast .Index (value = ast .Tuple (node .args , ast .Load ()))
384389
385- self .stmt_stack += [VectorFieldEvalNode (node .func .field , args , tmp1 , tmp2 , tmp3 , convert )]
390+ self .stmt_stack += [VectorFieldEvalNode (node .func .field , args , tmp1 , tmp2 , tmp3 , tmp4 , convert )]
386391 if tmp3 :
387392 return ast .Tuple ([ast .Name (id = tmp1 ), ast .Name (id = tmp2 ), ast .Name (id = tmp3 )], ast .Load ())
388393 else :
@@ -421,6 +426,8 @@ def __init__(self, fieldset=None, ptype=JITParticle):
421426 self .fieldset = fieldset
422427 self .ptype = ptype
423428 self .field_args = collections .OrderedDict ()
429+ if isinstance (fieldset .U , Field ) and fieldset .U .gridindexingtype == "croco" and hasattr (fieldset , "H" ):
430+ self .field_args ["H" ] = fieldset .H # CROCO requires H field
424431 self .vector_field_args = collections .OrderedDict ()
425432 self .const_args = collections .OrderedDict ()
426433
@@ -456,7 +463,7 @@ def generate(self, py_ast, funcvars: list[str]):
456463 for kvar in self .kernel_vars + self .array_vars :
457464 if kvar in funcvars :
458465 funcvars .remove (kvar )
459- self .ccode .body .insert (0 , c .Value ("int" , " parcels_interp_state" ))
466+ self .ccode .body .insert (0 , c .Statement ("int parcels_interp_state = 0 " ))
460467 if len (funcvars ) > 0 :
461468 for f in funcvars :
462469 self .ccode .body .insert (0 , c .Statement (f"type_coord { f } = 0" ))
@@ -819,6 +826,16 @@ def visit_FieldEvalNode(self, node):
819826 self .visit (node .field )
820827 self .visit (node .args )
821828 args = self ._check_FieldSamplingArguments (node .args .ccode )
829+ statements_croco = []
830+ if "croco" in node .field .obj .gridindexingtype and node .field .obj .name != "H" :
831+ statements_croco .append (
832+ c .Assign (
833+ "parcels_interp_state" ,
834+ f"temporal_interpolation({ args [3 ]} , { args [2 ]} , 0, time, H, &particles->xi[pnum*ngrid], &particles->yi[pnum*ngrid], &particles->zi[pnum*ngrid], &particles->ti[pnum*ngrid], &{ node .var } , LINEAR, { node .field .obj .gridindexingtype .upper ()} )" ,
835+ )
836+ )
837+ statements_croco .append (c .Statement (f"{ node .var } = { args [1 ]} /{ node .var } " ))
838+ args = (args [0 ], node .var , args [2 ], args [3 ])
822839 ccode_eval = node .field .obj ._ccode_eval (node .var , * args )
823840 stmts = [
824841 c .Assign ("parcels_interp_state" , ccode_eval ),
@@ -830,12 +847,22 @@ def visit_FieldEvalNode(self, node):
830847 conv_stat = c .Statement (f"{ node .var } *= { ccode_conv } " )
831848 stmts += [conv_stat ]
832849
833- node .ccode = c .Block (stmts + [c .Statement ("CHECKSTATUS_KERNELLOOP(parcels_interp_state)" )])
850+ node .ccode = c .Block (statements_croco + stmts + [c .Statement ("CHECKSTATUS_KERNELLOOP(parcels_interp_state)" )])
834851
835852 def visit_VectorFieldEvalNode (self , node ):
836853 self .visit (node .field )
837854 self .visit (node .args )
838855 args = self ._check_FieldSamplingArguments (node .args .ccode )
856+ statements_croco = []
857+ if "3DSigma" in node .field .obj .vector_type :
858+ statements_croco .append (
859+ c .Assign (
860+ "parcels_interp_state" ,
861+ f"temporal_interpolation({ args [3 ]} , { args [2 ]} , 0, time, H, &particles->xi[pnum*ngrid], &particles->yi[pnum*ngrid], &particles->zi[pnum*ngrid], &particles->ti[pnum*ngrid], &{ node .var } , LINEAR, { node .field .obj .U .gridindexingtype .upper ()} )" ,
862+ )
863+ )
864+ statements_croco .append (c .Statement (f"{ node .var4 } = { args [1 ]} /{ node .var } " ))
865+ args = (args [0 ], node .var4 , args [2 ], args [3 ])
839866 ccode_eval = node .field .obj ._ccode_eval (
840867 node .var , node .var2 , node .var3 , node .field .obj .U , node .field .obj .V , node .field .obj .W , * args
841868 )
@@ -845,12 +872,13 @@ def visit_VectorFieldEvalNode(self, node):
845872 statements = [c .Statement (f"{ node .var } *= { ccode_conv1 } " ), c .Statement (f"{ node .var2 } *= { ccode_conv2 } " )]
846873 else :
847874 statements = []
848- if node .convert and node .field .obj .vector_type == "3D" :
875+ if node .convert and "3D" in node .field .obj .vector_type :
849876 ccode_conv3 = node .field .obj .W ._ccode_convert (* args )
850877 statements .append (c .Statement (f"{ node .var3 } *= { ccode_conv3 } " ))
851878 conv_stat = c .Block (statements )
852879 node .ccode = c .Block (
853880 [
881+ c .Block (statements_croco ),
854882 c .Assign ("parcels_interp_state" , ccode_eval ),
855883 c .Assign ("particles->state[pnum]" , "max(particles->state[pnum], parcels_interp_state)" ),
856884 conv_stat ,
@@ -891,7 +919,7 @@ def visit_NestedVectorFieldEvalNode(self, node):
891919 statements = [c .Statement (f"{ node .var } *= { ccode_conv1 } " ), c .Statement (f"{ node .var2 } *= { ccode_conv2 } " )]
892920 else :
893921 statements = []
894- if fld . vector_type == "3D" :
922+ if "3D" in fld . vector_type :
895923 ccode_conv3 = fld .W ._ccode_convert (* args )
896924 statements .append (c .Statement (f"{ node .var3 } *= { ccode_conv3 } " ))
897925 cstat += [
0 commit comments