Skip to content

Commit 7f09044

Browse files
authored
(transform): csl_stencil canonicalization pass (#2814)
Adds a canonicalisation pass for `csl_stencil.apply`. The op takes an empty tensor as`iter_arg`, which it does not manage itself. The conversion pass in #2803 initialises an `iter_arg` for each instance of `apply`. This canonicalisation pass identifies where this can be re-used, effectively removing redundant allocations. --------- Co-authored-by: n-io <[email protected]>
1 parent 2d67e7c commit 7f09044

File tree

3 files changed

+136
-1
lines changed

3 files changed

+136
-1
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// RUN: xdsl-opt %s -p canonicalize --split-input-file | filecheck %s
2+
3+
4+
builtin.module {
5+
func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
6+
%0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
7+
8+
%1 = tensor.empty() : tensor<510xf32>
9+
%2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) <{"num_chunks" = 2, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>]}> -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({
10+
^0(%3 : memref<4xtensor<255xf32>>, %4 : index, %5 : tensor<510xf32>):
11+
%6 = csl_stencil.access %3[1, 0] : memref<4xtensor<255xf32>>
12+
%7 = "tensor.insert_slice"(%6, %5, %4) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
13+
csl_stencil.yield %7 : tensor<510xf32>
14+
}, {
15+
^0(%8 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %9 : tensor<510xf32>):
16+
csl_stencil.yield %9 : tensor<510xf32>
17+
})
18+
stencil.store %2 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
19+
20+
%10 = tensor.empty() : tensor<510xf32>
21+
%11 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %10 : tensor<510xf32>) <{"num_chunks" = 2, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>]}> -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({
22+
^0(%12 : memref<4xtensor<255xf32>>, %13 : index, %14 : tensor<510xf32>):
23+
%15 = csl_stencil.access %12[1, 0] : memref<4xtensor<255xf32>>
24+
%16 = "tensor.insert_slice"(%15, %14, %13) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
25+
csl_stencil.yield %16 : tensor<510xf32>
26+
}, {
27+
^0(%17 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %18 : tensor<510xf32>):
28+
csl_stencil.yield %18 : tensor<510xf32>
29+
})
30+
stencil.store %11 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
31+
32+
%19 = tensor.empty() : tensor<510xf32>
33+
%20 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %19 : tensor<510xf32>) <{"num_chunks" = 2, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>]}> -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({
34+
^0(%21 : memref<4xtensor<255xf32>>, %22 : index, %23 : tensor<510xf32>):
35+
%24 = csl_stencil.access %21[1, 0] : memref<4xtensor<255xf32>>
36+
%25 = "tensor.insert_slice"(%24, %23, %22) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
37+
csl_stencil.yield %25 : tensor<510xf32>
38+
}, {
39+
^0(%26 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %27 : tensor<510xf32>):
40+
csl_stencil.yield %27 : tensor<510xf32>
41+
})
42+
stencil.store %20 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
43+
func.return
44+
}
45+
}
46+
47+
48+
// CHECK-NEXT: builtin.module {
49+
// CHECK-NEXT: func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
50+
// CHECK-NEXT: %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
51+
// CHECK-NEXT: %1 = tensor.empty() : tensor<510xf32>
52+
// CHECK-NEXT: %2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({
53+
// CHECK-NEXT: ^0(%3 : memref<4xtensor<255xf32>>, %4 : index, %5 : tensor<510xf32>):
54+
// CHECK-NEXT: %6 = csl_stencil.access %3[1, 0] : memref<4xtensor<255xf32>>
55+
// CHECK-NEXT: %7 = "tensor.insert_slice"(%6, %5, %4) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
56+
// CHECK-NEXT: csl_stencil.yield %7 : tensor<510xf32>
57+
// CHECK-NEXT: }, {
58+
// CHECK-NEXT: ^1(%8 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %9 : tensor<510xf32>):
59+
// CHECK-NEXT: csl_stencil.yield %9 : tensor<510xf32>
60+
// CHECK-NEXT: })
61+
// CHECK-NEXT: stencil.store %2 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
62+
// CHECK-NEXT: %3 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({
63+
// CHECK-NEXT: ^0(%4 : memref<4xtensor<255xf32>>, %5 : index, %6 : tensor<510xf32>):
64+
// CHECK-NEXT: %7 = csl_stencil.access %4[1, 0] : memref<4xtensor<255xf32>>
65+
// CHECK-NEXT: %8 = "tensor.insert_slice"(%7, %6, %5) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
66+
// CHECK-NEXT: csl_stencil.yield %8 : tensor<510xf32>
67+
// CHECK-NEXT: }, {
68+
// CHECK-NEXT: ^1(%9 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %10 : tensor<510xf32>):
69+
// CHECK-NEXT: csl_stencil.yield %10 : tensor<510xf32>
70+
// CHECK-NEXT: })
71+
// CHECK-NEXT: stencil.store %3 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
72+
// CHECK-NEXT: %4 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({
73+
// CHECK-NEXT: ^0(%5 : memref<4xtensor<255xf32>>, %6 : index, %7 : tensor<510xf32>):
74+
// CHECK-NEXT: %8 = csl_stencil.access %5[1, 0] : memref<4xtensor<255xf32>>
75+
// CHECK-NEXT: %9 = "tensor.insert_slice"(%8, %7, %6) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
76+
// CHECK-NEXT: csl_stencil.yield %9 : tensor<510xf32>
77+
// CHECK-NEXT: }, {
78+
// CHECK-NEXT: ^1(%10 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %11 : tensor<510xf32>):
79+
// CHECK-NEXT: csl_stencil.yield %11 : tensor<510xf32>
80+
// CHECK-NEXT: })
81+
// CHECK-NEXT: stencil.store %4 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
82+
// CHECK-NEXT: func.return
83+
// CHECK-NEXT: }
84+
// CHECK-NEXT: }

