Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API: Introducing complex (np.complex64/128) native support #2375

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
5 changes: 3 additions & 2 deletions devito/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def reinit_compiler(val):
"""
Re-initialize the Compiler.
"""
configuration['compiler'].__init__(suffix=configuration['compiler'].suffix,
configuration['compiler'].__init__(name=configuration['compiler'].name,
suffix=configuration['compiler'].suffix,
mpi=configuration['mpi'])
return val

Expand All @@ -65,7 +66,7 @@ def reinit_compiler(val):
configuration.add('platform', 'cpu64', list(platform_registry),
callback=lambda i: platform_registry[i]())
configuration.add('compiler', 'custom', list(compiler_registry),
callback=lambda i: compiler_registry[i]())
callback=lambda i: compiler_registry[i](name=i))

# Setup language for shared-memory parallelism
preprocessor = lambda i: {0: 'C', 1: 'openmp'}.get(i, i) # Handles DEVITO_OPENMP deprec
Expand Down
9 changes: 6 additions & 3 deletions devito/arch/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def __init__(self):
_cpp = False

def __init__(self, **kwargs):
self._name = kwargs.pop('name', self.__class__.__name__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need the name anymore no?


super().__init__(**kwargs)

self.__lookup_cmds__()
Expand Down Expand Up @@ -223,13 +225,13 @@ def __new_with__(self, **kwargs):
Create a new Compiler from an existing one, inherenting from it
the flags that are not specified via ``kwargs``.
"""
return self.__class__(suffix=kwargs.pop('suffix', self.suffix),
return self.__class__(name=self.name, suffix=kwargs.pop('suffix', self.suffix),
mpi=kwargs.pop('mpi', configuration['mpi']),
**kwargs)

@property
def name(self):
return self.__class__.__name__
return self._name

@property
def version(self):
Expand Down Expand Up @@ -593,7 +595,7 @@ def __init_finalize__(self, **kwargs):
self.cflags.remove('-O3')
self.cflags.remove('-Wall')

self.cflags.append('-std=c++11')
self.cflags.append('-std=c++14')

language = kwargs.pop('language', configuration['language'])
platform = kwargs.pop('platform', configuration['platform'])
Expand Down Expand Up @@ -978,6 +980,7 @@ def __new_with__(self, **kwargs):
'nvc++': NvidiaCompiler,
'nvidia': NvidiaCompiler,
'cuda': CudaCompiler,
'nvcc': CudaCompiler,
'osx': ClangCompiler,
'intel': OneapiCompiler,
'icx': OneapiCompiler,
Expand Down
8 changes: 6 additions & 2 deletions devito/data/allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,12 @@ def initialize(cls):
return

def alloc(self, shape, dtype, padding=0):
datasize = int(reduce(mul, shape))
ctype = dtype_to_ctype(dtype)
# For complex number, allocate double the size of its real/imaginary part
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potentially useful elsewhere, so I'd move it into a function inside devito/tools/dtypes_lowering maybe?

alloc_dtype = dtype(0).real.__class__
c_scale = 2 if np.issubdtype(dtype, np.complexfloating) else 1

datasize = int(reduce(mul, shape) * c_scale)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potentially PITA comment...

An observation: for complex we implement SoA, while for bundles we implement AoS

it's true we don't support it yet, but one day we might want to be able to allocate bundles-like Functions (hence AoS) in user-land

so, long story short, here I'd have

datasize  = infer_datasize(shape, dtype)

where for now you just put in the logic above

as I said, this is probably a nitpick, so feel free to ignore

ctype = dtype_to_ctype(alloc_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should pass in dtype and extend dtype_to_ctype figure out the rest


# Add padding, if any
try:
Expand Down
5 changes: 2 additions & 3 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from devito.logger import warning
from devito.tools import (as_tuple, filter_ordered, flatten, frozendict,
infer_dtype, is_integer, split)
from devito.types import (Array, DimensionTuple, Evaluable, Indexed,
StencilDimension)
from devito.types import Array, DimensionTuple, Evaluable, StencilDimension

__all__ = ['Differentiable', 'DiffDerivative', 'IndexDerivative', 'EvalDerivative',
'Weights']
Expand Down Expand Up @@ -68,7 +67,7 @@ def grid(self):

@cached_property
def dtype(self):
dtypes = {f.dtype for f in self.find(Indexed)} - {None}
dtypes = {f.dtype for f in self._functions} - {None}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.find is so expensive! Good we can get rid of it

return infer_dtype(dtypes)

@cached_property
Expand Down
38 changes: 25 additions & 13 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sympy import IndexedBase
from sympy.core.function import Application

from devito.parameters import configuration, switchconfig
from devito.exceptions import VisitorException
from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle,
Call, Lambda, BlankLine, Section, ListMajor)
Expand Down Expand Up @@ -176,7 +177,7 @@ class CGen(Visitor):

def __init__(self, *args, compiler=None, **kwargs):
super().__init__(*args, **kwargs)
self._compiler = compiler
self._compiler = compiler or configuration['compiler']

# The following mappers may be customized by subclasses (that is,
# backend-specific CGen-erators)
Expand All @@ -188,6 +189,16 @@ def __init__(self, *args, compiler=None, **kwargs):
}
_restrict_keyword = 'restrict'

@property
def compiler(self):
return self._compiler

def visit(self, o, *args, **kwargs):
# Make sure the visitor always is within the generating compiler
# in case the configuration is accessed
with switchconfig(compiler=self.compiler.name):
return super().visit(o, *args, **kwargs)

def _gen_struct_decl(self, obj, masked=()):
"""
Convert ctypes.Struct -> cgen.Structure.
Expand Down Expand Up @@ -376,10 +387,11 @@ def visit_tuple(self, o):
def visit_PointerCast(self, o):
f = o.function
i = f.indexed
cstr = i._C_typedata

if f.is_PointerArray:
# lvalue
lvalue = c.Value(i._C_typedata, '**%s' % f.name)
lvalue = c.Value(cstr, '**%s' % f.name)

# rvalue
if isinstance(o.obj, ArrayObject):
Expand All @@ -388,7 +400,7 @@ def visit_PointerCast(self, o):
v = f._C_name
else:
assert False
rvalue = '(%s**) %s' % (i._C_typedata, v)
rvalue = '(%s**) %s' % (cstr, v)

else:
# lvalue
Expand All @@ -399,10 +411,10 @@ def visit_PointerCast(self, o):
if o.flat is None:
shape = ''.join("[%s]" % ccode(i) for i in o.castshape)
rshape = '(*)%s' % shape
lvalue = c.Value(i._C_typedata, '(*restrict %s)%s' % (v, shape))
lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape))
else:
rshape = '*'
lvalue = c.Value(i._C_typedata, '*%s' % v)
lvalue = c.Value(cstr, '*%s' % v)
if o.alignment and f._data_alignment:
lvalue = c.AlignedAttribute(f._data_alignment, lvalue)

