From 445b39e7b3e494f1ef8b90b475133c0d44808872 Mon Sep 17 00:00:00 2001 From: mloubout Date: Fri, 15 Nov 2024 09:25:49 -0500 Subject: [PATCH] api: fix pickling of sparse operations --- devito/operations/interpolators.py | 10 ++++++++-- tests/test_pickle.py | 21 +++++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/devito/operations/interpolators.py b/devito/operations/interpolators.py index df90d96086..5a9160ea4b 100644 --- a/devito/operations/interpolators.py +++ b/devito/operations/interpolators.py @@ -13,7 +13,7 @@ from devito.finite_differences.elementary import floor from devito.logger import warning from devito.symbolics import retrieve_function_carriers, retrieve_functions, INT -from devito.tools import as_tuple, flatten, filter_ordered +from devito.tools import as_tuple, flatten, filter_ordered, Pickable from devito.types import (ConditionalDimension, Eq, Inc, Evaluable, Symbol, CustomDimension, SubFunction) from devito.types.utils import DimensionTuple @@ -33,7 +33,7 @@ def wrapper(interp, *args, **kwargs): return wrapper -class UnevaluatedSparseOperation(sympy.Expr, Evaluable): +class UnevaluatedSparseOperation(sympy.Expr, Evaluable, Pickable): """ Represents an Injection or an Interpolation operation performed on a @@ -48,6 +48,7 @@ class UnevaluatedSparseOperation(sympy.Expr, Evaluable): """ subdomain = None + __rargs__ = ('interpolator',) def __new__(cls, interpolator): obj = super().__new__(cls) @@ -79,6 +80,9 @@ class Interpolation(UnevaluatedSparseOperation): Evaluates to a list of Eq objects. """ + __rargs__ = ('expr', 'increment', 'implicit_dims', 'self_subs') + \ + UnevaluatedSparseOperation.__rargs__ + def __new__(cls, expr, increment, implicit_dims, self_subs, interpolator): obj = super().__new__(cls, interpolator) @@ -107,6 +111,8 @@ class Injection(UnevaluatedSparseOperation): Evaluates to a list of Eq objects. """ + __rargs__ = ('field', 'expr', 'implicit_dims') + UnevaluatedSparseOperation.__rargs__ + def __new__(cls, field, expr, implicit_dims, interpolator): obj = super().__new__(cls, interpolator) diff --git a/tests/test_pickle.py b/tests/test_pickle.py index 6786282ee0..c2a676fde6 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -183,6 +183,27 @@ def test_alias_sparse_function(self, pickle): assert sf.dtype == f0.dtype == new_f0.dtype assert sf.npoint == f0.npoint == new_f0.npoint + @pytest.mark.parametrize('interp', ['linear', 'sinc']) + @pytest.mark.parametrize('op', ['inject', 'interpolate']) + def test_sparse_op(self, pickle, interp, op): + grid = Grid(shape=(3,)) + sf = SparseFunction(name='sf', grid=grid, npoint=3, space_order=2, + coordinates=[(0.,), (1.,), (2.,)], + interpolation=interp) + u = Function(name='u', grid=grid, space_order=4) + + if op == 'inject': + expr = sf.inject(u, sf) + else: + expr = sf.interpolate(u) + + pkl_expr = pickle.dumps(expr) + new_expr = pickle.loads(pkl_expr) + + assert new_expr.interpolator._name == expr.interpolator._name + assert new_expr.implicit_dims == expr.implicit_dims + assert str(new_expr.evaluate) == str(expr.evaluate) + def test_internal_symbols(self, pickle): s = dSymbol(name='s', dtype=np.float32) pkl_s = pickle.dumps(s)