2222 "UXPiecewiseLinearNode" ,
2323 "XFreeslip" ,
2424 "XLinear" ,
25+ "XLinearInvdistLandTracer" ,
2526 "XNearest" ,
2627 "XPartialslip" ,
2728 "ZeroInterpolator" ,
@@ -75,11 +76,11 @@ def _get_corner_data_Agrid(
7576
7677 # Y coordinates: [yi, yi, yi+1, yi+1] for each spatial point, repeated for time/z
7778 yi_1 = np .clip (yi + 1 , 0 , data .shape [2 ] - 1 )
78- yi = np .tile (np .repeat ( np . column_stack ( [yi , yi_1 ]), 2 ), ( lenT ) * ( lenZ ) )
79+ yi = np .tile (np .array ( [yi , yi , yi_1 , yi_1 ]). flatten (), lenT * lenZ )
7980
8081 # X coordinates: [xi, xi+1, xi, xi+1] for each spatial point, repeated for time/z
8182 xi_1 = np .clip (xi + 1 , 0 , data .shape [3 ] - 1 )
82- xi = np .tile (np .column_stack ([xi , xi_1 , xi , xi_1 ]).flatten (), ( lenT ) * ( lenZ ) )
83+ xi = np .tile (np .array ([xi , xi_1 ]).flatten (), lenT * lenZ * 2 )
8384
8485 # Create DataArrays for indexing
8586 selection_dict = {
@@ -91,7 +92,7 @@ def _get_corner_data_Agrid(
9192 if "time" in data .dims :
9293 selection_dict ["time" ] = xr .DataArray (ti , dims = ("points" ))
9394
94- return data .isel (selection_dict ).data .reshape (lenT , lenZ , npart , 4 )
95+ return data .isel (selection_dict ).data .reshape (lenT , lenZ , 2 , 2 , npart )
9596
9697
9798def XLinear (
@@ -114,22 +115,22 @@ def XLinear(
114115 corner_data = _get_corner_data_Agrid (data , ti , zi , yi , xi , lenT , lenZ , len (xsi ), axis_dim )
115116
116117 if lenT == 2 :
117- tau = tau [np .newaxis , :, np . newaxis ]
118- corner_data = corner_data [0 , :, :, : ] * (1 - tau ) + corner_data [1 , :, : , :] * tau
118+ tau = tau [np .newaxis , :]
119+ corner_data = corner_data [0 , :] * (1 - tau ) + corner_data [1 , :] * tau
119120 else :
120- corner_data = corner_data [0 , :, :, : ]
121+ corner_data = corner_data [0 , :]
121122
122123 if lenZ == 2 :
123- zeta = zeta [:, np .newaxis ]
124- corner_data = corner_data [0 , :, : ] * (1 - zeta ) + corner_data [1 , : , :] * zeta
124+ zeta = zeta [np .newaxis , : ]
125+ corner_data = corner_data [0 , :] * (1 - zeta ) + corner_data [1 , :] * zeta
125126 else :
126- corner_data = corner_data [0 , :, : ]
127+ corner_data = corner_data [0 , :]
127128
128129 value = (
129- (1 - xsi ) * (1 - eta ) * corner_data [: , 0 ]
130- + xsi * (1 - eta ) * corner_data [: , 1 ]
131- + (1 - xsi ) * eta * corner_data [:, 2 ]
132- + xsi * eta * corner_data [:, 3 ]
130+ (1 - xsi ) * (1 - eta ) * corner_data [0 , 0 , : ]
131+ + xsi * (1 - eta ) * corner_data [0 , 1 , : ]
132+ + (1 - xsi ) * eta * corner_data [1 , 0 , : ]
133+ + xsi * eta * corner_data [1 , 1 , : ]
133134 )
134135 return value .compute () if is_dask_collection (value ) else value
135136
@@ -409,8 +410,8 @@ def _Spatialslip(
409410 corner_dataV = _get_corner_data_Agrid (vectorfield .V .data , ti , zi , yi , xi , lenT , lenZ , npart , axis_dim )
410411
411412 def is_land (ti : int , zi : int , yi : int , xi : int ):
412- uval = corner_dataU [ti , zi , : , xi + 2 * yi ]
413- vval = corner_dataV [ti , zi , : , xi + 2 * yi ]
413+ uval = corner_dataU [ti , zi , yi , xi , : ]
414+ vval = corner_dataV [ti , zi , yi , xi , : ]
414415 return np .where (np .isclose (uval , 0.0 ) & np .isclose (vval , 0.0 ), True , False )
415416
416417 f_u = np .ones_like (xsi )
@@ -571,6 +572,52 @@ def XNearest(
571572 return value .compute () if is_dask_collection (value ) else value
572573
573574
575+ def XLinearInvdistLandTracer (
576+ particle_positions : dict [str , float | np .ndarray ],
577+ grid_positions : dict [_XGRID_AXES , dict [str , int | float | np .ndarray ]],
578+ field : Field ,
579+ ):
580+ """Linear spatial interpolation on a regular grid, where points on land are not used."""
581+ values = XLinear (particle_positions , grid_positions , field )
582+
583+ xi , xsi = grid_positions ["X" ]["index" ], grid_positions ["X" ]["bcoord" ]
584+ yi , eta = grid_positions ["Y" ]["index" ], grid_positions ["Y" ]["bcoord" ]
585+ zi , zeta = grid_positions ["Z" ]["index" ], grid_positions ["Z" ]["bcoord" ]
586+ ti , tau = grid_positions ["T" ]["index" ], grid_positions ["T" ]["bcoord" ]
587+
588+ axis_dim = field .grid .get_axis_dim_mapping (field .data .dims )
589+ lenT = 2 if np .any (tau > 0 ) else 1
590+ lenZ = 2 if np .any (zeta > 0 ) else 1
591+
592+ corner_data = _get_corner_data_Agrid (field .data , ti , zi , yi , xi , lenT , lenZ , len (xsi ), axis_dim )
593+
594+ land_mask = np .isnan (corner_data )
595+ nb_land = np .sum (land_mask , axis = (0 , 1 , 2 , 3 ))
596+
597+ if np .any (nb_land ):
598+ all_land_mask = nb_land == 4 * lenZ * lenT
599+ values [all_land_mask ] = 0.0
600+
601+ not_all_land = ~ all_land_mask
602+ if np .any (not_all_land ):
603+ i_grid = np .arange (2 )[None , None , None , :, None ]
604+ j_grid = np .arange (2 )[None , None , :, None , None ]
605+ eta_b = eta [None , None , None , None , :]
606+ xsi_b = xsi [None , None , None , None , :]
607+
608+ inv_dist = 1.0 / ((eta_b - j_grid ) ** 2 + (xsi_b - i_grid ) ** 2 )
609+
610+ valid_mask = ~ land_mask
611+ weighted = np .where (valid_mask , corner_data * inv_dist , 0.0 )
612+
613+ val = np .sum (weighted , axis = (0 , 1 , 2 , 3 ))
614+ w_sum = np .sum (np .where (valid_mask , inv_dist , 0.0 ), axis = (0 , 1 , 2 , 3 ))
615+
616+ values [not_all_land ] = val [not_all_land ] / w_sum [not_all_land ]
617+
618+ return values .compute () if is_dask_collection (values ) else values
619+
620+
574621def UXPiecewiseConstantFace (
575622 particle_positions : dict [str , float | np .ndarray ],
576623 grid_positions : dict [_UXGRID_AXES , dict [str , int | float | np .ndarray ]],
0 commit comments