|
141 | 141 | // MXF16-CONTRACT: return %[[VAL_0]] : tensor<2x16x64x48xf32>
|
142 | 142 | // MXF16-CONTRACT: }
|
143 | 143 |
|
144 |
| -// MXBF16-DEQUANT: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> |
145 |
| -// MXBF16-DEQUANT: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> |
146 |
| -// MXBF16-DEQUANT: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> |
147 |
| -// MXBF16-DEQUANT: #[[$ATTR_3:.+]] = affine_map<(d0, d1) -> (d0)> |
148 |
| -// MXBF16-DEQUANT: #[[$ATTR_4:.+]] = affine_map<(d0, d1) -> (d1)> |
149 |
| -// MXBF16-DEQUANT: #[[$ATTR_5:.+]] = affine_map<(d0, d1) -> (d0, d1)> |
| 144 | + |
| 145 | +// Perform Gemm dequntization using given scales. |
| 146 | + |
| 147 | +// MXBF16-DEQUANT: #map = affine_map<(d0, d1, d2) -> (d0, d2)> |
| 148 | +// MXBF16-DEQUANT: #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> |
| 149 | +// MXBF16-DEQUANT: #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> |
| 150 | +// MXBF16-DEQUANT: #map3 = affine_map<(d0, d1) -> (d0)> |
| 151 | +// MXBF16-DEQUANT: #map4 = affine_map<(d0, d1) -> (d1)> |
| 152 | +// MXBF16-DEQUANT: #map5 = affine_map<(d0, d1) -> (d0, d1)> |
150 | 153 | // MXBF16-DEQUANT-LABEL: func.func @entry(
|
151 |
| -// MXBF16-DEQUANT-SAME: %[[ARG0:.*]]: tensor<128x2304xbf16>, |
152 |
| -// MXBF16-DEQUANT-SAME: %[[ARG1:.*]]: tensor<128xf32>, |
153 |
| -// MXBF16-DEQUANT-SAME: %[[ARG2:.*]]: tensor<2304x768xbf16>, |
154 |
| -// MXBF16-DEQUANT-SAME: %[[ARG3:.*]]: tensor<768xf32>, |
155 |
| -// MXBF16-DEQUANT-SAME: %[[ARG4:.*]]: tensor<128x768xf32>) -> tensor<128x768xf32> { |
156 |
| -// MXBF16-DEQUANT: %[[VAL_0:.*]] = linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[ARG0]], %[[ARG2]] : tensor<128x2304xbf16>, tensor<2304x768xbf16>) outs(%[[ARG4]] : tensor<128x768xf32>) -> tensor<128x768xf32> |
157 |
| -// MXBF16-DEQUANT: %[[VAL_1:.*]] = tensor.empty() : tensor<128x768xf32> |
158 |
| -// MXBF16-DEQUANT: %[[VAL_2:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_3]], #[[$ATTR_4]], #[[$ATTR_5]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1]], %[[ARG3]] : tensor<128xf32>, tensor<768xf32>) outs(%[[VAL_1]] : tensor<128x768xf32>) { |
159 |
| -// MXBF16-DEQUANT: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32): |
160 |
| -// MXBF16-DEQUANT: %[[VAL_6:.*]] = arith.mulf %[[VAL_3]], %[[VAL_4]] : f32 |
161 |
| -// MXBF16-DEQUANT: linalg.yield %[[VAL_6]] : f32 |
162 |
| -// MXBF16-DEQUANT: } -> tensor<128x768xf32> |
163 |
| -// MXBF16-DEQUANT: %[[VAL_7:.*]] = tensor.empty() : tensor<128x768xf32> |
164 |
| -// MXBF16-DEQUANT: %[[VAL_8:.*]] = linalg.mul ins(%[[VAL_0]], %[[VAL_2]] : tensor<128x768xf32>, tensor<128x768xf32>) outs(%[[VAL_7]] : tensor<128x768xf32>) -> tensor<128x768xf32> |
165 |
| -// MXBF16-DEQUANT: return %[[VAL_8]] : tensor<128x768xf32> |
166 |
| -// MXBF16-DEQUANT: } |
167 |
| - |
168 |
| -// MXI8F32-DEQUANT: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> |
169 |
| -// MXI8F32-DEQUANT: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> |
170 |
| -// MXI8F32-DEQUANT: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> |
171 |
| -// MXI8F32-DEQUANT: #[[$ATTR_3:.+]] = affine_map<(d0, d1) -> (d0)> |
172 |
| -// MXI8F32-DEQUANT: #[[$ATTR_4:.+]] = affine_map<(d0, d1) -> (d1)> |
173 |
| -// MXI8F32-DEQUANT: #[[$ATTR_5:.+]] = affine_map<(d0, d1) -> (d0, d1)> |
| 154 | +// MXBF16-DEQUANT-SAME: %arg0: tensor<128x2304xbf16>, |
| 155 | +// MXBF16-DEQUANT-SAME: %arg1: tensor<128xf32>, |
| 156 | +// MXBF16-DEQUANT-SAME: %arg2: tensor<2304x768xbf16>, |
| 157 | +// MXBF16-DEQUANT-SAME: %arg3: tensor<768xf32>, |
| 158 | +// MXBF16-DEQUANT-SAME: %arg4: tensor<128x768xf32>) -> tensor<128x768xf32> { |
| 159 | +// MXBF16-DEQUANT: linalg.contract indexing_maps = [#map, #map1, #map2] |
| 160 | +// MXBF16-DEQUANT: linalg.generic {{.*}} iterator_types = ["parallel", "parallel"] |
| 161 | +// MXBF16-DEQUANT: arith.mulf |
| 162 | +// MXBF16-DEQUANT: linalg.mul |
| 163 | + |
| 164 | + |
| 165 | +// MXI8F32-DEQUANT: #map = affine_map<(d0, d1, d2) -> (d0, d2)> |
| 166 | +// MXI8F32-DEQUANT: #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> |
| 167 | +// MXI8F32-DEQUANT: #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> |
| 168 | +// MXI8F32-DEQUANT: #map3 = affine_map<(d0, d1) -> (d0)> |
| 169 | +// MXI8F32-DEQUANT: #map4 = affine_map<(d0, d1) -> (d1)> |
| 170 | +// MXI8F32-DEQUANT: #map5 = affine_map<(d0, d1) -> (d0, d1)> |
174 | 171 | // MXI8F32-DEQUANT-LABEL: func.func @entry(
|
175 |
| -// MXI8F32-DEQUANT-SAME: %[[ARG0:.*]]: tensor<128x2304xi8>, |
176 |
| -// MXI8F32-DEQUANT-SAME: %[[ARG1:.*]]: tensor<128xf32>, |
177 |
| -// MXI8F32-DEQUANT-SAME: %[[ARG2:.*]]: tensor<2304x768xi8>, |
178 |
| -// MXI8F32-DEQUANT-SAME: %[[ARG3:.*]]: tensor<768xf32>, |
179 |
| -// MXI8F32-DEQUANT-SAME: %[[ARG4:.*]]: tensor<128x768xf32>) -> tensor<128x768xf32> { |
180 |
| -// MXI8F32-DEQUANT: %[[VAL_0:.*]] = linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[ARG0]], %[[ARG2]] : tensor<128x2304xi8>, tensor<2304x768xi8>) outs(%[[ARG4]] : tensor<128x768xf32>) -> tensor<128x768xf32> |
181 |
| -// MXI8F32-DEQUANT: %[[VAL_1:.*]] = tensor.empty() : tensor<128x768xf32> |
182 |
| -// MXI8F32-DEQUANT: %[[VAL_2:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_3]], #[[$ATTR_4]], #[[$ATTR_5]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1]], %[[ARG3]] : tensor<128xf32>, tensor<768xf32>) outs(%[[VAL_1]] : tensor<128x768xf32>) { |
183 |
| -// MXI8F32-DEQUANT: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32): |
184 |
| -// MXI8F32-DEQUANT: %[[VAL_6:.*]] = arith.mulf %[[VAL_3]], %[[VAL_4]] : f32 |
185 |
| -// MXI8F32-DEQUANT: linalg.yield %[[VAL_6]] : f32 |
186 |
| -// MXI8F32-DEQUANT: } -> tensor<128x768xf32> |
187 |
| -// MXI8F32-DEQUANT: %[[VAL_7:.*]] = tensor.empty() : tensor<128x768xf32> |
188 |
| -// MXI8F32-DEQUANT: %[[VAL_8:.*]] = linalg.mul ins(%[[VAL_0]], %[[VAL_2]] : tensor<128x768xf32>, tensor<128x768xf32>) outs(%[[VAL_7]] : tensor<128x768xf32>) -> tensor<128x768xf32> |
189 |
| -// MXI8F32-DEQUANT: return %[[VAL_8]] : tensor<128x768xf32> |
190 |
| -// MXI8F32-DEQUANT: } |
191 |
| - |
192 |
| -// MXF32I8-QUANT: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> |
193 |
| -// MXF32I8-QUANT: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> |
194 |
| -// MXF32I8-QUANT: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> |
195 |
| -// MXF32I8-QUANT: #[[$ATTR_3:.+]] = affine_map<(d0) -> (d0)> |
196 |
| -// MXF32I8-QUANT: #[[$ATTR_4:.+]] = affine_map<(d0, d1) -> (d0, d1)> |
| 172 | +// MXI8F32-DEQUANT-SAME: %arg0: tensor<128x2304xi8>, |
| 173 | +// MXI8F32-DEQUANT-SAME: %arg1: tensor<128xf32>, |
| 174 | +// MXI8F32-DEQUANT-SAME: %arg2: tensor<2304x768xi8>, |
| 175 | +// MXI8F32-DEQUANT-SAME: %arg3: tensor<768xf32>, |
| 176 | +// MXI8F32-DEQUANT-SAME: %arg4: tensor<128x768xf32>) -> tensor<128x768xf32> { |
| 177 | +// MXI8F32-DEQUANT: linalg.contract indexing_maps = [#map, #map1, #map2] |
| 178 | +// MXI8F32-DEQUANT: linalg.generic {{.*}} iterator_types = ["parallel", "parallel"] |
| 179 | +// MXI8F32-DEQUANT: arith.mulf |
| 180 | +// MXI8F32-DEQUANT: linalg.mul |
| 181 | + |
| 182 | + |
| 183 | +// Perform Gemm quntization with dynamic scale computation. |
| 184 | + |
| 185 | +// MXF32I8-QUANT: #map = affine_map<(d0, d1, d2) -> (d0, d2)> |
| 186 | +// MXF32I8-QUANT: #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> |
| 187 | +// MXF32I8-QUANT: #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> |
| 188 | +// MXF32I8-QUANT: #map3 = affine_map<(d0) -> (d0)> |
| 189 | +// MXF32I8-QUANT: #map4 = affine_map<(d0, d1) -> (d0, d1)> |
197 | 190 | // MXF32I8-QUANT-LABEL: func.func @entry(
|
198 | 191 | // MXF32I8-QUANT-SAME: %[[ARG0:.*]]: tensor<128x2304xf32>,
|
199 | 192 | // MXF32I8-QUANT-SAME: %[[ARG1:.*]]: tensor<2304x768xf32>,
|
200 | 193 | // MXF32I8-QUANT-SAME: %[[ARG2:.*]]: tensor<128x768xi8>) -> tensor<128x768xi8> {
|
201 |
| -// MXF32I8-QUANT: %[[VAL_0:.*]] = arith.constant 0.000000e+00 : f32 |
202 |
| -// MXF32I8-QUANT: %[[VAL_1:.*]] = tensor.empty() : tensor<128x768xf32> |
203 |
| -// MXF32I8-QUANT: %[[VAL_2:.*]] = linalg.fill ins(%[[VAL_0]] : f32) outs(%[[VAL_1]] : tensor<128x768xf32>) -> tensor<128x768xf32> |
204 |
| -// MXF32I8-QUANT: %[[VAL_3:.*]] = linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[ARG0]], %[[ARG1]] : tensor<128x2304xf32>, tensor<2304x768xf32>) outs(%[[VAL_2]] : tensor<128x768xf32>) -> tensor<128x768xf32> |
205 |
| -// MXF32I8-QUANT: %[[VAL_4:.*]] = tensor.empty() : tensor<128x768xf32> |
206 |
| -// MXF32I8-QUANT: %[[VAL_5:.*]] = arith.constant 0xFF800000 : f32 |
207 |
| -// MXF32I8-QUANT: %[[VAL_6:.*]] = tensor.empty() : tensor<768xf32> |
208 |
| -// MXF32I8-QUANT: %[[VAL_7:.*]] = linalg.fill ins(%[[VAL_5]] : f32) outs(%[[VAL_6]] : tensor<768xf32>) -> tensor<768xf32> |
209 |
| -// MXF32I8-QUANT: %[[VAL_8:.*]] = linalg.reduce ins(%[[VAL_3]] : tensor<128x768xf32>) outs(%[[VAL_7]] : tensor<768xf32>) dimensions = [0] |
210 |
| -// MXF32I8-QUANT: (%[[VAL_9:.*]]: f32, %[[VAL_10:.*]]: f32) { |
211 |
| -// MXF32I8-QUANT: %[[VAL_11:.*]] = math.absf %[[VAL_9]] : f32 |
212 |
| -// MXF32I8-QUANT: %[[VAL_12:.*]] = arith.maximumf %[[VAL_11]], %[[VAL_10]] : f32 |
213 |
| -// MXF32I8-QUANT: linalg.yield %[[VAL_12]] : f32 |
214 |
| -// MXF32I8-QUANT: } |
215 |
| -// MXF32I8-QUANT: %[[VAL_13:.*]] = arith.constant 0 : i32 |
216 |
| -// MXF32I8-QUANT: %[[VAL_14:.*]] = arith.constant 0.000000e+00 : f32 |
217 |
| -// MXF32I8-QUANT: %[[VAL_15:.*]] = tensor.empty() : tensor<768xf32> |
218 |
| -// MXF32I8-QUANT: %[[VAL_16:.*]] = linalg.fill ins(%[[VAL_14]] : f32) outs(%[[VAL_15]] : tensor<768xf32>) -> tensor<768xf32> |
219 |
| -// MXF32I8-QUANT: %[[VAL_17:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_3]], #[[$ATTR_3]]], iterator_types = ["parallel"]} ins(%[[VAL_8]] : tensor<768xf32>) outs(%[[VAL_16]] : tensor<768xf32>) { |
220 |
| -// MXF32I8-QUANT: ^bb0(%[[VAL_18:.*]]: f32, %[[VAL_19:.*]]: f32): |
221 |
| -// MXF32I8-QUANT: %[[VAL_20:.*]] = llvm.intr.frexp(%[[VAL_18]]) : (f32) -> !llvm.struct<(f32, i32)> |
222 |
| -// MXF32I8-QUANT: %[[VAL_21:.*]] = llvm.extractvalue %[[VAL_20]][1] : !llvm.struct<(f32, i32)> |
223 |
| -// MXF32I8-QUANT: %[[VAL_22:.*]] = arith.constant 7 : i32 |
224 |
| -// MXF32I8-QUANT: %[[VAL_23:.*]] = arith.subi %[[VAL_21]], %[[VAL_22]] : i32 |
225 |
| -// MXF32I8-QUANT: %[[VAL_24:.*]] = arith.subi %[[VAL_13]], %[[VAL_23]] : i32 |
226 |
| -// MXF32I8-QUANT: %[[VAL_25:.*]] = arith.sitofp %[[VAL_24]] : i32 to f32 |
227 |
| -// MXF32I8-QUANT: %[[VAL_26:.*]] = math.exp2 %[[VAL_25]] : f32 |
228 |
| -// MXF32I8-QUANT: linalg.yield %[[VAL_26]] : f32 |
229 |
| -// MXF32I8-QUANT: } -> tensor<768xf32> |
230 |
| -// MXF32I8-QUANT: %[[VAL_27:.*]] = linalg.fill ins(%[[VAL_5]] : f32) outs(%[[VAL_4]] : tensor<128x768xf32>) -> tensor<128x768xf32> |
231 |
| -// MXF32I8-QUANT: %[[VAL_28:.*]] = linalg.broadcast ins(%[[VAL_17]] : tensor<768xf32>) outs(%[[VAL_27]] : tensor<128x768xf32>) dimensions = [0] |
232 |
| -// MXF32I8-QUANT: %[[VAL_29:.*]] = linalg.mul ins(%[[VAL_3]], %[[VAL_28]] : tensor<128x768xf32>, tensor<128x768xf32>) outs(%[[VAL_2]] : tensor<128x768xf32>) -> tensor<128x768xf32> |
233 |
| -// MXF32I8-QUANT: %[[VAL_30:.*]] = tensor.empty() : tensor<128x768xi8> |
234 |
| -// MXF32I8-QUANT: %[[VAL_31:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_4]], #[[$ATTR_4]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_29]] : tensor<128x768xf32>) outs(%[[VAL_30]] : tensor<128x768xi8>) { |
235 |
| -// MXF32I8-QUANT: ^bb0(%[[VAL_32:.*]]: f32, %[[VAL_33:.*]]: i8): |
236 |
| -// MXF32I8-QUANT: %[[VAL_34:.*]] = arith.fptosi %[[VAL_32]] : f32 to i8 |
237 |
| -// MXF32I8-QUANT: linalg.yield %[[VAL_34]] : i8 |
238 |
| -// MXF32I8-QUANT: } -> tensor<128x768xi8> |
239 |
| -// MXF32I8-QUANT: return %[[VAL_31]] : tensor<128x768xi8> |
240 |
| -// MXF32I8-QUANT: } |
| 194 | +// MXF32I8-QUANT: linalg.contract indexing_maps = [#map, #map1, #map2] |
| 195 | +// MXF32I8-QUANT: linalg.reduce {{.*}} dimensions = [0] |
| 196 | +// MXF32I8-QUANT: math.absf |
| 197 | +// MXF32I8-QUANT: arith.maximumf |
| 198 | +// MXF32I8-QUANT: linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel"]} |
| 199 | +// MXF32I8-QUANT: llvm.intr.frexp |
| 200 | +// MXF32I8-QUANT: llvm.extractvalue |
| 201 | +// MXF32I8-QUANT: arith.constant 7 |
| 202 | +// MXF32I8-QUANT: arith.subi |
| 203 | +// MXF32I8-QUANT: arith.subi |
| 204 | +// MXF32I8-QUANT: arith.sitofp |
| 205 | +// MXF32I8-QUANT: math.exp2 |
| 206 | +// MXF32I8-QUANT: linalg.broadcast |
| 207 | +// MXF32I8-QUANT: linalg.mul |
| 208 | +// MXF32I8-QUANT: linalg.generic |
| 209 | +// MXF32I8-QUANT: arith.fptosi |
0 commit comments