Skip to content

Commit

Permalink
compiler: Propagate metadata down to _arg_defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Sep 30, 2024
1 parent 2ae6822 commit 22e5e96
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 31 deletions.
9 changes: 8 additions & 1 deletion devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,11 @@ def input(self):
def temporaries(self):
return tuple(i for i in self.parameters if i.is_TempFunction)

@cached_property
def transients(self):
return tuple(i for i in self.parameters
if i.is_AbstractFunction and i.is_transient)

@cached_property
def objects(self):
return tuple(i for i in self.parameters if i.is_Object)
Expand Down Expand Up @@ -560,7 +565,9 @@ def _prepare_arguments(self, autotune=None, **kwargs):

# Prepare to process data-carriers
args = kwargs['args'] = ReducerMap()
kwargs['metadata'] = self.threads_info
kwargs['metadata'] = {'platform': self._platform,
'transients': self.transients,
**self.threads_info}

overrides, defaults = split(self.input, lambda p: p.name in kwargs)

Expand Down
39 changes: 22 additions & 17 deletions devito/passes/iet/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,30 +470,22 @@ def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm, read_only=F
if is_gpu_create(obj, self.gpu_create):
mmap = self.lang._map_alloc(obj)

cdims = make_zero_init(obj)
eq = Eq(obj[cdims], 0)

irs, _ = self.rcompile(eq)

init = irs.iet.body.body[0]

name = self.sregistry.make_name(prefix='init')
efunc = make_callable(name, init)
init = Call(name, efunc.parameters)
efuncs, init = make_zero_init(obj, self.rcompile, self.sregistry)

mmap = (mmap, init)
else:
mmap = self.lang._map_to(obj)
efunc = ()
efuncs = ()

unmap = [self.lang._map_update(obj),
self.lang._map_release(obj, devicerm=devicerm)]
# Copy back to host memory, release device memory
unmap = (self.lang._map_update(obj),
self.lang._map_release(obj, devicerm=devicerm))
else:
mmap = self.lang._map_to(obj)
efunc = ()
efuncs = ()
unmap = self.lang._map_delete(obj, devicerm=devicerm)

storage.update(obj, site, maps=mmap, unmaps=unmap, efuncs=efunc)
storage.update(obj, site, maps=mmap, unmaps=unmap, efuncs=efuncs)

@iet_pass
def place_transfers(self, iet, data_movs=None, **kwargs):
Expand Down Expand Up @@ -566,7 +558,7 @@ def process(self, graph):
self.place_casts(graph)


def make_zero_init(obj):
def make_zero_init(obj, rcompile, sregistry):
cdims = []
for d, (h0, h1), s in zip(obj.dimensions, obj._size_halo, obj.symbolic_shape):
if d.is_NonlinearDerived:
Expand All @@ -578,4 +570,17 @@ def make_zero_init(obj):
M = d.symbolic_max + h1
cdims.append(CustomDimension(name=d.name, parent=d,
symbolic_min=m, symbolic_max=M))
return cdims

eq = Eq(obj[cdims], 0)

irs, byproduct = rcompile(eq)

init = irs.iet.body.body[0]

name = sregistry.make_name(prefix='init')
efunc = make_callable(name, init)
init = Call(name, efunc.parameters)

efuncs = [efunc] + [i.root for i in byproduct.funcs]

return efuncs, init
30 changes: 17 additions & 13 deletions devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,6 @@ def __staggered_setup__(self, **kwargs):
def _functions(self):
return {self.function}

@property
def _data_buffer(self):
"""
Reference to the data. Unlike :attr:`data` and :attr:`data_with_halo`,
this *never* returns a view of the data. This method is for internal use only.
"""
return self._data_allocated

@property
def _data_alignment(self):
return self._allocator.guaranteed_alignment
Expand Down Expand Up @@ -558,6 +550,14 @@ def _data_allocated(self):
self._halo_exchange()
return np.asarray(self._data)

def _data_buffer(self, **kwargs):
"""
Reference to the data. Unlike :attr:`data` and :attr:`data_with_halo`,
this *never* returns a view of the data. This method is for internal
use only.
"""
return self._data_allocated

def _data_in_region(self, region, dim, side):
"""
The data values in a given region.
Expand Down Expand Up @@ -691,9 +691,11 @@ def _C_make_dataobj(self, alias=None, **args):
"""
key = alias or self
data = args[key.name]

dataobj = byref(self._C_ctype._type_())
dataobj._obj.data = data.ctypes.data_as(c_restrict_void_p)
dataobj._obj.size = (c_ulong*self.ndim)(*data.shape)

# MPI-related fields
dataobj._obj.npsize = (c_ulong*self.ndim)(*[i - sum(j) for i, j in
zip(data.shape, self._size_padding)])
Expand Down Expand Up @@ -809,7 +811,7 @@ def _arg_names(self):
"""Tuple of argument names introduced by this function."""
return (self.name,)

def _arg_defaults(self, alias=None):
def _arg_defaults(self, alias=None, metadata=None):
"""
A map of default argument values defined by this symbol.
Expand All @@ -819,15 +821,15 @@ def _arg_defaults(self, alias=None):
To bind the argument values to different names.
"""
key = alias or self
args = ReducerMap({key.name: self._data_buffer})
args = ReducerMap({key.name: self._data_buffer(metadata=metadata)})

# Collect default dimension arguments from all indices
for a, i, s in zip(key.dimensions, self.dimensions, self.shape):
args.update(i._arg_defaults(_min=0, size=s, alias=a))

return args

def _arg_values(self, **kwargs):
def _arg_values(self, metadata=None, **kwargs):
"""
A map of argument values after evaluating user input. If no
user input is provided, return a default value.
Expand All @@ -843,7 +845,8 @@ def _arg_values(self, **kwargs):
new = kwargs.pop(self.name)
if isinstance(new, DiscreteFunction):
# Set new values and re-derive defaults
values = new._arg_defaults(alias=self).reduce_all()
values = new._arg_defaults(alias=self, metadata=metadata)
values = values.reduce_all()
else:
# We've been provided a pure-data replacement (array)
values = {self.name: new}
Expand All @@ -852,7 +855,8 @@ def _arg_values(self, **kwargs):
size = s - sum(self._size_nodomain[i])
values.update(i._arg_defaults(size=size))
else:
values = self._arg_defaults(alias=self).reduce_all()
values = self._arg_defaults(alias=self, metadata=metadata)
values = values.reduce_all()

return values

Expand Down

0 comments on commit 22e5e96

Please sign in to comment.