Skip to content

Commit

Permalink
misc: rework multituple fir easier use
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Sep 26, 2023
1 parent ccd4f72 commit 2e5c8e2
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 64 deletions.
2 changes: 1 addition & 1 deletion devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _normalize_kwargs(cls, **kwargs):
o['cire-schedule'] = oo.pop('cire-schedule', cls.CIRE_SCHEDULE)

# GPU parallelism
o['par-tile'] = ParTile(oo.pop('par-tile', False), default=(32, 4))
o['par-tile'] = ParTile(oo.pop('par-tile', False), default=(32, 4, 4, 4))
o['par-collapse-ncores'] = 1 # Always collapse (meaningful if `par-tile=False`)
o['par-collapse-work'] = 1 # Always collapse (meaningful if `par-tile=False`)
o['par-chunk-nonaffine'] = oo.pop('par-chunk-nonaffine', cls.PAR_CHUNK_NONAFFINE)
Expand Down
4 changes: 3 additions & 1 deletion devito/core/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,8 @@ class OptOption(object):
class ParTileArg(tuple):

def __new__(cls, items, shm=0, tag=None):
if items is None:
items = tuple()
obj = super().__new__(cls, items)
obj.shm = shm
obj.tag = tag
Expand All @@ -340,7 +342,7 @@ class ParTile(tuple, OptOption):

def __new__(cls, items, default=None):
if not items:
return None
return tuple()
elif isinstance(items, bool):
if not default:
raise ValueError("Expected `default` value, got None")
Expand Down
14 changes: 5 additions & 9 deletions devito/passes/iet/languages/openacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from devito.passes.iet.languages.C import CBB
from devito.passes.iet.languages.openmp import OmpRegion, OmpIteration
from devito.symbolics import FieldFromPointer, Macro, cast_mapper
from devito.tools import filter_ordered
from devito.tools import filter_ordered, UnboundTuple
from devito.types import DeviceMap, Symbol

__all__ = ['DeviceAccizer', 'DeviceAccDataManager', 'AccOrchestrator']
Expand All @@ -30,7 +30,8 @@ def _make_clauses(cls, ncollapsed=0, reduction=None, tile=None, **kwargs):
clauses = []

if tile:
clauses.append('tile(%s)' % ','.join(str(i) for i in tile))
stile = [str(tile[i]) for i in range(ncollapsed)]
clauses.append('tile(%s)' % ','.join(stile))
elif ncollapsed > 1:
clauses.append('collapse(%d)' % ncollapsed)

Expand Down Expand Up @@ -159,18 +160,13 @@ def _make_partree(self, candidates, nthreads=None):
assert candidates

root, collapsable = self._select_candidates(candidates)
ncollapsable = len(collapsable)
ncollapsable = len(collapsable) + 1

if self._is_offloadable(root) and \
all(i.is_Affine for i in [root] + collapsable) and \
self.par_tile:
tile = self.par_tile.next()
assert isinstance(tile, tuple)
nremainder = (ncollapsable + 1) - len(tile)
if nremainder >= 0:
tile += (tile[-1],)*nremainder
else:
tile = tile[:ncollapsable + 1]
assert isinstance(tile, UnboundTuple)

body = self.DeviceIteration(gpu_fit=self.gpu_fit, tile=tile,
ncollapsed=ncollapsable, **root.args)
Expand Down
4 changes: 2 additions & 2 deletions devito/passes/iet/parpragma.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def __init__(self, sregistry, options, platform, compiler):
super().__init__(sregistry, options, platform, compiler)

self.gpu_fit = options['gpu-fit']
self.par_tile = UnboundTuple(options['par-tile'])
self.par_tile = UnboundTuple(*options['par-tile'])
self.par_disabled = options['par-disabled']

def _make_threaded_prodders(self, partree):
Expand Down Expand Up @@ -613,7 +613,7 @@ def _make_partree(self, candidates, nthreads=None, index=None):

if self._is_offloadable(root):
body = self.DeviceIteration(gpu_fit=self.gpu_fit,
ncollapsed=len(collapsable) + 1,
ncollapsed=len(collapsable)+1,
tile=self.par_tile.next(),
**root.args)
partree = ParallelTree([], body, nthreads=nthreads)
Expand Down
128 changes: 80 additions & 48 deletions devito/tools/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,76 @@ def __hash__(self):
return self._hash


class UnboundedMultiTuple(object):
class UnboundTuple(object):
"""
An UnboundedTuple is a tuple that can be
infinitely iterated over.
Examples
--------
>>> ub = UnboundTuple((1, 2),(3, 4))
>>> ub
UnboundTuple(UnboundTuple(1, 2), UnboundTuple(3, 4))
>>> ub.next()
UnboundTuple(1, 2)
>>> ub.next()
UnboundTuple(3, 4)
>>> ub.next()
UnboundTuple(3, 4)
"""

def __init__(self, *items):
nitems = []
for i in as_tuple(items):
if isinstance(i, Iterable):
nitems.append(UnboundTuple(*i))
elif i is not None:
nitems.append(i)

self.items = tuple(nitems)
self.last = len(self.items)
self.current = 0

@property
def default(self):
return self.items[0]

def next(self):
if self.last == 0:
return None
item = self.items[self.current]
if self.current == self.last-1 or self.current == -1:
self.current = -1
else:
self.current += 1
return item

