-
Notifications
You must be signed in to change notification settings - Fork 1
/
matmul_kernel.ttgir
249 lines (249 loc) · 24.6 KB
/
matmul_kernel.ttgir
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>
#loc = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":99:0)
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @matmul_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":99:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":99:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":99:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":99:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":99:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":99:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":99:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":99:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":99:0)) attributes {noinline = false} {
%c2_i32 = arith.constant 2 : i32 loc(#loc1)
%cst = arith.constant dense<16> : tensor<128x16xi32, #blocked> loc(#loc1)
%c16_i32 = arith.constant 16 : i32 loc(#loc1)
%c256_i32 = arith.constant 256 : i32 loc(#loc1)
%c128_i32 = arith.constant 128 : i32 loc(#loc1)
%c8_i32 = arith.constant 8 : i32 loc(#loc1)
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #blocked> loc(#loc1)
%cst_1 = arith.constant dense<0.000000e+00> : tensor<16x256xf32, #blocked1> loc(#loc1)
%c0_i32 = arith.constant 0 : i32 loc(#loc1)
%c1_i32 = arith.constant 1 : i32 loc(#loc1)
%c127_i32 = arith.constant 127 : i32 loc(#loc1)
%c255_i32 = arith.constant 255 : i32 loc(#loc1)
%c15_i32 = arith.constant 15 : i32 loc(#loc1)
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> loc(#loc1)
%0 = tt.get_program_id x : i32 loc(#loc2)
%1 = arith.addi %arg3, %c127_i32 : i32 loc(#loc56)
%2 = arith.divsi %1, %c128_i32 : i32 loc(#loc57)
%3 = arith.addi %arg4, %c255_i32 : i32 loc(#loc58)
%4 = arith.divsi %3, %c256_i32 : i32 loc(#loc59)
%5 = arith.muli %4, %c8_i32 : i32 loc(#loc7)
%6 = arith.divsi %0, %5 : i32 loc(#loc8)
%7 = arith.muli %6, %c8_i32 : i32 loc(#loc9)
%8 = arith.subi %2, %7 : i32 loc(#loc10)
%9 = arith.minsi %8, %c8_i32 : i32 loc(#loc11)
%10 = arith.remsi %0, %9 : i32 loc(#loc12)
%11 = arith.addi %7, %10 : i32 loc(#loc13)
%12 = arith.remsi %0, %5 : i32 loc(#loc14)
%13 = arith.divsi %12, %9 : i32 loc(#loc15)
%14 = arith.muli %11, %c128_i32 : i32 loc(#loc16)
%15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> loc(#loc17)
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc17)
%17 = tt.splat %14 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> loc(#loc18)
%18 = tt.splat %14 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc18)
%19 = arith.addi %17, %15 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> loc(#loc18)
%20 = arith.addi %18, %16 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc18)
%21 = tt.splat %arg3 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> loc(#loc19)
%22 = arith.remsi %19, %21 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> loc(#loc19)
%23 = arith.muli %13, %c256_i32 : i32 loc(#loc20)
%24 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc21)
%25 = tt.splat %23 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc22)
%26 = arith.addi %25, %24 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc22)
%27 = tt.splat %arg4 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc23)
%28 = arith.remsi %26, %27 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc23)
%29 = tt.expand_dims %22 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> loc(#loc24)
%30 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked> loc(#loc25)
%31 = arith.muli %29, %30 : tensor<128x1xi32, #blocked> loc(#loc25)
%32 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> loc(#loc26)
%33 = tt.expand_dims %32 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> loc(#loc26)
%34 = tt.broadcast %31 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> loc(#loc27)
%35 = tt.broadcast %33 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> loc(#loc27)
%36 = arith.addi %34, %35 : tensor<128x16xi32, #blocked> loc(#loc27)
%37 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<128x16x!tt.ptr<f32>, #blocked> loc(#loc28)
%38 = tt.addptr %37, %36 : tensor<128x16x!tt.ptr<f32>, #blocked>, tensor<128x16xi32, #blocked> loc(#loc28)
%39 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc29)
%40 = tt.expand_dims %39 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> loc(#loc29)
%41 = tt.splat %arg7 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc30)
%42 = arith.muli %40, %41 : tensor<16x1xi32, #blocked1> loc(#loc30)
%43 = tt.expand_dims %28 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> loc(#loc31)
%44 = tt.broadcast %42 : tensor<16x1xi32, #blocked1> -> tensor<16x256xi32, #blocked1> loc(#loc32)
%45 = tt.broadcast %43 : tensor<1x256xi32, #blocked1> -> tensor<16x256xi32, #blocked1> loc(#loc32)
%46 = arith.addi %44, %45 : tensor<16x256xi32, #blocked1> loc(#loc32)
%47 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<16x256x!tt.ptr<f32>, #blocked1> loc(#loc33)
%48 = tt.addptr %47, %46 : tensor<16x256x!tt.ptr<f32>, #blocked1>, tensor<16x256xi32, #blocked1> loc(#loc33)
%49 = arith.addi %arg5, %c15_i32 : i32 loc(#loc60)
%50 = arith.divsi %49, %c16_i32 : i32 loc(#loc61)
%51 = arith.muli %arg7, %c16_i32 : i32 loc(#loc35)
%52 = tt.splat %51 : i32 -> tensor<16x256xi32, #blocked1> loc(#loc36)
%53 = triton_gpu.local_alloc : () -> !tt.memdesc<2x128x16xf32, #shared, #triton_gpu.shared_memory, mutable> loc(#loc37)
%54 = triton_gpu.local_alloc : () -> !tt.memdesc<2x16x256xf32, #shared1, #triton_gpu.shared_memory, mutable> loc(#loc38)
%55 = arith.cmpi sgt, %50, %c0_i32 : i32 loc(#loc39)
%56 = tt.splat %arg5 : i32 -> tensor<1x16xi32, #blocked> loc(#loc40)
%57 = arith.cmpi slt, %33, %56 : tensor<1x16xi32, #blocked> loc(#loc40)
%58 = tt.broadcast %57 : tensor<1x16xi1, #blocked> -> tensor<128x16xi1, #blocked> loc(#loc37)
%59 = triton_gpu.memdesc_subview %53[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x128x16xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x16xf32, #shared, #triton_gpu.shared_memory, mutable> loc(#loc37)
%60 = tt.splat %55 : i1 -> tensor<128x16xi1, #blocked> loc(#loc39)
%61 = arith.andi %60, %58 : tensor<128x16xi1, #blocked> loc(#loc39)
%62 = triton_gpu.async_copy_global_to_local %38, %59 mask %61 other %cst_0 : tensor<128x16x!tt.ptr<f32>, #blocked> -> <128x16xf32, #shared, #triton_gpu.shared_memory, mutable> loc(#loc37)
%63 = triton_gpu.async_commit_group %62 loc(#loc37)
%64 = tt.splat %arg5 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc41)
%65 = arith.cmpi slt, %40, %64 : tensor<16x1xi32, #blocked1> loc(#loc41)
%66 = tt.broadcast %65 : tensor<16x1xi1, #blocked1> -> tensor<16x256xi1, #blocked1> loc(#loc38)
%67 = triton_gpu.memdesc_subview %54[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x16x256xf32, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x256xf32, #shared1, #triton_gpu.shared_memory, mutable> loc(#loc38)
%68 = tt.splat %55 : i1 -> tensor<16x256xi1, #blocked1> loc(#loc39)
%69 = arith.andi %68, %66 : tensor<16x256xi1, #blocked1> loc(#loc39)
%70 = triton_gpu.async_copy_global_to_local %48, %67 mask %69 other %cst_1 : tensor<16x256x!tt.ptr<f32>, #blocked1> -> <16x256xf32, #shared1, #triton_gpu.shared_memory, mutable> loc(#loc38)
%71 = triton_gpu.async_commit_group %70 loc(#loc38)
%72 = arith.cmpi sgt, %50, %c1_i32 : i32 loc(#loc39)
%73 = tt.addptr %38, %cst : tensor<128x16x!tt.ptr<f32>, #blocked>, tensor<128x16xi32, #blocked> loc(#loc42)
%74 = tt.addptr %48, %52 : tensor<16x256x!tt.ptr<f32>, #blocked1>, tensor<16x256xi32, #blocked1> loc(#loc36)
%75 = arith.subi %arg5, %c16_i32 : i32 loc(#loc43)
%76 = tt.splat %75 : i32 -> tensor<1x16xi32, #blocked> loc(#loc40)
%77 = arith.cmpi slt, %33, %76 : tensor<1x16xi32, #blocked> loc(#loc40)
%78 = tt.broadcast %77 : tensor<1x16xi1, #blocked> -> tensor<128x16xi1, #blocked> loc(#loc37)
%79 = triton_gpu.memdesc_subview %53[%c1_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x128x16xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x16xf32, #shared, #triton_gpu.shared_memory, mutable> loc(#loc37)
%80 = tt.splat %72 : i1 -> tensor<128x16xi1, #blocked> loc(#loc39)
%81 = arith.andi %80, %78 : tensor<128x16xi1, #blocked> loc(#loc39)
%82 = triton_gpu.async_copy_global_to_local %73, %79 mask %81 other %cst_0 : tensor<128x16x!tt.ptr<f32>, #blocked> -> <128x16xf32, #shared, #triton_gpu.shared_memory, mutable> loc(#loc37)
%83 = triton_gpu.async_commit_group %82 loc(#loc37)
%84 = tt.splat %75 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc41)
%85 = arith.cmpi slt, %40, %84 : tensor<16x1xi32, #blocked1> loc(#loc41)
%86 = tt.broadcast %85 : tensor<16x1xi1, #blocked1> -> tensor<16x256xi1, #blocked1> loc(#loc38)
%87 = triton_gpu.memdesc_subview %54[%c1_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x16x256xf32, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x256xf32, #shared1, #triton_gpu.shared_memory, mutable> loc(#loc38)
%88 = tt.splat %72 : i1 -> tensor<16x256xi1, #blocked1> loc(#loc39)
%89 = arith.andi %88, %86 : tensor<16x256xi1, #blocked1> loc(#loc39)
%90 = triton_gpu.async_copy_global_to_local %74, %87 mask %89 other %cst_1 : tensor<16x256x!tt.ptr<f32>, #blocked1> -> <16x256xf32, #shared1, #triton_gpu.shared_memory, mutable> loc(#loc38)
%91 = triton_gpu.async_commit_group %90 loc(#loc38)
%92 = triton_gpu.async_wait %71 {num = 2 : i32} loc(#loc37)
%93 = triton_gpu.memdesc_subview %59[%c0_i32, %c0_i32] : !tt.memdesc<128x16xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x8xf32, #shared, #triton_gpu.shared_memory> loc(#loc37)
%94 = triton_gpu.local_load %93 : !tt.memdesc<128x8xf32, #shared, #triton_gpu.shared_memory> -> tensor<128x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc37)
%95 = triton_gpu.memdesc_subview %67[%c0_i32, %c0_i32] : !tt.memdesc<16x256xf32, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<8x256xf32, #shared1, #triton_gpu.shared_memory> loc(#loc38)
%96 = triton_gpu.local_load %95 : !tt.memdesc<8x256xf32, #shared1, #triton_gpu.shared_memory> -> tensor<8x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> loc(#loc38)
%97:10 = scf.for %arg9 = %c0_i32 to %50 step %c1_i32 iter_args(%arg10 = %cst_2, %arg11 = %73, %arg12 = %74, %arg13 = %c1_i32, %arg14 = %c0_i32, %arg15 = %59, %arg16 = %67, %arg17 = %91, %arg18 = %94, %arg19 = %96) -> (tensor<128x256xf32, #mma>, tensor<128x16x!tt.ptr<f32>, #blocked>, tensor<16x256x!tt.ptr<f32>, #blocked1>, i32, i32, !tt.memdesc<128x16xf32, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x256xf32, #shared1, #triton_gpu.shared_memory, mutable>, !triton_gpu.async.token, tensor<128x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>, tensor<8x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>) : i32 {
%116 = arith.subi %50, %c2_i32 : i32 loc(#loc39)
%117 = arith.cmpi slt, %arg9, %116 : i32 loc(#loc39)
%118 = triton_gpu.memdesc_subview %arg15[%c0_i32, %c8_i32] : !tt.memdesc<128x16xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x8xf32, #shared, #triton_gpu.shared_memory> loc(#loc37)
%119 = triton_gpu.local_load %118 : !tt.memdesc<128x8xf32, #shared, #triton_gpu.shared_memory> -> tensor<128x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc37)
%120 = triton_gpu.memdesc_subview %arg16[%c8_i32, %c0_i32] : !tt.memdesc<16x256xf32, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<8x256xf32, #shared1, #triton_gpu.shared_memory> loc(#loc38)
%121 = triton_gpu.local_load %120 : !tt.memdesc<8x256xf32, #shared1, #triton_gpu.shared_memory> -> tensor<8x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> loc(#loc38)
%122 = tt.dot %arg18, %arg19, %arg10, inputPrecision = tf32 : tensor<128x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<8x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<128x256xf32, #mma> loc(#loc44)
%123 = tt.addptr %arg11, %cst : tensor<128x16x!tt.ptr<f32>, #blocked>, tensor<128x16xi32, #blocked> loc(#loc42)
%124 = tt.addptr %arg12, %52 : tensor<16x256x!tt.ptr<f32>, #blocked1>, tensor<16x256xi32, #blocked1> loc(#loc36)
%125 = arith.addi %arg13, %c1_i32 : i32 loc(#loc39)
%126 = arith.cmpi slt, %125, %c2_i32 : i32 loc(#loc39)
%127 = arith.select %126, %125, %c0_i32 : i32 loc(#loc39)
%128 = arith.addi %arg9, %c2_i32 : i32 loc(#loc39)
%129 = arith.muli %128, %c16_i32 : i32 loc(#loc45)
%130 = arith.subi %arg5, %129 : i32 loc(#loc43)
%131 = tt.splat %130 : i32 -> tensor<1x16xi32, #blocked> loc(#loc40)
%132 = arith.cmpi slt, %33, %131 : tensor<1x16xi32, #blocked> loc(#loc40)
%133 = tt.broadcast %132 : tensor<1x16xi1, #blocked> -> tensor<128x16xi1, #blocked> loc(#loc37)
%134 = triton_gpu.memdesc_subview %53[%127, %c0_i32, %c0_i32] : !tt.memdesc<2x128x16xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x16xf32, #shared, #triton_gpu.shared_memory, mutable> loc(#loc37)
%135 = tt.splat %117 : i1 -> tensor<128x16xi1, #blocked> loc(#loc39)
%136 = arith.andi %135, %133 : tensor<128x16xi1, #blocked> loc(#loc39)
%137 = triton_gpu.async_copy_global_to_local %123, %134 mask %136 other %cst_0 : tensor<128x16x!tt.ptr<f32>, #blocked> -> <128x16xf32, #shared, #triton_gpu.shared_memory, mutable> loc(#loc37)
%138 = triton_gpu.async_commit_group %137 loc(#loc37)
%139 = tt.splat %130 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc41)
%140 = arith.cmpi slt, %40, %139 : tensor<16x1xi32, #blocked1> loc(#loc41)
%141 = tt.broadcast %140 : tensor<16x1xi1, #blocked1> -> tensor<16x256xi1, #blocked1> loc(#loc38)
%142 = triton_gpu.memdesc_subview %54[%127, %c0_i32, %c0_i32] : !tt.memdesc<2x16x256xf32, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x256xf32, #shared1, #triton_gpu.shared_memory, mutable> loc(#loc38)
%143 = tt.splat %117 : i1 -> tensor<16x256xi1, #blocked1> loc(#loc39)
%144 = arith.andi %143, %141 : tensor<16x256xi1, #blocked1> loc(#loc39)
%145 = triton_gpu.async_copy_global_to_local %124, %142 mask %144 other %cst_1 : tensor<16x256x!tt.ptr<f32>, #blocked1> -> <16x256xf32, #shared1, #triton_gpu.shared_memory, mutable> loc(#loc38)
%146 = triton_gpu.async_commit_group %145 loc(#loc38)
%147 = arith.addi %arg14, %c1_i32 : i32 loc(#loc39)
%148 = arith.cmpi slt, %147, %c2_i32 : i32 loc(#loc39)
%149 = arith.select %148, %147, %c0_i32 : i32 loc(#loc39)
%150 = triton_gpu.memdesc_subview %53[%149, %c0_i32, %c0_i32] : !tt.memdesc<2x128x16xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x16xf32, #shared, #triton_gpu.shared_memory, mutable> loc(#loc37)
%151 = triton_gpu.async_wait %arg17 {num = 2 : i32} loc(#loc37)
%152 = triton_gpu.memdesc_subview %54[%149, %c0_i32, %c0_i32] : !tt.memdesc<2x16x256xf32, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x256xf32, #shared1, #triton_gpu.shared_memory, mutable> loc(#loc38)
%153 = triton_gpu.memdesc_subview %150[%c0_i32, %c0_i32] : !tt.memdesc<128x16xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x8xf32, #shared, #triton_gpu.shared_memory> loc(#loc37)
%154 = triton_gpu.local_load %153 : !tt.memdesc<128x8xf32, #shared, #triton_gpu.shared_memory> -> tensor<128x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc37)
%155 = triton_gpu.memdesc_subview %152[%c0_i32, %c0_i32] : !tt.memdesc<16x256xf32, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<8x256xf32, #shared1, #triton_gpu.shared_memory> loc(#loc38)
%156 = triton_gpu.local_load %155 : !tt.memdesc<8x256xf32, #shared1, #triton_gpu.shared_memory> -> tensor<8x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> loc(#loc38)
%157 = tt.dot %119, %121, %122, inputPrecision = tf32 : tensor<128x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<8x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<128x256xf32, #mma> loc(#loc44)
scf.yield %157, %123, %124, %127, %149, %150, %152, %146, %154, %156 : tensor<128x256xf32, #mma>, tensor<128x16x!tt.ptr<f32>, #blocked>, tensor<16x256x!tt.ptr<f32>, #blocked1>, i32, i32, !tt.memdesc<128x16xf32, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x256xf32, #shared1, #triton_gpu.shared_memory, mutable>, !triton_gpu.async.token, tensor<128x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>, tensor<8x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> loc(#loc39)
} loc(#loc39)
%98 = triton_gpu.async_wait {num = 0 : i32} loc(#loc39)
triton_gpu.local_dealloc %53 : !tt.memdesc<2x128x16xf32, #shared, #triton_gpu.shared_memory, mutable> loc(#loc39)
triton_gpu.local_dealloc %54 : !tt.memdesc<2x16x256xf32, #shared1, #triton_gpu.shared_memory, mutable> loc(#loc39)
%99 = tt.expand_dims %20 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc46)
%100 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked1> loc(#loc47)
%101 = arith.muli %100, %99 : tensor<128x1xi32, #blocked1> loc(#loc47)
%102 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<128x1x!tt.ptr<f32>, #blocked1> loc(#loc48)
%103 = tt.addptr %102, %101 : tensor<128x1x!tt.ptr<f32>, #blocked1>, tensor<128x1xi32, #blocked1> loc(#loc48)
%104 = tt.expand_dims %26 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> loc(#loc49)
%105 = tt.broadcast %103 : tensor<128x1x!tt.ptr<f32>, #blocked1> -> tensor<128x256x!tt.ptr<f32>, #blocked1> loc(#loc50)
%106 = tt.broadcast %104 : tensor<1x256xi32, #blocked1> -> tensor<128x256xi32, #blocked1> loc(#loc50)
%107 = tt.addptr %105, %106 : tensor<128x256x!tt.ptr<f32>, #blocked1>, tensor<128x256xi32, #blocked1> loc(#loc50)
%108 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked1> loc(#loc51)
%109 = arith.cmpi slt, %99, %108 : tensor<128x1xi32, #blocked1> loc(#loc51)
%110 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked1> loc(#loc52)
%111 = arith.cmpi slt, %104, %110 : tensor<1x256xi32, #blocked1> loc(#loc52)
%112 = tt.broadcast %109 : tensor<128x1xi1, #blocked1> -> tensor<128x256xi1, #blocked1> loc(#loc53)
%113 = tt.broadcast %111 : tensor<1x256xi1, #blocked1> -> tensor<128x256xi1, #blocked1> loc(#loc53)
%114 = arith.andi %112, %113 : tensor<128x256xi1, #blocked1> loc(#loc53)
%115 = triton_gpu.convert_layout %97#0 : tensor<128x256xf32, #mma> -> tensor<128x256xf32, #blocked1> loc(#loc54)
tt.store %107, %115, %114 : tensor<128x256x!tt.ptr<f32>, #blocked1> loc(#loc54)
tt.return loc(#loc55)
} loc(#loc)
} loc(#loc)
#loc1 = loc(unknown)
#loc2 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":117:24)
#loc3 = loc("/data/users/bertrand/triton/python/triton/language/standard.py":44:22)
#loc4 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":118:27)
#loc5 = loc("/data/users/bertrand/triton/python/triton/language/standard.py":44:28)
#loc6 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":119:27)
#loc7 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":120:38)
#loc8 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":121:22)
#loc9 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":122:29)
#loc10 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":123:35)
#loc11 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":123:48)
#loc12 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":124:33)
#loc13 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":124:27)
#loc14 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":125:19)
#loc15 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":125:40)
#loc16 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":127:23)
#loc17 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":127:51)
#loc18 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":127:38)
#loc19 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":127:68)
#loc20 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":128:23)
#loc21 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":128:51)
#loc22 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":128:38)
#loc23 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":128:68)
#loc24 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":130:30)
#loc25 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":130:41)
#loc26 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":130:60)
#loc27 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":130:53)
#loc28 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":130:22)
#loc29 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":131:29)
#loc30 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":131:40)
#loc31 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":131:60)
#loc32 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":131:52)
#loc33 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":131:22)
#loc34 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":134:33)
#loc35 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":139:33)
#loc36 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":139:18)
#loc37 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":135:20)
#loc38 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":136:20)
#loc39 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":134:22)
#loc40 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":135:51)
#loc41 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":136:51)
#loc42 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":138:18)
#loc43 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":135:55)
#loc44 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":137:35)
#loc45 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":135:59)
#loc46 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":144:41)
#loc47 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":144:33)
#loc48 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":144:21)
#loc49 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":144:72)
#loc50 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":144:52)
#loc51 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":145:33)
#loc52 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":145:58)
#loc53 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":145:39)
#loc54 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":146:21)
#loc55 = loc("/data/users/bertrand/tf32_gemm/triton_kernel.py":146:4)
#loc56 = loc(callsite(#loc3 at #loc4))
#loc57 = loc(callsite(#loc5 at #loc4))
#loc58 = loc(callsite(#loc3 at #loc6))
#loc59 = loc(callsite(#loc5 at #loc6))
#loc60 = loc(callsite(#loc3 at #loc34))
#loc61 = loc(callsite(#loc5 at #loc34))