xdsl/dialects/csl/csl_stencil.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@
3535
var_result_def,
3636
)
3737
from xdsl.parser import AttrParser, Parser
38+
from xdsl.pattern_rewriter import RewritePattern
3839
from xdsl.printer import Printer
3940
from xdsl.traits import (
4041
HasAncestor,
42+
HasCanonicalisationPatternsTrait,
4143
HasParent,
4244
IsolatedFromAbove,
4345
IsTerminator,
@@ -146,6 +148,16 @@ def __init__(
146148
)
147149

148150

151+
class ApplyOpHasCanonicalizationPatternsTrait(HasCanonicalisationPatternsTrait):
152+
@classmethod
153+
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
154+
from xdsl.transforms.canonicalization_patterns.csl_stencil import (
155+
RedundantIterArgInitialisation,
156+
)
157+
158+
return (RedundantIterArgInitialisation(),)
159+
160+
149161
@irdl_op_definition
150162
class ApplyOp(IRDLOperation):
151163
"""
@@ -205,7 +217,13 @@ class ApplyOp(IRDLOperation):
205217

206218
res = var_result_def(stencil.TempType)
207219

208-
traits = frozenset([IsolatedFromAbove(), RecursiveMemoryEffect()])
220+
traits = frozenset(
221+
[
222+
IsolatedFromAbove(),
223+
ApplyOpHasCanonicalizationPatternsTrait(),
224+
RecursiveMemoryEffect(),
225+
]
226+
)
209227

210228
def print(self, printer: Printer):
211229
def print_arg(arg: tuple[SSAValue, Attribute]):
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from xdsl.dialects import tensor
2+
from xdsl.dialects.csl import csl_stencil
3+
from xdsl.ir import OpResult
4+
from xdsl.pattern_rewriter import (
5+
PatternRewriter,
6+
RewritePattern,
7+
op_type_rewrite_pattern,
8+
)
9+
10+
11+
class RedundantIterArgInitialisation(RewritePattern):
12+
"""
13+
Removes redundant allocations of empty tensors with no uses other than passed
14+
as `iter_arg` to `csl_stencil.apply`. Prefer re-use where possible.
15+
"""
16+
17+
@op_type_rewrite_pattern
18+
def match_and_rewrite(
19+
self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter
20+
) -> None:
21+
if len(op.iter_arg.uses) > 1:
22+
return
23+
24+
next_apply = op
25+
while (next_apply := next_apply.next_op) is not None:
26+
if (
27+
isinstance(next_apply, csl_stencil.ApplyOp)
28+
and len(next_apply.iter_arg.uses) == 1
29+
and isinstance(next_apply.iter_arg, OpResult)
30+
and isinstance(next_apply.iter_arg.op, tensor.EmptyOp)
31+
and op.iter_arg.type == next_apply.iter_arg.type
32+
):
33+
rewriter.replace_op(next_apply.iter_arg.op, [], [op.iter_arg])

0 commit comments

Comments
 (0)