def __len__(self):
return self.last

def __repr__(self):
sitems = [s.__repr__() for s in self.items]
return "%s(%s)" % (self.__class__.__name__, ", ".join(sitems))

def __getitem__(self, idx):
if isinstance(idx, slice):
start = idx.start or 0
stop = idx.stop or self.last
if stop < 0:
stop = self.last + stop
step = idx.step or 1
return UnboundTuple(self[i] for i in range(start, stop, step))
try:
if idx >= self.last-1:
return self.items[self.last-1]
else:
return self.items[idx]
except TypeError:
# Slice, ...
return UnboundTuple(self[i] for i in idx)


class UnboundedMultiTuple(UnboundTuple):

"""
An UnboundedMultiTuple is an ordered collection of tuples that can be
Expand All @@ -562,10 +631,10 @@ class UnboundedMultiTuple(object):
--------
>>> ub = UnboundedMultiTuple([1, 2], [3, 4])
>>> ub
UnboundedMultiTuple((1, 2), (3, 4))
UnboundedMultiTuple(UnboundTuple(1, 2), UnboundTuple(3, 4))
>>> ub.iter()
>>> ub
UnboundedMultiTuple(*(1, 2), (3, 4))
UnboundedMultiTuple(UnboundTuple(1, 2), UnboundTuple(3, 4))
>>> ub.next()
1
>>> ub.next()
Expand All @@ -574,7 +643,7 @@ class UnboundedMultiTuple(object):
>>> ub.iter() # No effect, tip has reached the last tuple
>>> ub.iter() # No effect, tip has reached the last tuple
>>> ub
UnboundedMultiTuple((1, 2), *(3, 4))
UnboundedMultiTuple(UnboundTuple(1, 2), UnboundTuple(3, 4))
>>> ub.next()
3
>>> ub.next()
Expand All @@ -585,52 +654,15 @@ class UnboundedMultiTuple(object):
"""

def __init__(self, *items):
# Normalize input
nitems = []
for i in as_tuple(items):
if isinstance(i, Iterable):
nitems.append(tuple(i))
else:
raise ValueError("Expected sequence, got %s" % type(i))

self.items = tuple(nitems)
self.tip = -1
self.curiter = None

def __repr__(self):
items = [str(i) for i in self.items]
if self.curiter is not None:
items[self.tip] = "*%s" % items[self.tip]
return "%s(%s)" % (self.__class__.__name__, ", ".join(items))
super().__init__(*items)
self.current = -1

def iter(self):
if not self.items:
raise ValueError("No tuples available")
self.tip = min(self.tip + 1, max(len(self.items) - 1, 0))
self.curiter = iter(self.items[self.tip])
self.current = min(self.current + 1, self.last - 1)
self.items[self.current].current = 0
return

def next(self):
if self.curiter is None:
if self.items[self.current].current == -1:
raise StopIteration
return next(self.curiter)


class UnboundTuple(object):
"""
A simple data structure that returns the last element forever once reached
"""

def __init__(self, items):
self.items = as_tuple(items)
self.last = len(self.items)
self.current = 0

def next(self):
if self.last == 0:
return None
item = self.items[self.current]
self.current = min(self.last - 1, self.current+1)
return item

def __len__(self):
return self.last
return self.items[self.current].next()
10 changes: 7 additions & 3 deletions tests/test_gpu_openacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,21 @@ def test_tile_insteadof_collapse(self, par_tile):
opt=('advanced', {'par-tile': par_tile}))

trees = retrieve_iteration_tree(op)
stile = (32, 4, 4, 4) if par_tile != (32, 4, 4, 8) else (32, 4, 4, 8)
assert len(trees) == 4

assert trees[0][1].pragmas[0].value ==\
'acc parallel loop tile(32,4,4) present(u)'
assert trees[1][1].pragmas[0].value ==\
'acc parallel loop tile(32,4) present(u)'
# Only the AFFINE Iterations are tiled
strtile = ','.join([str(i) for i in stile])
assert trees[3][1].pragmas[0].value ==\
'acc parallel loop collapse(4) present(src,src_coords,u)'
'acc parallel loop tile(%s) present(src,src_coords,u)' % strtile

@pytest.mark.parametrize('par_tile', [((32, 4, 4), (8, 8)), ((32, 4), (8, 8)),
((32, 4, 4), (8, 8, 8))])
((32, 4, 4), (8, 8, 8)),
((32, 4, 4), (8, 8), None)])
def test_multiple_tile_sizes(self, par_tile):
grid = Grid(shape=(3, 3, 3))
t = grid.stepping_dim
Expand All @@ -136,8 +139,9 @@ def test_multiple_tile_sizes(self, par_tile):
'acc parallel loop tile(32,4,4) present(u)'
assert trees[1][1].pragmas[0].value ==\
'acc parallel loop tile(8,8) present(u)'
sclause = 'collapse(4)' if par_tile[-1] is None else 'tile(8,8,8,8)'
assert trees[3][1].pragmas[0].value ==\
'acc parallel loop collapse(4) present(src,src_coords,u)'
'acc parallel loop %s present(src,src_coords,u)' % sclause

def test_multi_tile_blocking_structure(self):
grid = Grid(shape=(8, 8, 8))
Expand Down

0 comments on commit 2e5c8e2

Please sign in to comment.