Skip to content

Commit 772484d

Browse files
committed
Refactor the code, simplify conditional test case checks.
1 parent baf4a7d commit 772484d

File tree

4 files changed

+84
-123
lines changed

4 files changed

+84
-123
lines changed

test/Integration/mlir-gen-matmul.mlir

Lines changed: 60 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -141,100 +141,69 @@
141141
// MXF16-CONTRACT: return %[[VAL_0]] : tensor<2x16x64x48xf32>
142142
// MXF16-CONTRACT: }
143143

144-
// MXBF16-DEQUANT: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
145-
// MXBF16-DEQUANT: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
146-
// MXBF16-DEQUANT: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
147-
// MXBF16-DEQUANT: #[[$ATTR_3:.+]] = affine_map<(d0, d1) -> (d0)>
148-
// MXBF16-DEQUANT: #[[$ATTR_4:.+]] = affine_map<(d0, d1) -> (d1)>
149-
// MXBF16-DEQUANT: #[[$ATTR_5:.+]] = affine_map<(d0, d1) -> (d0, d1)>
144+
145+
// Perform Gemm dequntization using given scales.
146+
147+
// MXBF16-DEQUANT: #map = affine_map<(d0, d1, d2) -> (d0, d2)>
148+
// MXBF16-DEQUANT: #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
149+
// MXBF16-DEQUANT: #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
150+
// MXBF16-DEQUANT: #map3 = affine_map<(d0, d1) -> (d0)>
151+
// MXBF16-DEQUANT: #map4 = affine_map<(d0, d1) -> (d1)>
152+
// MXBF16-DEQUANT: #map5 = affine_map<(d0, d1) -> (d0, d1)>
150153
// MXBF16-DEQUANT-LABEL: func.func @entry(
151-
// MXBF16-DEQUANT-SAME: %[[ARG0:.*]]: tensor<128x2304xbf16>,
152-
// MXBF16-DEQUANT-SAME: %[[ARG1:.*]]: tensor<128xf32>,
153-
// MXBF16-DEQUANT-SAME: %[[ARG2:.*]]: tensor<2304x768xbf16>,
154-
// MXBF16-DEQUANT-SAME: %[[ARG3:.*]]: tensor<768xf32>,
155-
// MXBF16-DEQUANT-SAME: %[[ARG4:.*]]: tensor<128x768xf32>) -> tensor<128x768xf32> {
156-
// MXBF16-DEQUANT: %[[VAL_0:.*]] = linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[ARG0]], %[[ARG2]] : tensor<128x2304xbf16>, tensor<2304x768xbf16>) outs(%[[ARG4]] : tensor<128x768xf32>) -> tensor<128x768xf32>
157-
// MXBF16-DEQUANT: %[[VAL_1:.*]] = tensor.empty() : tensor<128x768xf32>
158-
// MXBF16-DEQUANT: %[[VAL_2:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_3]], #[[$ATTR_4]], #[[$ATTR_5]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1]], %[[ARG3]] : tensor<128xf32>, tensor<768xf32>) outs(%[[VAL_1]] : tensor<128x768xf32>) {
159-
// MXBF16-DEQUANT: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
160-
// MXBF16-DEQUANT: %[[VAL_6:.*]] = arith.mulf %[[VAL_3]], %[[VAL_4]] : f32
161-
// MXBF16-DEQUANT: linalg.yield %[[VAL_6]] : f32
162-
// MXBF16-DEQUANT: } -> tensor<128x768xf32>
163-
// MXBF16-DEQUANT: %[[VAL_7:.*]] = tensor.empty() : tensor<128x768xf32>
164-
// MXBF16-DEQUANT: %[[VAL_8:.*]] = linalg.mul ins(%[[VAL_0]], %[[VAL_2]] : tensor<128x768xf32>, tensor<128x768xf32>) outs(%[[VAL_7]] : tensor<128x768xf32>) -> tensor<128x768xf32>
165-
// MXBF16-DEQUANT: return %[[VAL_8]] : tensor<128x768xf32>
166-
// MXBF16-DEQUANT: }
167-
168-
// MXI8F32-DEQUANT: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
169-
// MXI8F32-DEQUANT: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
170-
// MXI8F32-DEQUANT: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
171-
// MXI8F32-DEQUANT: #[[$ATTR_3:.+]] = affine_map<(d0, d1) -> (d0)>
172-
// MXI8F32-DEQUANT: #[[$ATTR_4:.+]] = affine_map<(d0, d1) -> (d1)>
173-
// MXI8F32-DEQUANT: #[[$ATTR_5:.+]] = affine_map<(d0, d1) -> (d0, d1)>
154+
// MXBF16-DEQUANT-SAME: %arg0: tensor<128x2304xbf16>,
155+
// MXBF16-DEQUANT-SAME: %arg1: tensor<128xf32>,
156+
// MXBF16-DEQUANT-SAME: %arg2: tensor<2304x768xbf16>,
157+
// MXBF16-DEQUANT-SAME: %arg3: tensor<768xf32>,
158+
// MXBF16-DEQUANT-SAME: %arg4: tensor<128x768xf32>) -> tensor<128x768xf32> {
159+
// MXBF16-DEQUANT: linalg.contract indexing_maps = [#map, #map1, #map2]
160+
// MXBF16-DEQUANT: linalg.generic {{.*}} iterator_types = ["parallel", "parallel"]
161+
// MXBF16-DEQUANT: arith.mulf
162+
// MXBF16-DEQUANT: linalg.mul
163+
164+
165+
// MXI8F32-DEQUANT: #map = affine_map<(d0, d1, d2) -> (d0, d2)>
166+
// MXI8F32-DEQUANT: #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
167+
// MXI8F32-DEQUANT: #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
168+
// MXI8F32-DEQUANT: #map3 = affine_map<(d0, d1) -> (d0)>
169+
// MXI8F32-DEQUANT: #map4 = affine_map<(d0, d1) -> (d1)>
170+
// MXI8F32-DEQUANT: #map5 = affine_map<(d0, d1) -> (d0, d1)>
174171
// MXI8F32-DEQUANT-LABEL: func.func @entry(
175-
// MXI8F32-DEQUANT-SAME: %[[ARG0:.*]]: tensor<128x2304xi8>,
176-
// MXI8F32-DEQUANT-SAME: %[[ARG1:.*]]: tensor<128xf32>,
177-
// MXI8F32-DEQUANT-SAME: %[[ARG2:.*]]: tensor<2304x768xi8>,
178-
// MXI8F32-DEQUANT-SAME: %[[ARG3:.*]]: tensor<768xf32>,
179-
// MXI8F32-DEQUANT-SAME: %[[ARG4:.*]]: tensor<128x768xf32>) -> tensor<128x768xf32> {
180-
// MXI8F32-DEQUANT: %[[VAL_0:.*]] = linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[ARG0]], %[[ARG2]] : tensor<128x2304xi8>, tensor<2304x768xi8>) outs(%[[ARG4]] : tensor<128x768xf32>) -> tensor<128x768xf32>
181-
// MXI8F32-DEQUANT: %[[VAL_1:.*]] = tensor.empty() : tensor<128x768xf32>
182-
// MXI8F32-DEQUANT: %[[VAL_2:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_3]], #[[$ATTR_4]], #[[$ATTR_5]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1]], %[[ARG3]] : tensor<128xf32>, tensor<768xf32>) outs(%[[VAL_1]] : tensor<128x768xf32>) {
183-
// MXI8F32-DEQUANT: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
184-
// MXI8F32-DEQUANT: %[[VAL_6:.*]] = arith.mulf %[[VAL_3]], %[[VAL_4]] : f32
185-
// MXI8F32-DEQUANT: linalg.yield %[[VAL_6]] : f32
186-
// MXI8F32-DEQUANT: } -> tensor<128x768xf32>
187-
// MXI8F32-DEQUANT: %[[VAL_7:.*]] = tensor.empty() : tensor<128x768xf32>
188-
// MXI8F32-DEQUANT: %[[VAL_8:.*]] = linalg.mul ins(%[[VAL_0]], %[[VAL_2]] : tensor<128x768xf32>, tensor<128x768xf32>) outs(%[[VAL_7]] : tensor<128x768xf32>) -> tensor<128x768xf32>
189-
// MXI8F32-DEQUANT: return %[[VAL_8]] : tensor<128x768xf32>
190-
// MXI8F32-DEQUANT: }
191-
192-
// MXF32I8-QUANT: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
193-
// MXF32I8-QUANT: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
194-
// MXF32I8-QUANT: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
195-
// MXF32I8-QUANT: #[[$ATTR_3:.+]] = affine_map<(d0) -> (d0)>
196-
// MXF32I8-QUANT: #[[$ATTR_4:.+]] = affine_map<(d0, d1) -> (d0, d1)>
172+
// MXI8F32-DEQUANT-SAME: %arg0: tensor<128x2304xi8>,
173+
// MXI8F32-DEQUANT-SAME: %arg1: tensor<128xf32>,
174+
// MXI8F32-DEQUANT-SAME: %arg2: tensor<2304x768xi8>,
175+
// MXI8F32-DEQUANT-SAME: %arg3: tensor<768xf32>,
176+
// MXI8F32-DEQUANT-SAME: %arg4: tensor<128x768xf32>) -> tensor<128x768xf32> {
177+
// MXI8F32-DEQUANT: linalg.contract indexing_maps = [#map, #map1, #map2]
178+
// MXI8F32-DEQUANT: linalg.generic {{.*}} iterator_types = ["parallel", "parallel"]
179+
// MXI8F32-DEQUANT: arith.mulf
180+
// MXI8F32-DEQUANT: linalg.mul
181+
182+
183+
// Perform Gemm quntization with dynamic scale computation.
184+
185+
// MXF32I8-QUANT: #map = affine_map<(d0, d1, d2) -> (d0, d2)>
186+
// MXF32I8-QUANT: #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
187+
// MXF32I8-QUANT: #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
188+
// MXF32I8-QUANT: #map3 = affine_map<(d0) -> (d0)>
189+
// MXF32I8-QUANT: #map4 = affine_map<(d0, d1) -> (d0, d1)>
197190
// MXF32I8-QUANT-LABEL: func.func @entry(
198191
// MXF32I8-QUANT-SAME: %[[ARG0:.*]]: tensor<128x2304xf32>,
199192
// MXF32I8-QUANT-SAME: %[[ARG1:.*]]: tensor<2304x768xf32>,
200193
// MXF32I8-QUANT-SAME: %[[ARG2:.*]]: tensor<128x768xi8>) -> tensor<128x768xi8> {
201-
// MXF32I8-QUANT: %[[VAL_0:.*]] = arith.constant 0.000000e+00 : f32
202-
// MXF32I8-QUANT: %[[VAL_1:.*]] = tensor.empty() : tensor<128x768xf32>
203-
// MXF32I8-QUANT: %[[VAL_2:.*]] = linalg.fill ins(%[[VAL_0]] : f32) outs(%[[VAL_1]] : tensor<128x768xf32>) -> tensor<128x768xf32>
204-
// MXF32I8-QUANT: %[[VAL_3:.*]] = linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[ARG0]], %[[ARG1]] : tensor<128x2304xf32>, tensor<2304x768xf32>) outs(%[[VAL_2]] : tensor<128x768xf32>) -> tensor<128x768xf32>
205-
// MXF32I8-QUANT: %[[VAL_4:.*]] = tensor.empty() : tensor<128x768xf32>
206-
// MXF32I8-QUANT: %[[VAL_5:.*]] = arith.constant 0xFF800000 : f32
207-
// MXF32I8-QUANT: %[[VAL_6:.*]] = tensor.empty() : tensor<768xf32>
208-
// MXF32I8-QUANT: %[[VAL_7:.*]] = linalg.fill ins(%[[VAL_5]] : f32) outs(%[[VAL_6]] : tensor<768xf32>) -> tensor<768xf32>
209-
// MXF32I8-QUANT: %[[VAL_8:.*]] = linalg.reduce ins(%[[VAL_3]] : tensor<128x768xf32>) outs(%[[VAL_7]] : tensor<768xf32>) dimensions = [0]
210-
// MXF32I8-QUANT: (%[[VAL_9:.*]]: f32, %[[VAL_10:.*]]: f32) {
211-
// MXF32I8-QUANT: %[[VAL_11:.*]] = math.absf %[[VAL_9]] : f32
212-
// MXF32I8-QUANT: %[[VAL_12:.*]] = arith.maximumf %[[VAL_11]], %[[VAL_10]] : f32
213-
// MXF32I8-QUANT: linalg.yield %[[VAL_12]] : f32
214-
// MXF32I8-QUANT: }
215-
// MXF32I8-QUANT: %[[VAL_13:.*]] = arith.constant 0 : i32
216-
// MXF32I8-QUANT: %[[VAL_14:.*]] = arith.constant 0.000000e+00 : f32
217-
// MXF32I8-QUANT: %[[VAL_15:.*]] = tensor.empty() : tensor<768xf32>
218-
// MXF32I8-QUANT: %[[VAL_16:.*]] = linalg.fill ins(%[[VAL_14]] : f32) outs(%[[VAL_15]] : tensor<768xf32>) -> tensor<768xf32>
219-
// MXF32I8-QUANT: %[[VAL_17:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_3]], #[[$ATTR_3]]], iterator_types = ["parallel"]} ins(%[[VAL_8]] : tensor<768xf32>) outs(%[[VAL_16]] : tensor<768xf32>) {
220-
// MXF32I8-QUANT: ^bb0(%[[VAL_18:.*]]: f32, %[[VAL_19:.*]]: f32):
221-
// MXF32I8-QUANT: %[[VAL_20:.*]] = llvm.intr.frexp(%[[VAL_18]]) : (f32) -> !llvm.struct<(f32, i32)>
222-
// MXF32I8-QUANT: %[[VAL_21:.*]] = llvm.extractvalue %[[VAL_20]][1] : !llvm.struct<(f32, i32)>
223-
// MXF32I8-QUANT: %[[VAL_22:.*]] = arith.constant 7 : i32
224-
// MXF32I8-QUANT: %[[VAL_23:.*]] = arith.subi %[[VAL_21]], %[[VAL_22]] : i32
225-
// MXF32I8-QUANT: %[[VAL_24:.*]] = arith.subi %[[VAL_13]], %[[VAL_23]] : i32
226-
// MXF32I8-QUANT: %[[VAL_25:.*]] = arith.sitofp %[[VAL_24]] : i32 to f32
227-
// MXF32I8-QUANT: %[[VAL_26:.*]] = math.exp2 %[[VAL_25]] : f32
228-
// MXF32I8-QUANT: linalg.yield %[[VAL_26]] : f32
229-
// MXF32I8-QUANT: } -> tensor<768xf32>
230-
// MXF32I8-QUANT: %[[VAL_27:.*]] = linalg.fill ins(%[[VAL_5]] : f32) outs(%[[VAL_4]] : tensor<128x768xf32>) -> tensor<128x768xf32>
231-
// MXF32I8-QUANT: %[[VAL_28:.*]] = linalg.broadcast ins(%[[VAL_17]] : tensor<768xf32>) outs(%[[VAL_27]] : tensor<128x768xf32>) dimensions = [0]
232-
// MXF32I8-QUANT: %[[VAL_29:.*]] = linalg.mul ins(%[[VAL_3]], %[[VAL_28]] : tensor<128x768xf32>, tensor<128x768xf32>) outs(%[[VAL_2]] : tensor<128x768xf32>) -> tensor<128x768xf32>
233-
// MXF32I8-QUANT: %[[VAL_30:.*]] = tensor.empty() : tensor<128x768xi8>
234-
// MXF32I8-QUANT: %[[VAL_31:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_4]], #[[$ATTR_4]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_29]] : tensor<128x768xf32>) outs(%[[VAL_30]] : tensor<128x768xi8>) {
235-
// MXF32I8-QUANT: ^bb0(%[[VAL_32:.*]]: f32, %[[VAL_33:.*]]: i8):
236-
// MXF32I8-QUANT: %[[VAL_34:.*]] = arith.fptosi %[[VAL_32]] : f32 to i8
237-
// MXF32I8-QUANT: linalg.yield %[[VAL_34]] : i8
238-
// MXF32I8-QUANT: } -> tensor<128x768xi8>
239-
// MXF32I8-QUANT: return %[[VAL_31]] : tensor<128x768xi8>
240-
// MXF32I8-QUANT: }
194+
// MXF32I8-QUANT: linalg.contract indexing_maps = [#map, #map1, #map2]
195+
// MXF32I8-QUANT: linalg.reduce {{.*}} dimensions = [0]
196+
// MXF32I8-QUANT: math.absf
197+
// MXF32I8-QUANT: arith.maximumf
198+
// MXF32I8-QUANT: linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel"]}
199+
// MXF32I8-QUANT: llvm.intr.frexp
200+
// MXF32I8-QUANT: llvm.extractvalue
201+
// MXF32I8-QUANT: arith.constant 7
202+
// MXF32I8-QUANT: arith.subi
203+
// MXF32I8-QUANT: arith.subi
204+
// MXF32I8-QUANT: arith.sitofp
205+
// MXF32I8-QUANT: math.exp2
206+
// MXF32I8-QUANT: linalg.broadcast
207+
// MXF32I8-QUANT: linalg.mul
208+
// MXF32I8-QUANT: linalg.generic
209+
// MXF32I8-QUANT: arith.fptosi

