Skip to content

Commit

Permalink
First attempt at rewrite of particle loop and Recovery-Kernels in scipy
Browse files Browse the repository at this point in the history
  • Loading branch information
erikvansebille committed Aug 1, 2023
1 parent 72a69e7 commit 3e1e6a3
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 163 deletions.
26 changes: 25 additions & 1 deletion parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
FieldOutOfBoundSurfaceError,
FieldSamplingError,
TimeExtrapolationError,
ErrorCode
)

from .fieldfilebuffer import (
Expand Down Expand Up @@ -1182,6 +1183,19 @@ def __getitem__(self, key):
else:
return self.eval(*key)

# try: # TODO check if this is needed
# if _isParticle(key):
# val = self.eval(key.time, key.depth, key.lat, key.lon, key)
# else:
# val = self.eval(*key)
# except (FieldOutOfBoundError, FieldSamplingError):
# val = np.nan
# if _isParticle(key):
# key.state = ErrorCode.ErrorOutOfBounds
# elif _isParticle(key[-1]):
# key[-1].state = ErrorCode.ErrorOutOfBounds
# return val

def eval(self, time, z, y, x, particle=None, applyConversion=True):
"""Interpolate field values in space and time.
Expand Down Expand Up @@ -2058,7 +2072,17 @@ def __getitem__(self, key):
break
except (FieldOutOfBoundError, FieldSamplingError):
if iField == len(self)-1:
raise
if isinstance(self[iField], VectorField):
if self[iField].vector_type == '3D':
val = (np.nan, np.nan, np.nan)
else:
val = (np.nan, np.nan)
else:
val = np.nan
if _isParticle(key):
key.state = ErrorCode.ErrorOutOfBounds
elif _isParticle(key[-1]):
key[-1].state = ErrorCode.ErrorOutOfBounds
else:
pass
return val
107 changes: 12 additions & 95 deletions parcels/kernel/basekernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,13 @@
# === import just necessary field classes to perform setup checks === #
from parcels.field import (
Field,
FieldOutOfBoundError,
FieldOutOfBoundSurfaceError,
NestedField,
SummedField,
TimeExtrapolationError,
VectorField,
)
from parcels.grid import GridCode
from parcels.tools.global_statics import get_cache_dir
from parcels.tools.statuscodes import ErrorCode, OperationCode, StateCode
from parcels.tools.statuscodes import OperationCode, StateCode

__all__ = ['BaseKernel']

Expand Down Expand Up @@ -409,7 +406,7 @@ def load_fieldset_jit(self, pset):
if not g.lat.flags.c_contiguous:
g.lat = np.array(g.lat, order='C')

def evaluate_particle(self, p, endtime, sign_dt, dt, analytical=False):
def evaluate_particle(self, p, endtime, sign_dt, dt, analytical=False): # TODO check arguments to this function
"""Execute the kernel evaluation of for an individual particle.
Parameters
Expand All @@ -427,101 +424,21 @@ def evaluate_particle(self, p, endtime, sign_dt, dt, analytical=False):
sign_dt :
"""
variables = self._ptype.variables
# back up variables in case of OperationCode.Repeat
p_var_back = {}
pdt_prekernels = .0
# Don't execute particles that aren't started yet
sign_end_part = np.sign(endtime - p.time)
# Compute min/max dt for first timestep. Only use endtime-p.time for one timestep
reset_dt = False
if abs(endtime - p.time) < abs(p.dt):
dt_pos = abs(endtime - p.time)
reset_dt = True
else:
dt_pos = abs(p.dt)
reset_dt = False

# ==== numerically stable; also making sure that continuously-recovered particles do end successfully,
# as they fulfil the condition here on entering at the final calculation here. ==== #
if ((sign_end_part != sign_dt) or np.isclose(dt_pos, 0)) and not np.isclose(dt, 0):
if abs(p.time) >= abs(endtime):
p.set_state(StateCode.Success)
return p

while p.state in [StateCode.Evaluate, OperationCode.Repeat] or np.isclose(dt, 0):
for var in variables:
p_var_back[var.name] = getattr(p, var.name)
try:
pdt_prekernels = sign_dt * dt_pos
p.dt = pdt_prekernels
state_prev = p.state
res = self._pyfunc(p, self._fieldset, p.time)
if res is None:
res = StateCode.Success
while p.state in [StateCode.Evaluate, OperationCode.Repeat]:
pre_dt = p.dt
if abs(endtime - p.time) < abs(p.dt):
p.dt = abs(endtime - p.time) * sign_dt

if res is StateCode.Success and p.state != state_prev:
res = p.state

if not analytical and res == StateCode.Success and not np.isclose(p.dt, pdt_prekernels):
res = OperationCode.Repeat

except FieldOutOfBoundError as fse_xy:
res = ErrorCode.ErrorOutOfBounds
p.exception = fse_xy
except FieldOutOfBoundSurfaceError as fse_z:
res = ErrorCode.ErrorThroughSurface
p.exception = fse_z
except TimeExtrapolationError as fse_t:
res = ErrorCode.ErrorTimeExtrapolation
p.exception = fse_t

except Exception as e:
res = ErrorCode.Error
p.exception = e

# Handle particle time and time loop
if res in [StateCode.Success, OperationCode.Delete]:
# Update time and repeat
p.time += p.dt
if reset_dt and p.dt == pdt_prekernels:
p.dt = dt
p.update_next_dt()
if analytical:
p.dt = np.inf
if abs(endtime - p.time) < abs(p.dt):
dt_pos = abs(endtime - p.time)
reset_dt = True
else:
dt_pos = abs(p.dt)
reset_dt = False
res = self._pyfunc(p, self._fieldset, p.time)

sign_end_part = np.sign(endtime - p.time)
if res != OperationCode.Delete and not np.isclose(dt_pos, 0) and (sign_end_part == sign_dt):
if res is None:
if p.time < endtime:
res = StateCode.Evaluate
if sign_end_part != sign_dt:
dt_pos = 0

p.set_state(res)
if np.isclose(dt, 0):
break
else:
p.set_state(res)
# Try again without time update
for var in variables:
if var.name not in ['dt', 'state']:
setattr(p, var.name, p_var_back[var.name])
if abs(endtime - p.time) < abs(p.dt):
dt_pos = abs(endtime - p.time)
reset_dt = True
else:
dt_pos = abs(p.dt)
reset_dt = False
res = StateCode.Success
p.set_state(res)

sign_end_part = np.sign(endtime - p.time)
if sign_end_part != sign_dt:
dt_pos = 0
break
p.dt = pre_dt
return p

def execute_jit(self, pset, endtime, dt):
Expand Down
33 changes: 0 additions & 33 deletions parcels/kernel/kernelaos.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,36 +196,3 @@ def execute(self, pset, endtime, dt, recovery=None, output_file=None, execute_on

# Remove all particles that signalled deletion
self.remove_deleted(pset)

# Identify particles that threw errors
error_particles = [p for p in pset if p.state not in [StateCode.Success, StateCode.Evaluate]]

while len(error_particles) > 0:
# Apply recovery kernel
for p in error_particles:
if p.state == OperationCode.StopExecution:
return
if p.state == OperationCode.Repeat:
p.reset_state()
elif p.state == OperationCode.Delete:
pass
elif p.state in recovery_map:
recovery_kernel = recovery_map[p.state]
p.set_state(StateCode.Success)
recovery_kernel(p, self.fieldset, p.time)
if p.isComputed():
p.reset_state()
else:
logger.warning_once(f'Deleting particle {p.id} because of non-recoverable error')
p.delete()

# Remove all particles that signalled deletion
self.remove_deleted(pset)

# Execute core loop again to continue interrupted particles
if self.ptype.uses_jit:
self.execute_jit(pset, endtime, dt)
else:
self.execute_python(pset, endtime, dt)

error_particles = [p for p in pset if p.state not in [StateCode.Success, StateCode.Evaluate]]
34 changes: 0 additions & 34 deletions parcels/kernel/kernelsoa.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,37 +197,3 @@ def execute(self, pset, endtime, dt, recovery=None, output_file=None, execute_on

# Remove all particles that signalled deletion
self.remove_deleted(pset)

# Identify particles that threw errors
n_error = pset.num_error_particles

while n_error > 0:
error_pset = pset.error_particles
# Apply recovery kernel
for p in error_pset:
if p.state == OperationCode.StopExecution:
return
if p.state == OperationCode.Repeat:
p.reset_state()
elif p.state == OperationCode.Delete:
pass
elif p.state in recovery_map:
recovery_kernel = recovery_map[p.state]
p.set_state(StateCode.Success)
recovery_kernel(p, self.fieldset, p.time)
if p.isComputed():
p.reset_state()
else:
logger.warning_once(f'Deleting particle {p.id} because of non-recoverable error')
p.delete()

# Remove all particles that signalled deletion
self.remove_deleted(pset)

# Execute core loop again to continue interrupted particles
if self.ptype.uses_jit:
self.execute_jit(pset, endtime, dt)
else:
self.execute_python(pset, endtime, dt)

n_error = pset.num_error_particles

0 comments on commit 3e1e6a3

Please sign in to comment.