From 94d5571c3c100316c6444de91552e1b688834b7a Mon Sep 17 00:00:00 2001 From: mloubout Date: Mon, 8 Jul 2024 12:47:53 -0400 Subject: [PATCH] compiler: subdtype numpy for dtype lowering --- devito/passes/iet/dtypes.py | 6 +----- devito/passes/iet/languages/C.py | 19 ++++++++++++++++--- devito/passes/iet/languages/CXX.py | 20 +++++++++++++++++--- devito/symbolics/printer.py | 2 +- devito/tools/dtypes_lowering.py | 14 +------------- 5 files changed, 36 insertions(+), 25 deletions(-) diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index 1932b60f3a..57eb10c4d8 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -43,11 +43,7 @@ def _complex_dtypes(iet, lang): """ mapper = {} - for s in FindSymbols('indexeds').visit(iet): - if s.dtype in lang['types']: - mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) - - for s in FindSymbols().visit(iet): + for s in FindSymbols('indexeds|basics|symbolics').visit(iet): if s.dtype in lang['types']: mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index bd5e0e6413..2cee279428 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,16 +1,29 @@ +import ctypes as ct import numpy as np from devito.ir import Call from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB -from devito.tools import CustomNpType +from devito.tools.dtypes_lowering import ctypes_vector_mapper + __all__ = ['CBB', 'CDataManager', 'COrchestrator'] -CCFloat = CustomNpType('_Complex float', np.complex64) -CCDouble = CustomNpType('_Complex double', np.complex128) +class CCFloat(np.complex64): + pass + + +class CCDouble(np.complex128): + pass + + +c_complex = type('_Complex float', (ct.c_double,), {}) +c_double_complex = type('_Complex double', (ct.c_longdouble,), {}) + +ctypes_vector_mapper[CCFloat] = c_complex +ctypes_vector_mapper[CCDouble] = c_double_complex class CBB(LangBB): diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index 5f74070472..fb802acb8b 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -1,8 +1,9 @@ +import ctypes as ct import numpy as np from devito.ir import Call, UsingNamespace from devito.passes.iet.langbase import LangBB -from devito.tools import CustomNpType +from devito.tools.dtypes_lowering import ctypes_vector_mapper __all__ = ['CXXBB'] @@ -43,8 +44,21 @@ """ -CXXCFloat = CustomNpType('std::complex', np.complex64, template='float') -CXXCDouble = CustomNpType('std::complex', np.complex128, template='double') + +class CXXCFloat(np.complex64): + pass + + +class CXXCDouble(np.complex128): + pass + + +cxx_complex = type('std::complex', (ct.c_double,), {}) +cxx_double_complex = type('std::complex', (ct.c_longdouble,), {}) + + +ctypes_vector_mapper[CXXCFloat] = cxx_complex +ctypes_vector_mapper[CXXCDouble] = cxx_double_complex class CXXBB(LangBB): diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index c9c73ed0b4..77bc407dd6 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -48,7 +48,7 @@ def single_prec(self, expr=None, with_f=False): if no_f and expr is not None: return False dtype = sympy_dtype(expr) if expr is not None else self.dtype - return dtype in [np.float32, np.float16, np.complex64] + return any(issubclass(dtype, d) for d in [np.float32, np.float16, np.complex64]) def complex_prec(self, expr=None): if self.compiler._cpp: diff --git a/devito/tools/dtypes_lowering.py b/devito/tools/dtypes_lowering.py index 3d04f73e84..43def2d8cd 100644 --- a/devito/tools/dtypes_lowering.py +++ b/devito/tools/dtypes_lowering.py @@ -13,7 +13,7 @@ 'double3', 'double4', 'dtypes_vector_mapper', 'dtype_to_mpidtype', 'dtype_to_cstr', 'dtype_to_ctype', 'dtype_to_mpitype', 'dtype_len', 'ctypes_to_cstr', 'c_restrict_void_p', 'ctypes_vector_mapper', - 'is_external_ctype', 'infer_dtype', 'CustomDtype', 'CustomNpType'] + 'is_external_ctype', 'infer_dtype', 'CustomDtype'] # *** Custom np.dtypes @@ -123,18 +123,6 @@ def __repr__(self): __str__ = __repr__ -class CustomNpType(CustomDtype): - """ - Custom dtype for underlying numpy type. - """ - - def __init__(self, name, nptype, template=None, modifier=None): - self.nptype = nptype - super().__init__(name, template, modifier) - - def __call__(self, val): - return self.nptype(val) - # *** np.dtypes lowering