Skip to content

Commit 1825c02

Browse files
authored
(dialects): csl_stencil apply op (#2781)
This PR implements the `csl_stencil.apply` op as outlined in Step 2 of #2747 This operation combines a `csl_stencil.prefetch` (symmetric buffer communication across a given stencil shape) with a `stencil.apply`. Please see the doc string of the op for a detailed description. --------- Co-authored-by: n-io <[email protected]>
1 parent bac07d1 commit 1825c02

File tree

3 files changed

+420
-10
lines changed

3 files changed

+420
-10
lines changed

tests/dialects/test_csl_stencil.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from xdsl.builder import Builder
2+
from xdsl.dialects.builtin import IntegerAttr, IntegerType, MemRefType, TensorType, f32
3+
from xdsl.dialects.csl.csl_stencil import AccessOp, ApplyOp
4+
from xdsl.dialects.stencil import IndexAttr, TempType
5+
from xdsl.ir import Region, SSAValue
6+
from xdsl.utils.test_value import TestSSAValue
7+
8+
9+
def test_access_patterns():
10+
temp_t = TempType(5, f32)
11+
temp = TestSSAValue(temp_t)
12+
mref = TestSSAValue(mref_t := MemRefType(tens_t := TensorType(f32, (5,)), (4,)))
13+
14+
@Builder.implicit_region((mref_t, temp_t))
15+
def region0(args: tuple[SSAValue, ...]):
16+
t0, t1 = args
17+
for x in (-1, 1):
18+
AccessOp(t0, IndexAttr.get(x, 0), tens_t)
19+
for y in (-1, 1):
20+
AccessOp(t0, IndexAttr.get(0, y), tens_t)
21+
22+
AccessOp(t1, IndexAttr.get(1, 1), tens_t)
23+
AccessOp(t1, IndexAttr.get(-1, -1), tens_t)
24+
25+
@Builder.implicit_region((temp_t, temp_t))
26+
def region1(args: tuple[SSAValue, ...]):
27+
t0, t1 = args
28+
AccessOp(t0, IndexAttr.get(0, 0), tens_t)
29+
AccessOp(t1, IndexAttr.get(0, 0), tens_t)
30+
31+
apply = ApplyOp(
32+
operands=[temp, mref, []],
33+
properties={
34+
"swaps": None,
35+
"topo": None,
36+
"num_chunks": IntegerAttr(1, IntegerType(64)),
37+
},
38+
regions=[
39+
Region(region0.detach_block(0)),
40+
Region(region1.detach_block(0)),
41+
],
42+
result_types=[tens_t],
43+
)
44+
45+
r0_t0_acc, r0_t1_acc, r1_t0_acc, r1_t1_acc = tuple(apply.get_accesses())
46+
47+
assert r0_t0_acc.visual_pattern() == " X \nXOX\n X "
48+
assert r0_t1_acc.visual_pattern() == "X \n O \n X"
49+
50+
assert not r0_t0_acc.is_diagonal
51+
assert r0_t1_acc.is_diagonal
52+
53+
assert len(tuple(r0_t1_acc.get_diagonals())) == 2
54+
55+
assert r1_t0_acc.visual_pattern() == "X"
56+
assert r1_t1_acc.visual_pattern() == "X"

tests/filecheck/dialects/csl/csl-stencil-ops.mlir

Lines changed: 116 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ builtin.module {
3030
}
3131
}
3232