tools/mlir-gen/MLIRGen.cpp

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ MLIRGenerator::MLIRGenerator(StringRef outputOpKindStr, StringRef kernelStr,
184184
builder.setInsertionPoint(module);
185185
}
186186

187-
void MLIRGenerator::getKernelTypes(KernelArgs &args, bool isQuantKernel) {
187+
void MLIRGenerator::getKernelTypes(KernelArgs &args) {
188188
// Input type, also first layer's input
189189
TensorType currentType = getShape({batch, layers.front()}, PACK_INPUT);
190190

@@ -200,14 +200,13 @@ void MLIRGenerator::getKernelTypes(KernelArgs &args, bool isQuantKernel) {
200200
arg.index = i;
201201
arg.input.type = currentType;
202202
// Scale inputs are only needed for dequantization.
203-
if (isQuantKernel && quantType == QuantizationType::Dequant)
204-
arg.inputScale.type = getShape({batch}, INPUT_SCALE, isQuantKernel);
203+
if (quantType == QuantizationType::Dequant)
204+
arg.inputScale.type = getShape({batch}, INPUT_SCALE);
205205
arg.weight.type = getShape({inputSize, outputSize}, PACK_WEIGHT);
206-
if (isQuantKernel && quantType == QuantizationType::Dequant)
207-
arg.weightScale.type =
208-
getShape({outputSize}, WEIGHT_SCALE, isQuantKernel);
206+
if (quantType == QuantizationType::Dequant)
207+
arg.weightScale.type = getShape({outputSize}, WEIGHT_SCALE);
209208
arg.bias.type = getShape({outputSize}, PACK_OUTPUT);
210-
arg.output.type = getShape({batch, outputSize}, PACK_OUTPUT, isQuantKernel);
209+
arg.output.type = getShape({batch, outputSize}, PACK_OUTPUT);
211210
args.push_back(arg);
212211

213212
// Update next input type with the output type of this layer
@@ -255,15 +254,15 @@ Value MLIRGenerator::createLayer(LayerArgs &args, bool hasMixedType) {
255254
return chain;
256255
}
257256

258-
void MLIRGenerator::createKernel(bool hasMixedType, bool isQuantKernel) {
257+
void MLIRGenerator::createKernel(bool hasMixedType) {
259258
assert(((kernelType == KernelType::Const) ||
260259
(kernelType == KernelType::Args)) &&
261260
"Invalid kernel type");
262261
OpBuilder::InsertionGuard guard(builder);
263262

264263
// Get all kernel types first
265264
KernelArgs args;
266-
getKernelTypes(args, isQuantKernel);
265+
getKernelTypes(args);
267266
assert(args.size() > 0 && "Invalid model size");
268267
unsigned lastLayer = args.size() - 1;
269268
auto &firstArg = args[0];
@@ -275,11 +274,11 @@ void MLIRGenerator::createKernel(bool hasMixedType, bool isQuantKernel) {
275274
SmallVector<Type, 1> inputTypes{firstArg.input.type};
276275
if (kernelType == KernelType::Args) {
277276
for (auto &layer : args) {
278-
if (isQuantKernel && quantType == QuantizationType::Dequant)
277+
if (quantType == QuantizationType::Dequant)
279278
inputTypes.push_back(layer.inputScale.type);
280279

281280
inputTypes.push_back(layer.weight.type);
282-
if (isQuantKernel && quantType == QuantizationType::Dequant)
281+
if (quantType == QuantizationType::Dequant)
283282
inputTypes.push_back(layer.weightScale.type);
284283

285284
if (enableBias)
@@ -297,13 +296,12 @@ void MLIRGenerator::createKernel(bool hasMixedType, bool isQuantKernel) {
297296
// * Layer: input/weights/bias/output = args
298297
firstArg.input.value = func.getArgument(0);
299298
// Scales are only needed for dequantization
300-
if (isQuantKernel && quantType == QuantizationType::Dequant)
299+
if (quantType == QuantizationType::Dequant)
301300
firstArg.inputScale.value = func.getArgument(1);
302301

303302
// Argument position is input + N * { weight/bias } + output
304303
// First weight is at position 1, every two
305-
unsigned argPos =
306-
!(isQuantKernel && quantType == QuantizationType::Dequant) ? 1 : 2;
304+
unsigned argPos = !(quantType == QuantizationType::Dequant) ? 1 : 2;
307305
// Caches the output to chain into the next layer's input
308306
Value lastOutput;
309307
for (auto &arg : args) {
@@ -314,7 +312,7 @@ void MLIRGenerator::createKernel(bool hasMixedType, bool isQuantKernel) {
314312
// Initialize weights and biases
315313
if (kernelType == KernelType::Args) {
316314
arg.weight.value = func.getArgument(argPos++);
317-
if (isQuantKernel && quantType == QuantizationType::Dequant)
315+
if (quantType == QuantizationType::Dequant)
318316
arg.weightScale.value = func.getArgument(argPos++);
319317
if (enableBias)
320318
arg.bias.value = func.getArgument(argPos++);
@@ -341,7 +339,7 @@ void MLIRGenerator::createKernel(bool hasMixedType, bool isQuantKernel) {
341339
// Now pass the input through all layers.Separated the quantization layer
342340
// creation to simplify the design and reduce code complxity as there will
343341
// be more ways to introduce quantization ops in the future.
344-
if (isQuantKernel)
342+
if (quantType != QuantizationType::None)
345343
lastOutput = createQuantLayer(arg);
346344
else
347345
lastOutput = createLayer(arg, hasMixedType);
@@ -351,10 +349,9 @@ void MLIRGenerator::createKernel(bool hasMixedType, bool isQuantKernel) {
351349
builder.create<func::ReturnOp>(loc, lastArg.output.value);
352350
}
353351

354-
int MLIRGenerator::generate(StringRef filename, bool hasMixedType,
355-
bool isQuantKernel) {
352+
int MLIRGenerator::generate(StringRef filename, bool hasMixedType) {
356353
// First, populate the module with all functions
357-
createKernel(hasMixedType, isQuantKernel);
354+
createKernel(hasMixedType);
358355

359356
// Verify
360357
if (failed(module.verify())) {
@@ -1025,15 +1022,14 @@ Value MLIRGenerator::lowerSoftmax(Value input, Value output) {
10251022
return softmax;
10261023
}
10271024

1028-
TensorType MLIRGenerator::getShape(ArrayRef<int64_t> dims, PackingType type,
1029-
bool isQuantKernel) {
1025+
TensorType MLIRGenerator::getShape(ArrayRef<int64_t> dims, PackingType type) {
10301026
// Already packed type, just return ND tensor
10311027
if (dims.size() > 2)
10321028
return RankedTensorType::get(dims, type == PACK_OUTPUT ? dataTypes[1]
10331029
: dataTypes[0]);
10341030

10351031
if (!tiles.size()) {
1036-
if (isQuantKernel) {
1032+
if (quantType != QuantizationType::None) {
10371033
if (type == INPUT_SCALE || type == WEIGHT_SCALE) {
10381034
return RankedTensorType::get(dims, dataTypes[2]);
10391035
} else if (type == PACK_OUTPUT) {

tools/mlir-gen/MLIRGen.h

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,7 @@ class MLIRGenerator {
121121
};
122122

123123
/// Return shaped type (packed if requested)
124-
TensorType getShape(ArrayRef<int64_t>, PackingType,
125-
bool isQuantKernel = false);
124+
TensorType getShape(ArrayRef<int64_t>, PackingType);
126125

127126
/// Return a zero-init tensor for matmul outputs
128127
Value getZeroInitTensor(TensorType);
@@ -242,7 +241,7 @@ class MLIRGenerator {
242241

243242
/// Creates the kernel types from layer definitions and options. Boolean
244243
/// indicates if mixed type (quantization) is used.
245-
void getKernelTypes(KernelArgs &, bool isQuantKernel = false);
244+
void getKernelTypes(KernelArgs &);
246245

247246
/// Creates a layer function, to be called by the kernel. Boolean indicates
248247
/// if mixed type (quantization) is used.
@@ -254,7 +253,7 @@ class MLIRGenerator {
254253
/// Creates a kernel (N * {GEMM + AddBias + ReLU} + Softmax)
255254
/// AddBias, ReLU and Softmax are optional. Boolean indicates if mixed type
256255
/// (quantization) is used.
257-
void createKernel(bool hasMixedType = false, bool isQuantKernel = false);
256+
void createKernel(bool hasMixedType = false);
258257

259258
public:
260259
/// Creates a specific module. Different configurations need different modules
@@ -267,10 +266,8 @@ class MLIRGenerator {
267266

268267
/// Generates the whole IR and write to file
269268
/// Return 0 on success, 1 on failure. 'hasMixedType' indicates simple mixed
270-
/// type without quant. 'isQuantKernel' indicates a quantization kernel with
271-
/// quant/dequant ops
272-
int generate(StringRef filename, bool hasMixedType = false,
273-
bool isQuantKernel = false);
269+
/// type without quant.
270+
int generate(StringRef filename, bool hasMixedType = false);
274271
};
275272

276273
} // namespace mlir

0 commit comments

Comments
 (0)