@@ -218,12 +218,9 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):
218218 else :
219219 _ei = particles .ei [:, self .igrid ]
220220
221- tau , ti = _search_time_index (self , time )
222- position = self .grid .search (z , y , x , ei = _ei )
223- _update_particles_ei (particles , position , self )
224- _update_particle_states_position (particles , position )
221+ particle_positions , grid_positions = _get_positions (self , time , z , y , x , particles , _ei )
225222
226- value = self ._interp_method (self , ti , position , tau , time , z , y , x )
223+ value = self ._interp_method (particle_positions , grid_positions , self )
227224
228225 _update_particle_states_interp_value (particles , value )
229226
@@ -304,20 +301,17 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):
304301 else :
305302 _ei = particles .ei [:, self .igrid ]
306303
307- tau , ti = _search_time_index (self .U , time )
308- position = self .grid .search (z , y , x , ei = _ei )
309- _update_particles_ei (particles , position , self )
310- _update_particle_states_position (particles , position )
304+ particle_positions , grid_positions = _get_positions (self .U , time , z , y , x , particles , _ei )
311305
312306 if self ._vector_interp_method is None :
313- u = self .U ._interp_method (self . U , ti , position , tau , time , z , y , x )
314- v = self .V ._interp_method (self . V , ti , position , tau , time , z , y , x )
307+ u = self .U ._interp_method (particle_positions , grid_positions , self . U )
308+ v = self .V ._interp_method (particle_positions , grid_positions , self . V )
315309 if "3D" in self .vector_type :
316- w = self .W ._interp_method (self . W , ti , position , tau , time , z , y , x )
310+ w = self .W ._interp_method (particle_positions , grid_positions , self . W )
317311 else :
318312 w = 0.0
319313 else :
320- (u , v , w ) = self ._vector_interp_method (self , ti , position , tau , time , z , y , x )
314+ (u , v , w ) = self ._vector_interp_method (particle_positions , grid_positions , self )
321315
322316 if applyConversion :
323317 u = self .U .units .to_target (u , z , y , x )
@@ -343,45 +337,54 @@ def __getitem__(self, key):
343337 return _deal_with_errors (error , key , vector_type = self .vector_type )
344338
345339
346- def _update_particles_ei (particles , position , field ):
340+ def _update_particles_ei (particles , grid_positions : dict , field : Field ):
347341 """Update the element index (ei) of the particles"""
348342 if particles is not None :
349343 if isinstance (field .grid , XGrid ):
350344 particles .ei [:, field .igrid ] = field .grid .ravel_index (
351345 {
352- "X" : position ["X" ][0 ],
353- "Y" : position ["Y" ][0 ],
354- "Z" : position ["Z" ][0 ],
346+ "X" : grid_positions ["X" ]["index" ],
347+ "Y" : grid_positions ["Y" ]["index" ],
348+ "Z" : grid_positions ["Z" ]["index" ],
355349 }
356350 )
357351 elif isinstance (field .grid , UxGrid ):
358352 particles .ei [:, field .igrid ] = field .grid .ravel_index (
359353 {
360- "Z" : position ["Z" ][0 ],
361- "FACE" : position ["FACE" ][0 ],
354+ "Z" : grid_positions ["Z" ]["index" ],
355+ "FACE" : grid_positions ["FACE" ]["index" ],
362356 }
363357 )
364358
365359
366- def _update_particle_states_position (particles , position ):
360+ def _update_particle_states_position (particles , grid_positions : dict ):
367361 """Update the particle states based on the position dictionary."""
368362 if particles : # TODO also support uxgrid search
369363 for dim in ["X" , "Y" ]:
370- if dim in position :
364+ if dim in grid_positions :
371365 particles .state = np .maximum (
372- np .where (position [dim ][0 ] == - 1 , StatusCode .ErrorOutOfBounds , particles .state ), particles .state
366+ np .where (grid_positions [dim ]["index" ] == - 1 , StatusCode .ErrorOutOfBounds , particles .state ),
367+ particles .state ,
373368 )
374369 particles .state = np .maximum (
375- np .where (position [dim ][0 ] == GRID_SEARCH_ERROR , StatusCode .ErrorGridSearching , particles .state ),
370+ np .where (
371+ grid_positions [dim ]["index" ] == GRID_SEARCH_ERROR ,
372+ StatusCode .ErrorGridSearching ,
373+ particles .state ,
374+ ),
376375 particles .state ,
377376 )
378- if "Z" in position :
377+ if "Z" in grid_positions :
379378 particles .state = np .maximum (
380- np .where (position ["Z" ][0 ] == RIGHT_OUT_OF_BOUNDS , StatusCode .ErrorOutOfBounds , particles .state ),
379+ np .where (
380+ grid_positions ["Z" ]["index" ] == RIGHT_OUT_OF_BOUNDS , StatusCode .ErrorOutOfBounds , particles .state
381+ ),
381382 particles .state ,
382383 )
383384 particles .state = np .maximum (
384- np .where (position ["Z" ][0 ] == LEFT_OUT_OF_BOUNDS , StatusCode .ErrorThroughSurface , particles .state ),
385+ np .where (
386+ grid_positions ["Z" ]["index" ] == LEFT_OUT_OF_BOUNDS , StatusCode .ErrorThroughSurface , particles .state
387+ ),
385388 particles .state ,
386389 )
387390
@@ -469,3 +472,14 @@ def _assert_same_time_interval(fields: list[Field]) -> None:
469472 raise ValueError (
470473 f"Fields must have the same time domain. { fields [0 ].name } : { reference_time_interval } , { field .name } : { field .time_interval } "
471474 )
475+
476+
477+ def _get_positions (field : Field , time , z , y , x , particles , _ei ) -> tuple [dict , dict ]:
478+ """Initialize and populate particle_positions and grid_positions dictionaries"""
479+ particle_positions = {"time" : time , "z" : z , "lat" : y , "lon" : x }
480+ grid_positions = {}
481+ grid_positions .update (_search_time_index (field , time ))
482+ grid_positions .update (field .grid .search (z , y , x , ei = _ei ))
483+ _update_particles_ei (particles , grid_positions , field )
484+ _update_particle_states_position (particles , grid_positions )
485+ return particle_positions , grid_positions
0 commit comments