Skip to content

Commit a19a42f

Browse files
author
Xiang Li
committed
Add test for only triton-to-structured.
1 parent cd193c5 commit a19a42f

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// RUN: triton-shared-opt --triton-to-structured --canonicalize %s | FileCheck %s
2+
3+
module attributes {} {
4+
tt.func public @gather_kernel(%arg0: !tt.ptr<i64> { tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64> { tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64> { tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
5+
%c0_i32 = arith.constant 0 : i32
6+
%0 = tt.get_program_id x : i32
7+
%1 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32>
8+
%2 = arith.muli %0, %arg6 : i32
9+
%3 = tt.addptr %arg1, %2 : !tt.ptr<i64>, i32
10+
%4 = tt.splat %3 : !tt.ptr<i64> -> tensor<512x!tt.ptr<i64>>
11+
%5 = tt.addptr %4, %1 : tensor<512x!tt.ptr<i64>>, tensor<512xi32>
12+
%6 = tt.splat %arg3 : i32 -> tensor<512xi32>
13+
%7 = arith.cmpi slt, %1, %6 : tensor<512xi32>
14+
%8 = tt.load %5, %7 : tensor<512x!tt.ptr<i64>>
15+
%9 = arith.cmpi eq, %arg4, %c0_i32 : i32
16+
%10 = scf.if %9 -> (tensor<512x!tt.ptr<i64>>) {
17+
%15 = arith.extsi %arg5 : i32 to i64
18+
%16 = tt.splat %15 : i64 -> tensor<512xi64>
19+
%17 = arith.muli %8, %16 : tensor<512xi64>
20+
%18 = tt.splat %arg0 : !tt.ptr<i64> -> tensor<512x!tt.ptr<i64>>
21+
%19 = tt.addptr %18, %17 : tensor<512x!tt.ptr<i64>>, tensor<512xi64>
22+
%20 = tt.addptr %19, %1 : tensor<512x!tt.ptr<i64>>, tensor<512xi32>
23+
scf.yield %20 : tensor<512x!tt.ptr<i64>>
24+
} else {
25+
%15 = arith.muli %0, %arg5 : i32
26+
%16 = tt.addptr %arg0, %15 : !tt.ptr<i64>, i32
27+
%17 = tt.splat %16 : !tt.ptr<i64> -> tensor<512x!tt.ptr<i64>>
28+
%18 = tt.addptr %17, %8 : tensor<512x!tt.ptr<i64>>, tensor<512xi64>
29+
scf.yield %18 : tensor<512x!tt.ptr<i64>>
30+
}
31+
%11 = tt.load %10, %7 : tensor<512x!tt.ptr<i64>>
32+
%12 = tt.addptr %arg2, %2 : !tt.ptr<i64>, i32
33+
%13 = tt.splat %12 : !tt.ptr<i64> -> tensor<512x!tt.ptr<i64>>
34+
%14 = tt.addptr %13, %1 : tensor<512x!tt.ptr<i64>>, tensor<512xi32>
35+
tt.store %14, %11, %7 : tensor<512x!tt.ptr<i64>>
36+
tt.return
37+
}
38+
}
39+
40+
// CHECK-LABEL: tt.func public @gather_kernel(
41+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !tt.ptr<i64> {tt.divisibility = 16 : i32},
42+
// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !tt.ptr<i64> {tt.divisibility = 16 : i32},
43+
// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !tt.ptr<i64> {tt.divisibility = 16 : i32},
44+
// CHECK-SAME: %[[VAL_3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32 {tt.divisibility = 16 : i32},
45+
// CHECK-SAME: %[[VAL_4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32 {tt.divisibility = 16 : i32},
46+
// CHECK-SAME: %[[VAL_5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32 {tt.divisibility = 16 : i32},
47+
// CHECK-SAME: %[[VAL_6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
48+
// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
49+
// CHECK: %[[VAL_8:.*]] = arith.constant 512 : index
50+
// CHECK: %[[VAL_9:.*]] = arith.constant 0 : i32
51+
// CHECK: %[[VAL_10:.*]] = tt.get_program_id x : i32
52+
// CHECK: %[[VAL_11:.*]] = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32>
53+
// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_10]], %[[VAL_6]] : i32
54+
// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : i32 to index
55+
// CHECK: %[[VAL_14:.*]] = arith.index_cast %[[VAL_12]] : i32 to index
56+
// CHECK: %[[VAL_15:.*]] = tts.make_tptr %[[VAL_1]] to sizes: [512], strides: [1], offsets: {{\[}}%[[VAL_14]]], shape: [0], order: [] : <i64> to tensor<512x!tt.ptr<i64>>
57+
// CHECK: %[[VAL_16:.*]] = tt.splat %[[VAL_3]] : i32 -> tensor<512xi32>
58+
// CHECK: %[[VAL_17:.*]] = arith.cmpi slt, %[[VAL_11]], %[[VAL_16]] : tensor<512xi32>
59+
// CHECK: %[[VAL_18:.*]] = arith.index_cast %[[VAL_3]] : i32 to index
60+
// CHECK: %[[VAL_19:.*]] = arith.minsi %[[VAL_18]], %[[VAL_8]] : index
61+
// CHECK: %[[VAL_20:.*]] = arith.maxsi %[[VAL_19]], %[[VAL_7]] : index
62+
// CHECK: %[[VAL_21:.*]] = "tts.load"(%[[VAL_15]], %[[VAL_20]]) <{operandSegmentSizes = array<i32: 1, 1, 0>, static_mask_dims = array<i64: -9223372036854775808>}> : (tensor<512x!tt.ptr<i64>>, index) -> tensor<512xi64>
63+
// CHECK: %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_4]], %[[VAL_9]] : i32
64+
// CHECK: %[[VAL_23:.*]] = scf.if %[[VAL_22]] -> (tensor<512x!tt.ptr<i64>>) {
65+
// CHECK: %[[VAL_24:.*]] = arith.extsi %[[VAL_5]] : i32 to i64
66+
// CHECK: %[[VAL_25:.*]] = tt.splat %[[VAL_24]] : i64 -> tensor<512xi64>
67+
// CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_21]], %[[VAL_25]] : tensor<512xi64>
68+
// CHECK: %[[VAL_27:.*]] = tt.splat %[[VAL_0]] : !tt.ptr<i64> -> tensor<512x!tt.ptr<i64>>
69+
// CHECK: %[[VAL_28:.*]] = tt.addptr %[[VAL_27]], %[[VAL_26]] : tensor<512x!tt.ptr<i64>>, tensor<512xi64>
70+
// CHECK: %[[VAL_29:.*]] = tt.addptr %[[VAL_28]], %[[VAL_11]] : tensor<512x!tt.ptr<i64>>, tensor<512xi32>
71+
// CHECK: %[[VAL_30:.*]], %[[VAL_31:.*]], %[[VAL_32:.*]] = "tts.get_structured_state"(%[[VAL_29]]) <{resultSegmentSizes = array<i32: 1, 1, 1>}> : (tensor<512x!tt.ptr<i64>>) -> (tensor<512x!tt.ptr<i64>>, index, index)
72+
// CHECK: scf.yield %[[VAL_30]] : tensor<512x!tt.ptr<i64>>
73+
// CHECK: } else {
74+
// CHECK: %[[VAL_33:.*]] = arith.muli %[[VAL_10]], %[[VAL_5]] : i32
75+
// CHECK: %[[VAL_34:.*]] = tt.addptr %[[VAL_0]], %[[VAL_33]] : !tt.ptr<i64>, i32
76+
// CHECK: %[[VAL_35:.*]] = tt.splat %[[VAL_34]] : !tt.ptr<i64> -> tensor<512x!tt.ptr<i64>>
77+
// CHECK: %[[VAL_36:.*]] = tt.addptr %[[VAL_35]], %[[VAL_21]] : tensor<512x!tt.ptr<i64>>, tensor<512xi64>
78+
// CHECK: %[[VAL_37:.*]], %[[VAL_38:.*]], %[[VAL_39:.*]] = "tts.get_structured_state"(%[[VAL_36]]) <{resultSegmentSizes = array<i32: 1, 1, 1>}> : (tensor<512x!tt.ptr<i64>>) -> (tensor<512x!tt.ptr<i64>>, index, index)
79+
// CHECK: scf.yield %[[VAL_37]] : tensor<512x!tt.ptr<i64>>
80+
// CHECK: }
81+
// CHECK: %[[VAL_40:.*]] = tt.load %[[VAL_23]], %[[VAL_17]] : tensor<512x!tt.ptr<i64>>
82+
// CHECK: %[[VAL_41:.*]] = tts.make_tptr %[[VAL_2]] to sizes: [512], strides: [1], offsets: {{\[}}%[[VAL_13]]], shape: [0], order: [] : <i64> to tensor<512x!tt.ptr<i64>>
83+
// CHECK: %[[VAL_42:.*]] = arith.index_cast %[[VAL_3]] : i32 to index
84+
// CHECK: %[[VAL_43:.*]] = arith.minsi %[[VAL_42]], %[[VAL_8]] : index
85+
// CHECK: %[[VAL_44:.*]] = arith.maxsi %[[VAL_43]], %[[VAL_7]] : index
86+
// CHECK: "tts.store"(%[[VAL_41]], %[[VAL_40]], %[[VAL_44]]) <{static_mask_dims = array<i64: -9223372036854775808>}> : (tensor<512x!tt.ptr<i64>>, tensor<512xi64>, index) -> ()
87+
// CHECK: tt.return
88+
// CHECK: }

0 commit comments

Comments
 (0)