Skip to content

Commit 698fdf2

Browse files
authored
Merge pull request #2388 from devitocodes/dtype-print
API: fix printer dtype processing
2 parents be7c403 + 3c95932 commit 698fdf2

File tree

5 files changed

+46
-10
lines changed

5 files changed

+46
-10
lines changed

devito/passes/iet/linearization.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,12 @@ def key1(f, d):
7171
"""
7272
if f.is_regular:
7373
# For paddable objects the following holds:
74-
# `same dim + same halo => same (auto-)padding`
75-
return (d, f._size_halo[d], f.is_autopaddable)
74+
# `same dim + same halo + same dtype => same (auto-)padding`
75+
# Bundle need the actual function dtype
76+
if f.is_Bundle:
77+
return (d, f._size_halo[d], f.is_autopaddable, f.c0.dtype)
78+
else:
79+
return (d, f._size_halo[d], f.is_autopaddable, f.dtype)
7680
else:
7781
return False
7882

devito/symbolics/printer.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def dtype(self):
3939
def compiler(self):
4040
return self._settings['compiler']
4141

42+
def single_prec(self, expr=None):
43+
dtype = sympy_dtype(expr) if expr is not None else self.dtype
44+
return dtype in [np.float32, np.float16]
45+
4246
def parenthesize(self, item, level, strict=False):
4347
if isinstance(item, BooleanFunction):
4448
return "(%s)" % self._print(item)
@@ -104,9 +108,8 @@ def _print_math_func(self, expr, nest=False, known=None):
104108
except KeyError:
105109
return super()._print_math_func(expr, nest=nest, known=known)
106110

107-
dtype = sympy_dtype(expr)
108-
if dtype is np.float32:
109-
cname += 'f'
111+
if self.single_prec(expr):
112+
cname = '%sf' % cname
110113

111114
args = ', '.join((self._print(arg) for arg in expr.args))
112115

@@ -116,7 +119,7 @@ def _print_Pow(self, expr):
116119
# Need to override because of issue #1627
117120
# E.g., (Pow(h_x, -1) AND h_x.dtype == np.float32) => 1.0F/h_x
118121
try:
119-
if expr.exp == -1 and self.dtype == np.float32:
122+
if expr.exp == -1 and self.single_prec():
120123
PREC = precedence(expr)
121124
return '1.0F/%s' % self.parenthesize(expr.base, PREC)
122125
except AttributeError:
@@ -196,8 +199,8 @@ def _print_Float(self, expr):
196199
elif rv.startswith('.0'):
197200
rv = '0.' + rv[2:]
198201

199-
if self.dtype == np.float32:
200-
rv = rv + 'F'
202+
if self.single_prec():
203+
rv = '%sF' % rv
201204

202205
return rv
203206

@@ -252,8 +255,8 @@ def _print_ComponentAccess(self, expr):
252255

253256
def _print_TrigonometricFunction(self, expr):
254257
func_name = str(expr.func)
255-
if self.dtype == np.float32:
256-
func_name += 'f'
258+
if self.single_prec():
259+
func_name = '%sf' % func_name
257260
return '%s(%s)' % (func_name, self._print(*expr.args))
258261

259262
def _print_DefFunction(self, expr):

devito/types/sparse.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,14 @@ def inject(self, field, expr, u_t=None, p_t=None, implicit_dims=None):
10191019

10201020
return super().inject(field, expr, implicit_dims=implicit_dims)
10211021

1022+
@property
1023+
def forward(self):
1024+
"""Symbol for the time-forward state of the TimeFunction."""
1025+
i = int(self.time_order / 2) if self.time_order >= 2 else 1
1026+
_t = self.dimensions[self._time_position]
1027+
1028+
return self._subs(_t, _t + i * _t.spacing)
1029+
10221030

10231031
class PrecomputedSparseFunction(AbstractSparseFunction):
10241032
"""

docker/Dockerfile.intel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ RUN apt-get update -y && apt-get dist-upgrade -y && \
6060
libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev
6161

6262
ENV MPI4PY_FLAGS='. /opt/intel/oneapi/setvars.sh intel64 && '
63+
ENV MPI4PY_RC_RECV_MPROBE=0
6364

6465
##############################################################
6566
# ICC image

tests/test_linearize.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,3 +592,23 @@ def test_inc_w_default_dims():
592592
assert f.shape[0]*k._default_value == 35
593593
assert np.all(g.data[3] == f.shape[0]*k._default_value)
594594
assert np.all(g.data[4:] == 0)
595+
596+
597+
def test_different_dtype():
598+
space_order = 4
599+
600+
grid = Grid(shape=(4, 4))
601+
602+
f = Function(name='f', grid=grid, space_order=space_order)
603+
b = Function(name='b', grid=grid, space_order=space_order, dtype=np.float64)
604+
605+
f.data[:] = 2.1
606+
b.data[:] = 1.3
607+
608+
eq = Eq(f, b.dx + f.dy)
609+
610+
op1 = Operator(eq, opt=('advanced', {'linearize': True}))
611+
612+
# Check generated code has different strides for different dtypes
613+
assert "bL0(x,y) b[(x)*y_stride0 + (y)]" in str(op1)
614+
assert "L0(x,y) f[(x)*y_stride1 + (y)]" in str(op1)

0 commit comments

Comments
 (0)