Skip to content

Commit

Permalink
compiler: subdtype numpy for dtype lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jul 8, 2024
1 parent 9abeea8 commit 94d5571
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 25 deletions.
6 changes: 1 addition & 5 deletions devito/passes/iet/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,7 @@ def _complex_dtypes(iet, lang):
"""
mapper = {}

for s in FindSymbols('indexeds').visit(iet):
if s.dtype in lang['types']:
mapper[s] = s._rebuild(dtype=lang['types'][s.dtype])

for s in FindSymbols().visit(iet):
for s in FindSymbols('indexeds|basics|symbolics').visit(iet):
if s.dtype in lang['types']:
mapper[s] = s._rebuild(dtype=lang['types'][s.dtype])

Expand Down
19 changes: 16 additions & 3 deletions devito/passes/iet/languages/C.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,29 @@
import ctypes as ct
import numpy as np

from devito.ir import Call
from devito.passes.iet.definitions import DataManager
from devito.passes.iet.orchestration import Orchestrator
from devito.passes.iet.langbase import LangBB
from devito.tools import CustomNpType
from devito.tools.dtypes_lowering import ctypes_vector_mapper


__all__ = ['CBB', 'CDataManager', 'COrchestrator']


CCFloat = CustomNpType('_Complex float', np.complex64)
CCDouble = CustomNpType('_Complex double', np.complex128)
class CCFloat(np.complex64):
pass


class CCDouble(np.complex128):
pass


c_complex = type('_Complex float', (ct.c_double,), {})
c_double_complex = type('_Complex double', (ct.c_longdouble,), {})

ctypes_vector_mapper[CCFloat] = c_complex
ctypes_vector_mapper[CCDouble] = c_double_complex


class CBB(LangBB):
Expand Down
20 changes: 17 additions & 3 deletions devito/passes/iet/languages/CXX.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import ctypes as ct
import numpy as np

from devito.ir import Call, UsingNamespace
from devito.passes.iet.langbase import LangBB
from devito.tools import CustomNpType
from devito.tools.dtypes_lowering import ctypes_vector_mapper

__all__ = ['CXXBB']

Expand Down Expand Up @@ -43,8 +44,21 @@
"""

CXXCFloat = CustomNpType('std::complex', np.complex64, template='float')
CXXCDouble = CustomNpType('std::complex', np.complex128, template='double')

class CXXCFloat(np.complex64):
pass


class CXXCDouble(np.complex128):
pass


cxx_complex = type('std::complex<float>', (ct.c_double,), {})
cxx_double_complex = type('std::complex<double>', (ct.c_longdouble,), {})


ctypes_vector_mapper[CXXCFloat] = cxx_complex
ctypes_vector_mapper[CXXCDouble] = cxx_double_complex


class CXXBB(LangBB):
Expand Down
2 changes: 1 addition & 1 deletion devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def single_prec(self, expr=None, with_f=False):
if no_f and expr is not None:
return False
dtype = sympy_dtype(expr) if expr is not None else self.dtype
return dtype in [np.float32, np.float16, np.complex64]
return any(issubclass(dtype, d) for d in [np.float32, np.float16, np.complex64])

def complex_prec(self, expr=None):
if self.compiler._cpp:
Expand Down
14 changes: 1 addition & 13 deletions devito/tools/dtypes_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
'double3', 'double4', 'dtypes_vector_mapper', 'dtype_to_mpidtype',
'dtype_to_cstr', 'dtype_to_ctype', 'dtype_to_mpitype', 'dtype_len',
'ctypes_to_cstr', 'c_restrict_void_p', 'ctypes_vector_mapper',
'is_external_ctype', 'infer_dtype', 'CustomDtype', 'CustomNpType']
'is_external_ctype', 'infer_dtype', 'CustomDtype']


# *** Custom np.dtypes
Expand Down Expand Up @@ -123,18 +123,6 @@ def __repr__(self):
__str__ = __repr__


class CustomNpType(CustomDtype):
"""
Custom dtype for underlying numpy type.
"""

def __init__(self, name, nptype, template=None, modifier=None):
self.nptype = nptype
super().__init__(name, template, modifier)

def __call__(self, val):
return self.nptype(val)

# *** np.dtypes lowering


Expand Down

0 comments on commit 94d5571

Please sign in to comment.