|
6 | 6 |
|
7 | 7 | from conftest import opts_tiling, assert_structure
|
8 | 8 | from devito import (ConditionalDimension, Constant, Grid, Function, TimeFunction,
|
9 |
| - Eq, solve, Operator, SubDomain, SubDomainSet) |
| 9 | + Eq, solve, Operator, SubDomain, SubDomainSet, Lt) |
10 | 10 | from devito.ir import FindNodes, Expression, Iteration
|
11 | 11 | from devito.tools import timed_region
|
12 | 12 |
|
@@ -693,3 +693,101 @@ class Dummy(SubDomainSet):
|
693 | 693 | # Switch the thickness symbols between MultiSubDimensions with the rebuild
|
694 | 694 | remixed = [d._rebuild(thickness=t) for d, t in zip(sdims, tkns[::-1])]
|
695 | 695 | assert [d.thickness for d in remixed] == tkns[::-1]
|
| 696 | + |
| 697 | + |
| 698 | +class TestSubDomain_w_condition(object): |
| 699 | + |
| 700 | + def test_condition_w_subdomain_v0(self): |
| 701 | + |
| 702 | + shape = (10, ) |
| 703 | + grid = Grid(shape=shape) |
| 704 | + x, = grid.dimensions |
| 705 | + |
| 706 | + class Middle(SubDomain): |
| 707 | + name = 'middle' |
| 708 | + |
| 709 | + def define(self, dimensions): |
| 710 | + return {x: ('middle', 2, 4)} |
| 711 | + |
| 712 | + mid = Middle() |
| 713 | + my_grid = Grid(shape=shape, subdomains=(mid, )) |
| 714 | + |
| 715 | + f = Function(name='f', grid=my_grid) |
| 716 | + |
| 717 | + sdf = Function(name='sdf', grid=my_grid) |
| 718 | + sdf.data[5:] = 1 |
| 719 | + |
| 720 | + condition = Lt(sdf[mid.dimensions[0]], 1) |
| 721 | + |
| 722 | + ci = ConditionalDimension(name='ci', condition=condition, |
| 723 | + parent=mid.dimensions[0]) |
| 724 | + |
| 725 | + op = Operator(Eq(f, f + 10, implicit_dims=ci, |
| 726 | + subdomain=my_grid.subdomains['middle'])) |
| 727 | + op.apply() |
| 728 | + |
| 729 | + assert_structure(op, ['x'], 'x') |
| 730 | + |
| 731 | + def test_condition_w_subdomain_v1(self): |
| 732 | + |
| 733 | + shape = (10, 10) |
| 734 | + grid = Grid(shape=shape) |
| 735 | + x, y = grid.dimensions |
| 736 | + |
| 737 | + class Middle(SubDomain): |
| 738 | + name = 'middle' |
| 739 | + |
| 740 | + def define(self, dimensions): |
| 741 | + return {x: x, y: ('middle', 2, 4)} |
| 742 | + |
| 743 | + mid = Middle() |
| 744 | + my_grid = Grid(shape=shape, subdomains=(mid, )) |
| 745 | + |
| 746 | + sdf = Function(name='sdf', grid=grid) |
| 747 | + sdf.data[:, 5:] = 1 |
| 748 | + sdf.data[2:6, 3:5] = 1 |
| 749 | + |
| 750 | + x1, y1 = mid.dimensions |
| 751 | + |
| 752 | + condition = Lt(sdf[x1, y1], 1) |
| 753 | + ci = ConditionalDimension(name='ci', condition=condition, parent=y1) |
| 754 | + |
| 755 | + f = Function(name='f', grid=my_grid) |
| 756 | + op = Operator(Eq(f, f + 10, implicit_dims=ci, |
| 757 | + subdomain=my_grid.subdomains['middle'])) |
| 758 | + |
| 759 | + op.apply() |
| 760 | + |
| 761 | + assert_structure(op, ['xy'], 'xy') |
| 762 | + |
| 763 | + def test_condition_w_subdomain_v2(self): |
| 764 | + |
| 765 | + shape = (10, 10) |
| 766 | + grid = Grid(shape=shape) |
| 767 | + x, y = grid.dimensions |
| 768 | + |
| 769 | + class Middle(SubDomain): |
| 770 | + name = 'middle' |
| 771 | + |
| 772 | + def define(self, dimensions): |
| 773 | + return {x: ('middle', 2, 4), y: ('middle', 2, 4)} |
| 774 | + |
| 775 | + mid = Middle() |
| 776 | + my_grid = Grid(shape=shape, subdomains=(mid, )) |
| 777 | + |
| 778 | + sdf = Function(name='sdf', grid=my_grid) |
| 779 | + sdf.data[2:4, 5:] = 1 |
| 780 | + sdf.data[2:6, 3:5] = 1 |
| 781 | + |
| 782 | + x1, y1 = mid.dimensions |
| 783 | + |
| 784 | + condition = Lt(sdf[x1, y1], 1) |
| 785 | + ci = ConditionalDimension(name='ci', condition=condition, parent=y1) |
| 786 | + |
| 787 | + f = Function(name='f', grid=my_grid) |
| 788 | + op = Operator(Eq(f, f + 10, implicit_dims=ci, |
| 789 | + subdomain=my_grid.subdomains['middle'])) |
| 790 | + |
| 791 | + op.apply() |
| 792 | + |
| 793 | + assert_structure(op, ['xy'], 'xy') |
0 commit comments