Skip to content

Commit

Permalink
(dialects): csl_stencil apply op (#2781)
Browse files Browse the repository at this point in the history
This PR implements the `csl_stencil.apply` op as outlined in Step 2 of
#2747

This operation combines a `csl_stencil.prefetch` (symmetric buffer
communication across a given stencil shape) with a `stencil.apply`.
Please see the doc string of the op for a detailed description.

---------

Co-authored-by: n-io <[email protected]>
  • Loading branch information
n-io and n-io authored Jun 28, 2024
1 parent bac07d1 commit 1825c02
Show file tree
Hide file tree
Showing 3 changed files with 420 additions and 10 deletions.
56 changes: 56 additions & 0 deletions tests/dialects/test_csl_stencil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from xdsl.builder import Builder
from xdsl.dialects.builtin import IntegerAttr, IntegerType, MemRefType, TensorType, f32
from xdsl.dialects.csl.csl_stencil import AccessOp, ApplyOp
from xdsl.dialects.stencil import IndexAttr, TempType
from xdsl.ir import Region, SSAValue
from xdsl.utils.test_value import TestSSAValue


def test_access_patterns():
temp_t = TempType(5, f32)
temp = TestSSAValue(temp_t)
mref = TestSSAValue(mref_t := MemRefType(tens_t := TensorType(f32, (5,)), (4,)))

@Builder.implicit_region((mref_t, temp_t))
def region0(args: tuple[SSAValue, ...]):
t0, t1 = args
for x in (-1, 1):
AccessOp(t0, IndexAttr.get(x, 0), tens_t)
for y in (-1, 1):
AccessOp(t0, IndexAttr.get(0, y), tens_t)

AccessOp(t1, IndexAttr.get(1, 1), tens_t)
AccessOp(t1, IndexAttr.get(-1, -1), tens_t)

@Builder.implicit_region((temp_t, temp_t))
def region1(args: tuple[SSAValue, ...]):
t0, t1 = args
AccessOp(t0, IndexAttr.get(0, 0), tens_t)
AccessOp(t1, IndexAttr.get(0, 0), tens_t)

apply = ApplyOp(
operands=[temp, mref, []],
properties={
"swaps": None,
"topo": None,
"num_chunks": IntegerAttr(1, IntegerType(64)),
},
regions=[
Region(region0.detach_block(0)),
Region(region1.detach_block(0)),
],
result_types=[tens_t],
)

r0_t0_acc, r0_t1_acc, r1_t0_acc, r1_t1_acc = tuple(apply.get_accesses())

assert r0_t0_acc.visual_pattern() == " X \nXOX\n X "
assert r0_t1_acc.visual_pattern() == "X \n O \n X"

assert not r0_t0_acc.is_diagonal
assert r0_t1_acc.is_diagonal

assert len(tuple(r0_t1_acc.get_diagonals())) == 2

assert r1_t0_acc.visual_pattern() == "X"
assert r1_t1_acc.visual_pattern() == "X"
119 changes: 116 additions & 3 deletions tests/filecheck/dialects/csl/csl-stencil-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ builtin.module {
}
}