Expand All @@ -415,30 +427,30 @@ def visit_PointerCast(self, o):
else:
assert False

rvalue = '(%s %s) %s->%s' % (i._C_typedata, rshape, f._C_name, v)
rvalue = '(%s %s) %s->%s' % (cstr, rshape, f._C_name, v)
else:
if isinstance(o.obj, Pointer):
v = o.obj.name
else:
v = f._C_name

rvalue = '(%s %s) %s' % (i._C_typedata, rshape, v)
rvalue = '(%s %s) %s' % (cstr, rshape, v)

return c.Initializer(lvalue, rvalue)

def visit_Dereference(self, o):
a0, a1 = o.functions
if a1.is_PointerArray or a1.is_TempFunction:
i = a1.indexed
cstr = i._C_typedata
if o.flat is None:
shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:])
rvalue = '(%s (*)%s) %s[%s]' % (i._C_typedata, shape, a1.name,
rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name,
a1.dim.name)
lvalue = c.Value(i._C_typedata,
'(*restrict %s)%s' % (a0.name, shape))
lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape))
else:
rvalue = '(%s *) %s[%s]' % (i._C_typedata, a1.name, a1.dim.name)
lvalue = c.Value(i._C_typedata, '*restrict %s' % a0.name)
rvalue = '(%s *) %s[%s]' % (cstr, a1.name, a1.dim.name)
lvalue = c.Value(cstr, '*restrict %s' % a0.name)
if a0._data_alignment:
lvalue = c.AlignedAttribute(a0._data_alignment, lvalue)
else:
Expand Down Expand Up @@ -590,7 +602,7 @@ def visit_MultiTraversable(self, o):
return c.Collection(body)

def visit_UsingNamespace(self, o):
return c.Statement('using namespace %s' % ccode(o.namespace))
return c.Statement('using namespace %s' % str(o.namespace))

def visit_Lambda(self, o):
body = []
Expand Down
5 changes: 4 additions & 1 deletion devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,8 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):

# Lower IET to a target-specific IET
graph = Graph(iet, **kwargs)

# Specialize
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

irrelevant

graph = cls._specialize_iet(graph, **kwargs)