33-
// CHECK-NEXT: builtin.module {
33+
// CHECK: builtin.module {
3434
// CHECK-NEXT: func.func @gauss_seidel_func(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
3535
// CHECK-NEXT: %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
3636
// CHECK-NEXT: %pref = "csl_stencil.prefetch"(%0) <{"topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>]}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> memref<4xtensor<510xf32>>
@@ -59,8 +59,7 @@ builtin.module {
5959
// CHECK-NEXT: }
6060
// CHECK-NEXT: }
6161

62-
63-
// CHECK-GENERIC-NEXT: "builtin.module"() ({
62+
// CHECK-GENERIC: "builtin.module"() ({
6463
// CHECK-GENERIC-NEXT: "func.func"() <{"sym_name" = "gauss_seidel_func", "function_type" = (!stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> ()}> ({
6564
// CHECK-GENERIC-NEXT: ^0(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>):
6665
// CHECK-GENERIC-NEXT: %0 = "stencil.load"(%a) : (!stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
@@ -90,3 +89,117 @@ builtin.module {
9089
// CHECK-GENERIC-NEXT: "func.return"() : () -> ()
9190
// CHECK-GENERIC-NEXT: }) : () -> ()
9291
// CHECK-GENERIC-NEXT: }) : () -> ()
92+
93+
// -----
94+
95+
builtin.module {
96+
func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
97+
%0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
98+
99+
%1 = tensor.empty() : tensor<510xf32>
100+
%2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) <{"num_chunks" = 2, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>]}> -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({
101+
^0(%recv : memref<4xtensor<255xf32>>, %offset : index, %iter_arg : tensor<510xf32>):
102+
// reduces chunks from neighbours into one chunk (clear_recv_buf_cb)
103+
%4 = csl_stencil.access %recv[1, 0] : memref<4xtensor<255xf32>>
104+
%5 = csl_stencil.access %recv[-1, 0] : memref<4xtensor<255xf32>>
105+
%6 = csl_stencil.access %recv[0, 1] : memref<4xtensor<255xf32>>
106+
%7 = csl_stencil.access %recv[0, -1] : memref<4xtensor<255xf32>>
107+
108+
%8 = arith.addf %4, %5 : tensor<255xf32>
109+
%9 = arith.addf %8, %6 : tensor<255xf32>
110+
%10 = arith.addf %9, %7 : tensor<255xf32>
111+
112+
%11 = "tensor.insert_slice"(%10, %iter_arg, %offset) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
113+
csl_stencil.yield %11 : tensor<510xf32>
114+
}, {
115+
^0(%3 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %rcv : tensor<510xf32>):
116+
// takes combined chunks and applies further compute (communicate_cb)
117+
%12 = csl_stencil.access %3[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
118+
%13 = csl_stencil.access %3[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
119+
%14 = "tensor.extract_slice"(%12) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
120+
%15 = "tensor.extract_slice"(%13) <{"static_offsets" = array<i64: -1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
121+
122+
%16 = arith.addf %rcv, %14 : tensor<510xf32>
123+
%17 = arith.addf %16, %15 : tensor<510xf32>
124+
125+
%18 = arith.constant 1.666600e-01 : f32
126+
%19 = tensor.empty() : tensor<510xf32>
127+
%20 = linalg.fill ins(%18 : f32) outs(%19 : tensor<510xf32>) -> tensor<510xf32>
128+
%21 = arith.mulf %17, %20 : tensor<510xf32>
129+
130+
csl_stencil.yield %21 : tensor<510xf32>
131+
})
132+
133+
stencil.store %2 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
134+
func.return
135+
}
136+
}
137+
138+
// CHECK: builtin.module {
139+
// CHECK-NEXT: func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
140+
// CHECK-NEXT: %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
141+
// CHECK-NEXT: %1 = tensor.empty() : tensor<510xf32>
142+
// CHECK-NEXT: %2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({
143+
// CHECK-NEXT: ^0(%recv : memref<4xtensor<255xf32>>, %offset : index, %iter_arg : tensor<510xf32>):
144+
// CHECK-NEXT: %3 = csl_stencil.access %recv[1, 0] : memref<4xtensor<255xf32>>
145+
// CHECK-NEXT: %4 = csl_stencil.access %recv[-1, 0] : memref<4xtensor<255xf32>>
146+
// CHECK-NEXT: %5 = csl_stencil.access %recv[0, 1] : memref<4xtensor<255xf32>>
147+
// CHECK-NEXT: %6 = csl_stencil.access %recv[0, -1] : memref<4xtensor<255xf32>>
148+
// CHECK-NEXT: %7 = arith.addf %3, %4 : tensor<255xf32>
149+
// CHECK-NEXT: %8 = arith.addf %7, %5 : tensor<255xf32>
150+
// CHECK-NEXT: %9 = arith.addf %8, %6 : tensor<255xf32>
151+
// CHECK-NEXT: %10 = "tensor.insert_slice"(%9, %iter_arg, %offset) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
152+
// CHECK-NEXT: csl_stencil.yield %10 : tensor<510xf32>
153+
// CHECK-NEXT: }, {
154+
// CHECK-NEXT: ^1(%11 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %rcv : tensor<510xf32>):
155+
// CHECK-NEXT: %12 = csl_stencil.access %11[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
156+
// CHECK-NEXT: %13 = csl_stencil.access %11[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
157+
// CHECK-NEXT: %14 = "tensor.extract_slice"(%12) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
158+
// CHECK-NEXT: %15 = "tensor.extract_slice"(%13) <{"static_offsets" = array<i64: -1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
159+
// CHECK-NEXT: %16 = arith.addf %rcv, %14 : tensor<510xf32>
160+
// CHECK-NEXT: %17 = arith.addf %16, %15 : tensor<510xf32>
161+
// CHECK-NEXT: %18 = arith.constant 1.666600e-01 : f32
162+
// CHECK-NEXT: %19 = tensor.empty() : tensor<510xf32>
163+
// CHECK-NEXT: %20 = linalg.fill ins(%18 : f32) outs(%19 : tensor<510xf32>) -> tensor<510xf32>
164+
// CHECK-NEXT: %21 = arith.mulf %17, %20 : tensor<510xf32>
165+
// CHECK-NEXT: csl_stencil.yield %21 : tensor<510xf32>
166+
// CHECK-NEXT: })
167+
// CHECK-NEXT: stencil.store %2 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
168+
// CHECK-NEXT: func.return
169+
// CHECK-NEXT: }
170+
// CHECK-NEXT: }
171+
172+
// CHECK-GENERIC: "builtin.module"() ({
173+
// CHECK-GENERIC-NEXT: "func.func"() <{"sym_name" = "gauss_seidel", "function_type" = (!stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> ()}> ({
174+
// CHECK-GENERIC-NEXT: ^0(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>):
175+
// CHECK-GENERIC-NEXT: %0 = "stencil.load"(%a) : (!stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
176+
// CHECK-GENERIC-NEXT: %1 = "tensor.empty"() : () -> tensor<510xf32>
177+
// CHECK-GENERIC-NEXT: %2 = "csl_stencil.apply"(%0, %1) <{"num_chunks" = 2 : i64, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>]}> ({
178+
// CHECK-GENERIC-NEXT: ^1(%recv : memref<4xtensor<255xf32>>, %offset : index, %iter_arg : tensor<510xf32>):
179+
// CHECK-GENERIC-NEXT: %3 = "csl_stencil.access"(%recv) <{"offset" = #stencil.index[1, 0], "offset_mapping" = #stencil.index[0, 1]}> : (memref<4xtensor<255xf32>>) -> tensor<255xf32>
180+
// CHECK-GENERIC-NEXT: %4 = "csl_stencil.access"(%recv) <{"offset" = #stencil.index[-1, 0], "offset_mapping" = #stencil.index[0, 1]}> : (memref<4xtensor<255xf32>>) -> tensor<255xf32>
181+
// CHECK-GENERIC-NEXT: %5 = "csl_stencil.access"(%recv) <{"offset" = #stencil.index[0, 1], "offset_mapping" = #stencil.index[0, 1]}> : (memref<4xtensor<255xf32>>) -> tensor<255xf32>
182+
// CHECK-GENERIC-NEXT: %6 = "csl_stencil.access"(%recv) <{"offset" = #stencil.index[0, -1], "offset_mapping" = #stencil.index[0, 1]}> : (memref<4xtensor<255xf32>>) -> tensor<255xf32>
183+
// CHECK-GENERIC-NEXT: %7 = "arith.addf"(%3, %4) <{"fastmath" = #arith.fastmath<none>}> : (tensor<255xf32>, tensor<255xf32>) -> tensor<255xf32>
184+
// CHECK-GENERIC-NEXT: %8 = "arith.addf"(%7, %5) <{"fastmath" = #arith.fastmath<none>}> : (tensor<255xf32>, tensor<255xf32>) -> tensor<255xf32>
185+
// CHECK-GENERIC-NEXT: %9 = "arith.addf"(%8, %6) <{"fastmath" = #arith.fastmath<none>}> : (tensor<255xf32>, tensor<255xf32>) -> tensor<255xf32>
186+
// CHECK-GENERIC-NEXT: %10 = "tensor.insert_slice"(%9, %iter_arg, %offset) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
187+
// CHECK-GENERIC-NEXT: "csl_stencil.yield"(%10) : (tensor<510xf32>) -> ()
188+
// CHECK-GENERIC-NEXT: }, {
189+
// CHECK-GENERIC-NEXT: ^2(%11 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %rcv : tensor<510xf32>):
190+
// CHECK-GENERIC-NEXT: %12 = "csl_stencil.access"(%11) <{"offset" = #stencil.index[0, 0], "offset_mapping" = #stencil.index[0, 1]}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> tensor<512xf32>
191+
// CHECK-GENERIC-NEXT: %13 = "csl_stencil.access"(%11) <{"offset" = #stencil.index[0, 0], "offset_mapping" = #stencil.index[0, 1]}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> tensor<512xf32>
192+
// CHECK-GENERIC-NEXT: %14 = "tensor.extract_slice"(%12) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
193+
// CHECK-GENERIC-NEXT: %15 = "tensor.extract_slice"(%13) <{"static_offsets" = array<i64: -1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
194+
// CHECK-GENERIC-NEXT: %16 = "arith.addf"(%rcv, %14) <{"fastmath" = #arith.fastmath<none>}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32>
195+
// CHECK-GENERIC-NEXT: %17 = "arith.addf"(%16, %15) <{"fastmath" = #arith.fastmath<none>}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32>
196+
// CHECK-GENERIC-NEXT: %18 = "arith.constant"() <{"value" = 1.666600e-01 : f32}> : () -> f32
197+
// CHECK-GENERIC-NEXT: %19 = "tensor.empty"() : () -> tensor<510xf32>
198+
// CHECK-GENERIC-NEXT: %20 = "linalg.fill"(%18, %19) <{"operandSegmentSizes" = array<i32: 1, 1>}> : (f32, tensor<510xf32>) -> tensor<510xf32>
199+
// CHECK-GENERIC-NEXT: %21 = "arith.mulf"(%17, %20) <{"fastmath" = #arith.fastmath<none>}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32>
200+
// CHECK-GENERIC-NEXT: "csl_stencil.yield"(%21) : (tensor<510xf32>) -> ()
201+
// CHECK-GENERIC-NEXT: }) : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, tensor<510xf32>) -> !stencil.temp<[0,1]x[0,1]xtensor<510xf32>>
202+
// CHECK-GENERIC-NEXT: "stencil.store"(%2, %b) {"bounds" = #stencil.bounds[0, 0] : [1, 1]} : (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>, !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> ()
203+
// CHECK-GENERIC-NEXT: "func.return"() : () -> ()
204+
// CHECK-GENERIC-NEXT: }) : () -> ()
205+
// CHECK-GENERIC-NEXT: }) : () -> ()

0 commit comments

Comments
 (0)