// CHECK-NEXT: builtin.module {
// CHECK: builtin.module {
// CHECK-NEXT: func.func @gauss_seidel_func(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
// CHECK-NEXT: %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %pref = "csl_stencil.prefetch"(%0) <{"topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>]}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> memref<4xtensor<510xf32>>
Expand Down Expand Up @@ -59,8 +59,7 @@ builtin.module {
// CHECK-NEXT: }
// CHECK-NEXT: }


// CHECK-GENERIC-NEXT: "builtin.module"() ({
// CHECK-GENERIC: "builtin.module"() ({
// CHECK-GENERIC-NEXT: "func.func"() <{"sym_name" = "gauss_seidel_func", "function_type" = (!stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> ()}> ({
// CHECK-GENERIC-NEXT: ^0(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>):
// CHECK-GENERIC-NEXT: %0 = "stencil.load"(%a) : (!stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
Expand Down Expand Up @@ -90,3 +89,117 @@ builtin.module {
// CHECK-GENERIC-NEXT: "func.return"() : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()

// -----

builtin.module {
func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
%0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>

%1 = tensor.empty() : tensor<510xf32>
%2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) <{"num_chunks" = 2, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>]}> -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({
^0(%recv : memref<4xtensor<255xf32>>, %offset : index, %iter_arg : tensor<510xf32>):
// reduces chunks from neighbours into one chunk (clear_recv_buf_cb)
%4 = csl_stencil.access %recv[1, 0] : memref<4xtensor<255xf32>>
%5 = csl_stencil.access %recv[-1, 0] : memref<4xtensor<255xf32>>
%6 = csl_stencil.access %recv[0, 1] : memref<4xtensor<255xf32>>
%7 = csl_stencil.access %recv[0, -1] : memref<4xtensor<255xf32>>

%8 = arith.addf %4, %5 : tensor<255xf32>
%9 = arith.addf %8, %6 : tensor<255xf32>
%10 = arith.addf %9, %7 : tensor<255xf32>

%11 = "tensor.insert_slice"(%10, %iter_arg, %offset) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
csl_stencil.yield %11 : tensor<510xf32>
}, {
^0(%3 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %rcv : tensor<510xf32>):
// takes combined chunks and applies further compute (communicate_cb)
%12 = csl_stencil.access %3[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
%13 = csl_stencil.access %3[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
%14 = "tensor.extract_slice"(%12) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
%15 = "tensor.extract_slice"(%13) <{"static_offsets" = array<i64: -1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>

%16 = arith.addf %rcv, %14 : tensor<510xf32>
%17 = arith.addf %16, %15 : tensor<510xf32>

%18 = arith.constant 1.666600e-01 : f32
%19 = tensor.empty() : tensor<510xf32>
%20 = linalg.fill ins(%18 : f32) outs(%19 : tensor<510xf32>) -> tensor<510xf32>
%21 = arith.mulf %17, %20 : tensor<510xf32>

csl_stencil.yield %21 : tensor<510xf32>
})

stencil.store %2 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
func.return
}
}

// CHECK: builtin.module {
// CHECK-NEXT: func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
// CHECK-NEXT: %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %1 = tensor.empty() : tensor<510xf32>
// CHECK-NEXT: %2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({
// CHECK-NEXT: ^0(%recv : memref<4xtensor<255xf32>>, %offset : index, %iter_arg : tensor<510xf32>):
// CHECK-NEXT: %3 = csl_stencil.access %recv[1, 0] : memref<4xtensor<255xf32>>
// CHECK-NEXT: %4 = csl_stencil.access %recv[-1, 0] : memref<4xtensor<255xf32>>
// CHECK-NEXT: %5 = csl_stencil.access %recv[0, 1] : memref<4xtensor<255xf32>>
// CHECK-NEXT: %6 = csl_stencil.access %recv[0, -1] : memref<4xtensor<255xf32>>
// CHECK-NEXT: %7 = arith.addf %3, %4 : tensor<255xf32>
// CHECK-NEXT: %8 = arith.addf %7, %5 : tensor<255xf32>
// CHECK-NEXT: %9 = arith.addf %8, %6 : tensor<255xf32>
// CHECK-NEXT: %10 = "tensor.insert_slice"(%9, %iter_arg, %offset) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %10 : tensor<510xf32>
// CHECK-NEXT: }, {
// CHECK-NEXT: ^1(%11 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %rcv : tensor<510xf32>):
// CHECK-NEXT: %12 = csl_stencil.access %11[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %13 = csl_stencil.access %11[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %14 = "tensor.extract_slice"(%12) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %15 = "tensor.extract_slice"(%13) <{"static_offsets" = array<i64: -1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %16 = arith.addf %rcv, %14 : tensor<510xf32>
// CHECK-NEXT: %17 = arith.addf %16, %15 : tensor<510xf32>
// CHECK-NEXT: %18 = arith.constant 1.666600e-01 : f32
// CHECK-NEXT: %19 = tensor.empty() : tensor<510xf32>
// CHECK-NEXT: %20 = linalg.fill ins(%18 : f32) outs(%19 : tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %21 = arith.mulf %17, %20 : tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %21 : tensor<510xf32>
// CHECK-NEXT: })
// CHECK-NEXT: stencil.store %2 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: func.return
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK-GENERIC: "builtin.module"() ({
// CHECK-GENERIC-NEXT: "func.func"() <{"sym_name" = "gauss_seidel", "function_type" = (!stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> ()}> ({
// CHECK-GENERIC-NEXT: ^0(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>):
// CHECK-GENERIC-NEXT: %0 = "stencil.load"(%a) : (!stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-GENERIC-NEXT: %1 = "tensor.empty"() : () -> tensor<510xf32>
// CHECK-GENERIC-NEXT: %2 = "csl_stencil.apply"(%0, %1) <{"num_chunks" = 2 : i64, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>]}> ({
// CHECK-GENERIC-NEXT: ^1(%recv : memref<4xtensor<255xf32>>, %offset : index, %iter_arg : tensor<510xf32>):
// CHECK-GENERIC-NEXT: %3 = "csl_stencil.access"(%recv) <{"offset" = #stencil.index[1, 0], "offset_mapping" = #stencil.index[0, 1]}> : (memref<4xtensor<255xf32>>) -> tensor<255xf32>
// CHECK-GENERIC-NEXT: %4 = "csl_stencil.access"(%recv) <{"offset" = #stencil.index[-1, 0], "offset_mapping" = #stencil.index[0, 1]}> : (memref<4xtensor<255xf32>>) -> tensor<255xf32>
// CHECK-GENERIC-NEXT: %5 = "csl_stencil.access"(%recv) <{"offset" = #stencil.index[0, 1], "offset_mapping" = #stencil.index[0, 1]}> : (memref<4xtensor<255xf32>>) -> tensor<255xf32>
// CHECK-GENERIC-NEXT: %6 = "csl_stencil.access"(%recv) <{"offset" = #stencil.index[0, -1], "offset_mapping" = #stencil.index[0, 1]}> : (memref<4xtensor<255xf32>>) -> tensor<255xf32>
// CHECK-GENERIC-NEXT: %7 = "arith.addf"(%3, %4) <{"fastmath" = #arith.fastmath<none>}> : (tensor<255xf32>, tensor<255xf32>) -> tensor<255xf32>
// CHECK-GENERIC-NEXT: %8 = "arith.addf"(%7, %5) <{"fastmath" = #arith.fastmath<none>}> : (tensor<255xf32>, tensor<255xf32>) -> tensor<255xf32>
// CHECK-GENERIC-NEXT: %9 = "arith.addf"(%8, %6) <{"fastmath" = #arith.fastmath<none>}> : (tensor<255xf32>, tensor<255xf32>) -> tensor<255xf32>
// CHECK-GENERIC-NEXT: %10 = "tensor.insert_slice"(%9, %iter_arg, %offset) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: "csl_stencil.yield"(%10) : (tensor<510xf32>) -> ()
// CHECK-GENERIC-NEXT: }, {
// CHECK-GENERIC-NEXT: ^2(%11 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %rcv : tensor<510xf32>):
// CHECK-GENERIC-NEXT: %12 = "csl_stencil.access"(%11) <{"offset" = #stencil.index[0, 0], "offset_mapping" = #stencil.index[0, 1]}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> tensor<512xf32>
// CHECK-GENERIC-NEXT: %13 = "csl_stencil.access"(%11) <{"offset" = #stencil.index[0, 0], "offset_mapping" = #stencil.index[0, 1]}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> tensor<512xf32>
// CHECK-GENERIC-NEXT: %14 = "tensor.extract_slice"(%12) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: %15 = "tensor.extract_slice"(%13) <{"static_offsets" = array<i64: -1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: %16 = "arith.addf"(%rcv, %14) <{"fastmath" = #arith.fastmath<none>}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: %17 = "arith.addf"(%16, %15) <{"fastmath" = #arith.fastmath<none>}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: %18 = "arith.constant"() <{"value" = 1.666600e-01 : f32}> : () -> f32
// CHECK-GENERIC-NEXT: %19 = "tensor.empty"() : () -> tensor<510xf32>
// CHECK-GENERIC-NEXT: %20 = "linalg.fill"(%18, %19) <{"operandSegmentSizes" = array<i32: 1, 1>}> : (f32, tensor<510xf32>) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: %21 = "arith.mulf"(%17, %20) <{"fastmath" = #arith.fastmath<none>}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: "csl_stencil.yield"(%21) : (tensor<510xf32>) -> ()
// CHECK-GENERIC-NEXT: }) : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, tensor<510xf32>) -> !stencil.temp<[0,1]x[0,1]xtensor<510xf32>>
// CHECK-GENERIC-NEXT: "stencil.store"(%2, %b) {"bounds" = #stencil.bounds[0, 0] : [1, 1]} : (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>, !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> ()
// CHECK-GENERIC-NEXT: "func.return"() : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()
Loading

0 comments on commit 1825c02

Please sign in to comment.