From ac6d213d85a72687bc3c6f13eee476dcc096d6fb Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 20 Jun 2024 10:28:38 -0400 Subject: [PATCH] compiler: rework dtype lowering --- devito/arch/compiler.py | 18 ---- devito/operator/operator.py | 6 +- devito/passes/iet/__init__.py | 2 +- devito/passes/iet/definitions.py | 12 ++- devito/passes/iet/dtypes.py | 54 +++++++++++ devito/passes/iet/langbase.py | 11 +++ devito/passes/iet/languages/C.py | 13 ++- devito/passes/iet/languages/CXX.py | 69 ++++++++++++++ devito/passes/iet/languages/openacc.py | 5 +- devito/passes/iet/misc.py | 2 +- devito/symbolics/__init__.py | 1 + devito/symbolics/extended_dtypes.py | 123 ++++++++++++++++++++++++ devito/symbolics/extended_sympy.py | 126 +------------------------ devito/symbolics/inspection.py | 3 +- devito/symbolics/printer.py | 7 +- devito/tools/dtypes_lowering.py | 24 ++--- devito/types/basic.py | 33 +++++-- devito/types/misc.py | 2 +- 18 files changed, 336 insertions(+), 175 deletions(-) create mode 100644 devito/passes/iet/dtypes.py create mode 100644 devito/passes/iet/languages/CXX.py create mode 100644 devito/symbolics/extended_dtypes.py diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index 5df3891074d..d7abc1f762f 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -245,20 +245,6 @@ def version(self): return version - @property - def _complex_ctype(self): - """ - Type definition for complex numbers. These two cases cover 99% of the cases since - - Hip is now using std::complex -https://rocm.docs.amd.com/en/docs-5.1.3/CHANGELOG.html#hip-api-deprecations-and-warnings - - Sycl supports std::complex - - C's _Complex is part of C99 - """ - if self._cpp: - return lambda dtype: 'std::complex<%s>' % str(dtype) - else: - return lambda dtype: '%s _Complex' % str(dtype) - def get_version(self): result, stdout, stderr = call_capture_output((self.cc, "--version")) if result != 0: @@ -713,10 +699,6 @@ def __lookup_cmds__(self): self.MPICC = 'nvcc' self.MPICXX = 'nvcc' - @property - def _complex_ctype(self): - return lambda dtype: 'thrust::complex<%s>' % str(dtype) - class HipCompiler(Compiler): diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 0e4b07379a7..df46aca8b49 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -22,7 +22,7 @@ from devito.parameters import configuration from devito.passes import (Graph, lower_index_derivatives, generate_implicit, generate_macros, minimize_symbols, unevaluate, - error_mapper, include_complex) + error_mapper) from devito.symbolics import estimate_cost from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_tuple, flatten, filter_sorted, frozendict, is_integer, split, timed_pass, @@ -466,10 +466,6 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Lower IET to a target-specific IET graph = Graph(iet, **kwargs) - # Complex header if needed. Needs to be done before specialization - # as some specific cases require complex to be loaded first - include_complex(graph, language=kwargs['language'], compiler=kwargs['compiler']) - # Specialize graph = cls._specialize_iet(graph, **kwargs) diff --git a/devito/passes/iet/__init__.py b/devito/passes/iet/__init__.py index 6b4ada0b737..1cdb97c7946 100644 --- a/devito/passes/iet/__init__.py +++ b/devito/passes/iet/__init__.py @@ -8,4 +8,4 @@ from .instrument import * # noqa from .languages import * # noqa from .errors import * # noqa -from .complex import * # noqa +from .dtypes import * # noqa diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index ca4164d184d..81a0168d58d 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -12,6 +12,7 @@ from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction, FindSymbols, MapExprStmts, Transformer, make_callable) from devito.passes import is_gpu_create +from devito.passes.iet.dtypes import lower_complex from devito.passes.iet.engine import iet_pass from devito.passes.iet.langbase import LangBB from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer, @@ -73,10 +74,12 @@ class DataManager: The language used to express data allocations, deletions, and host-device transfers. """ - def __init__(self, rcompile=None, sregistry=None, platform=None, **kwargs): + def __init__(self, rcompile=None, sregistry=None, platform=None, + compiler=None, **kwargs): self.rcompile = rcompile self.sregistry = sregistry self.platform = platform + self.compiler = compiler def _alloc_object_on_low_lat_mem(self, site, obj, storage): """ @@ -409,12 +412,18 @@ def place_casts(self, iet, **kwargs): return iet, {} + @iet_pass + def make_langtypes(self, iet): + iet, metadata = lower_complex(iet, self.lang, self.compiler) + return iet, metadata + def process(self, graph): """ Apply the `place_definitions` and `place_casts` passes. """ self.place_definitions(graph, globs=set()) self.place_casts(graph) + self.make_langtypes(graph) class DeviceAwareDataManager(DataManager): @@ -564,6 +573,7 @@ def process(self, graph): self.place_devptr(graph) self.place_bundling(graph, writes_input=graph.writes_input) self.place_casts(graph) + self.make_langtypes(graph) def make_zero_init(obj): diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py new file mode 100644 index 00000000000..d287d8fecba --- /dev/null +++ b/devito/passes/iet/dtypes.py @@ -0,0 +1,54 @@ +import numpy as np +import ctypes + +from devito.ir import FindSymbols, Uxreplace + +__all__ = ['lower_complex'] + + +def lower_complex(iet, lang, compiler): + """ + Add headers for complex arithmetic + """ + # Check if there is complex numbers that always take dtype precedence + types = {f.dtype for f in FindSymbols().visit(iet) + if not issubclass(f.dtype, ctypes._Pointer)} + + if not any(np.issubdtype(d, np.complexfloating) for d in types): + return iet, {} + + lib = (lang['header-complex'],) + headers = lang.get('I-def') + + # Some languges such as c++11 need some extra arithmetic definitions + if lang.get('def-complex'): + dest = compiler.get_jit_dir() + hfile = dest.joinpath('complex_arith.h') + with open(str(hfile), 'w') as ff: + ff.write(str(lang['def-complex'])) + lib += (str(hfile),) + + iet = _complex_dtypes(iet, lang) + + return iet, {'includes': lib, 'headers': headers} + + +def _complex_dtypes(iet, lang): + """ + Lower dtypes to language specific types + """ + mapper = {} + + for s in FindSymbols('indexeds').visit(iet): + if s.dtype in lang['types']: + mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) + + for s in FindSymbols().visit(iet): + if s.dtype in lang['types']: + mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) + + body = Uxreplace(mapper).visit(iet.body) + params = Uxreplace(mapper).visit(iet.parameters) + iet = iet._rebuild(body=body, parameters=params) + + return iet diff --git a/devito/passes/iet/langbase.py b/devito/passes/iet/langbase.py index d27674c4194..e34aa2dac3e 100644 --- a/devito/passes/iet/langbase.py +++ b/devito/passes/iet/langbase.py @@ -31,6 +31,9 @@ def __getitem__(self, k): raise NotImplementedError("Missing required mapping for `%s`" % k) return self.mapper[k] + def get(self, k): + return self.mapper.get(k) + class LangBB(metaclass=LangMeta): @@ -200,6 +203,14 @@ def initialize(self, iet, options=None): """ return iet, {} + @iet_pass + def make_langtypes(self, iet): + """ + An `iet_pass` which transforms an IET such that the target language + types are introduced. + """ + return iet, {} + @property def Region(self): return self.lang.Region diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 4b3358798d2..65c6d4bd49d 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,11 +1,18 @@ +import numpy as np + from devito.ir import Call from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB +from devito.tools import CustomNpType __all__ = ['CBB', 'CDataManager', 'COrchestrator'] +CCFloat = CustomNpType('_Complex float', np.complex64) +CCDouble = CustomNpType('_Complex double', np.complex128) + + class CBB(LangBB): mapper = { @@ -19,7 +26,11 @@ class CBB(LangBB): 'host-free-pin': lambda i: Call('free', (i,)), 'alloc-global-symbol': lambda i, j, k: - Call('memcpy', (i, j, k)) + Call('memcpy', (i, j, k)), + # Complex + 'header-complex': 'complex.h', + 'types': {np.complex128: CCDouble, np.complex64: CCFloat}, + 'I-def': None } diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py new file mode 100644 index 00000000000..b00cd439c0e --- /dev/null +++ b/devito/passes/iet/languages/CXX.py @@ -0,0 +1,69 @@ +import numpy as np + +from devito.ir import Call +from devito.passes.iet.langbase import LangBB +from devito.tools import CustomNpType + +__all__ = ['CXXBB'] + + +std_arith = """ +#include + +template +std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){ + return std::complex<_Tp>(b.real() * a, b.imag() * a); +} + +template +std::complex<_Tp> operator * (const std::complex<_Tp> & b, const _Ti & a){ + return std::complex<_Tp>(b.real() * a, b.imag() * a); +} + +template +std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){ + _Tp denom = b.real() * b.real () + b.imag() * b.imag() + return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom); +} + +template +std::complex<_Tp> operator / (const std::complex<_Tp> & b, const _Ti & a){ + return std::complex<_Tp>(b.real() / a, b.imag() / a); +} + +template +std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){ + return std::complex<_Tp>(b.real() + a, b.imag()); +} + +template +std::complex<_Tp> operator + (const std::complex<_Tp> & b, const _Ti & a){ + return std::complex<_Tp>(b.real() + a, b.imag()); +} + +""" + +CXXCFloat = CustomNpType('std::complex', np.complex64, template='float') +CXXCDouble = CustomNpType('std::complex', np.complex128, template='double') + + +class CXXBB(LangBB): + + mapper = { + 'header-memcpy': 'string.h', + 'host-alloc': lambda i, j, k: + Call('posix_memalign', (i, j, k)), + 'host-alloc-pin': lambda i, j, k: + Call('posix_memalign', (i, j, k)), + 'host-free': lambda i: + Call('free', (i,)), + 'host-free-pin': lambda i: + Call('free', (i,)), + 'alloc-global-symbol': lambda i, j, k: + Call('memcpy', (i, j, k)), + # Complex + 'header-complex': 'complex', + 'I-def': (('_Complex_I', ('std::complex(0.0, 1.0)')),), + 'def-complex': std_arith, + 'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat}, + } diff --git a/devito/passes/iet/languages/openacc.py b/devito/passes/iet/languages/openacc.py index 186a106211d..b9c59121efb 100644 --- a/devito/passes/iet/languages/openacc.py +++ b/devito/passes/iet/languages/openacc.py @@ -10,7 +10,7 @@ from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.parpragma import (PragmaDeviceAwareTransformer, PragmaLangBB, PragmaIteration, PragmaTransfer) -from devito.passes.iet.languages.C import CBB +from devito.passes.iet.languages.CXX import CXXBB from devito.passes.iet.languages.openmp import OmpRegion, OmpIteration from devito.symbolics import FieldFromPointer, Macro, cast_mapper from devito.tools import filter_ordered, UnboundTuple @@ -118,7 +118,8 @@ class AccBB(PragmaLangBB): 'device-free': lambda i, *a: Call('acc_free', (i,)) } - mapper.update(CBB.mapper) + + mapper.update(CXXBB.mapper) Region = OmpRegion HostIteration = OmpIteration # Host parallelism still goes via OpenMP diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 50511b6005f..f0b2b7f4f54 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -12,7 +12,7 @@ from devito.ir.iet.efunc import DeviceFunction, EntryFunction from devito.symbolics import (ValueLimit, evalrel, has_integer_args, limits_mapper, ccode) -from devito.tools import Bunch, as_mapper, filter_ordered, split, dtype_to_cstr +from devito.tools import Bunch, as_mapper, filter_ordered, split from devito.types import FIndexed __all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions', diff --git a/devito/symbolics/__init__.py b/devito/symbolics/__init__.py index 0f5c261471f..9d7bee01b85 100644 --- a/devito/symbolics/__init__.py +++ b/devito/symbolics/__init__.py @@ -1,4 +1,5 @@ from devito.symbolics.extended_sympy import * # noqa +from devito.symbolics.extended_dtypes import * # noqa from devito.symbolics.queries import * # noqa from devito.symbolics.search import * # noqa from devito.symbolics.printer import * # noqa diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py new file mode 100644 index 00000000000..c558eb4e186 --- /dev/null +++ b/devito/symbolics/extended_dtypes.py @@ -0,0 +1,123 @@ +import numpy as np + +from devito.symbolics.extended_sympy import ReservedWord, Cast, CastStar, ValueLimit +from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa + int2, int3, int4) + +__all__ = ['cast_mapper', 'limits_mapper', 'INT', 'FLOAT', 'DOUBLE', 'VOID'] # noqa + + +limits_mapper = { + np.int32: Bunch(min=ValueLimit('INT_MIN'), max=ValueLimit('INT_MAX')), + np.int64: Bunch(min=ValueLimit('LONG_MIN'), max=ValueLimit('LONG_MAX')), + np.float32: Bunch(min=-ValueLimit('FLT_MAX'), max=ValueLimit('FLT_MAX')), + np.float64: Bunch(min=-ValueLimit('DBL_MAX'), max=ValueLimit('DBL_MAX')), +} + + +class CustomType(ReservedWord): + pass + + +# Dynamically create INT, INT2, .... INTP, INT2P, ... FLOAT, ... +for base_name in ['int', 'float', 'double']: + for i in ['', '2', '3', '4']: + v = '%s%s' % (base_name, i) + cls = type(v.upper(), (Cast,), {'_base_typ': v}) + globals()[cls.__name__] = cls + + clsp = type('%sP' % v.upper(), (CastStar,), {'base': cls}) + globals()[clsp.__name__] = clsp + + +class CHAR(Cast): + _base_typ = 'char' + + +class SHORT(Cast): + _base_typ = 'short' + + +class USHORT(Cast): + _base_typ = 'unsigned short' + + +class UCHAR(Cast): + _base_typ = 'unsigned char' + + +class LONG(Cast): + _base_typ = 'long' + + +class ULONG(Cast): + _base_typ = 'unsigned long' + + +class CFLOAT(Cast): + _base_typ = 'float' + + +class CDOUBLE(Cast): + _base_typ = 'double' + + +class VOID(Cast): + _base_typ = 'void' + + +class CHARP(CastStar): + base = CHAR + + +class UCHARP(CastStar): + base = UCHAR + + +class SHORTP(CastStar): + base = SHORT + + +class USHORTP(CastStar): + base = USHORT + + +class CFLOATP(CastStar): + base = CFLOAT + + +class CDOUBLEP(CastStar): + base = CDOUBLE + + +cast_mapper = { + np.int8: CHAR, + np.uint8: UCHAR, + np.int16: SHORT, # noqa + np.uint16: USHORT, # noqa + int: INT, # noqa + np.int32: INT, # noqa + np.int64: LONG, + np.uint64: ULONG, + np.float32: FLOAT, # noqa + float: DOUBLE, # noqa + np.float64: DOUBLE, # noqa + + (np.int8, '*'): CHARP, + (np.uint8, '*'): UCHARP, + (int, '*'): INTP, # noqa + (np.uint16, '*'): USHORTP, # noqa + (np.int16, '*'): SHORTP, # noqa + (np.int32, '*'): INTP, # noqa + (np.int64, '*'): INTP, # noqa + (np.float32, '*'): FLOATP, # noqa + (float, '*'): DOUBLEP, # noqa + (np.float64, '*'): DOUBLEP, # noqa +} + +for base_name in ['int', 'float', 'double']: + for i in [2, 3, 4]: + v = '%s%d' % (base_name, i) + cls = locals()[v] + cast_mapper[cls] = locals()[v.upper()] + cast_mapper[(cls, '*')] = locals()['%sP' % v.upper()] diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 03fec7438af..b386a68a792 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -7,7 +7,6 @@ from sympy import Expr, Function, Number, Tuple, sympify from sympy.core.decorators import call_highest_priority -from devito import configuration from devito.finite_differences.elementary import Min, Max from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa float3, float4, double2, double3, double4, int2, int3, @@ -20,8 +19,7 @@ 'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction', 'MathFunction', 'InlineIf', 'ReservedWord', 'Keyword', 'String', 'Macro', 'Class', 'MacroArgument', 'CustomType', 'Deref', 'Namespace', - 'Rvalue', 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'Null', 'SizeOf', 'rfunc', - 'cast_mapper', 'BasicWrapperMixin', 'ValueLimit', 'limits_mapper'] + 'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', 'ValueLimit'] class CondEq(sympy.Eq): @@ -548,14 +546,6 @@ class ValueLimit(ReservedWord, sympy.Expr): pass -limits_mapper = { - np.int32: Bunch(min=ValueLimit('INT_MIN'), max=ValueLimit('INT_MAX')), - np.int64: Bunch(min=ValueLimit('LONG_MIN'), max=ValueLimit('LONG_MAX')), - np.float32: Bunch(min=-ValueLimit('FLT_MAX'), max=ValueLimit('FLT_MAX')), - np.float64: Bunch(min=-ValueLimit('DBL_MAX'), max=ValueLimit('DBL_MAX')), -} - - class DefFunction(Function, Pickable): """ @@ -773,120 +763,6 @@ def __new__(cls, base=''): return cls.base(base, '*') -# Dynamically create INT, INT2, .... INTP, INT2P, ... FLOAT, ... -for base_name in ['int', 'float', 'double']: - for i in ['', '2', '3', '4']: - v = '%s%s' % (base_name, i) - cls = type(v.upper(), (Cast,), {'_base_typ': v}) - globals()[cls.__name__] = cls - - clsp = type('%sP' % v.upper(), (CastStar,), {'base': cls}) - globals()[clsp.__name__] = clsp - - -class CHAR(Cast): - _base_typ = 'char' - - -class SHORT(Cast): - _base_typ = 'short' - - -class USHORT(Cast): - _base_typ = 'unsigned short' - - -class UCHAR(Cast): - _base_typ = 'unsigned char' - - -class LONG(Cast): - _base_typ = 'long' - - -class ULONG(Cast): - _base_typ = 'unsigned long' - - -class VOID(Cast): - _base_typ = 'void' - - -class CFLOAT(Cast): - - @property - def _base_typ(self): - return configuration['compiler']._complex_ctype('float') - - -class CDOUBLE(Cast): - - @property - def _base_typ(self): - return configuration['compiler']._complex_ctype('double') - - -class CHARP(CastStar): - base = CHAR - - -class UCHARP(CastStar): - base = UCHAR - - -class SHORTP(CastStar): - base = SHORT - - -class USHORTP(CastStar): - base = USHORT - - -class CFLOATP(CastStar): - base = CFLOAT - - -class CDOUBLEP(CastStar): - base = CDOUBLE - - -cast_mapper = { - np.int8: CHAR, - np.uint8: UCHAR, - np.int16: SHORT, # noqa - np.uint16: USHORT, # noqa - int: INT, # noqa - np.int32: INT, # noqa - np.int64: LONG, - np.uint64: ULONG, - np.float32: FLOAT, # noqa - float: DOUBLE, # noqa - np.float64: DOUBLE, # noqa - np.complex64: CFLOAT, # noqa - np.complex128: CDOUBLE, # noqa - - (np.int8, '*'): CHARP, - (np.uint8, '*'): UCHARP, - (int, '*'): INTP, # noqa - (np.uint16, '*'): USHORTP, # noqa - (np.int16, '*'): SHORTP, # noqa - (np.int32, '*'): INTP, # noqa - (np.int64, '*'): INTP, # noqa - (np.float32, '*'): FLOATP, # noqa - (float, '*'): DOUBLEP, # noqa - (np.float64, '*'): DOUBLEP, # noqa - (np.complex64, '*'): CFLOATP, # noqa - (np.complex128, '*'): CDOUBLEP, # noqa -} - -for base_name in ['int', 'float', 'double']: - for i in [2, 3, 4]: - v = '%s%d' % (base_name, i) - cls = locals()[v] - cast_mapper[cls] = locals()[v.upper()] - cast_mapper[(cls, '*')] = locals()['%sP' % v.upper()] - - # Some other utility objects Null = Macro('NULL') diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 53c7b07e395..11b95a16d35 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -8,7 +8,8 @@ from devito.finite_differences import Derivative from devito.finite_differences.differentiable import IndexDerivative from devito.logger import warning -from devito.symbolics.extended_sympy import (INT, CallFromPointer, Cast, +from devito.symbolics.extended_dtypes import INT +from devito.symbolics.extended_sympy import (CallFromPointer, Cast, DefFunction, ReservedWord) from devito.symbolics.queries import q_routine from devito.tools import as_tuple, prod diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index c7917b3ea11..527bb795f1b 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -11,6 +11,7 @@ from sympy.printing.precedence import PRECEDENCE_VALUES, precedence from sympy.printing.c import C99CodePrinter +from devito import configuration from devito.arch.compiler import AOMPCompiler from devito.symbolics.inspection import has_integer_args, sympy_dtype from devito.types.basic import AbstractFunction @@ -37,13 +38,17 @@ def dtype(self): @property def compiler(self): - return self._settings['compiler'] + return self._settings['compiler'] or configuration['compiler'] def single_prec(self, expr=None): + if self.compiler._cpp and expr is not None: + return False dtype = sympy_dtype(expr) if expr is not None else self.dtype return dtype in [np.float32, np.float16, np.complex64] def complex_prec(self, expr=None): + if self.compiler._cpp: + return False dtype = sympy_dtype(expr) if expr is not None else self.dtype return np.issubdtype(dtype, np.complexfloating) diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index 8a30b04cc44..3d04f73e842 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -13,7 +13,7 @@ 'double3', 'double4', 'dtypes_vector_mapper', 'dtype_to_mpidtype', 'dtype_to_cstr', 'dtype_to_ctype', 'dtype_to_mpitype', 'dtype_len', 'ctypes_to_cstr', 'c_restrict_void_p', 'ctypes_vector_mapper', - 'is_external_ctype', 'infer_dtype', 'CustomDtype'] + 'is_external_ctype', 'infer_dtype', 'CustomDtype', 'CustomNpType'] # *** Custom np.dtypes @@ -123,6 +123,18 @@ def __repr__(self): __str__ = __repr__ +class CustomNpType(CustomDtype): + """ + Custom dtype for underlying numpy type. + """ + + def __init__(self, name, nptype, template=None, modifier=None): + self.nptype = nptype + super().__init__(name, template, modifier) + + def __call__(self, val): + return self.nptype(val) + # *** np.dtypes lowering @@ -136,16 +148,6 @@ def dtype_to_ctype(dtype): if isinstance(dtype, CustomDtype): return dtype - # Complex data - if np.issubdtype(dtype, np.complexfloating): - rtype = dtype(0).real.__class__ - from devito import configuration - make = configuration['compiler']._complex_ctype - ctname = make(dtype_to_cstr(rtype)) - ctype = dtype_to_ctype(rtype) - r = type(ctname, (ctype,), {}) - return r - try: return ctypes_vector_mapper[dtype] except KeyError: diff --git a/devito/types/basic.py b/devito/types/basic.py index 394d6276b6c..cdf6b020906 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -13,7 +13,8 @@ from devito.data import default_allocator from devito.parameters import configuration from devito.tools import (Pickable, as_tuple, ctypes_to_cstr, dtype_to_ctype, - frozendict, memoized_meth, sympy_mutex) + frozendict, memoized_meth, sympy_mutex, CustomDtype, + Reconstructable) from devito.types.args import ArgProvider from devito.types.caching import Cached, Uncached from devito.types.lazy import Evaluable @@ -82,6 +83,9 @@ def _C_typedata(self): The type of the object in the generated code as a `str`. """ _type = self._C_ctype + if isinstance(_type, CustomDtype): + return _type + while issubclass(_type, _Pointer): _type = _type._type_ @@ -858,6 +862,7 @@ def __new__(cls, *args, **kwargs): name = kwargs.get('name') alias = kwargs.get('alias') function = kwargs.get('function') + dtype = kwargs.get('dtype') if alias or (function and function.name != name): function = kwargs['function'] = None @@ -865,7 +870,8 @@ def __new__(cls, *args, **kwargs): # definitely a reconstruction if function is not None and \ function.name == name and \ - function.indices == indices: + function.indices == indices and \ + function.dtype == dtype: # Special case: a syntactically identical alias of `function`, so # let's just return `function` itself return function @@ -1170,7 +1176,8 @@ def bound_symbols(self): @cached_property def indexed(self): """The wrapped IndexedData object.""" - return IndexedData(self.name, shape=self._shape, function=self.function) + return IndexedData(self.name, shape=self._shape, function=self.function, + dtype=self.dtype) @cached_property def dmap(self): @@ -1414,13 +1421,14 @@ class IndexedBase(sympy.IndexedBase, Basic, Pickable): __rargs__ = ('label', 'shape') __rkwargs__ = ('function',) - def __new__(cls, label, shape, function=None): + def __new__(cls, label, shape, function=None, dtype=None): # Make sure `label` is a devito.Symbol, not a sympy.Symbol if isinstance(label, str): label = Symbol(name=label, dtype=None) with sympy_mutex: obj = sympy.IndexedBase.__new__(cls, label, shape) obj.function = function + obj._dtype = dtype or function.dtype return obj func = Pickable._rebuild @@ -1454,7 +1462,7 @@ def indices(self): @property def dtype(self): - return self.function.dtype + return self._dtype @cached_property def free_symbols(self): @@ -1516,7 +1524,7 @@ def _C_ctype(self): return self.function._C_ctype -class Indexed(sympy.Indexed): +class Indexed(sympy.Indexed, Reconstructable): # The two type flags have changed in upstream sympy as of version 1.1, # but the below interpretation is used throughout the compiler to @@ -1528,6 +1536,17 @@ class Indexed(sympy.Indexed): is_Dimension = False + __rargs__ = ('base', 'indices') + __rkwargs__ = ('dtype',) + + def __new__(cls, base, *indices, dtype=None, **kwargs): + if len(indices) == 1: + indices = as_tuple(indices[0]) + newobj = sympy.Indexed.__new__(cls, base, *indices) + newobj._dtype = dtype or base.dtype + + return newobj + @memoized_meth def __str__(self): return super().__str__() @@ -1549,7 +1568,7 @@ def function(self): @property def dtype(self): - return self.function.dtype + return self._dtype @property def name(self): diff --git a/devito/types/misc.py b/devito/types/misc.py index 72f1ab895a4..b8f68e39c12 100644 --- a/devito/types/misc.py +++ b/devito/types/misc.py @@ -79,7 +79,7 @@ class FIndexed(Indexed, Pickable): __rkwargs__ = ('strides_map', 'accessor') def __new__(cls, base, *args, strides_map=None, accessor=None): - obj = super().__new__(cls, base, *args) + obj = super().__new__(cls, base, args) obj.strides_map = frozendict(strides_map or {}) obj.accessor = accessor