-
Notifications
You must be signed in to change notification settings - Fork 229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
API: Introducing complex (np.complex64/128) native support #2375
base: master
Are you sure you want to change the base?
Changes from all commits
2b28090
aa353b4
92dfd9a
4364524
470f4f5
7ffff0a
d1dd24e
4f43f26
6b4f12d
9abeea8
94d5571
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -92,8 +92,12 @@ def initialize(cls): | |
return | ||
|
||
def alloc(self, shape, dtype, padding=0): | ||
datasize = int(reduce(mul, shape)) | ||
ctype = dtype_to_ctype(dtype) | ||
# For complex number, allocate double the size of its real/imaginary part | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. potentially useful elsewhere, so I'd move it into a function inside |
||
alloc_dtype = dtype(0).real.__class__ | ||
c_scale = 2 if np.issubdtype(dtype, np.complexfloating) else 1 | ||
|
||
datasize = int(reduce(mul, shape) * c_scale) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. potentially PITA comment... An observation: for complex we implement SoA, while for bundles we implement AoS it's true we don't support it yet, but one day we might want to be able to allocate bundles-like Functions (hence AoS) in user-land so, long story short, here I'd have
where for now you just put in the logic above as I said, this is probably a nitpick, so feel free to ignore |
||
ctype = dtype_to_ctype(alloc_dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you should pass in |
||
|
||
# Add padding, if any | ||
try: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,8 +14,7 @@ | |
from devito.logger import warning | ||
from devito.tools import (as_tuple, filter_ordered, flatten, frozendict, | ||
infer_dtype, is_integer, split) | ||
from devito.types import (Array, DimensionTuple, Evaluable, Indexed, | ||
StencilDimension) | ||
from devito.types import Array, DimensionTuple, Evaluable, StencilDimension | ||
|
||
__all__ = ['Differentiable', 'DiffDerivative', 'IndexDerivative', 'EvalDerivative', | ||
'Weights'] | ||
|
@@ -68,7 +67,7 @@ def grid(self): | |
|
||
@cached_property | ||
def dtype(self): | ||
dtypes = {f.dtype for f in self.find(Indexed)} - {None} | ||
dtypes = {f.dtype for f in self._functions} - {None} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. .find is so expensive! Good we can get rid of it |
||
return infer_dtype(dtypes) | ||
|
||
@cached_property | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -469,6 +469,8 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): | |
|
||
# Lower IET to a target-specific IET | ||
graph = Graph(iet, **kwargs) | ||
|
||
# Specialize | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. irrelevant |
||
graph = cls._specialize_iet(graph, **kwargs) | ||
|
||
# Instrument the IET for C-level profiling | ||
|
@@ -1347,7 +1349,8 @@ def parse_kwargs(**kwargs): | |
raise InvalidOperator("Illegal `compiler=%s`" % str(compiler)) | ||
kwargs['compiler'] = compiler_registry[compiler](platform=kwargs['platform'], | ||
language=kwargs['language'], | ||
mpi=configuration['mpi']) | ||
mpi=configuration['mpi'], | ||
name=compiler) | ||
elif any([platform, language]): | ||
kwargs['compiler'] =\ | ||
configuration['compiler'].__new_with__(platform=kwargs['platform'], | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction, | ||
FindSymbols, MapExprStmts, Transformer, make_callable) | ||
from devito.passes import is_gpu_create | ||
from devito.passes.iet.dtypes import lower_complex | ||
from devito.passes.iet.engine import iet_pass | ||
from devito.passes.iet.langbase import LangBB | ||
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer, | ||
|
@@ -73,10 +74,12 @@ class DataManager: | |
The language used to express data allocations, deletions, and host-device transfers. | ||
""" | ||
|
||
def __init__(self, rcompile=None, sregistry=None, platform=None, **kwargs): | ||
def __init__(self, rcompile=None, sregistry=None, platform=None, | ||
compiler=None, **kwargs): | ||
self.rcompile = rcompile | ||
self.sregistry = sregistry | ||
self.platform = platform | ||
self.compiler = compiler | ||
|
||
def _alloc_object_on_low_lat_mem(self, site, obj, storage): | ||
""" | ||
|
@@ -409,12 +412,18 @@ def place_casts(self, iet, **kwargs): | |
|
||
return iet, {} | ||
|
||
@iet_pass | ||
def make_langtypes(self, iet): | ||
iet, metadata = lower_complex(iet, self.lang, self.compiler) | ||
return iet, metadata | ||
|
||
def process(self, graph): | ||
""" | ||
Apply the `place_definitions` and `place_casts` passes. | ||
""" | ||
self.place_definitions(graph, globs=set()) | ||
self.place_casts(graph) | ||
self.make_langtypes(graph) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you're calling this in the wrong place. It should be called within https://github.com/devitocodes/devito/blob/master/devito/operator/operator.py#L480 that is potentially somewhere here: https://github.com/devitocodes/devito/blob/master/devito/passes/iet/misc.py#L166 Note the name There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had it there originally but I thought this would be more fitting there
But I don't have a string preference |
||
|
||
|
||
class DeviceAwareDataManager(DataManager): | ||
|
@@ -564,6 +573,7 @@ def process(self, graph): | |
self.place_devptr(graph) | ||
self.place_bundling(graph, writes_input=graph.writes_input) | ||
self.place_casts(graph) | ||
self.make_langtypes(graph) | ||
|
||
|
||
def make_zero_init(obj): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import numpy as np | ||
import ctypes | ||
|
||
from devito.ir import FindSymbols, Uxreplace | ||
|
||
__all__ = ['lower_complex'] | ||
|
||
|
||
def lower_complex(iet, lang, compiler): | ||
""" | ||
Add headers for complex arithmetic | ||
""" | ||
# Check if there is complex numbers that always take dtype precedence | ||
types = {f.dtype for f in FindSymbols().visit(iet) | ||
if not issubclass(f.dtype, ctypes._Pointer)} | ||
|
||
if not any(np.issubdtype(d, np.complexfloating) for d in types): | ||
return iet, {} | ||
|
||
lib = (lang['header-complex'],) | ||
|
||
metadata = {} | ||
if lang.get('complex-namespace') is not None: | ||
metadata['namespaces'] = lang['complex-namespace'] | ||
|
||
# Some languges such as c++11 need some extra arithmetic definitions | ||
if lang.get('def-complex'): | ||
dest = compiler.get_jit_dir() | ||
hfile = dest.joinpath('complex_arith.h') | ||
with open(str(hfile), 'w') as ff: | ||
ff.write(str(lang['def-complex'])) | ||
lib += (str(hfile),) | ||
|
||
iet = _complex_dtypes(iet, lang) | ||
metadata['includes'] = lib | ||
|
||
return iet, metadata | ||
|
||
|
||
def _complex_dtypes(iet, lang): | ||
""" | ||
Lower dtypes to language specific types | ||
""" | ||
mapper = {} | ||
|
||
for s in FindSymbols('indexeds|basics|symbolics').visit(iet): | ||
if s.dtype in lang['types']: | ||
mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) | ||
|
||
body = Uxreplace(mapper).visit(iet.body) | ||
params = Uxreplace(mapper).visit(iet.parameters) | ||
iet = iet._rebuild(body=body, parameters=params) | ||
|
||
return iet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need the
name
anymore no?