Skip to content

Commit

Permalink
Merge pull request #2388 from devitocodes/dtype-print
Browse files Browse the repository at this point in the history
API: fix printer dtype processing
  • Loading branch information
mloubout authored Jun 18, 2024
2 parents be7c403 + 3c95932 commit 698fdf2
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 10 deletions.
8 changes: 6 additions & 2 deletions devito/passes/iet/linearization.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,12 @@ def key1(f, d):
"""
if f.is_regular:
# For paddable objects the following holds:
# `same dim + same halo => same (auto-)padding`
return (d, f._size_halo[d], f.is_autopaddable)
# `same dim + same halo + same dtype => same (auto-)padding`
# Bundle need the actual function dtype
if f.is_Bundle:
return (d, f._size_halo[d], f.is_autopaddable, f.c0.dtype)
else:
return (d, f._size_halo[d], f.is_autopaddable, f.dtype)
else:
return False

Expand Down
19 changes: 11 additions & 8 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def dtype(self):
def compiler(self):
return self._settings['compiler']

def single_prec(self, expr=None):
dtype = sympy_dtype(expr) if expr is not None else self.dtype
return dtype in [np.float32, np.float16]

def parenthesize(self, item, level, strict=False):
if isinstance(item, BooleanFunction):
return "(%s)" % self._print(item)
Expand Down Expand Up @@ -104,9 +108,8 @@ def _print_math_func(self, expr, nest=False, known=None):
except KeyError:
return super()._print_math_func(expr, nest=nest, known=known)

dtype = sympy_dtype(expr)
if dtype is np.float32:
cname += 'f'
if self.single_prec(expr):
cname = '%sf' % cname

args = ', '.join((self._print(arg) for arg in expr.args))

Expand All @@ -116,7 +119,7 @@ def _print_Pow(self, expr):
# Need to override because of issue #1627
# E.g., (Pow(h_x, -1) AND h_x.dtype == np.float32) => 1.0F/h_x
try:
if expr.exp == -1 and self.dtype == np.float32:
if expr.exp == -1 and self.single_prec():
PREC = precedence(expr)
return '1.0F/%s' % self.parenthesize(expr.base, PREC)
except AttributeError:
Expand Down Expand Up @@ -196,8 +199,8 @@ def _print_Float(self, expr):
elif rv.startswith('.0'):
rv = '0.' + rv[2:]

if self.dtype == np.float32:
rv = rv + 'F'
if self.single_prec():
rv = '%sF' % rv

return rv

Expand Down Expand Up @@ -252,8 +255,8 @@ def _print_ComponentAccess(self, expr):

def _print_TrigonometricFunction(self, expr):
func_name = str(expr.func)
if self.dtype == np.float32:
func_name += 'f'
if self.single_prec():
func_name = '%sf' % func_name
return '%s(%s)' % (func_name, self._print(*expr.args))

def _print_DefFunction(self, expr):
Expand Down
8 changes: 8 additions & 0 deletions devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,14 @@ def inject(self, field, expr, u_t=None, p_t=None, implicit_dims=None):

return super().inject(field, expr, implicit_dims=implicit_dims)

@property
def forward(self):
"""Symbol for the time-forward state of the TimeFunction."""
i = int(self.time_order / 2) if self.time_order >= 2 else 1
_t = self.dimensions[self._time_position]

return self._subs(_t, _t + i * _t.spacing)


class PrecomputedSparseFunction(AbstractSparseFunction):
"""
Expand Down
1 change: 1 addition & 0 deletions docker/Dockerfile.intel
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ RUN apt-get update -y && apt-get dist-upgrade -y && \
libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev

ENV MPI4PY_FLAGS='. /opt/intel/oneapi/setvars.sh intel64 && '
ENV MPI4PY_RC_RECV_MPROBE=0

##############################################################
# ICC image
Expand Down
20 changes: 20 additions & 0 deletions tests/test_linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,3 +592,23 @@ def test_inc_w_default_dims():
assert f.shape[0]*k._default_value == 35
assert np.all(g.data[3] == f.shape[0]*k._default_value)
assert np.all(g.data[4:] == 0)


def test_different_dtype():
space_order = 4

grid = Grid(shape=(4, 4))

f = Function(name='f', grid=grid, space_order=space_order)
b = Function(name='b', grid=grid, space_order=space_order, dtype=np.float64)

f.data[:] = 2.1
b.data[:] = 1.3

eq = Eq(f, b.dx + f.dy)

op1 = Operator(eq, opt=('advanced', {'linearize': True}))

# Check generated code has different strides for different dtypes
assert "bL0(x,y) b[(x)*y_stride0 + (y)]" in str(op1)
assert "L0(x,y) f[(x)*y_stride1 + (y)]" in str(op1)

0 comments on commit 698fdf2

Please sign in to comment.