diff --git a/src/parcels/_core/field.py b/src/parcels/_core/field.py index 9ed1432cb9..36a9b0c169 100644 --- a/src/parcels/_core/field.py +++ b/src/parcels/_core/field.py @@ -22,6 +22,7 @@ from parcels._core.utils.time import TimeInterval from parcels._core.uxgrid import UxGrid from parcels._core.xgrid import XGrid, _transpose_xfield_data_to_tzyx +from parcels._python import assert_same_function_signature from parcels._reprs import default_repr from parcels._typing import VectorType from parcels.interpolators import ( @@ -30,7 +31,6 @@ ZeroInterpolator, ZeroInterpolator_Vector, ) -from parcels.utils._helpers import _assert_same_function_signature __all__ = ["Field", "VectorField"] @@ -139,7 +139,7 @@ def __init__( if interp_method is None: self._interp_method = _DEFAULT_INTERPOLATOR_MAPPING[type(self.grid)] else: - _assert_same_function_signature(interp_method, ref=ZeroInterpolator, context="Interpolation") + assert_same_function_signature(interp_method, ref=ZeroInterpolator, context="Interpolation") self._interp_method = interp_method self.igrid = -1 # Default the grid index to -1 @@ -195,7 +195,7 @@ def interp_method(self): @interp_method.setter def interp_method(self, method: Callable): - _assert_same_function_signature(method, ref=ZeroInterpolator, context="Interpolation") + assert_same_function_signature(method, ref=ZeroInterpolator, context="Interpolation") self._interp_method = method def _check_velocitysampling(self): @@ -270,7 +270,7 @@ def __init__( if vector_interp_method is None: self._vector_interp_method = None else: - _assert_same_function_signature(vector_interp_method, ref=ZeroInterpolator_Vector, context="Interpolation") + assert_same_function_signature(vector_interp_method, ref=ZeroInterpolator_Vector, context="Interpolation") self._vector_interp_method = vector_interp_method def __repr__(self): @@ -286,7 +286,7 @@ def vector_interp_method(self): @vector_interp_method.setter def vector_interp_method(self, method: Callable): - _assert_same_function_signature(method, ref=ZeroInterpolator_Vector, context="Interpolation") + assert_same_function_signature(method, ref=ZeroInterpolator_Vector, context="Interpolation") self._vector_interp_method = method def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True): diff --git a/src/parcels/_core/kernel.py b/src/parcels/_core/kernel.py index a1404dc08d..83ff810c15 100644 --- a/src/parcels/_core/kernel.py +++ b/src/parcels/_core/kernel.py @@ -17,12 +17,12 @@ _raise_outside_time_interval_error, ) from parcels._core.warnings import KernelWarning +from parcels._python import assert_same_function_signature from parcels.kernels import ( AdvectionAnalytical, AdvectionRK4, AdvectionRK45, ) -from parcels.utils._helpers import _assert_same_function_signature if TYPE_CHECKING: from collections.abc import Callable @@ -67,7 +67,7 @@ def __init__( for f in pyfuncs: if not isinstance(f, types.FunctionType): raise TypeError(f"Argument pyfunc should be a function or list of functions. Got {type(f)}") - _assert_same_function_signature(f, ref=AdvectionRK4, context="Kernel") + assert_same_function_signature(f, ref=AdvectionRK4, context="Kernel") if len(pyfuncs) == 0: raise ValueError("List of `pyfuncs` should have at least one function.") diff --git a/src/parcels/_core/particlefile.py b/src/parcels/_core/particlefile.py index 123efd4387..63c32b4091 100644 --- a/src/parcels/_core/particlefile.py +++ b/src/parcels/_core/particlefile.py @@ -15,7 +15,7 @@ import parcels from parcels._core.particle import _SAME_AS_FIELDSET_TIME_INTERVAL, ParticleClass -from parcels.utils._helpers import timedelta_to_float +from parcels._core.utils.time import timedelta_to_float if TYPE_CHECKING: from parcels._core.particle import Variable diff --git a/src/parcels/utils/interpolation_utils.py b/src/parcels/_core/utils/interpolation.py similarity index 100% rename from src/parcels/utils/interpolation_utils.py rename to src/parcels/_core/utils/interpolation.py diff --git a/src/parcels/_core/utils/time.py b/src/parcels/_core/utils/time.py index 8cc34acf45..3cc836ff5f 100644 --- a/src/parcels/_core/utils/time.py +++ b/src/parcels/_core/utils/time.py @@ -144,3 +144,12 @@ def maybe_convert_python_timedelta_to_numpy(dt: timedelta | np.timedelta64) -> n return np.timedelta64(0, "s") except Exception as e: raise ValueError(f"Could not convert {dt!r} to np.timedelta64.") from e + + +def timedelta_to_float(dt: float | timedelta | np.timedelta64) -> float: + """Convert a timedelta to a float in seconds.""" + if isinstance(dt, timedelta): + return dt.total_seconds() + if isinstance(dt, np.timedelta64): + return float(dt / np.timedelta64(1, "s")) + return float(dt) diff --git a/src/parcels/_interpolation.py b/src/parcels/_interpolation.py deleted file mode 100644 index e38f0d0d0f..0000000000 --- a/src/parcels/_interpolation.py +++ /dev/null @@ -1,345 +0,0 @@ -from collections.abc import Callable, Mapping -from dataclasses import dataclass - -import numpy as np - -from parcels._typing import GridIndexingType -from parcels.utils._helpers import should_calculate_next_ti - - -@dataclass -class InterpolationContext2D: - """Information provided by Parcels during 2D spatial interpolation. See Delandmeter and Van Sebille (2019), 10.5194/gmd-12-3571-2019 for more info. - - Attributes - ---------- - data: np.ndarray - field data of shape (time, y, x) - tau: float - time interpolation coordinate in unit length - eta: float - y-direction interpolation coordinate in unit cube (between 0 and 1) - xsi: float - x-direction interpolation coordinate in unit cube (between 0 and 1) - ti: int - time index - yi: int - y index of cell containing particle - xi: int - x index of cell containing particle - - """ - - data: np.ndarray - tau: float - eta: float - xsi: float - ti: int - yi: int - xi: int - - -@dataclass -class InterpolationContext3D: - """Information provided by Parcels during 3D spatial interpolation. See Delandmeter and Van Sebille (2019), 10.5194/gmd-12-3571-2019 for more info. - - Attributes - ---------- - data: np.ndarray - field data of shape (time, z, y, x). This needs to be complete in the vertical - direction as some interpolation methods need to know whether they are at the - surface or bottom. - tau: float - time interpolation coordinate in unit length - zeta: float - vertical interpolation coordinate in unit cube - eta: float - y-direction interpolation coordinate in unit cube - xsi: float - x-direction interpolation coordinate in unit cube - zi: int - z index of cell containing particle - ti: int - time index - yi: int - y index of cell containing particle - xi: int - x index of cell containing particle - gridindexingtype: GridIndexingType - grid indexing type - - """ - - data: np.ndarray - tau: float - zeta: float - eta: float - xsi: float - ti: int - zi: int - yi: int - xi: int - gridindexingtype: GridIndexingType # included in 3D as z-face is indexed differently with MOM5 and POP - - -_interpolator_registry_2d: dict[str, Callable[[InterpolationContext2D], float]] = {} -_interpolator_registry_3d: dict[str, Callable[[InterpolationContext3D], float]] = {} - - -def get_2d_interpolator_registry() -> Mapping[str, Callable[[InterpolationContext2D], float]]: - # See Discussion on Python Discord for more context (function prevents re-alias of global variable) - # _interpolator_registry_2d etc shouldn't be imported directly - # https://discord.com/channels/267624335836053506/1329136004459794483 - return _interpolator_registry_2d - - -def get_3d_interpolator_registry() -> Mapping[str, Callable[[InterpolationContext3D], float]]: - return _interpolator_registry_3d - - -def register_2d_interpolator(name: str): - def decorator(interpolator: Callable[[InterpolationContext2D], float]): - _interpolator_registry_2d[name] = interpolator - return interpolator - - return decorator - - -def register_3d_interpolator(name: str): - def decorator(interpolator: Callable[[InterpolationContext3D], float]): - _interpolator_registry_3d[name] = interpolator - return interpolator - - return decorator - - -@register_2d_interpolator("nearest") -def _nearest_2d(ctx: InterpolationContext2D) -> float: - xii = ctx.xi if ctx.xsi <= 0.5 else ctx.xi + 1 - yii = ctx.yi if ctx.eta <= 0.5 else ctx.yi + 1 - ft0 = ctx.data[ctx.ti, yii, xii] - if not should_calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]): - return ft0 - ft1 = ctx.data[ctx.ti + 1, yii, xii] - return (1 - ctx.tau) * ft0 + ctx.tau * ft1 - - -def _interp_on_unit_square(*, eta: float, xsi: float, data: np.ndarray, yi: int, xi: int) -> float: - """Interpolation on a unit square. See Delandmeter and Van Sebille (2019), 10.5194/gmd-12-3571-2019.""" - return ( - (1 - xsi) * (1 - eta) * data[yi, xi] - + xsi * (1 - eta) * data[yi, xi + 1] - + xsi * eta * data[yi + 1, xi + 1] - + (1 - xsi) * eta * data[yi + 1, xi] - ) - - -@register_2d_interpolator("linear") -@register_2d_interpolator("bgrid_velocity") -@register_2d_interpolator("partialslip") -@register_2d_interpolator("freeslip") -def _linear_2d(ctx: InterpolationContext2D) -> float: - ft0 = _interp_on_unit_square(eta=ctx.eta, xsi=ctx.xsi, data=ctx.data[ctx.ti, :, :], yi=ctx.yi, xi=ctx.xi) - if not should_calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]): - return ft0 - ft1 = _interp_on_unit_square(eta=ctx.eta, xsi=ctx.xsi, data=ctx.data[ctx.ti + 1, :, :], yi=ctx.yi, xi=ctx.xi) - return (1 - ctx.tau) * ft0 + ctx.tau * ft1 - - -@register_2d_interpolator("linear_invdist_land_tracer") -def _linear_invdist_land_tracer_2d(ctx: InterpolationContext2D) -> float: - xsi = ctx.xsi - eta = ctx.eta - data = ctx.data - yi = ctx.yi - xi = ctx.xi - ti = ctx.ti - land = np.isclose(data[ti, yi : yi + 2, xi : xi + 2], 0.0) - nb_land = np.sum(land) - - def _get_data_temporalinterp(*, ti, yi, xi): - dt0 = data[ti, yi, xi] - if not should_calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]): - return dt0 - dt1 = data[ti + 1, yi, xi] - return (1 - ctx.tau) * dt0 + ctx.tau * dt1 - - if nb_land == 4: - return 0 - elif nb_land > 0: - val = 0 - w_sum = 0.0 - for j in range(2): - for i in range(2): - distance = pow((eta - j), 2) + pow((xsi - i), 2) - if np.isclose(distance, 0): - if land[j][i] == 1: # index search led us directly onto land - return 0 - else: - return _get_data_temporalinterp(ti=ti, yi=yi + j, xi=xi + i) - elif land[j][i] == 0: - val += _get_data_temporalinterp(ti=ti, yi=yi + j, xi=xi + i) / distance - w_sum += 1 / distance - return val / w_sum - else: - return _interp_on_unit_square(eta=eta, xsi=xsi, data=data[ti, :, :], yi=yi, xi=xi) - - -@register_2d_interpolator("cgrid_tracer") -@register_2d_interpolator("bgrid_tracer") -def _tracer_2d(ctx: InterpolationContext2D) -> float: - ft0 = ctx.data[ctx.ti, ctx.yi + 1, ctx.xi + 1] - if not should_calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]): - return ft0 - ft1 = ctx.data[ctx.ti + 1, ctx.yi + 1, ctx.xi + 1] - return (1 - ctx.tau) * ft0 + ctx.tau * ft1 - - -@register_3d_interpolator("nearest") -def _nearest_3d(ctx: InterpolationContext3D) -> float: - xii = ctx.xi if ctx.xsi <= 0.5 else ctx.xi + 1 - yii = ctx.yi if ctx.eta <= 0.5 else ctx.yi + 1 - zii = ctx.zi if ctx.zeta <= 0.5 else ctx.zi + 1 - ft0 = ctx.data[ctx.ti, zii, yii, xii] - if not should_calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]): - return ft0 - ft1 = ctx.data[ctx.ti + 1, zii, yii, xii] - return (1 - ctx.tau) * ft0 + ctx.tau * ft1 - - -def _get_cgrid_depth_point(*, zeta: float, data: np.ndarray, zi: int, yi: int, xi: int) -> float: - f0 = data[zi, yi, xi] - f1 = data[zi + 1, yi, xi] - return (1 - zeta) * f0 + zeta * f1 - - -@register_3d_interpolator("cgrid_velocity") -def _cgrid_W_velocity_3d(ctx: InterpolationContext3D) -> float: - # evaluating W velocity in c_grid - if ctx.gridindexingtype == "nemo": - ft0 = _get_cgrid_depth_point( - zeta=ctx.zeta, data=ctx.data[ctx.ti, :, :, :], zi=ctx.zi, yi=ctx.yi + 1, xi=ctx.xi + 1 - ) - elif ctx.gridindexingtype in ["mitgcm", "croco"]: - ft0 = _get_cgrid_depth_point(zeta=ctx.zeta, data=ctx.data[ctx.ti, :, :, :], zi=ctx.zi, yi=ctx.yi, xi=ctx.xi) - if not should_calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]): - return ft0 - - if ctx.gridindexingtype == "nemo": - ft1 = _get_cgrid_depth_point( - zeta=ctx.zeta, data=ctx.data[ctx.ti + 1, :, :, :], zi=ctx.zi, yi=ctx.yi + 1, xi=ctx.xi + 1 - ) - elif ctx.gridindexingtype in ["mitgcm", "croco"]: - ft1 = _get_cgrid_depth_point(zeta=ctx.zeta, data=ctx.data[ctx.ti + 1, :, :, :], zi=ctx.zi, yi=ctx.yi, xi=ctx.xi) - return (1 - ctx.tau) * ft0 + ctx.tau * ft1 - - -@register_3d_interpolator("linear_invdist_land_tracer") -def _linear_invdist_land_tracer_3d(ctx: InterpolationContext3D) -> float: - land = np.isclose(ctx.data[ctx.ti, ctx.zi : ctx.zi + 2, ctx.yi : ctx.yi + 2, ctx.xi : ctx.xi + 2], 0.0) - nb_land = np.sum(land) - - def _get_data_temporalinterp(*, ti, zi, yi, xi): - dt0 = ctx.data[ti, zi, yi, xi] - if not should_calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]): - return dt0 - dt1 = data[ti + 1, zi, yi, xi] - return (1 - ctx.tau) * dt0 + ctx.tau * dt1 - - if nb_land == 8: - return 0 - elif nb_land > 0: - val = 0.0 - w_sum = 0.0 - for k in range(2): - for j in range(2): - for i in range(2): - distance = pow((ctx.zeta - k), 2) + pow((ctx.eta - j), 2) + pow((ctx.xsi - i), 2) - if np.isclose(distance, 0): - if land[k][j][i] == 1: # index search led us directly onto land - return 0 - else: - return _get_data_temporalinterp(ti=ctx.ti, zi=ctx.zi + k, yi=ctx.yi + j, xi=ctx.xi + i) - elif land[k][j][i] == 0: - val += ( - _get_data_temporalinterp(ti=ctx.ti, zi=ctx.zi + k, yi=ctx.yi + j, xi=ctx.xi + i) / distance - ) - w_sum += 1 / distance - return val / w_sum - else: - data = ctx.data[ctx.ti, ctx.zi, :, :] - f0 = _interp_on_unit_square(eta=ctx.eta, xsi=ctx.xsi, data=data, yi=ctx.yi, xi=ctx.xi) - - data = ctx.data[ctx.ti, ctx.zi + 1, :, :] - f1 = _interp_on_unit_square(eta=ctx.eta, xsi=ctx.xsi, data=data, yi=ctx.yi, xi=ctx.xi) - - return (1 - ctx.zeta) * f0 + ctx.zeta * f1 - - -def _get_3d_f0_f1(*, eta: float, xsi: float, data: np.ndarray, zi: int, yi: int, xi: int) -> tuple[float, float | None]: - data_2d = data[zi, :, :] - f0 = _interp_on_unit_square(eta=eta, xsi=xsi, data=data_2d, yi=yi, xi=xi) - try: - data_2d = data[zi + 1, :, :] - except IndexError: - f1 = None # POP indexing at edge of domain - else: - f1 = _interp_on_unit_square(eta=eta, xsi=xsi, data=data_2d, yi=yi, xi=xi) - - return f0, f1 - - -def _z_layer_interp( - *, zeta: float, f0: float, f1: float | None, zi: int, zdim: int, gridindexingtype: GridIndexingType -): - if gridindexingtype == "pop" and zi >= zdim - 2: - # Since POP is indexed at cell top, allow linear interpolation of W to zero in lowest cell - return (1 - zeta) * f0 - assert f1 is not None, "f1 should not be None for gridindexingtype != 'pop'" - if gridindexingtype == "mom5" and zi == -1: - # Since MOM5 is indexed at cell bottom, allow linear interpolation of W to zero in uppermost cell - return zeta * f1 - return (1 - zeta) * f0 + zeta * f1 - - -@register_3d_interpolator("linear") -@register_3d_interpolator("partialslip") -@register_3d_interpolator("freeslip") -def _linear_3d(ctx: InterpolationContext3D) -> float: - zdim = ctx.data.shape[1] - data_3d = ctx.data[ctx.ti, :, :, :] - fz0, fz1 = _get_3d_f0_f1(eta=ctx.eta, xsi=ctx.xsi, data=data_3d, zi=ctx.zi, yi=ctx.yi, xi=ctx.xi) - if should_calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]): - data_3d = ctx.data[ctx.ti + 1, :, :, :] - fz0_t1, fz1_t1 = _get_3d_f0_f1(eta=ctx.eta, xsi=ctx.xsi, data=data_3d, zi=ctx.zi, yi=ctx.yi, xi=ctx.xi) - fz0 = (1 - ctx.tau) * fz0 + ctx.tau * fz0_t1 - if fz1_t1 is not None and fz1 is not None: - fz1 = (1 - ctx.tau) * fz1 + ctx.tau * fz1_t1 - - return _z_layer_interp(zeta=ctx.zeta, f0=fz0, f1=fz1, zi=ctx.zi, zdim=zdim, gridindexingtype=ctx.gridindexingtype) - - -@register_3d_interpolator("bgrid_velocity") -def _linear_3d_bgrid_velocity(ctx: InterpolationContext3D) -> float: - if ctx.gridindexingtype == "mom5": - ctx.zeta = 1.0 - else: - ctx.zeta = 0.0 - return _linear_3d(ctx) - - -@register_3d_interpolator("bgrid_w_velocity") -def _linear_3d_bgrid_w_velocity(ctx: InterpolationContext3D) -> float: - ctx.eta = 1.0 - ctx.xsi = 1.0 - return _linear_3d(ctx) - - -@register_3d_interpolator("bgrid_tracer") -@register_3d_interpolator("cgrid_tracer") -def _tracer_3d(ctx: InterpolationContext3D) -> float: - ft0 = ctx.data[ctx.ti, ctx.zi, ctx.yi + 1, ctx.xi + 1] - if not should_calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]): - return ft0 - ft1 = ctx.data[ctx.ti + 1, ctx.zi, ctx.yi + 1, ctx.xi + 1] - return (1 - ctx.tau) * ft0 + ctx.tau * ft1 diff --git a/src/parcels/_python.py b/src/parcels/_python.py index 91f5b46458..a78e9bf0d4 100644 --- a/src/parcels/_python.py +++ b/src/parcels/_python.py @@ -1,4 +1,6 @@ # Generic Python helpers +import inspect +from collections.abc import Callable def isinstance_noimport(obj, class_or_tuple): @@ -12,17 +14,22 @@ def isinstance_noimport(obj, class_or_tuple): ) -def test_isinstance_noimport(): - class A: - pass +def assert_same_function_signature(f: Callable, *, ref: Callable, context: str) -> None: + """Ensures a function `f` has the same signature as the reference function `ref`.""" + sig_ref = inspect.signature(ref) + sig = inspect.signature(f) - class B: - pass + if len(sig_ref.parameters) != len(sig.parameters): + raise ValueError( + f"{context} function must have {len(sig_ref.parameters)} parameters, got {len(sig.parameters)}" + ) - a = A() - b = B() - - assert isinstance_noimport(a, "A") - assert not isinstance_noimport(a, "B") - assert isinstance_noimport(b, ("A", "B")) - assert not isinstance_noimport(b, "C") + for param1, param2 in zip(sig_ref.parameters.values(), sig.parameters.values(), strict=False): + if param1.kind != param2.kind: + raise ValueError( + f"Parameter '{param2.name}' has incorrect parameter kind. Expected {param1.kind}, got {param2.kind}" + ) + if param1.name != param2.name: + raise ValueError( + f"Parameter '{param2.name}' has incorrect name. Expected '{param1.name}', got '{param2.name}'" + ) diff --git a/src/parcels/interpolators.py b/src/parcels/interpolators.py index 45dda5bea8..51e16510f7 100644 --- a/src/parcels/interpolators.py +++ b/src/parcels/interpolators.py @@ -8,7 +8,7 @@ import xarray as xr from dask import is_dask_collection -import parcels.utils.interpolation_utils as i_u +import parcels._core.utils.interpolation as i_u if TYPE_CHECKING: from parcels._core.field import Field, VectorField diff --git a/src/parcels/kernels/advection.py b/src/parcels/kernels/advection.py index aa08d48912..3955b72b9a 100644 --- a/src/parcels/kernels/advection.py +++ b/src/parcels/kernels/advection.py @@ -189,7 +189,7 @@ def AdvectionAnalytical(particles, fieldset): # pragma: no cover """ import numpy as np - import parcels.utils.interpolation_utils as i_u + import parcels._core.utils.interpolation as i_u tol = 1e-10 I_s = 10 # number of intermediate time steps diff --git a/src/parcels/utils/__init__.py b/src/parcels/utils/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/parcels/utils/_helpers.py b/src/parcels/utils/_helpers.py deleted file mode 100644 index 0b078110e0..0000000000 --- a/src/parcels/utils/_helpers.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Internal helpers for Parcels.""" - -from __future__ import annotations - -import inspect -from collections.abc import Callable -from datetime import timedelta - -import numpy as np - -PACKAGE = "Parcels" - - -def timedelta_to_float(dt: float | timedelta | np.timedelta64) -> float: - """Convert a timedelta to a float in seconds.""" - if isinstance(dt, timedelta): - return dt.total_seconds() - if isinstance(dt, np.timedelta64): - return float(dt / np.timedelta64(1, "s")) - return float(dt) - - -def should_calculate_next_ti(ti: int, tau: float, tdim: int): - """Check if the time is beyond the last time in the field""" - return np.greater(tau, 0) and ti < tdim - 1 - - -def _assert_same_function_signature(f: Callable, *, ref: Callable, context: str) -> None: - """Ensures a function `f` has the same signature as the reference function `ref`.""" - sig_ref = inspect.signature(ref) - sig = inspect.signature(f) - - if len(sig_ref.parameters) != len(sig.parameters): - raise ValueError( - f"{context} function must have {len(sig_ref.parameters)} parameters, got {len(sig.parameters)}" - ) - - for param1, param2 in zip(sig_ref.parameters.values(), sig.parameters.values(), strict=False): - if param1.kind != param2.kind: - raise ValueError( - f"Parameter '{param2.name}' has incorrect parameter kind. Expected {param1.kind}, got {param2.kind}" - ) - if param1.name != param2.name: - raise ValueError( - f"Parameter '{param2.name}' has incorrect name. Expected '{param1.name}', got '{param2.name}'" - ) diff --git a/src/parcels/utils/timer.py b/src/parcels/utils/timer.py deleted file mode 100644 index ea1181a448..0000000000 --- a/src/parcels/utils/timer.py +++ /dev/null @@ -1,69 +0,0 @@ -import datetime -import time - -from parcels._compat import MPI - -__all__ = [] # type: ignore - - -class Timer: - def __init__(self, name, parent=None, start=True): - self._start = None - self._t = 0 - self._name = name - self._children = [] - self._parent = parent - if self._parent: - self._parent._children.append(self) - if start: - self.start() - - def start(self): - if self._parent: - assert self._parent._start, f"Timer '{self._name}' cannot be started. Its parent timer does not run" - if self._start is not None: - raise RuntimeError(f"Timer {self._name} cannot start since it is already running") - self._start = time.time() - - def stop(self): - assert self._start, f"Timer '{self._name}' was stopped before being started" - self._t += time.time() - self._start - self._start = None - - def print_local(self): - if self._start: - print(f"Timer '{self._name}': {self._t + time.time() - self._start:g} s (process running)") - else: - print(f"Timer '{self._name}': {self._t:g} s") - - def local_time(self): - return self._t + time.time() - self._start if self._start else self._t - - def print_tree_sequential(self, step=0, root_time=0, parent_time=0): - time = self.local_time() - if step == 0: - root_time = time - print(f"({round(time / root_time * 100):3d}%)", end="") - print(" " * (step + 1), end="") - if step > 0: - print(f"({round(time / parent_time * 100):3d}%) ", end="") - t_str = f"{time:1.3e} s" if root_time < 300 else datetime.timedelta(seconds=time) - print(f"Timer {(self._name).ljust(20 - 2 * step + 7 * (step == 0))}: {t_str}") - for child in self._children: - child.print_tree_sequential(step + 1, root_time, time) - - def print_tree(self, step=0, root_time=0, parent_time=0): - if MPI is None: - self.print_tree_sequential(step, root_time, parent_time) - else: - mpi_comm = MPI.COMM_WORLD - mpi_rank = mpi_comm.Get_rank() - mpi_size = mpi_comm.Get_size() - if mpi_size == 1: - self.print_tree_sequential(step, root_time, parent_time) - else: - for iproc in range(mpi_size): - if iproc == mpi_rank: - print(f"Proc {mpi_rank}/{mpi_size} - Timer tree") - self.print_tree_sequential(step, root_time, parent_time) - mpi_comm.Barrier() diff --git a/tests-v3/tools/test_helpers.py b/tests-v3/tools/test_helpers.py index 1403b679f0..c3499b55a1 100644 --- a/tests-v3/tools/test_helpers.py +++ b/tests-v3/tools/test_helpers.py @@ -1,10 +1,7 @@ -from datetime import timedelta - -import numpy as np import pytest import parcels.tools._helpers as helpers -from parcels.tools._helpers import deprecated, deprecated_made_private, timedelta_to_float +from parcels.tools._helpers import deprecated, deprecated_made_private def test_format_list_items_multiline(): @@ -67,20 +64,3 @@ def some_function(x, y): some_function(1, 2) assert "deprecated::" in some_function.__doc__ - - -@pytest.mark.parametrize( - "input, expected", - [ - (timedelta(days=1), 24 * 60 * 60), - (np.timedelta64(1, "D"), 24 * 60 * 60), - (3600.0, 3600.0), - ], -) -def test_timedelta_to_float(input, expected): - assert timedelta_to_float(input) == expected - - -def test_timedelta_to_float_exceptions(): - with pytest.raises((ValueError, TypeError)): - timedelta_to_float("invalid_type") diff --git a/tests/test_python.py b/tests/test_python.py new file mode 100644 index 0000000000..01f3327561 --- /dev/null +++ b/tests/test_python.py @@ -0,0 +1,17 @@ +from parcels._python import isinstance_noimport + + +def test_isinstance_noimport(): + class A: + pass + + class B: + pass + + a = A() + b = B() + + assert isinstance_noimport(a, "A") + assert not isinstance_noimport(a, "B") + assert isinstance_noimport(b, ("A", "B")) + assert not isinstance_noimport(b, "C") diff --git a/tests/utils/test_time.py b/tests/utils/test_time.py index 5bab70b3b2..b321744fce 100644 --- a/tests/utils/test_time.py +++ b/tests/utils/test_time.py @@ -8,7 +8,7 @@ from hypothesis import given from hypothesis import strategies as st -from parcels._core.utils.time import TimeInterval, maybe_convert_python_timedelta_to_numpy +from parcels._core.utils.time import TimeInterval, maybe_convert_python_timedelta_to_numpy, timedelta_to_float calendar_strategy = st.sampled_from( [ @@ -198,3 +198,20 @@ def test_time_interval_intersection_different_calendars(): def test_maybe_convert_python_timedelta_to_numpy(td, expected): result = maybe_convert_python_timedelta_to_numpy(td) assert result == expected + + +@pytest.mark.parametrize( + "input, expected", + [ + (timedelta(days=1), 24 * 60 * 60), + (np.timedelta64(1, "D"), 24 * 60 * 60), + (3600.0, 3600.0), + ], +) +def test_timedelta_to_float(input, expected): + assert timedelta_to_float(input) == expected + + +def test_timedelta_to_float_exceptions(): + with pytest.raises((ValueError, TypeError)): + timedelta_to_float("invalid_type")