Skip to content

Commit

Permalink
api: fix pickling of sparse operations
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Nov 15, 2024
1 parent e0e69d4 commit 445b39e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
10 changes: 8 additions & 2 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from devito.finite_differences.elementary import floor
from devito.logger import warning
from devito.symbolics import retrieve_function_carriers, retrieve_functions, INT
from devito.tools import as_tuple, flatten, filter_ordered
from devito.tools import as_tuple, flatten, filter_ordered, Pickable
from devito.types import (ConditionalDimension, Eq, Inc, Evaluable, Symbol,
CustomDimension, SubFunction)
from devito.types.utils import DimensionTuple
Expand All @@ -33,7 +33,7 @@ def wrapper(interp, *args, **kwargs):
return wrapper


class UnevaluatedSparseOperation(sympy.Expr, Evaluable):
class UnevaluatedSparseOperation(sympy.Expr, Evaluable, Pickable):

"""
Represents an Injection or an Interpolation operation performed on a
Expand All @@ -48,6 +48,7 @@ class UnevaluatedSparseOperation(sympy.Expr, Evaluable):
"""

subdomain = None
__rargs__ = ('interpolator',)

def __new__(cls, interpolator):
obj = super().__new__(cls)
Expand Down Expand Up @@ -79,6 +80,9 @@ class Interpolation(UnevaluatedSparseOperation):
Evaluates to a list of Eq objects.
"""

__rargs__ = ('expr', 'increment', 'implicit_dims', 'self_subs') + \
UnevaluatedSparseOperation.__rargs__

def __new__(cls, expr, increment, implicit_dims, self_subs, interpolator):
obj = super().__new__(cls, interpolator)

Expand Down Expand Up @@ -107,6 +111,8 @@ class Injection(UnevaluatedSparseOperation):
Evaluates to a list of Eq objects.
"""

__rargs__ = ('field', 'expr', 'implicit_dims') + UnevaluatedSparseOperation.__rargs__

def __new__(cls, field, expr, implicit_dims, interpolator):
obj = super().__new__(cls, interpolator)

Expand Down
21 changes: 21 additions & 0 deletions tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,27 @@ def test_alias_sparse_function(self, pickle):
assert sf.dtype == f0.dtype == new_f0.dtype
assert sf.npoint == f0.npoint == new_f0.npoint

@pytest.mark.parametrize('interp', ['linear', 'sinc'])
@pytest.mark.parametrize('op', ['inject', 'interpolate'])
def test_sparse_op(self, pickle, interp, op):
grid = Grid(shape=(3,))
sf = SparseFunction(name='sf', grid=grid, npoint=3, space_order=2,
coordinates=[(0.,), (1.,), (2.,)],
interpolation=interp)
u = Function(name='u', grid=grid, space_order=4)

if op == 'inject':
expr = sf.inject(u, sf)
else:
expr = sf.interpolate(u)

pkl_expr = pickle.dumps(expr)
new_expr = pickle.loads(pkl_expr)

assert new_expr.interpolator._name == expr.interpolator._name
assert new_expr.implicit_dims == expr.implicit_dims
assert str(new_expr.evaluate) == str(expr.evaluate)

def test_internal_symbols(self, pickle):
s = dSymbol(name='s', dtype=np.float32)
pkl_s = pickle.dumps(s)
Expand Down

0 comments on commit 445b39e

Please sign in to comment.