# Instrument the IET for C-level profiling
Expand Down Expand Up @@ -1347,7 +1349,8 @@ def parse_kwargs(**kwargs):
raise InvalidOperator("Illegal `compiler=%s`" % str(compiler))
kwargs['compiler'] = compiler_registry[compiler](platform=kwargs['platform'],
language=kwargs['language'],
mpi=configuration['mpi'])
mpi=configuration['mpi'],
name=compiler)
elif any([platform, language]):
kwargs['compiler'] =\
configuration['compiler'].__new_with__(platform=kwargs['platform'],
Expand Down
1 change: 0 additions & 1 deletion devito/passes/clusters/factorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def _collect_nested(expr):
Recursion helper for `collect_nested`.
"""
# Return semantic (rebuilt expression, factorization candidates)

if expr.is_Number:
return expr, {'coeffs': expr}
elif expr.is_Function:
Expand Down
1 change: 1 addition & 0 deletions devito/passes/iet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .instrument import * # noqa
from .languages import * # noqa
from .errors import * # noqa
from .dtypes import * # noqa
12 changes: 11 additions & 1 deletion devito/passes/iet/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction,
FindSymbols, MapExprStmts, Transformer, make_callable)
from devito.passes import is_gpu_create
from devito.passes.iet.dtypes import lower_complex
from devito.passes.iet.engine import iet_pass
from devito.passes.iet.langbase import LangBB
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer,
Expand Down Expand Up @@ -73,10 +74,12 @@ class DataManager:
The language used to express data allocations, deletions, and host-device transfers.
"""

def __init__(self, rcompile=None, sregistry=None, platform=None, **kwargs):
def __init__(self, rcompile=None, sregistry=None, platform=None,
compiler=None, **kwargs):
self.rcompile = rcompile
self.sregistry = sregistry
self.platform = platform
self.compiler = compiler

def _alloc_object_on_low_lat_mem(self, site, obj, storage):
"""
Expand Down Expand Up @@ -409,12 +412,18 @@ def place_casts(self, iet, **kwargs):

return iet, {}

@iet_pass
def make_langtypes(self, iet):
iet, metadata = lower_complex(iet, self.lang, self.compiler)
return iet, metadata

def process(self, graph):
"""
Apply the `place_definitions` and `place_casts` passes.
"""
self.place_definitions(graph, globs=set())
self.place_casts(graph)
self.make_langtypes(graph)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're calling this in the wrong place.

It should be called within https://github.com/devitocodes/devito/blob/master/devito/operator/operator.py#L480

that is potentially somewhere here:

https://github.com/devitocodes/devito/blob/master/devito/passes/iet/misc.py#L166

Note the name generate_macros is somewhat legacy, but you get why I'm suggesting to move it in there

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had it there originally but I thought this would be more fitting there

  • It is the literal needed for all definitions
  • It is Target "Datatype"

But I don't have a string preference



class DeviceAwareDataManager(DataManager):
Expand Down Expand Up @@ -564,6 +573,7 @@ def process(self, graph):
self.place_devptr(graph)
self.place_bundling(graph, writes_input=graph.writes_input)
self.place_casts(graph)
self.make_langtypes(graph)


def make_zero_init(obj):
Expand Down
54 changes: 54 additions & 0 deletions devito/passes/iet/dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np
import ctypes

from devito.ir import FindSymbols, Uxreplace

__all__ = ['lower_complex']


def lower_complex(iet, lang, compiler):
"""
Add headers for complex arithmetic
"""
# Check if there is complex numbers that always take dtype precedence
types = {f.dtype for f in FindSymbols().visit(iet)
if not issubclass(f.dtype, ctypes._Pointer)}

if not any(np.issubdtype(d, np.complexfloating) for d in types):
return iet, {}

lib = (lang['header-complex'],)

metadata = {}
if lang.get('complex-namespace') is not None:
metadata['namespaces'] = lang['complex-namespace']

# Some languges such as c++11 need some extra arithmetic definitions
if lang.get('def-complex'):
dest = compiler.get_jit_dir()
hfile = dest.joinpath('complex_arith.h')
with open(str(hfile), 'w') as ff:
ff.write(str(lang['def-complex']))
lib += (str(hfile),)

iet = _complex_dtypes(iet, lang)
metadata['includes'] = lib

return iet, metadata


def _complex_dtypes(iet, lang):
"""
Lower dtypes to language specific types
"""
mapper = {}

for s in FindSymbols('indexeds|basics|symbolics').visit(iet):
if s.dtype in lang['types']:
mapper[s] = s._rebuild(dtype=lang['types'][s.dtype])

body = Uxreplace(mapper).visit(iet.body)
params = Uxreplace(mapper).visit(iet.parameters)
iet = iet._rebuild(body=body, parameters=params)

return iet
11 changes: 11 additions & 0 deletions devito/passes/iet/langbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def __getitem__(self, k):
raise NotImplementedError("Missing required mapping for `%s`" % k)
return self.mapper[k]

def get(self, k):
return self.mapper.get(k)


class LangBB(metaclass=LangMeta):

Expand Down Expand Up @@ -200,6 +203,14 @@ def initialize(self, iet, options=None):
"""
return iet, {}

@iet_pass
def make_langtypes(self, iet):
"""
An `iet_pass` which transforms an IET such that the target language
types are introduced.
"""
return iet, {}

@property
def Region(self):
return self.lang.Region
Expand Down
Loading
Loading