From 2b280901ef039f48df21c0f01857246b6504e711 Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 2 Aug 2023 10:15:54 -0400 Subject: [PATCH 01/11] api: add support for complex dtype --- devito/data/allocators.py | 8 +++++-- devito/finite_differences/differentiable.py | 2 +- devito/operator/operator.py | 6 +++++ devito/passes/clusters/factorization.py | 4 ++-- devito/passes/iet/misc.py | 26 ++++++++++++++++++++- devito/symbolics/inspection.py | 5 ++++ devito/tools/dtypes_lowering.py | 4 +++- devito/types/basic.py | 23 +++++++++++++++--- 8 files changed, 68 insertions(+), 10 deletions(-) diff --git a/devito/data/allocators.py b/devito/data/allocators.py index 72289c57bf..14f1b04fd1 100644 --- a/devito/data/allocators.py +++ b/devito/data/allocators.py @@ -92,8 +92,12 @@ def initialize(cls): return def alloc(self, shape, dtype, padding=0): - datasize = int(reduce(mul, shape)) - ctype = dtype_to_ctype(dtype) + # For complex number, allocate double the size of its real/imaginary part + alloc_dtype = dtype(0).real.__class__ + c_scale = 2 if np.issubdtype(dtype, np.complexfloating) else 1 + + datasize = int(reduce(mul, shape) * c_scale) + ctype = dtype_to_ctype(alloc_dtype) # Add padding, if any try: diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 2e1fef6548..8b5a47207c 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -68,7 +68,7 @@ def grid(self): @cached_property def dtype(self): - dtypes = {f.dtype for f in self.find(Indexed)} - {None} + dtypes = {f.dtype for f in self._functions} - {None} return infer_dtype(dtypes) @cached_property diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 363c2507e3..e2fe8dd3a1 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -469,6 +469,12 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Lower IET to a target-specific IET graph = Graph(iet, **kwargs) + + # Complex header if needed. Needs to be done specialization + # as some specific cases requires complex to be loaded first + complex_include(graph) + + # Specialize graph = cls._specialize_iet(graph, **kwargs) # Instrument the IET for C-level profiling diff --git a/devito/passes/clusters/factorization.py b/devito/passes/clusters/factorization.py index 33253e245e..794e437e97 100644 --- a/devito/passes/clusters/factorization.py +++ b/devito/passes/clusters/factorization.py @@ -1,6 +1,7 @@ from collections import defaultdict from sympy import Add, Mul, S, collect +from sympy.core import NumberKind from devito.ir import cluster_pass from devito.symbolics import BasicWrapperMixin, estimate_cost, retrieve_symbols @@ -173,8 +174,7 @@ def _collect_nested(expr): Recursion helper for `collect_nested`. """ # Return semantic (rebuilt expression, factorization candidates) - - if expr.is_Number: + if expr.kind is NumberKind: return expr, {'coeffs': expr} elif expr.is_Function: return expr, {'funcs': expr} diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index f0b2b7f4f5..f8a99ead5f 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -3,7 +3,9 @@ import cgen import numpy as np import sympy +import numpy as np +from devito import configuration from devito.finite_differences import Max, Min from devito.ir import (Any, Forward, Iteration, List, Prodder, FindApplications, FindNodes, FindSymbols, Transformer, Uxreplace, @@ -16,7 +18,7 @@ from devito.types import FIndexed __all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions', - 'generate_macros', 'minimize_symbols'] + 'generate_macros', 'minimize_symbols', 'complex_include'] @iet_pass @@ -240,6 +242,28 @@ def minimize_symbols(iet): return iet, {} +@iet_pass +def complex_include(iet): + """ + Add headers for complex arithmetic + """ + if configuration['language'] == 'cuda': + lib = 'cuComplex.h' + elif configuration['language'] == 'hip': + lib = 'hip/hip_complex.h' + else: + lib = 'complex.h' + + functions = FindSymbols().visit(iet) + for f in functions: + try: + if np.issubdtype(f.dtype, np.complexfloating): + return iet, {'includes': (lib,)} + except TypeError: + pass + return iet, {} + + def remove_redundant_moddims(iet): key = lambda d: d.is_Modulo and d.origin is not None mds = [d for d in FindSymbols('dimensions').visit(iet) if key(d)] diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 437d48fff0..8339aabc2c 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -3,6 +3,8 @@ import numpy as np from sympy import (Function, Indexed, Integer, Mul, Number, Pow, S, Symbol, Tuple) +from sympy.core.operations import AssocOp +from sympy.core.numbers import ImaginaryUnit from devito.finite_differences import Derivative from devito.finite_differences.differentiable import IndexDerivative @@ -167,6 +169,7 @@ def _(expr, estimate, seen): return 0, True +@_estimate_cost.register(ImaginaryUnit) @_estimate_cost.register(Number) @_estimate_cost.register(ReservedWord) def _(expr, estimate, seen): @@ -189,6 +192,8 @@ def _(expr, estimate, seen): flops, flags = _estimate_cost.registry[object](expr, estimate, seen) if {S.One, S.NegativeOne}.intersection(expr.args): flops -= 1 + if ImaginaryUnit in expr.args: + flops *= 2 return flops, flags diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index 4e7908a552..f2d0c6ad31 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -133,6 +133,9 @@ def dtype_to_cstr(dtype): def dtype_to_ctype(dtype): """Translate numpy.dtype into a ctypes type.""" + if isinstance(dtype, CustomDtype): + return dtype + try: return ctypes_vector_mapper[dtype] except KeyError: @@ -230,7 +233,6 @@ def ctypes_to_cstr(ctype, toarray=None): retval = '%s[%d]' % (ctypes_to_cstr(ctype._type_, toarray), ctype._length_) elif ctype.__name__.startswith('c_'): name = ctype.__name__[2:] - # A primitive datatype # FIXME: Is there a better way of extracting the C typename ? # Here, we're following the ctypes convention that each basic type has diff --git a/devito/types/basic.py b/devito/types/basic.py index e21bae6453..065efe0590 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -13,7 +13,8 @@ from devito.data import default_allocator from devito.parameters import configuration from devito.tools import (Pickable, as_tuple, ctypes_to_cstr, dtype_to_ctype, - frozendict, memoized_meth, sympy_mutex) + frozendict, memoized_meth, sympy_mutex, dtype_to_cstr, + CustomDtype) from devito.types.args import ArgProvider from devito.types.caching import Cached, Uncached from devito.types.lazy import Evaluable @@ -432,7 +433,16 @@ def _C_name(self): @property def _C_ctype(self): - return dtype_to_ctype(self.dtype) + if isinstance(self.dtype, CustomDtype): + return self.dtype + elif np.issubdtype(self.dtype, np.complexfloating): + rtype = self.dtype(0).real.__class__ + ctname = '%s _Complex' % dtype_to_cstr(rtype) + ctype = dtype_to_ctype(rtype) + r = type(ctname, (ctype,), {}) + return r + else: + return dtype_to_ctype(self.dtype) def _subs(self, old, new, **hints): """ @@ -1470,7 +1480,14 @@ def _C_name(self): @cached_property def _C_ctype(self): try: - return POINTER(dtype_to_ctype(self.dtype)) + if np.issubdtype(self.dtype, np.complexfloating): + rtype = self.dtype(0).real.__class__ + ctname = '%s _Complex' % dtype_to_cstr(rtype) + ctype = dtype_to_ctype(rtype) + r = type(ctname, (ctype,), {}) + return POINTER(r) + else: + return POINTER(dtype_to_ctype(self.dtype)) except TypeError: # `dtype` is a ctypes-derived type! return self.dtype From aa353b4614957e4a833fed485f57754821a9df37 Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 22 May 2024 08:01:27 -0400 Subject: [PATCH 02/11] api: fix printer for complex dtype --- devito/finite_differences/differentiable.py | 3 +-- devito/passes/iet/misc.py | 1 - devito/symbolics/inspection.py | 1 - devito/symbolics/printer.py | 10 ++++++++ devito/types/basic.py | 2 +- tests/test_operator.py | 28 +++++++++++++++++---- 6 files changed, 35 insertions(+), 10 deletions(-) diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 8b5a47207c..a95e8be88d 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -14,8 +14,7 @@ from devito.logger import warning from devito.tools import (as_tuple, filter_ordered, flatten, frozendict, infer_dtype, is_integer, split) -from devito.types import (Array, DimensionTuple, Evaluable, Indexed, - StencilDimension) +from devito.types import Array, DimensionTuple, Evaluable, StencilDimension __all__ = ['Differentiable', 'DiffDerivative', 'IndexDerivative', 'EvalDerivative', 'Weights'] diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index f8a99ead5f..1f7bbbe881 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -3,7 +3,6 @@ import cgen import numpy as np import sympy -import numpy as np from devito import configuration from devito.finite_differences import Max, Min diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 8339aabc2c..3332bc68d6 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -3,7 +3,6 @@ import numpy as np from sympy import (Function, Indexed, Integer, Mul, Number, Pow, S, Symbol, Tuple) -from sympy.core.operations import AssocOp from sympy.core.numbers import ImaginaryUnit from devito.finite_differences import Derivative diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 2a25ef5c12..fddd133ac1 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -43,6 +43,10 @@ def single_prec(self, expr=None): dtype = sympy_dtype(expr) if expr is not None else self.dtype return dtype in [np.float32, np.float16] + def complex_prec(self, expr=None): + dtype = sympy_dtype(expr) if expr is not None else self.dtype + return np.issubdtype(dtype, np.complexfloating) + def parenthesize(self, item, level, strict=False): if isinstance(item, BooleanFunction): return "(%s)" % self._print(item) @@ -110,6 +114,8 @@ def _print_math_func(self, expr, nest=False, known=None): if self.single_prec(expr): cname = '%sf' % cname + if self.complex_prec(expr): + cname = 'c%s' % cname args = ', '.join((self._print(arg) for arg in expr.args)) @@ -255,8 +261,12 @@ def _print_ComponentAccess(self, expr): def _print_TrigonometricFunction(self, expr): func_name = str(expr.func) + if self.single_prec(): func_name = '%sf' % func_name + if self.complex_prec(): + func_name = 'c%s' % func_name + return '%s(%s)' % (func_name, self._print(*expr.args)) def _print_DefFunction(self, expr): diff --git a/devito/types/basic.py b/devito/types/basic.py index 065efe0590..8ee4fefc7f 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -1482,7 +1482,7 @@ def _C_ctype(self): try: if np.issubdtype(self.dtype, np.complexfloating): rtype = self.dtype(0).real.__class__ - ctname = '%s _Complex' % dtype_to_cstr(rtype) + ctname = '%s complex' % dtype_to_cstr(rtype) ctype = dtype_to_ctype(rtype) r = type(ctname, (ctype,), {}) return POINTER(r) diff --git a/tests/test_operator.py b/tests/test_operator.py index d5759c1c92..db962b7d6b 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -9,7 +9,7 @@ SparseFunction, SparseTimeFunction, Dimension, error, SpaceDimension, NODE, CELL, dimensions, configuration, TensorFunction, TensorTimeFunction, VectorFunction, VectorTimeFunction, - div, grad, switchconfig) + div, grad, switchconfig, exp) from devito import Inc, Le, Lt, Ge, Gt # noqa from devito.exceptions import InvalidOperator from devito.finite_differences.differentiable import diff2sympy @@ -640,6 +640,24 @@ def test_tensor(self, func1): op2 = Operator([Eq(f, f.dx) for f in f1.values()]) assert str(op1.ccode) == str(op2.ccode) + def test_complex(self): + grid = Grid((5, 5)) + x, y = grid.dimensions + # Float32 complex is called complex64 in numpy + u = Function(name="u", grid=grid, dtype=np.complex64) + + eq = Eq(u, x + 1j*y + exp(1j + x.spacing)) + # Currently wrong alias type + op = Operator(eq, opt='noop') + op() + + # Check against numpy + dx = grid.spacing_map[x.spacing] + xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) + npres = xx + 1j*yy + np.exp(1j + dx) + + assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0) + class TestAllocation: @@ -724,10 +742,10 @@ def verify_parameters(self, parameters, expected): """ boilerplate = ['timers'] parameters = [p.name for p in parameters] - for exp in expected: - if exp not in parameters + boilerplate: - error("Missing parameter: %s" % exp) - assert exp in parameters + boilerplate + for expi in expected: + if expi not in parameters + boilerplate: + error("Missing parameter: %s" % expi) + assert expi in parameters + boilerplate extra = [p for p in parameters if p not in expected and p not in boilerplate] if len(extra) > 0: error("Redundant parameters: %s" % str(extra)) From 92dfd9a1f308ec973a41cb91c275adefb9c4802d Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 22 May 2024 08:17:39 -0400 Subject: [PATCH 03/11] compiler: fix alias dtype with complex numbers --- devito/symbolics/inspection.py | 8 +++++++- devito/types/basic.py | 2 +- tests/test_gpu_common.py | 18 ++++++++++++++++++ tests/test_operator.py | 2 +- 4 files changed, 27 insertions(+), 3 deletions(-) diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 3332bc68d6..6649fa86bf 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -304,4 +304,10 @@ def sympy_dtype(expr, base=None): dtypes.add(i.dtype) except AttributeError: pass - return infer_dtype(dtypes) + dtype = infer_dtype(dtypes) + + # Promote if complex + if expr.has(ImaginaryUnit): + dtype = np.promote_types(dtype, np.complex64).type + + return dtype diff --git a/devito/types/basic.py b/devito/types/basic.py index 8ee4fefc7f..8e5d0d0455 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -437,7 +437,7 @@ def _C_ctype(self): return self.dtype elif np.issubdtype(self.dtype, np.complexfloating): rtype = self.dtype(0).real.__class__ - ctname = '%s _Complex' % dtype_to_cstr(rtype) + ctname = '%s complex' % dtype_to_cstr(rtype) ctype = dtype_to_ctype(rtype) r = type(ctname, (ctype,), {}) return r diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 8f100a1082..450814bf74 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -66,6 +66,24 @@ def test_maxpar_option(self): assert trees[0][0] is trees[1][0] assert trees[0][1] is not trees[1][1] + def test_complex(self): + grid = Grid((5, 5)) + x, y = grid.dimensions + # Float32 complex is called complex64 in numpy + u = Function(name="u", grid=grid, dtype=np.complex64) + + eq = Eq(u, x + 1j*y + exp(1j + x.spacing)) + # Currently wrong alias type + op = Operator(eq) + op() + + # Check against numpy + dx = grid.spacing_map[x.spacing] + xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) + npres = xx + 1j*yy + np.exp(1j + dx) + + assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0) + class TestPassesOptional: diff --git a/tests/test_operator.py b/tests/test_operator.py index db962b7d6b..5d975685ce 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -648,7 +648,7 @@ def test_complex(self): eq = Eq(u, x + 1j*y + exp(1j + x.spacing)) # Currently wrong alias type - op = Operator(eq, opt='noop') + op = Operator(eq) op() # Check against numpy From 4364524b84e903f501375635c056f47db2734dfe Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 22 May 2024 08:25:51 -0400 Subject: [PATCH 04/11] api: move complex ctype to dtype lowering --- devito/operator/operator.py | 2 +- devito/passes/clusters/factorization.py | 3 +-- devito/passes/iet/misc.py | 24 +++++++++++++----------- devito/symbolics/printer.py | 3 +++ devito/tools/dtypes_lowering.py | 8 ++++++++ devito/types/basic.py | 23 +++-------------------- tests/test_gpu_common.py | 2 +- tests/test_operator.py | 1 + 8 files changed, 31 insertions(+), 35 deletions(-) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index e2fe8dd3a1..266aacd81c 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -472,7 +472,7 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Complex header if needed. Needs to be done specialization # as some specific cases requires complex to be loaded first - complex_include(graph) + complex_include(graph, language=kwargs['language'], compiler=kwargs['compiler']) # Specialize graph = cls._specialize_iet(graph, **kwargs) diff --git a/devito/passes/clusters/factorization.py b/devito/passes/clusters/factorization.py index 794e437e97..47222a33be 100644 --- a/devito/passes/clusters/factorization.py +++ b/devito/passes/clusters/factorization.py @@ -1,7 +1,6 @@ from collections import defaultdict from sympy import Add, Mul, S, collect -from sympy.core import NumberKind from devito.ir import cluster_pass from devito.symbolics import BasicWrapperMixin, estimate_cost, retrieve_symbols @@ -174,7 +173,7 @@ def _collect_nested(expr): Recursion helper for `collect_nested`. """ # Return semantic (rebuilt expression, factorization candidates) - if expr.kind is NumberKind: + if expr.is_Number: return expr, {'coeffs': expr} elif expr.is_Function: return expr, {'funcs': expr} diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 1f7bbbe881..7211c5877a 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -4,7 +4,6 @@ import numpy as np import sympy -from devito import configuration from devito.finite_differences import Max, Min from devito.ir import (Any, Forward, Iteration, List, Prodder, FindApplications, FindNodes, FindSymbols, Transformer, Uxreplace, @@ -241,25 +240,28 @@ def minimize_symbols(iet): return iet, {} +_complex_lib = {'cuda': 'cuComplex.h', 'hip': 'hip/hip_complex.h'} + + @iet_pass -def complex_include(iet): +def complex_include(iet, language, compiler): """ Add headers for complex arithmetic """ - if configuration['language'] == 'cuda': - lib = 'cuComplex.h' - elif configuration['language'] == 'hip': - lib = 'hip/hip_complex.h' - else: - lib = 'complex.h' + lib = _complex_lib.get(language, 'complex.h') - functions = FindSymbols().visit(iet) - for f in functions: + headers = {} + # For openacc (cpp) need to define constant _Complex_I that isn't found otherwise + if compiler._cpp: + headers = {('_Complex_I', ('1.0fi'))} + + for f in FindSymbols().visit(iet): try: if np.issubdtype(f.dtype, np.complexfloating): - return iet, {'includes': (lib,)} + return iet, {'includes': (lib,), 'headers': headers} except TypeError: pass + return iet, {} diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index fddd133ac1..7de815549e 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -210,6 +210,9 @@ def _print_Float(self, expr): return rv + def _print_ImaginaryUnit(self, expr): + return '_Complex_I' + def _print_Differentiable(self, expr): return "(%s)" % self._print(expr._expr) diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index f2d0c6ad31..ff40f6c7d6 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -136,6 +136,14 @@ def dtype_to_ctype(dtype): if isinstance(dtype, CustomDtype): return dtype + # Complex data + if np.issubdtype(dtype, np.complexfloating): + rtype = dtype(0).real.__class__ + ctname = '%s _Complex' % dtype_to_cstr(rtype) + ctype = dtype_to_ctype(rtype) + r = type(ctname, (ctype,), {}) + return r + try: return ctypes_vector_mapper[dtype] except KeyError: diff --git a/devito/types/basic.py b/devito/types/basic.py index 8e5d0d0455..e21bae6453 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -13,8 +13,7 @@ from devito.data import default_allocator from devito.parameters import configuration from devito.tools import (Pickable, as_tuple, ctypes_to_cstr, dtype_to_ctype, - frozendict, memoized_meth, sympy_mutex, dtype_to_cstr, - CustomDtype) + frozendict, memoized_meth, sympy_mutex) from devito.types.args import ArgProvider from devito.types.caching import Cached, Uncached from devito.types.lazy import Evaluable @@ -433,16 +432,7 @@ def _C_name(self): @property def _C_ctype(self): - if isinstance(self.dtype, CustomDtype): - return self.dtype - elif np.issubdtype(self.dtype, np.complexfloating): - rtype = self.dtype(0).real.__class__ - ctname = '%s complex' % dtype_to_cstr(rtype) - ctype = dtype_to_ctype(rtype) - r = type(ctname, (ctype,), {}) - return r - else: - return dtype_to_ctype(self.dtype) + return dtype_to_ctype(self.dtype) def _subs(self, old, new, **hints): """ @@ -1480,14 +1470,7 @@ def _C_name(self): @cached_property def _C_ctype(self): try: - if np.issubdtype(self.dtype, np.complexfloating): - rtype = self.dtype(0).real.__class__ - ctname = '%s complex' % dtype_to_cstr(rtype) - ctype = dtype_to_ctype(rtype) - r = type(ctname, (ctype,), {}) - return POINTER(r) - else: - return POINTER(dtype_to_ctype(self.dtype)) + return POINTER(dtype_to_ctype(self.dtype)) except TypeError: # `dtype` is a ctypes-derived type! return self.dtype diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 450814bf74..d1af179792 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -7,7 +7,7 @@ from conftest import assert_structure from devito import (Constant, Eq, Inc, Grid, Function, ConditionalDimension, Dimension, MatrixSparseTimeFunction, SparseTimeFunction, - SubDimension, SubDomain, SubDomainSet, TimeFunction, + SubDimension, SubDomain, SubDomainSet, TimeFunction, exp, Operator, configuration, switchconfig, TensorTimeFunction) from devito.arch import get_gpu_info from devito.exceptions import InvalidArgument diff --git a/tests/test_operator.py b/tests/test_operator.py index 5d975685ce..9cdf34e313 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -655,6 +655,7 @@ def test_complex(self): dx = grid.spacing_map[x.spacing] xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) npres = xx + 1j*yy + np.exp(1j + dx) + print(op) assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0) From 470f4f50cff87f4091c0360e7b414c00e9c2574c Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 28 May 2024 13:00:56 -0400 Subject: [PATCH 05/11] compiler: generate std:complex for cpp compilers --- devito/ir/iet/visitors.py | 43 +++++++++++++++++++++++---------- devito/passes/iet/misc.py | 4 +-- devito/symbolics/printer.py | 8 ++++++ devito/tools/dtypes_lowering.py | 7 ++---- tests/test_gpu_common.py | 3 ++- tests/test_operator.py | 2 +- 6 files changed, 45 insertions(+), 22 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 54e9188e1a..69c99a5161 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -10,6 +10,7 @@ import ctypes import cgen as c +import numpy as np from sympy import IndexedBase from sympy.core.function import Application @@ -188,6 +189,21 @@ def __init__(self, *args, compiler=None, **kwargs): } _restrict_keyword = 'restrict' + def _complex_type(self, ctypestr, dtype): + # Not complex + try: + if not np.issubdtype(dtype, np.complexfloating): + return ctypestr + except TypeError: + return ctypestr + # Complex only supported for float and double + if ctypestr not in ('float', 'double'): + return ctypestr + if self._compiler._cpp: + return 'std::complex<%s>' % ctypestr + else: + return '%s _Complex' % ctypestr + def _gen_struct_decl(self, obj, masked=()): """ Convert ctypes.Struct -> cgen.Structure. @@ -243,10 +259,10 @@ def _gen_value(self, obj, mode=1, masked=()): if getattr(obj.function, k, False) and v not in masked] if (obj._mem_stack or obj._mem_constant) and mode == 1: - strtype = obj._C_typedata + strtype = self._complex_type(obj._C_typedata, obj.dtype) strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape) else: - strtype = ctypes_to_cstr(obj._C_ctype) + strtype = self._complex_type(ctypes_to_cstr(obj._C_ctype), obj.dtype) strshape = '' if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1: if not obj._mem_stack: @@ -376,10 +392,11 @@ def visit_tuple(self, o): def visit_PointerCast(self, o): f = o.function i = f.indexed + cstr = self._complex_type(i._C_typedata, i.dtype) if f.is_PointerArray: # lvalue - lvalue = c.Value(i._C_typedata, '**%s' % f.name) + lvalue = c.Value(cstr, '**%s' % f.name) # rvalue if isinstance(o.obj, ArrayObject): @@ -388,7 +405,7 @@ def visit_PointerCast(self, o): v = f._C_name else: assert False - rvalue = '(%s**) %s' % (i._C_typedata, v) + rvalue = '(%s**) %s' % (cstr, v) else: # lvalue @@ -399,10 +416,10 @@ def visit_PointerCast(self, o): if o.flat is None: shape = ''.join("[%s]" % ccode(i) for i in o.castshape) rshape = '(*)%s' % shape - lvalue = c.Value(i._C_typedata, '(*restrict %s)%s' % (v, shape)) + lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape)) else: rshape = '*' - lvalue = c.Value(i._C_typedata, '*%s' % v) + lvalue = c.Value(cstr, '*%s' % v) if o.alignment and f._data_alignment: lvalue = c.AlignedAttribute(f._data_alignment, lvalue) @@ -415,14 +432,14 @@ def visit_PointerCast(self, o): else: assert False - rvalue = '(%s %s) %s->%s' % (i._C_typedata, rshape, f._C_name, v) + rvalue = '(%s %s) %s->%s' % (cstr, rshape, f._C_name, v) else: if isinstance(o.obj, Pointer): v = o.obj.name else: v = f._C_name - rvalue = '(%s %s) %s' % (i._C_typedata, rshape, v) + rvalue = '(%s %s) %s' % (cstr, rshape, v) return c.Initializer(lvalue, rvalue) @@ -430,15 +447,15 @@ def visit_Dereference(self, o): a0, a1 = o.functions if a1.is_PointerArray or a1.is_TempFunction: i = a1.indexed + cstr = self._complex_type(i._C_typedata, i.dtype) if o.flat is None: shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:]) - rvalue = '(%s (*)%s) %s[%s]' % (i._C_typedata, shape, a1.name, + rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name, a1.dim.name) - lvalue = c.Value(i._C_typedata, - '(*restrict %s)%s' % (a0.name, shape)) + lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape)) else: - rvalue = '(%s *) %s[%s]' % (i._C_typedata, a1.name, a1.dim.name) - lvalue = c.Value(i._C_typedata, '*restrict %s' % a0.name) + rvalue = '(%s *) %s[%s]' % (cstr, a1.name, a1.dim.name) + lvalue = c.Value(cstr, '*restrict %s' % a0.name) if a0._data_alignment: lvalue = c.AlignedAttribute(a0._data_alignment, lvalue) else: diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 7211c5877a..1eac0664d4 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -248,12 +248,12 @@ def complex_include(iet, language, compiler): """ Add headers for complex arithmetic """ - lib = _complex_lib.get(language, 'complex.h') + lib = _complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h') headers = {} # For openacc (cpp) need to define constant _Complex_I that isn't found otherwise if compiler._cpp: - headers = {('_Complex_I', ('1.0fi'))} + headers = {('_Complex_I', ('std::complex(0.0f, 1.0f)'))} for f in FindSymbols().visit(iet): try: diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 7de815549e..fd15796bf8 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -264,8 +264,16 @@ def _print_ComponentAccess(self, expr): def _print_TrigonometricFunction(self, expr): func_name = str(expr.func) +<<<<<<< HEAD if self.single_prec(): +======= + dtype = self.dtype + if np.issubdtype(dtype, np.complexfloating): + func_name = 'c%s' % func_name + dtype = self.dtype(0).real.dtype.type + if dtype == np.float32: +>>>>>>> 75d50a431 (compiler: generate std:complex for cpp compilers) func_name = '%sf' % func_name if self.complex_prec(): func_name = 'c%s' % func_name diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index ff40f6c7d6..6ca336e305 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -139,10 +139,7 @@ def dtype_to_ctype(dtype): # Complex data if np.issubdtype(dtype, np.complexfloating): rtype = dtype(0).real.__class__ - ctname = '%s _Complex' % dtype_to_cstr(rtype) - ctype = dtype_to_ctype(rtype) - r = type(ctname, (ctype,), {}) - return r + return dtype_to_ctype(rtype) try: return ctypes_vector_mapper[dtype] @@ -217,7 +214,7 @@ class c_restrict_void_p(ctypes.c_void_p): # *** ctypes lowering -def ctypes_to_cstr(ctype, toarray=None): +def ctypes_to_cstr(ctype, toarray=None, cpp=False): """Translate ctypes types into C strings.""" if ctype in ctypes_vector_mapper.values(): retval = ctype.__name__ diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index d1af179792..c7bb0c0211 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -2,6 +2,7 @@ import pytest import numpy as np +import sympy import scipy.sparse from conftest import assert_structure @@ -72,7 +73,7 @@ def test_complex(self): # Float32 complex is called complex64 in numpy u = Function(name="u", grid=grid, dtype=np.complex64) - eq = Eq(u, x + 1j*y + exp(1j + x.spacing)) + eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) # Currently wrong alias type op = Operator(eq) op() diff --git a/tests/test_operator.py b/tests/test_operator.py index 9cdf34e313..61b117bcc6 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -646,7 +646,7 @@ def test_complex(self): # Float32 complex is called complex64 in numpy u = Function(name="u", grid=grid, dtype=np.complex64) - eq = Eq(u, x + 1j*y + exp(1j + x.spacing)) + eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) # Currently wrong alias type op = Operator(eq) op() From 7ffff0a2e8ad831d7d00aa244a8e037a47e3ffea Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 30 May 2024 12:33:30 -0400 Subject: [PATCH 06/11] compiler: add std::complex arithmetic defs for unsupported types --- devito/ir/iet/visitors.py | 3 ++- devito/passes/iet/misc.py | 33 +++++++++++++++++++++++++++++++-- devito/symbolics/printer.py | 10 +--------- 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 69c99a5161..aed6eb1351 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -14,6 +14,7 @@ from sympy import IndexedBase from sympy.core.function import Application +from devito.parameters import configuration from devito.exceptions import VisitorException from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle, Call, Lambda, BlankLine, Section, ListMajor) @@ -177,7 +178,7 @@ class CGen(Visitor): def __init__(self, *args, compiler=None, **kwargs): super().__init__(*args, **kwargs) - self._compiler = compiler + self._compiler = compiler or configuration['compiler'] # The following mappers may be customized by subclasses (that is, # backend-specific CGen-erators) diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 1eac0664d4..5b53c43796 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -248,17 +248,26 @@ def complex_include(iet, language, compiler): """ Add headers for complex arithmetic """ - lib = _complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h') + lib = (_complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h'),) headers = {} + # For openacc (cpp) need to define constant _Complex_I that isn't found otherwise if compiler._cpp: + # Constant I headers = {('_Complex_I', ('std::complex(0.0f, 1.0f)'))} + # Mix arithmetic definitions + dest = compiler.get_jit_dir() + hfile = dest.joinpath('stdcomplex_arith.h') + if not hfile.is_file(): + with open(str(hfile), 'w') as ff: + ff.write(str(_stdcomplex_defs)) + lib += (str(hfile),) for f in FindSymbols().visit(iet): try: if np.issubdtype(f.dtype, np.complexfloating): - return iet, {'includes': (lib,), 'headers': headers} + return iet, {'includes': lib, 'headers': headers} except TypeError: pass @@ -343,3 +352,23 @@ def _rename_subdims(target, dimensions): return {d: d._rebuild(d.root.name) for d in dims if d.root not in dimensions and names.count(d.root.name) < 2} + + +_stdcomplex_defs = """ +#include + +template +std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){ + return std::complex<_Tp>(b.real() * a, b.imag() * a); +} + +template +std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){ + return std::complex<_Tp>(b.real() / a, b.imag() / a); +} + +template +std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){ + return std::complex<_Tp>(b.real() + a, b.imag()); +} +""" diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index fd15796bf8..c7917b3ea1 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -41,7 +41,7 @@ def compiler(self): def single_prec(self, expr=None): dtype = sympy_dtype(expr) if expr is not None else self.dtype - return dtype in [np.float32, np.float16] + return dtype in [np.float32, np.float16, np.complex64] def complex_prec(self, expr=None): dtype = sympy_dtype(expr) if expr is not None else self.dtype @@ -264,16 +264,8 @@ def _print_ComponentAccess(self, expr): def _print_TrigonometricFunction(self, expr): func_name = str(expr.func) -<<<<<<< HEAD if self.single_prec(): -======= - dtype = self.dtype - if np.issubdtype(dtype, np.complexfloating): - func_name = 'c%s' % func_name - dtype = self.dtype(0).real.dtype.type - if dtype == np.float32: ->>>>>>> 75d50a431 (compiler: generate std:complex for cpp compilers) func_name = '%sf' % func_name if self.complex_prec(): func_name = 'c%s' % func_name From d1dd24e3b50787318a20463d5e7b8b6259bd4cbe Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 30 May 2024 14:08:12 -0400 Subject: [PATCH 07/11] compiler: fix alias dtype with complex numbers --- devito/__init__.py | 5 +++-- devito/arch/compiler.py | 24 +++++++++++++++++++-- devito/ir/iet/visitors.py | 34 +++++++++++++----------------- devito/operator/operator.py | 7 ++++--- devito/passes/iet/misc.py | 37 +++++++++++++++++++++++---------- devito/symbolics/inspection.py | 6 ++++-- devito/tools/dtypes_lowering.py | 12 ++++++++--- tests/test_gpu_common.py | 7 ++++--- tests/test_operator.py | 8 +++---- 9 files changed, 90 insertions(+), 50 deletions(-) diff --git a/devito/__init__.py b/devito/__init__.py index b0a981dcfa..b407988eb8 100644 --- a/devito/__init__.py +++ b/devito/__init__.py @@ -56,7 +56,8 @@ def reinit_compiler(val): """ Re-initialize the Compiler. """ - configuration['compiler'].__init__(suffix=configuration['compiler'].suffix, + configuration['compiler'].__init__(name=configuration['compiler'].name, + suffix=configuration['compiler'].suffix, mpi=configuration['mpi']) return val @@ -65,7 +66,7 @@ def reinit_compiler(val): configuration.add('platform', 'cpu64', list(platform_registry), callback=lambda i: platform_registry[i]()) configuration.add('compiler', 'custom', list(compiler_registry), - callback=lambda i: compiler_registry[i]()) + callback=lambda i: compiler_registry[i](name=i)) # Setup language for shared-memory parallelism preprocessor = lambda i: {0: 'C', 1: 'openmp'}.get(i, i) # Handles DEVITO_OPENMP deprec diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index 9cd94ed597..de4711d257 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -180,6 +180,8 @@ def __init__(self): _cpp = False def __init__(self, **kwargs): + self._name = kwargs.pop('name', self.__class__.__name__) + super().__init__(**kwargs) self.__lookup_cmds__() @@ -223,13 +225,13 @@ def __new_with__(self, **kwargs): Create a new Compiler from an existing one, inherenting from it the flags that are not specified via ``kwargs``. """ - return self.__class__(suffix=kwargs.pop('suffix', self.suffix), + return self.__class__(name=self.name, suffix=kwargs.pop('suffix', self.suffix), mpi=kwargs.pop('mpi', configuration['mpi']), **kwargs) @property def name(self): - return self.__class__.__name__ + return self._name @property def version(self): @@ -245,6 +247,20 @@ def version(self): return version + @property + def _complex_ctype(self): + """ + Type definition for complex numbers. THese two cases cover 99% of the cases since + - Hip is now using std::complex +https://rocm.docs.amd.com/en/docs-5.1.3/CHANGELOG.html#hip-api-deprecations-and-warnings + - Sycl supports std::complex + - C's _Complex is part of C99 + """ + if self._cpp: + return lambda dtype: 'std::complex<%s>' % str(dtype) + else: + return lambda dtype: '%s _Complex' % str(dtype) + def get_version(self): result, stdout, stderr = call_capture_output((self.cc, "--version")) if result != 0: @@ -699,6 +715,10 @@ def __lookup_cmds__(self): self.MPICC = 'nvcc' self.MPICXX = 'nvcc' + @property + def _complex_ctype(self): + return lambda dtype: 'thrust::complex<%s>' % str(dtype) + class HipCompiler(Compiler): diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index aed6eb1351..1e21f1d8ba 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -10,11 +10,10 @@ import ctypes import cgen as c -import numpy as np from sympy import IndexedBase from sympy.core.function import Application -from devito.parameters import configuration +from devito.parameters import configuration, switchconfig from devito.exceptions import VisitorException from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle, Call, Lambda, BlankLine, Section, ListMajor) @@ -190,20 +189,15 @@ def __init__(self, *args, compiler=None, **kwargs): } _restrict_keyword = 'restrict' - def _complex_type(self, ctypestr, dtype): - # Not complex - try: - if not np.issubdtype(dtype, np.complexfloating): - return ctypestr - except TypeError: - return ctypestr - # Complex only supported for float and double - if ctypestr not in ('float', 'double'): - return ctypestr - if self._compiler._cpp: - return 'std::complex<%s>' % ctypestr - else: - return '%s _Complex' % ctypestr + @property + def compiler(self): + return self._compiler + + def visit(self, o, *args, **kwargs): + # Make sure the visitor always is within the generating compiler + # in case the configuration is accessed + with switchconfig(compiler=self.compiler.name): + return super().visit(o, *args, **kwargs) def _gen_struct_decl(self, obj, masked=()): """ @@ -260,10 +254,10 @@ def _gen_value(self, obj, mode=1, masked=()): if getattr(obj.function, k, False) and v not in masked] if (obj._mem_stack or obj._mem_constant) and mode == 1: - strtype = self._complex_type(obj._C_typedata, obj.dtype) + strtype = obj._C_typedata strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape) else: - strtype = self._complex_type(ctypes_to_cstr(obj._C_ctype), obj.dtype) + strtype = ctypes_to_cstr(obj._C_ctype) strshape = '' if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1: if not obj._mem_stack: @@ -393,7 +387,7 @@ def visit_tuple(self, o): def visit_PointerCast(self, o): f = o.function i = f.indexed - cstr = self._complex_type(i._C_typedata, i.dtype) + cstr = i._C_typedata if f.is_PointerArray: # lvalue @@ -448,7 +442,7 @@ def visit_Dereference(self, o): a0, a1 = o.functions if a1.is_PointerArray or a1.is_TempFunction: i = a1.indexed - cstr = self._complex_type(i._C_typedata, i.dtype) + cstr = i._C_typedata if o.flat is None: shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:]) rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name, diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 266aacd81c..57ef1134ae 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -470,8 +470,8 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Lower IET to a target-specific IET graph = Graph(iet, **kwargs) - # Complex header if needed. Needs to be done specialization - # as some specific cases requires complex to be loaded first + # Complex header if needed. Needs to be done before specialization + # as some specific cases require complex to be loaded first complex_include(graph, language=kwargs['language'], compiler=kwargs['compiler']) # Specialize @@ -1353,7 +1353,8 @@ def parse_kwargs(**kwargs): raise InvalidOperator("Illegal `compiler=%s`" % str(compiler)) kwargs['compiler'] = compiler_registry[compiler](platform=kwargs['platform'], language=kwargs['language'], - mpi=configuration['mpi']) + mpi=configuration['mpi'], + name=compiler) elif any([platform, language]): kwargs['compiler'] =\ configuration['compiler'].__new_with__(platform=kwargs['platform'], diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 5b53c43796..53ebe7d3e8 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -12,7 +12,7 @@ from devito.ir.iet.efunc import DeviceFunction, EntryFunction from devito.symbolics import (ValueLimit, evalrel, has_integer_args, limits_mapper, ccode) -from devito.tools import Bunch, as_mapper, filter_ordered, split +from devito.tools import Bunch, as_mapper, filter_ordered, split, dtype_to_cstr from devito.types import FIndexed __all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions', @@ -240,7 +240,7 @@ def minimize_symbols(iet): return iet, {} -_complex_lib = {'cuda': 'cuComplex.h', 'hip': 'hip/hip_complex.h'} +_complex_lib = {'cuda': 'thrust/complex.h'} @iet_pass @@ -248,14 +248,20 @@ def complex_include(iet, language, compiler): """ Add headers for complex arithmetic """ + # Check if there is complex numbers that always take dtype precedence + max_dtype = np.result_type(*[f.dtype for f in FindSymbols().visit(iet)]) + if not np.issubdtype(max_dtype, np.complexfloating): + return iet, {} + lib = (_complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h'),) headers = {} # For openacc (cpp) need to define constant _Complex_I that isn't found otherwise if compiler._cpp: + c_str = dtype_to_cstr(max_dtype.type(0).real.dtype.type) # Constant I - headers = {('_Complex_I', ('std::complex(0.0f, 1.0f)'))} + headers = {('_Complex_I', ('std::complex<%s>(0.0, 1.0)' % c_str))} # Mix arithmetic definitions dest = compiler.get_jit_dir() hfile = dest.joinpath('stdcomplex_arith.h') @@ -264,14 +270,7 @@ def complex_include(iet, language, compiler): ff.write(str(_stdcomplex_defs)) lib += (str(hfile),) - for f in FindSymbols().visit(iet): - try: - if np.issubdtype(f.dtype, np.complexfloating): - return iet, {'includes': lib, 'headers': headers} - except TypeError: - pass - - return iet, {} + return iet, {'includes': lib, 'headers': headers} def remove_redundant_moddims(iet): @@ -362,8 +361,19 @@ def _rename_subdims(target, dimensions): return std::complex<_Tp>(b.real() * a, b.imag() * a); } +template +std::complex<_Tp> operator * (const std::complex<_Tp> & b, const _Ti & a){ + return std::complex<_Tp>(b.real() * a, b.imag() * a); +} + template std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){ + _Tp denom = b.real() * b.real () + b.imag() * b.imag() + return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom); +} + +template +std::complex<_Tp> operator / (const std::complex<_Tp> & b, const _Ti & a){ return std::complex<_Tp>(b.real() / a, b.imag() / a); } @@ -371,4 +381,9 @@ def _rename_subdims(target, dimensions): std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){ return std::complex<_Tp>(b.real() + a, b.imag()); } + +template +std::complex<_Tp> operator + (const std::complex<_Tp> & b, const _Ti & a){ + return std::complex<_Tp>(b.real() + a, b.imag()); +} """ diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 6649fa86bf..53c7b07e39 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -304,10 +304,12 @@ def sympy_dtype(expr, base=None): dtypes.add(i.dtype) except AttributeError: pass + dtype = infer_dtype(dtypes) - # Promote if complex - if expr.has(ImaginaryUnit): + # Promote if we missed complex number, i.e f + I + is_im = np.issubdtype(dtype, np.complexfloating) + if expr.has(ImaginaryUnit) and not is_im: dtype = np.promote_types(dtype, np.complex64).type return dtype diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index 6ca336e305..8a30b04cc4 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -139,7 +139,12 @@ def dtype_to_ctype(dtype): # Complex data if np.issubdtype(dtype, np.complexfloating): rtype = dtype(0).real.__class__ - return dtype_to_ctype(rtype) + from devito import configuration + make = configuration['compiler']._complex_ctype + ctname = make(dtype_to_cstr(rtype)) + ctype = dtype_to_ctype(rtype) + r = type(ctname, (ctype,), {}) + return r try: return ctypes_vector_mapper[dtype] @@ -214,7 +219,7 @@ class c_restrict_void_p(ctypes.c_void_p): # *** ctypes lowering -def ctypes_to_cstr(ctype, toarray=None, cpp=False): +def ctypes_to_cstr(ctype, toarray=None): """Translate ctypes types into C strings.""" if ctype in ctypes_vector_mapper.values(): retval = ctype.__name__ @@ -308,7 +313,8 @@ def infer_dtype(dtypes): # Resolve the vector types, if any dtypes = {dtypes_vector_mapper.get_base_dtype(i, i) for i in dtypes} - fdtypes = {i for i in dtypes if np.issubdtype(i, np.floating)} + fdtypes = {i for i in dtypes if np.issubdtype(i, np.floating) or + np.issubdtype(i, np.complexfloating)} if len(fdtypes) > 1: return max(fdtypes, key=lambda i: np.dtype(i).itemsize) elif len(fdtypes) == 1: diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index c7bb0c0211..79b6dccb08 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -67,15 +67,16 @@ def test_maxpar_option(self): assert trees[0][0] is trees[1][0] assert trees[0][1] is not trees[1][1] - def test_complex(self): + @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) + def test_complex(self, dtype): grid = Grid((5, 5)) x, y = grid.dimensions - # Float32 complex is called complex64 in numpy - u = Function(name="u", grid=grid, dtype=np.complex64) + u = Function(name="u", grid=grid, dtype=dtype) eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) # Currently wrong alias type op = Operator(eq) + print(op) op() # Check against numpy diff --git a/tests/test_operator.py b/tests/test_operator.py index 61b117bcc6..c1a8809379 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -640,22 +640,22 @@ def test_tensor(self, func1): op2 = Operator([Eq(f, f.dx) for f in f1.values()]) assert str(op1.ccode) == str(op2.ccode) - def test_complex(self): + @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) + def test_complex(self, dtype): grid = Grid((5, 5)) x, y = grid.dimensions - # Float32 complex is called complex64 in numpy - u = Function(name="u", grid=grid, dtype=np.complex64) + u = Function(name="u", grid=grid, dtype=dtype) eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) # Currently wrong alias type op = Operator(eq) + # print(op) op() # Check against numpy dx = grid.spacing_map[x.spacing] xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) npres = xx + 1j*yy + np.exp(1j + dx) - print(op) assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0) From 4f43f26d270d3bf2dc5815f796a47f1dc4f2475c Mon Sep 17 00:00:00 2001 From: mloubout Date: Fri, 31 May 2024 09:58:54 -0400 Subject: [PATCH 08/11] compiler: fix internal language specific types and cast wip --- devito/arch/compiler.py | 3 +- devito/operator/operator.py | 2 +- devito/passes/iet/__init__.py | 1 + devito/passes/iet/misc.py | 71 +----------------------------- devito/symbolics/extended_sympy.py | 29 +++++++++++- tests/test_gpu_common.py | 2 - tests/test_operator.py | 2 - 7 files changed, 33 insertions(+), 77 deletions(-) diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index de4711d257..a7d05259e5 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -250,7 +250,7 @@ def version(self): @property def _complex_ctype(self): """ - Type definition for complex numbers. THese two cases cover 99% of the cases since + Type definition for complex numbers. These two cases cover 99% of the cases since - Hip is now using std::complex https://rocm.docs.amd.com/en/docs-5.1.3/CHANGELOG.html#hip-api-deprecations-and-warnings - Sycl supports std::complex @@ -998,6 +998,7 @@ def __new_with__(self, **kwargs): 'nvc++': NvidiaCompiler, 'nvidia': NvidiaCompiler, 'cuda': CudaCompiler, + 'nvcc': CudaCompiler, 'osx': ClangCompiler, 'intel': OneapiCompiler, 'icx': OneapiCompiler, diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 57ef1134ae..efb47640ce 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -472,7 +472,7 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Complex header if needed. Needs to be done before specialization # as some specific cases require complex to be loaded first - complex_include(graph, language=kwargs['language'], compiler=kwargs['compiler']) + include_complex(graph, language=kwargs['language'], compiler=kwargs['compiler']) # Specialize graph = cls._specialize_iet(graph, **kwargs) diff --git a/devito/passes/iet/__init__.py b/devito/passes/iet/__init__.py index c09db00c9b..6b4ada0b73 100644 --- a/devito/passes/iet/__init__.py +++ b/devito/passes/iet/__init__.py @@ -8,3 +8,4 @@ from .instrument import * # noqa from .languages import * # noqa from .errors import * # noqa +from .complex import * # noqa diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 53ebe7d3e8..50511b6005 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -16,7 +16,7 @@ from devito.types import FIndexed __all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions', - 'generate_macros', 'minimize_symbols', 'complex_include'] + 'generate_macros', 'minimize_symbols'] @iet_pass @@ -240,39 +240,6 @@ def minimize_symbols(iet): return iet, {} -_complex_lib = {'cuda': 'thrust/complex.h'} - - -@iet_pass -def complex_include(iet, language, compiler): - """ - Add headers for complex arithmetic - """ - # Check if there is complex numbers that always take dtype precedence - max_dtype = np.result_type(*[f.dtype for f in FindSymbols().visit(iet)]) - if not np.issubdtype(max_dtype, np.complexfloating): - return iet, {} - - lib = (_complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h'),) - - headers = {} - - # For openacc (cpp) need to define constant _Complex_I that isn't found otherwise - if compiler._cpp: - c_str = dtype_to_cstr(max_dtype.type(0).real.dtype.type) - # Constant I - headers = {('_Complex_I', ('std::complex<%s>(0.0, 1.0)' % c_str))} - # Mix arithmetic definitions - dest = compiler.get_jit_dir() - hfile = dest.joinpath('stdcomplex_arith.h') - if not hfile.is_file(): - with open(str(hfile), 'w') as ff: - ff.write(str(_stdcomplex_defs)) - lib += (str(hfile),) - - return iet, {'includes': lib, 'headers': headers} - - def remove_redundant_moddims(iet): key = lambda d: d.is_Modulo and d.origin is not None mds = [d for d in FindSymbols('dimensions').visit(iet) if key(d)] @@ -351,39 +318,3 @@ def _rename_subdims(target, dimensions): return {d: d._rebuild(d.root.name) for d in dims if d.root not in dimensions and names.count(d.root.name) < 2} - - -_stdcomplex_defs = """ -#include - -template -std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){ - return std::complex<_Tp>(b.real() * a, b.imag() * a); -} - -template -std::complex<_Tp> operator * (const std::complex<_Tp> & b, const _Ti & a){ - return std::complex<_Tp>(b.real() * a, b.imag() * a); -} - -template -std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){ - _Tp denom = b.real() * b.real () + b.imag() * b.imag() - return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom); -} - -template -std::complex<_Tp> operator / (const std::complex<_Tp> & b, const _Ti & a){ - return std::complex<_Tp>(b.real() / a, b.imag() / a); -} - -template -std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){ - return std::complex<_Tp>(b.real() + a, b.imag()); -} - -template -std::complex<_Tp> operator + (const std::complex<_Tp> & b, const _Ti & a){ - return std::complex<_Tp>(b.real() + a, b.imag()); -} -""" diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 7ed801d17a..03fec7438a 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -7,6 +7,7 @@ from sympy import Expr, Function, Number, Tuple, sympify from sympy.core.decorators import call_highest_priority +from devito import configuration from devito.finite_differences.elementary import Min, Max from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa float3, float4, double2, double3, double4, int2, int3, @@ -811,6 +812,20 @@ class VOID(Cast): _base_typ = 'void' +class CFLOAT(Cast): + + @property + def _base_typ(self): + return configuration['compiler']._complex_ctype('float') + + +class CDOUBLE(Cast): + + @property + def _base_typ(self): + return configuration['compiler']._complex_ctype('double') + + class CHARP(CastStar): base = CHAR @@ -827,6 +842,14 @@ class USHORTP(CastStar): base = USHORT +class CFLOATP(CastStar): + base = CFLOAT + + +class CDOUBLEP(CastStar): + base = CDOUBLE + + cast_mapper = { np.int8: CHAR, np.uint8: UCHAR, @@ -839,6 +862,8 @@ class USHORTP(CastStar): np.float32: FLOAT, # noqa float: DOUBLE, # noqa np.float64: DOUBLE, # noqa + np.complex64: CFLOAT, # noqa + np.complex128: CDOUBLE, # noqa (np.int8, '*'): CHARP, (np.uint8, '*'): UCHARP, @@ -849,7 +874,9 @@ class USHORTP(CastStar): (np.int64, '*'): INTP, # noqa (np.float32, '*'): FLOATP, # noqa (float, '*'): DOUBLEP, # noqa - (np.float64, '*'): DOUBLEP # noqa + (np.float64, '*'): DOUBLEP, # noqa + (np.complex64, '*'): CFLOATP, # noqa + (np.complex128, '*'): CDOUBLEP, # noqa } for base_name in ['int', 'float', 'double']: diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 79b6dccb08..e229cbb98d 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -74,9 +74,7 @@ def test_complex(self, dtype): u = Function(name="u", grid=grid, dtype=dtype) eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) - # Currently wrong alias type op = Operator(eq) - print(op) op() # Check against numpy diff --git a/tests/test_operator.py b/tests/test_operator.py index c1a8809379..283249aac1 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -647,9 +647,7 @@ def test_complex(self, dtype): u = Function(name="u", grid=grid, dtype=dtype) eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing)) - # Currently wrong alias type op = Operator(eq) - # print(op) op() # Check against numpy From 6b4f12d838f847462e630b369887137a16772a5b Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 20 Jun 2024 10:28:38 -0400 Subject: [PATCH 09/11] compiler: rework dtype lowering --- devito/arch/compiler.py | 20 +--- devito/ir/iet/visitors.py | 2 +- devito/operator/operator.py | 4 - devito/passes/iet/__init__.py | 2 +- devito/passes/iet/definitions.py | 12 ++- devito/passes/iet/dtypes.py | 58 ++++++++++++ devito/passes/iet/langbase.py | 11 +++ devito/passes/iet/languages/C.py | 12 ++- devito/passes/iet/languages/CXX.py | 69 ++++++++++++++ devito/passes/iet/languages/openacc.py | 5 +- devito/passes/iet/misc.py | 2 +- devito/symbolics/__init__.py | 1 + devito/symbolics/extended_dtypes.py | 123 ++++++++++++++++++++++++ devito/symbolics/extended_sympy.py | 126 +------------------------ devito/symbolics/inspection.py | 3 +- devito/symbolics/printer.py | 12 ++- devito/tools/dtypes_lowering.py | 24 ++--- devito/types/basic.py | 33 +++++-- devito/types/misc.py | 2 +- 19 files changed, 344 insertions(+), 177 deletions(-) create mode 100644 devito/passes/iet/dtypes.py create mode 100644 devito/passes/iet/languages/CXX.py create mode 100644 devito/symbolics/extended_dtypes.py diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index a7d05259e5..61cfa22b4c 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -247,20 +247,6 @@ def version(self): return version - @property - def _complex_ctype(self): - """ - Type definition for complex numbers. These two cases cover 99% of the cases since - - Hip is now using std::complex -https://rocm.docs.amd.com/en/docs-5.1.3/CHANGELOG.html#hip-api-deprecations-and-warnings - - Sycl supports std::complex - - C's _Complex is part of C99 - """ - if self._cpp: - return lambda dtype: 'std::complex<%s>' % str(dtype) - else: - return lambda dtype: '%s _Complex' % str(dtype) - def get_version(self): result, stdout, stderr = call_capture_output((self.cc, "--version")) if result != 0: @@ -609,7 +595,7 @@ def __init_finalize__(self, **kwargs): self.cflags.remove('-O3') self.cflags.remove('-Wall') - self.cflags.append('-std=c++11') + self.cflags.append('-std=c++14') language = kwargs.pop('language', configuration['language']) platform = kwargs.pop('platform', configuration['platform']) @@ -715,10 +701,6 @@ def __lookup_cmds__(self): self.MPICC = 'nvcc' self.MPICXX = 'nvcc' - @property - def _complex_ctype(self): - return lambda dtype: 'thrust::complex<%s>' % str(dtype) - class HipCompiler(Compiler): diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 1e21f1d8ba..6e9879d873 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -602,7 +602,7 @@ def visit_MultiTraversable(self, o): return c.Collection(body) def visit_UsingNamespace(self, o): - return c.Statement('using namespace %s' % ccode(o.namespace)) + return c.Statement('using namespace %s' % str(o.namespace)) def visit_Lambda(self, o): body = [] diff --git a/devito/operator/operator.py b/devito/operator/operator.py index efb47640ce..ba411c0ea6 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -470,10 +470,6 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Lower IET to a target-specific IET graph = Graph(iet, **kwargs) - # Complex header if needed. Needs to be done before specialization - # as some specific cases require complex to be loaded first - include_complex(graph, language=kwargs['language'], compiler=kwargs['compiler']) - # Specialize graph = cls._specialize_iet(graph, **kwargs) diff --git a/devito/passes/iet/__init__.py b/devito/passes/iet/__init__.py index 6b4ada0b73..1cdb97c794 100644 --- a/devito/passes/iet/__init__.py +++ b/devito/passes/iet/__init__.py @@ -8,4 +8,4 @@ from .instrument import * # noqa from .languages import * # noqa from .errors import * # noqa -from .complex import * # noqa +from .dtypes import * # noqa diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index ca4164d184..81a0168d58 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -12,6 +12,7 @@ from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction, FindSymbols, MapExprStmts, Transformer, make_callable) from devito.passes import is_gpu_create +from devito.passes.iet.dtypes import lower_complex from devito.passes.iet.engine import iet_pass from devito.passes.iet.langbase import LangBB from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer, @@ -73,10 +74,12 @@ class DataManager: The language used to express data allocations, deletions, and host-device transfers. """ - def __init__(self, rcompile=None, sregistry=None, platform=None, **kwargs): + def __init__(self, rcompile=None, sregistry=None, platform=None, + compiler=None, **kwargs): self.rcompile = rcompile self.sregistry = sregistry self.platform = platform + self.compiler = compiler def _alloc_object_on_low_lat_mem(self, site, obj, storage): """ @@ -409,12 +412,18 @@ def place_casts(self, iet, **kwargs): return iet, {} + @iet_pass + def make_langtypes(self, iet): + iet, metadata = lower_complex(iet, self.lang, self.compiler) + return iet, metadata + def process(self, graph): """ Apply the `place_definitions` and `place_casts` passes. """ self.place_definitions(graph, globs=set()) self.place_casts(graph) + self.make_langtypes(graph) class DeviceAwareDataManager(DataManager): @@ -564,6 +573,7 @@ def process(self, graph): self.place_devptr(graph) self.place_bundling(graph, writes_input=graph.writes_input) self.place_casts(graph) + self.make_langtypes(graph) def make_zero_init(obj): diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py new file mode 100644 index 0000000000..912f707afd --- /dev/null +++ b/devito/passes/iet/dtypes.py @@ -0,0 +1,58 @@ +import numpy as np +import ctypes + +from devito.ir import FindSymbols, Uxreplace + +__all__ = ['lower_complex'] + + +def lower_complex(iet, lang, compiler): + """ + Add headers for complex arithmetic + """ + # Check if there is complex numbers that always take dtype precedence + types = {f.dtype for f in FindSymbols().visit(iet) + if not issubclass(f.dtype, ctypes._Pointer)} + + if not any(np.issubdtype(d, np.complexfloating) for d in types): + return iet, {} + + lib = (lang['header-complex'],) + + metadata = {} + if lang.get('complex-namespace') is not None: + metadata['namespaces'] = lang['complex-namespace'] + + # Some languges such as c++11 need some extra arithmetic definitions + if lang.get('def-complex'): + dest = compiler.get_jit_dir() + hfile = dest.joinpath('complex_arith.h') + with open(str(hfile), 'w') as ff: + ff.write(str(lang['def-complex'])) + lib += (str(hfile),) + + iet = _complex_dtypes(iet, lang) + metadata['includes'] = lib + print(metadata) + return iet, metadata + + +def _complex_dtypes(iet, lang): + """ + Lower dtypes to language specific types + """ + mapper = {} + + for s in FindSymbols('indexeds').visit(iet): + if s.dtype in lang['types']: + mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) + + for s in FindSymbols().visit(iet): + if s.dtype in lang['types']: + mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) + + body = Uxreplace(mapper).visit(iet.body) + params = Uxreplace(mapper).visit(iet.parameters) + iet = iet._rebuild(body=body, parameters=params) + + return iet diff --git a/devito/passes/iet/langbase.py b/devito/passes/iet/langbase.py index d27674c419..e34aa2dac3 100644 --- a/devito/passes/iet/langbase.py +++ b/devito/passes/iet/langbase.py @@ -31,6 +31,9 @@ def __getitem__(self, k): raise NotImplementedError("Missing required mapping for `%s`" % k) return self.mapper[k] + def get(self, k): + return self.mapper.get(k) + class LangBB(metaclass=LangMeta): @@ -200,6 +203,14 @@ def initialize(self, iet, options=None): """ return iet, {} + @iet_pass + def make_langtypes(self, iet): + """ + An `iet_pass` which transforms an IET such that the target language + types are introduced. + """ + return iet, {} + @property def Region(self): return self.lang.Region diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 4b3358798d..bd5e0e6413 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,11 +1,18 @@ +import numpy as np + from devito.ir import Call from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB +from devito.tools import CustomNpType __all__ = ['CBB', 'CDataManager', 'COrchestrator'] +CCFloat = CustomNpType('_Complex float', np.complex64) +CCDouble = CustomNpType('_Complex double', np.complex128) + + class CBB(LangBB): mapper = { @@ -19,7 +26,10 @@ class CBB(LangBB): 'host-free-pin': lambda i: Call('free', (i,)), 'alloc-global-symbol': lambda i, j, k: - Call('memcpy', (i, j, k)) + Call('memcpy', (i, j, k)), + # Complex + 'header-complex': 'complex.h', + 'types': {np.complex128: CCDouble, np.complex64: CCFloat}, } diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py new file mode 100644 index 0000000000..9f833d630b --- /dev/null +++ b/devito/passes/iet/languages/CXX.py @@ -0,0 +1,69 @@ +import numpy as np + +from devito.ir import Call, UsingNamespace +from devito.passes.iet.langbase import LangBB +from devito.tools import CustomNpType + +__all__ = ['CXXBB'] + + +std_arith = """ +#include + +template +std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){ + return std::complex<_Tp>(b.real() * a, b.imag() * a); +} + +template +std::complex<_Tp> operator * (const std::complex<_Tp> & b, const _Ti & a){ + return std::complex<_Tp>(b.real() * a, b.imag() * a); +} + +template +std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){ + _Tp denom = b.real() * b.real () + b.imag() * b.imag() + return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom); +} + +template +std::complex<_Tp> operator / (const std::complex<_Tp> & b, const _Ti & a){ + return std::complex<_Tp>(b.real() / a, b.imag() / a); +} + +template +std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){ + return std::complex<_Tp>(b.real() + a, b.imag()); +} + +template +std::complex<_Tp> operator + (const std::complex<_Tp> & b, const _Ti & a){ + return std::complex<_Tp>(b.real() + a, b.imag()); +} + +""" + +CXXCFloat = CustomNpType('std::complex', np.complex64, template='float') +CXXCDouble = CustomNpType('std::complex', np.complex128, template='double') + + +class CXXBB(LangBB): + + mapper = { + 'header-memcpy': 'string.h', + 'host-alloc': lambda i, j, k: + Call('posix_memalign', (i, j, k)), + 'host-alloc-pin': lambda i, j, k: + Call('posix_memalign', (i, j, k)), + 'host-free': lambda i: + Call('free', (i,)), + 'host-free-pin': lambda i: + Call('free', (i,)), + 'alloc-global-symbol': lambda i, j, k: + Call('memcpy', (i, j, k)), + # Complex + 'header-complex': 'complex', + 'complex-namespace': [UsingNamespace('std:complex_literals')], + 'def-complex': std_arith, + 'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat}, + } diff --git a/devito/passes/iet/languages/openacc.py b/devito/passes/iet/languages/openacc.py index bcd2c8d006..bcf5660ac7 100644 --- a/devito/passes/iet/languages/openacc.py +++ b/devito/passes/iet/languages/openacc.py @@ -9,7 +9,7 @@ from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.parpragma import (PragmaDeviceAwareTransformer, PragmaLangBB, PragmaIteration, PragmaTransfer) -from devito.passes.iet.languages.C import CBB +from devito.passes.iet.languages.CXX import CXXBB from devito.passes.iet.languages.openmp import OmpRegion, OmpIteration from devito.symbolics import FieldFromPointer, Macro, cast_mapper from devito.tools import filter_ordered, UnboundTuple @@ -122,7 +122,8 @@ class AccBB(PragmaLangBB): 'device-free': lambda i, *a: Call('acc_free', (i,)) } - mapper.update(CBB.mapper) + + mapper.update(CXXBB.mapper) Region = OmpRegion HostIteration = OmpIteration # Host parallelism still goes via OpenMP diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 50511b6005..f0b2b7f4f5 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -12,7 +12,7 @@ from devito.ir.iet.efunc import DeviceFunction, EntryFunction from devito.symbolics import (ValueLimit, evalrel, has_integer_args, limits_mapper, ccode) -from devito.tools import Bunch, as_mapper, filter_ordered, split, dtype_to_cstr +from devito.tools import Bunch, as_mapper, filter_ordered, split from devito.types import FIndexed __all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions', diff --git a/devito/symbolics/__init__.py b/devito/symbolics/__init__.py index 0f5c261471..9d7bee01b8 100644 --- a/devito/symbolics/__init__.py +++ b/devito/symbolics/__init__.py @@ -1,4 +1,5 @@ from devito.symbolics.extended_sympy import * # noqa +from devito.symbolics.extended_dtypes import * # noqa from devito.symbolics.queries import * # noqa from devito.symbolics.search import * # noqa from devito.symbolics.printer import * # noqa diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py new file mode 100644 index 0000000000..c558eb4e18 --- /dev/null +++ b/devito/symbolics/extended_dtypes.py @@ -0,0 +1,123 @@ +import numpy as np + +from devito.symbolics.extended_sympy import ReservedWord, Cast, CastStar, ValueLimit +from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa + int2, int3, int4) + +__all__ = ['cast_mapper', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID'] # noqa + + +limits_mapper = { + np.int32: Bunch(min=ValueLimit('INT_MIN'), max=ValueLimit('INT_MAX')), + np.int64: Bunch(min=ValueLimit('LONG_MIN'), max=ValueLimit('LONG_MAX')), + np.float32: Bunch(min=-ValueLimit('FLT_MAX'), max=ValueLimit('FLT_MAX')), + np.float64: Bunch(min=-ValueLimit('DBL_MAX'), max=ValueLimit('DBL_MAX')), +} + + +class CustomType(ReservedWord): + pass + + +# Dynamically create INT, INT2, .... INTP, INT2P, ... FLOAT, ... +for base_name in ['int', 'float', 'double']: + for i in ['', '2', '3', '4']: + v = '%s%s' % (base_name, i) + cls = type(v.upper(), (Cast,), {'_base_typ': v}) + globals()[cls.__name__] = cls + + clsp = type('%sP' % v.upper(), (CastStar,), {'base': cls}) + globals()[clsp.__name__] = clsp + + +class CHAR(Cast): + _base_typ = 'char' + + +class SHORT(Cast): + _base_typ = 'short' + + +class USHORT(Cast): + _base_typ = 'unsigned short' + + +class UCHAR(Cast): + _base_typ = 'unsigned char' + + +class LONG(Cast): + _base_typ = 'long' + + +class ULONG(Cast): + _base_typ = 'unsigned long' + + +class CFLOAT(Cast): + _base_typ = 'float' + + +class CDOUBLE(Cast): + _base_typ = 'double' + + +class VOID(Cast): + _base_typ = 'void' + + +class CHARP(CastStar): + base = CHAR + + +class UCHARP(CastStar): + base = UCHAR + + +class SHORTP(CastStar): + base = SHORT + + +class USHORTP(CastStar): + base = USHORT + + +class CFLOATP(CastStar): + base = CFLOAT + + +class CDOUBLEP(CastStar): + base = CDOUBLE + + +cast_mapper = { + np.int8: CHAR, + np.uint8: UCHAR, + np.int16: SHORT, # noqa + np.uint16: USHORT, # noqa + int: INT, # noqa + np.int32: INT, # noqa + np.int64: LONG, + np.uint64: ULONG, + np.float32: FLOAT, # noqa + float: DOUBLE, # noqa + np.float64: DOUBLE, # noqa + + (np.int8, '*'): CHARP, + (np.uint8, '*'): UCHARP, + (int, '*'): INTP, # noqa + (np.uint16, '*'): USHORTP, # noqa + (np.int16, '*'): SHORTP, # noqa + (np.int32, '*'): INTP, # noqa + (np.int64, '*'): INTP, # noqa + (np.float32, '*'): FLOATP, # noqa + (float, '*'): DOUBLEP, # noqa + (np.float64, '*'): DOUBLEP, # noqa +} + +for base_name in ['int', 'float', 'double']: + for i in [2, 3, 4]: + v = '%s%d' % (base_name, i) + cls = locals()[v] + cast_mapper[cls] = locals()[v.upper()] + cast_mapper[(cls, '*')] = locals()['%sP' % v.upper()] diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 03fec7438a..b386a68a79 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -7,7 +7,6 @@ from sympy import Expr, Function, Number, Tuple, sympify from sympy.core.decorators import call_highest_priority -from devito import configuration from devito.finite_differences.elementary import Min, Max from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa float3, float4, double2, double3, double4, int2, int3, @@ -20,8 +19,7 @@ 'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction', 'MathFunction', 'InlineIf', 'ReservedWord', 'Keyword', 'String', 'Macro', 'Class', 'MacroArgument', 'CustomType', 'Deref', 'Namespace', - 'Rvalue', 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'Null', 'SizeOf', 'rfunc', - 'cast_mapper', 'BasicWrapperMixin', 'ValueLimit', 'limits_mapper'] + 'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', 'ValueLimit'] class CondEq(sympy.Eq): @@ -548,14 +546,6 @@ class ValueLimit(ReservedWord, sympy.Expr): pass -limits_mapper = { - np.int32: Bunch(min=ValueLimit('INT_MIN'), max=ValueLimit('INT_MAX')), - np.int64: Bunch(min=ValueLimit('LONG_MIN'), max=ValueLimit('LONG_MAX')), - np.float32: Bunch(min=-ValueLimit('FLT_MAX'), max=ValueLimit('FLT_MAX')), - np.float64: Bunch(min=-ValueLimit('DBL_MAX'), max=ValueLimit('DBL_MAX')), -} - - class DefFunction(Function, Pickable): """ @@ -773,120 +763,6 @@ def __new__(cls, base=''): return cls.base(base, '*') -# Dynamically create INT, INT2, .... INTP, INT2P, ... FLOAT, ... -for base_name in ['int', 'float', 'double']: - for i in ['', '2', '3', '4']: - v = '%s%s' % (base_name, i) - cls = type(v.upper(), (Cast,), {'_base_typ': v}) - globals()[cls.__name__] = cls - - clsp = type('%sP' % v.upper(), (CastStar,), {'base': cls}) - globals()[clsp.__name__] = clsp - - -class CHAR(Cast): - _base_typ = 'char' - - -class SHORT(Cast): - _base_typ = 'short' - - -class USHORT(Cast): - _base_typ = 'unsigned short' - - -class UCHAR(Cast): - _base_typ = 'unsigned char' - - -class LONG(Cast): - _base_typ = 'long' - - -class ULONG(Cast): - _base_typ = 'unsigned long' - - -class VOID(Cast): - _base_typ = 'void' - - -class CFLOAT(Cast): - - @property - def _base_typ(self): - return configuration['compiler']._complex_ctype('float') - - -class CDOUBLE(Cast): - - @property - def _base_typ(self): - return configuration['compiler']._complex_ctype('double') - - -class CHARP(CastStar): - base = CHAR - - -class UCHARP(CastStar): - base = UCHAR - - -class SHORTP(CastStar): - base = SHORT - - -class USHORTP(CastStar): - base = USHORT - - -class CFLOATP(CastStar): - base = CFLOAT - - -class CDOUBLEP(CastStar): - base = CDOUBLE - - -cast_mapper = { - np.int8: CHAR, - np.uint8: UCHAR, - np.int16: SHORT, # noqa - np.uint16: USHORT, # noqa - int: INT, # noqa - np.int32: INT, # noqa - np.int64: LONG, - np.uint64: ULONG, - np.float32: FLOAT, # noqa - float: DOUBLE, # noqa - np.float64: DOUBLE, # noqa - np.complex64: CFLOAT, # noqa - np.complex128: CDOUBLE, # noqa - - (np.int8, '*'): CHARP, - (np.uint8, '*'): UCHARP, - (int, '*'): INTP, # noqa - (np.uint16, '*'): USHORTP, # noqa - (np.int16, '*'): SHORTP, # noqa - (np.int32, '*'): INTP, # noqa - (np.int64, '*'): INTP, # noqa - (np.float32, '*'): FLOATP, # noqa - (float, '*'): DOUBLEP, # noqa - (np.float64, '*'): DOUBLEP, # noqa - (np.complex64, '*'): CFLOATP, # noqa - (np.complex128, '*'): CDOUBLEP, # noqa -} - -for base_name in ['int', 'float', 'double']: - for i in [2, 3, 4]: - v = '%s%d' % (base_name, i) - cls = locals()[v] - cast_mapper[cls] = locals()[v.upper()] - cast_mapper[(cls, '*')] = locals()['%sP' % v.upper()] - - # Some other utility objects Null = Macro('NULL') diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 53c7b07e39..11b95a16d3 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -8,7 +8,8 @@ from devito.finite_differences import Derivative from devito.finite_differences.differentiable import IndexDerivative from devito.logger import warning -from devito.symbolics.extended_sympy import (INT, CallFromPointer, Cast, +from devito.symbolics.extended_dtypes import INT +from devito.symbolics.extended_sympy import (CallFromPointer, Cast, DefFunction, ReservedWord) from devito.symbolics.queries import q_routine from devito.tools import as_tuple, prod diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index c7917b3ea1..fc180300a3 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -11,6 +11,7 @@ from sympy.printing.precedence import PRECEDENCE_VALUES, precedence from sympy.printing.c import C99CodePrinter +from devito import configuration from devito.arch.compiler import AOMPCompiler from devito.symbolics.inspection import has_integer_args, sympy_dtype from devito.types.basic import AbstractFunction @@ -37,13 +38,17 @@ def dtype(self): @property def compiler(self): - return self._settings['compiler'] + return self._settings['compiler'] or configuration['compiler'] def single_prec(self, expr=None): + if self.compiler._cpp and expr is not None: + return False dtype = sympy_dtype(expr) if expr is not None else self.dtype return dtype in [np.float32, np.float16, np.complex64] def complex_prec(self, expr=None): + if self.compiler._cpp: + return False dtype = sympy_dtype(expr) if expr is not None else self.dtype return np.issubdtype(dtype, np.complexfloating) @@ -211,7 +216,10 @@ def _print_Float(self, expr): return rv def _print_ImaginaryUnit(self, expr): - return '_Complex_I' + if self.compiler._cpp: + return '1i' + else: + return '_Complex_I' def _print_Differentiable(self, expr): return "(%s)" % self._print(expr._expr) diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index 8a30b04cc4..3d04f73e84 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -13,7 +13,7 @@ 'double3', 'double4', 'dtypes_vector_mapper', 'dtype_to_mpidtype', 'dtype_to_cstr', 'dtype_to_ctype', 'dtype_to_mpitype', 'dtype_len', 'ctypes_to_cstr', 'c_restrict_void_p', 'ctypes_vector_mapper', - 'is_external_ctype', 'infer_dtype', 'CustomDtype'] + 'is_external_ctype', 'infer_dtype', 'CustomDtype', 'CustomNpType'] # *** Custom np.dtypes @@ -123,6 +123,18 @@ def __repr__(self): __str__ = __repr__ +class CustomNpType(CustomDtype): + """ + Custom dtype for underlying numpy type. + """ + + def __init__(self, name, nptype, template=None, modifier=None): + self.nptype = nptype + super().__init__(name, template, modifier) + + def __call__(self, val): + return self.nptype(val) + # *** np.dtypes lowering @@ -136,16 +148,6 @@ def dtype_to_ctype(dtype): if isinstance(dtype, CustomDtype): return dtype - # Complex data - if np.issubdtype(dtype, np.complexfloating): - rtype = dtype(0).real.__class__ - from devito import configuration - make = configuration['compiler']._complex_ctype - ctname = make(dtype_to_cstr(rtype)) - ctype = dtype_to_ctype(rtype) - r = type(ctname, (ctype,), {}) - return r - try: return ctypes_vector_mapper[dtype] except KeyError: diff --git a/devito/types/basic.py b/devito/types/basic.py index e21bae6453..15cf7ab1b8 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -13,7 +13,8 @@ from devito.data import default_allocator from devito.parameters import configuration from devito.tools import (Pickable, as_tuple, ctypes_to_cstr, dtype_to_ctype, - frozendict, memoized_meth, sympy_mutex) + frozendict, memoized_meth, sympy_mutex, CustomDtype, + Reconstructable) from devito.types.args import ArgProvider from devito.types.caching import Cached, Uncached from devito.types.lazy import Evaluable @@ -83,6 +84,9 @@ def _C_typedata(self): The type of the object in the generated code as a `str`. """ _type = self._C_ctype + if isinstance(_type, CustomDtype): + return _type + while issubclass(_type, _Pointer): _type = _type._type_ @@ -859,6 +863,7 @@ def __new__(cls, *args, **kwargs): name = kwargs.get('name') alias = kwargs.get('alias') function = kwargs.get('function') + dtype = kwargs.get('dtype') if alias or (function and function.name != name): function = kwargs['function'] = None @@ -866,7 +871,8 @@ def __new__(cls, *args, **kwargs): # definitely a reconstruction if function is not None and \ function.name == name and \ - function.indices == indices: + function.indices == indices and \ + function.dtype == dtype: # Special case: a syntactically identical alias of `function`, so # let's just return `function` itself return function @@ -1188,7 +1194,8 @@ def bound_symbols(self): @cached_property def indexed(self): """The wrapped IndexedData object.""" - return IndexedData(self.name, shape=self._shape, function=self.function) + return IndexedData(self.name, shape=self._shape, function=self.function, + dtype=self.dtype) @cached_property def dmap(self): @@ -1445,13 +1452,14 @@ class IndexedBase(sympy.IndexedBase, Basic, Pickable): __rargs__ = ('label', 'shape') __rkwargs__ = ('function',) - def __new__(cls, label, shape, function=None): + def __new__(cls, label, shape, function=None, dtype=None): # Make sure `label` is a devito.Symbol, not a sympy.Symbol if isinstance(label, str): label = Symbol(name=label, dtype=None) with sympy_mutex: obj = sympy.IndexedBase.__new__(cls, label, shape) obj.function = function + obj._dtype = dtype or function.dtype return obj func = Pickable._rebuild @@ -1485,7 +1493,7 @@ def indices(self): @property def dtype(self): - return self.function.dtype + return self._dtype @cached_property def free_symbols(self): @@ -1547,7 +1555,7 @@ def _C_ctype(self): return self.function._C_ctype -class Indexed(sympy.Indexed): +class Indexed(sympy.Indexed, Reconstructable): # The two type flags have changed in upstream sympy as of version 1.1, # but the below interpretation is used throughout the compiler to @@ -1559,6 +1567,17 @@ class Indexed(sympy.Indexed): is_Dimension = False + __rargs__ = ('base', 'indices') + __rkwargs__ = ('dtype',) + + def __new__(cls, base, *indices, dtype=None, **kwargs): + if len(indices) == 1: + indices = as_tuple(indices[0]) + newobj = sympy.Indexed.__new__(cls, base, *indices) + newobj._dtype = dtype or base.dtype + + return newobj + @memoized_meth def __str__(self): return super().__str__() @@ -1580,7 +1599,7 @@ def function(self): @property def dtype(self): - return self.function.dtype + return self._dtype @property def name(self): diff --git a/devito/types/misc.py b/devito/types/misc.py index 72f1ab895a..b8f68e39c1 100644 --- a/devito/types/misc.py +++ b/devito/types/misc.py @@ -79,7 +79,7 @@ class FIndexed(Indexed, Pickable): __rkwargs__ = ('strides_map', 'accessor') def __new__(cls, base, *args, strides_map=None, accessor=None): - obj = super().__new__(cls, base, *args) + obj = super().__new__(cls, base, args) obj.strides_map = frozendict(strides_map or {}) obj.accessor = accessor From 9abeea84145b24f3bd634b70e905ef1c3330a98c Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 27 Jun 2024 07:59:59 -0400 Subject: [PATCH 10/11] compiler: switch to c++14 for complex_literals --- devito/passes/iet/dtypes.py | 2 +- devito/passes/iet/languages/CXX.py | 2 +- devito/symbolics/extended_dtypes.py | 2 +- devito/symbolics/extended_sympy.py | 6 +----- devito/symbolics/printer.py | 15 +++++++++++---- tests/test_gpu_common.py | 2 +- 6 files changed, 16 insertions(+), 13 deletions(-) diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 912f707afd..1932b60f3a 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -33,7 +33,7 @@ def lower_complex(iet, lang, compiler): iet = _complex_dtypes(iet, lang) metadata['includes'] = lib - print(metadata) + return iet, metadata diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index 9f833d630b..5f74070472 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -63,7 +63,7 @@ class CXXBB(LangBB): Call('memcpy', (i, j, k)), # Complex 'header-complex': 'complex', - 'complex-namespace': [UsingNamespace('std:complex_literals')], + 'complex-namespace': [UsingNamespace('std::complex_literals')], 'def-complex': std_arith, 'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat}, } diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index c558eb4e18..0e8ce0cc98 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -4,7 +4,7 @@ from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa int2, int3, int4) -__all__ = ['cast_mapper', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID'] # noqa +__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID'] # noqa limits_mapper = { diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index b386a68a79..19fcd83d4e 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -18,7 +18,7 @@ 'CallFromComposite', 'FieldFromPointer', 'FieldFromComposite', 'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction', 'MathFunction', 'InlineIf', 'ReservedWord', 'Keyword', 'String', - 'Macro', 'Class', 'MacroArgument', 'CustomType', 'Deref', 'Namespace', + 'Macro', 'Class', 'MacroArgument', 'Deref', 'Namespace', 'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', 'ValueLimit'] @@ -508,10 +508,6 @@ class Keyword(ReservedWord): pass -class CustomType(ReservedWord): - pass - - class String(ReservedWord): pass diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index fc180300a3..c9c73ed0b4 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -34,14 +34,18 @@ class CodePrinter(C99CodePrinter): @property def dtype(self): - return self._settings['dtype'] + try: + return self._settings['dtype'].nptype + except AttributeError: + return self._settings['dtype'] @property def compiler(self): return self._settings['compiler'] or configuration['compiler'] - def single_prec(self, expr=None): - if self.compiler._cpp and expr is not None: + def single_prec(self, expr=None, with_f=False): + no_f = self.compiler._cpp and not with_f + if no_f and expr is not None: return False dtype = sympy_dtype(expr) if expr is not None else self.dtype return dtype in [np.float32, np.float16, np.complex64] @@ -217,7 +221,10 @@ def _print_Float(self, expr): def _print_ImaginaryUnit(self, expr): if self.compiler._cpp: - return '1i' + if self.single_prec(with_f=True): + return '1if' + else: + return '1i' else: return '_Complex_I' diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index e229cbb98d..2e26b78c22 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -82,7 +82,7 @@ def test_complex(self, dtype): xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5)) npres = xx + 1j*yy + np.exp(1j + dx) - assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0) + assert np.allclose(u.data, npres.T, rtol=1e-6, atol=0) class TestPassesOptional: From 94d5571c3c100316c6444de91552e1b688834b7a Mon Sep 17 00:00:00 2001 From: mloubout Date: Mon, 8 Jul 2024 12:47:53 -0400 Subject: [PATCH 11/11] compiler: subdtype numpy for dtype lowering --- devito/passes/iet/dtypes.py | 6 +----- devito/passes/iet/languages/C.py | 19 ++++++++++++++++--- devito/passes/iet/languages/CXX.py | 20 +++++++++++++++++--- devito/symbolics/printer.py | 2 +- devito/tools/dtypes_lowering.py | 14 +------------- 5 files changed, 36 insertions(+), 25 deletions(-) diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 1932b60f3a..57eb10c4d8 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -43,11 +43,7 @@ def _complex_dtypes(iet, lang): """ mapper = {} - for s in FindSymbols('indexeds').visit(iet): - if s.dtype in lang['types']: - mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) - - for s in FindSymbols().visit(iet): + for s in FindSymbols('indexeds|basics|symbolics').visit(iet): if s.dtype in lang['types']: mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index bd5e0e6413..2cee279428 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,16 +1,29 @@ +import ctypes as ct import numpy as np from devito.ir import Call from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB -from devito.tools import CustomNpType +from devito.tools.dtypes_lowering import ctypes_vector_mapper + __all__ = ['CBB', 'CDataManager', 'COrchestrator'] -CCFloat = CustomNpType('_Complex float', np.complex64) -CCDouble = CustomNpType('_Complex double', np.complex128) +class CCFloat(np.complex64): + pass + + +class CCDouble(np.complex128): + pass + + +c_complex = type('_Complex float', (ct.c_double,), {}) +c_double_complex = type('_Complex double', (ct.c_longdouble,), {}) + +ctypes_vector_mapper[CCFloat] = c_complex +ctypes_vector_mapper[CCDouble] = c_double_complex class CBB(LangBB): diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index 5f74070472..fb802acb8b 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -1,8 +1,9 @@ +import ctypes as ct import numpy as np from devito.ir import Call, UsingNamespace from devito.passes.iet.langbase import LangBB -from devito.tools import CustomNpType +from devito.tools.dtypes_lowering import ctypes_vector_mapper __all__ = ['CXXBB'] @@ -43,8 +44,21 @@ """ -CXXCFloat = CustomNpType('std::complex', np.complex64, template='float') -CXXCDouble = CustomNpType('std::complex', np.complex128, template='double') + +class CXXCFloat(np.complex64): + pass + + +class CXXCDouble(np.complex128): + pass + + +cxx_complex = type('std::complex', (ct.c_double,), {}) +cxx_double_complex = type('std::complex', (ct.c_longdouble,), {}) + + +ctypes_vector_mapper[CXXCFloat] = cxx_complex +ctypes_vector_mapper[CXXCDouble] = cxx_double_complex class CXXBB(LangBB): diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index c9c73ed0b4..77bc407dd6 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -48,7 +48,7 @@ def single_prec(self, expr=None, with_f=False): if no_f and expr is not None: return False dtype = sympy_dtype(expr) if expr is not None else self.dtype - return dtype in [np.float32, np.float16, np.complex64] + return any(issubclass(dtype, d) for d in [np.float32, np.float16, np.complex64]) def complex_prec(self, expr=None): if self.compiler._cpp: diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index 3d04f73e84..43def2d8cd 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -13,7 +13,7 @@ 'double3', 'double4', 'dtypes_vector_mapper', 'dtype_to_mpidtype', 'dtype_to_cstr', 'dtype_to_ctype', 'dtype_to_mpitype', 'dtype_len', 'ctypes_to_cstr', 'c_restrict_void_p', 'ctypes_vector_mapper', - 'is_external_ctype', 'infer_dtype', 'CustomDtype', 'CustomNpType'] + 'is_external_ctype', 'infer_dtype', 'CustomDtype'] # *** Custom np.dtypes @@ -123,18 +123,6 @@ def __repr__(self): __str__ = __repr__ -class CustomNpType(CustomDtype): - """ - Custom dtype for underlying numpy type. - """ - - def __init__(self, name, nptype, template=None, modifier=None): - self.nptype = nptype - super().__init__(name, template, modifier) - - def __call__(self, val): - return self.nptype(val) - # *** np.dtypes lowering