@@ -579,8 +579,6 @@ def XLinearInvdistLandTracer(
579579 """Linear spatial interpolation on a regular grid, where points on land are not used."""
580580 values = XLinear (particle_positions , grid_positions , field )
581581
582- on_land = np .argwhere (np .isnan (values ))
583-
584582 xi , xsi = grid_positions ["X" ]["index" ], grid_positions ["X" ]["bcoord" ]
585583 yi , eta = grid_positions ["Y" ]["index" ], grid_positions ["Y" ]["bcoord" ]
586584 zi , zeta = grid_positions ["Z" ]["index" ], grid_positions ["Z" ]["bcoord" ]
@@ -592,27 +590,29 @@ def XLinearInvdistLandTracer(
592590
593591 corner_data = _get_corner_data_Agrid (field .data , ti , zi , yi , xi , lenT , lenZ , len (xsi ), axis_dim )
594592
595- def is_land (p : int ):
596- value = corner_data [:, :, :, :, p ]
597- return np .where (np .isnan (value ), True , False )
593+ land_mask = np .isnan (corner_data )
594+ nb_land = np .sum (land_mask , axis = (0 , 1 , 2 , 3 ))
598595
599- for p in on_land :
600- land = is_land (p )
601- nb_land = np .sum (land )
602- if nb_land == 4 * lenZ * lenT :
603- values [p ] = 0.0
604- else :
605- val = 0
606- w_sum = 0
607- for t in range (lenT ):
608- for k in range (lenZ ):
609- for j in range (2 ):
610- for i in range (2 ):
611- if land [t ][k ][j ][i ] == 0 :
612- distance = pow ((eta [p ] - j ), 2 ) + pow ((xsi [p ] - i ), 2 )
613- val += corner_data [t , k , j , i , p ] / distance
614- w_sum += 1 / distance
615- values [p ] = val / w_sum
596+ if np .any (nb_land ):
597+ all_land_mask = nb_land == 4 * lenZ * lenT
598+ values [all_land_mask ] = 0.0
599+
600+ not_all_land = ~ all_land_mask
601+ if np .any (not_all_land ):
602+ i_grid = np .arange (2 )[None , None , None , :, None ]
603+ j_grid = np .arange (2 )[None , None , :, None , None ]
604+ eta_b = eta [None , None , None , None , :]
605+ xsi_b = xsi [None , None , None , None , :]
606+
607+ inv_dist = 1.0 / ((eta_b - j_grid ) ** 2 + (xsi_b - i_grid ) ** 2 )
608+
609+ valid_mask = ~ land_mask
610+ weighted = np .where (valid_mask , corner_data * inv_dist , 0.0 )
611+
612+ val = np .sum (weighted , axis = (0 , 1 , 2 , 3 ))
613+ w_sum = np .sum (np .where (valid_mask , inv_dist , 0.0 ), axis = (0 , 1 , 2 , 3 ))
614+
615+ values [not_all_land ] = val [not_all_land ] / w_sum [not_all_land ]
616616
617617 return values .compute () if is_dask_collection (values ) else values
618618
0 commit comments