Skip to content

Commit daf782e

Browse files
authored
Merge pull request #2050 from devitocodes/cond_subd
compiler: Fix placement of ConditionalDimension in subdomain
2 parents cafebc1 + c35a0c1 commit daf782e

File tree

2 files changed

+100
-2
lines changed

2 files changed

+100
-2
lines changed

devito/ir/clusters/algorithms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def guard(clusters):
238238
if cd._factor is not None:
239239
k = d
240240
else:
241-
dims = pull_dims(cd.condition)
241+
dims = pull_dims(cd.condition, flag=False)
242242
k = max(dims, default=d, key=lambda i: c.ispace.index(i))
243243

244244
# Pull `cd` from any expr

tests/test_subdomains.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from conftest import opts_tiling, assert_structure
88
from devito import (ConditionalDimension, Constant, Grid, Function, TimeFunction,
9-
Eq, solve, Operator, SubDomain, SubDomainSet)
9+
Eq, solve, Operator, SubDomain, SubDomainSet, Lt)
1010
from devito.ir import FindNodes, Expression, Iteration
1111
from devito.tools import timed_region
1212

@@ -693,3 +693,101 @@ class Dummy(SubDomainSet):
693693
# Switch the thickness symbols between MultiSubDimensions with the rebuild
694694
remixed = [d._rebuild(thickness=t) for d, t in zip(sdims, tkns[::-1])]
695695
assert [d.thickness for d in remixed] == tkns[::-1]
696+
697+
698+
class TestSubDomain_w_condition(object):
699+
700+
def test_condition_w_subdomain_v0(self):
701+
702+
shape = (10, )
703+
grid = Grid(shape=shape)
704+
x, = grid.dimensions
705+
706+
class Middle(SubDomain):
707+
name = 'middle'
708+
709+
def define(self, dimensions):
710+
return {x: ('middle', 2, 4)}
711+
712+
mid = Middle()
713+
my_grid = Grid(shape=shape, subdomains=(mid, ))
714+
715+
f = Function(name='f', grid=my_grid)
716+
717+
sdf = Function(name='sdf', grid=my_grid)
718+
sdf.data[5:] = 1
719+
720+
condition = Lt(sdf[mid.dimensions[0]], 1)
721+
722+
ci = ConditionalDimension(name='ci', condition=condition,
723+
parent=mid.dimensions[0])
724+
725+
op = Operator(Eq(f, f + 10, implicit_dims=ci,
726+
subdomain=my_grid.subdomains['middle']))
727+
op.apply()
728+
729+
assert_structure(op, ['x'], 'x')
730+
731+
def test_condition_w_subdomain_v1(self):
732+
733+
shape = (10, 10)
734+
grid = Grid(shape=shape)
735+
x, y = grid.dimensions
736+
737+
class Middle(SubDomain):
738+
name = 'middle'
739+
740+
def define(self, dimensions):
741+
return {x: x, y: ('middle', 2, 4)}
742+
743+
mid = Middle()
744+
my_grid = Grid(shape=shape, subdomains=(mid, ))
745+
746+
sdf = Function(name='sdf', grid=grid)
747+
sdf.data[:, 5:] = 1
748+
sdf.data[2:6, 3:5] = 1
749+
750+
x1, y1 = mid.dimensions
751+
752+
condition = Lt(sdf[x1, y1], 1)
753+
ci = ConditionalDimension(name='ci', condition=condition, parent=y1)
754+
755+
f = Function(name='f', grid=my_grid)
756+
op = Operator(Eq(f, f + 10, implicit_dims=ci,
757+
subdomain=my_grid.subdomains['middle']))
758+
759+
op.apply()
760+
761+
assert_structure(op, ['xy'], 'xy')
762+
763+
def test_condition_w_subdomain_v2(self):
764+
765+
shape = (10, 10)
766+
grid = Grid(shape=shape)
767+
x, y = grid.dimensions
768+
769+
class Middle(SubDomain):
770+
name = 'middle'
771+
772+
def define(self, dimensions):
773+
return {x: ('middle', 2, 4), y: ('middle', 2, 4)}
774+
775+
mid = Middle()
776+
my_grid = Grid(shape=shape, subdomains=(mid, ))
777+
778+
sdf = Function(name='sdf', grid=my_grid)
779+
sdf.data[2:4, 5:] = 1
780+
sdf.data[2:6, 3:5] = 1
781+
782+
x1, y1 = mid.dimensions
783+
784+
condition = Lt(sdf[x1, y1], 1)
785+
ci = ConditionalDimension(name='ci', condition=condition, parent=y1)
786+
787+
f = Function(name='f', grid=my_grid)
788+
op = Operator(Eq(f, f + 10, implicit_dims=ci,
789+
subdomain=my_grid.subdomains['middle']))
790+
791+
op.apply()
792+
793+
assert_structure(op, ['xy'], 'xy')

0 commit comments

Comments
 (0)