diff --git a/devito/passes/clusters/factorization.py b/devito/passes/clusters/factorization.py index 794e437e977..700cfadd764 100644 --- a/devito/passes/clusters/factorization.py +++ b/devito/passes/clusters/factorization.py @@ -174,7 +174,7 @@ def _collect_nested(expr): Recursion helper for `collect_nested`. """ # Return semantic (rebuilt expression, factorization candidates) - if expr.kind is NumberKind: + if expr.kind is NumberKind and not expr.is_Symbol: return expr, {'coeffs': expr} elif expr.is_Function: return expr, {'funcs': expr} diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 6fe38ee509b..9d44c8b0204 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -101,10 +101,11 @@ def _print_math_func(self, expr, nest=False, known=None): return super()._print_math_func(expr, nest=nest, known=known) dtype = sympy_dtype(expr) - if dtype is np.float32: - cname += 'f' - if np.issubdtype(self.dtype, np.complexfloating): + if np.issubdtype(dtype, np.complexfloating): cname = 'c%s' % cname + dtype = self.dtype(0).real.dtype + if dtype is np.float32: + cname = '%sf' % cname args = ', '.join((self._print(arg) for arg in expr.args)) return '%s(%s)' % (cname, args) @@ -198,6 +199,9 @@ def _print_Float(self, expr): return rv + def _print_ImaginaryUnit(self, expr): + return '_Complex_I' + def _print_Differentiable(self, expr): return "(%s)" % self._print(expr._expr) @@ -249,10 +253,12 @@ def _print_ComponentAccess(self, expr): def _print_TrigonometricFunction(self, expr): func_name = str(expr.func) - if self.dtype == np.float32: - func_name += 'f' - if np.issubdtype(self.dtype, np.complexfloating): + dtype = self.dtype + if np.issubdtype(dtype, np.complexfloating): func_name = 'c%s' % func_name + dtype = self.dtype(0).real.dtype + if dtype == np.float32: + func_name = '%sf' % func_name return '%s(%s)' % (func_name, self._print(*expr.args)) def _print_DefFunction(self, expr): diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index 3ac6f62fb4b..bd1d6a8afcc 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -136,6 +136,14 @@ def dtype_to_ctype(dtype): if isinstance(dtype, CustomDtype): return dtype + # Complex data + if np.issubdtype(dtype, np.complexfloating): + rtype = dtype(0).real.__class__ + ctname = '%s _Complex' % dtype_to_cstr(rtype) + ctype = dtype_to_ctype(rtype) + r = type(ctname, (ctype,), {}) + return r + try: return ctypes_vector_mapper[dtype] except KeyError: diff --git a/devito/types/basic.py b/devito/types/basic.py index af08d6281d9..5ca2be984d4 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -13,8 +13,7 @@ from devito.data import default_allocator from devito.parameters import configuration from devito.tools import (Pickable, as_tuple, ctypes_to_cstr, dtype_to_ctype, - frozendict, memoized_meth, sympy_mutex, dtype_to_cstr, - CustomDtype) + frozendict, memoized_meth, sympy_mutex) from devito.types.args import ArgProvider from devito.types.caching import Cached, Uncached from devito.types.lazy import Evaluable @@ -432,16 +431,7 @@ def _C_name(self): @property def _C_ctype(self): - if isinstance(self.dtype, CustomDtype): - return self.dtype - elif np.issubdtype(self.dtype, np.complexfloating): - rtype = self.dtype(0).real.__class__ - ctname = '%s complex' % dtype_to_cstr(rtype) - ctype = dtype_to_ctype(rtype) - r = type(ctname, (ctype,), {}) - return r - else: - return dtype_to_ctype(self.dtype) + return dtype_to_ctype(self.dtype) def _subs(self, old, new, **hints): """ @@ -1438,14 +1428,7 @@ def _C_name(self): @cached_property def _C_ctype(self): try: - if np.issubdtype(self.dtype, np.complexfloating): - rtype = self.dtype(0).real.__class__ - ctname = '%s complex' % dtype_to_cstr(rtype) - ctype = dtype_to_ctype(rtype) - r = type(ctname, (ctype,), {}) - return POINTER(r) - else: - return POINTER(dtype_to_ctype(self.dtype)) + return POINTER(dtype_to_ctype(self.dtype)) except TypeError: # `dtype` is a ctypes-derived type! return self.dtype diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 28d0a38edd3..9edff99f56a 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -7,7 +7,7 @@ from conftest import assert_structure from devito import (Constant, Eq, Inc, Grid, Function, ConditionalDimension, Dimension, MatrixSparseTimeFunction, SparseTimeFunction, - SubDimension, SubDomain, SubDomainSet, TimeFunction, + SubDimension, SubDomain, SubDomainSet, TimeFunction, exp, Operator, configuration, switchconfig, TensorTimeFunction) from devito.arch import get_gpu_info from devito.exceptions import InvalidArgument diff --git a/tests/test_operator.py b/tests/test_operator.py index 376c7e5dff6..73d772a1cb4 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -655,6 +655,7 @@ def test_complex(self): dx = grid.spacing_map[x.spacing] xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) npres = xx + 1j*yy + np.exp(1j + dx) + print(op) assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0)