Skip to content

Commit a46782b

Browse files
for stride 5, if concat of convs leading to odd value dim padding to make it even
1 parent 98226f7 commit a46782b

File tree

1 file changed

+47
-25
lines changed

1 file changed

+47
-25
lines changed

src/Dialect/ONNX/Transforms/Decompose.cpp

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,26 +1246,48 @@ Value decomposeConvT1dIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
12461246
RankedTensorType::get(outputShapeLevel1Concat, elementType);
12471247

12481248
// Below concats result will have the innermost dim as 2.
1249-
auto convOfmConcat =
1249+
Value convOfmConcat =
12501250
rewriter.create<ONNXConcatOp>(loc, level1ConcatOutputType,
12511251
ValueRange{reshapeOutputAddOneDimConv1, reshapeOutputAddOneDimConv2,
12521252
reshapeOutputAddOneDimConv3, reshapeOutputAddOneDimConv4,
12531253
reshapeOutputAddOneDimConv5},
12541254
-1);
1255-
SmallVector<int64_t> outputShapeForResult(paddedConvOutputShapeValue);
1256-
auto dimValueAtLastIndex =
1257-
paddedConvOutputShapeValue[paddedConvOutputShapeValue.size() - 1] * 5;
1258-
outputShapeForResult[outputShapeForResult.size() - 1] = dimValueAtLastIndex;
1255+
// Making the dim2 of concat even by padding one at the end.
1256+
bool isPaddedToMakeEven = false;
1257+
if (outputShapeLevel1Concat[2] % 2 != 0) {
1258+
SmallVector<int64_t> outputShapePadToEven(outputShapeLevel1Concat);
1259+
outputShapePadToEven[2] = outputShapePadToEven[2] + 1;
1260+
auto padToEvenOutputShapedType =
1261+
RankedTensorType::get(outputShapePadToEven, elementType);
1262+
1263+
std::array<int64_t, 8> padValueToEven = {0, 0, 0, 0, 0, 0, 1, 0};
1264+
1265+
auto onnxPadsToEvenValueConstant =
1266+
getONNXConstOpFromVector(rewriter, loc, padValueToEven);
1267+
1268+
convOfmConcat = rewriter.create<ONNXPadOp>(loc, padToEvenOutputShapedType,
1269+
convOfmConcat, onnxPadsToEvenValueConstant, onnxPaddingConstantZero,
1270+
onnxAxisValueConstantNone, rewriter.getStringAttr("constant"));
1271+
isPaddedToMakeEven = true;
1272+
}
1273+
// This is the shape of the five conv merge. Using [2] as this
1274+
// is convtranspose 1D.
1275+
SmallVector<int64_t> reshapeOutputShape(paddedConvOutputShapeValue);
1276+
reshapeOutputShape[2] =
1277+
(isPaddedToMakeEven ? (paddedConvOutputShapeValue[2] + 1)
1278+
: paddedConvOutputShapeValue[2]) *
1279+
5;
1280+
;
12591281

12601282
auto onnxConstForLastReshape =
1261-
getONNXConstOpFromVector(rewriter, loc, outputShapeForResult);
1283+
getONNXConstOpFromVector(rewriter, loc, reshapeOutputShape);
12621284

1263-
auto outputTypeBeforeSlice =
1264-
RankedTensorType::get(outputShapeForResult, elementType);
1285+
auto reshapeResultType =
1286+
RankedTensorType::get(reshapeOutputShape, elementType);
12651287
// Result is reshaped back to match the original convtranspose output
12661288
// dimensions
1267-
auto outputBeforeSlice = rewriter.create<ONNXReshapeOp>(
1268-
loc, outputTypeBeforeSlice, convOfmConcat, onnxConstForLastReshape);
1289+
auto reshapeOutput = rewriter.create<ONNXReshapeOp>(
1290+
loc, reshapeResultType, convOfmConcat, onnxConstForLastReshape);
12691291

12701292
SmallVector<int64_t> finalSliceOutputShape(convTransposeOutputShape);
12711293
auto finalSliceOutputType = RankedTensorType::get(
@@ -1276,8 +1298,8 @@ Value decomposeConvT1dIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
12761298
{finalSliceOutputShape[finalSliceOutputShape.size() - 1]});
12771299

12781300
auto finalSlicedOutput = rewriter.create<ONNXSliceOp>(loc,
1279-
finalSliceOutputType, outputBeforeSlice, startOnnxConstant,
1280-
endOnnxConstant, axisOnnxConstant, stepOnnxConstant);
1301+
finalSliceOutputType, reshapeOutput, startOnnxConstant, endOnnxConstant,
1302+
axisOnnxConstant, stepOnnxConstant);
12811303

12821304
return finalSlicedOutput;
12831305
}
@@ -1289,13 +1311,13 @@ Value decomposeConvT1dIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
12891311
// The phased convolutions are then merged to get the final output.
12901312
// The number of phases is determined by the strides of the convtranspose op.
12911313
// The num of phases = stride_x * stride_y.
1292-
// The phased convolutions are weights are created by slicing the weights of the
1293-
// convolution in the specified manner and output of convolutions are stiched
1294-
// together to get the final output. If the case where original weights cannot
1295-
// be sliced into conv weights directly, they are padded to make them compatible
1296-
// with the slicing. and subsequently the extra ofm generated by the padded
1297-
// weights are removed.
1298-
// Below shows the high level view of the decomposition.
1314+
// The phased convolutions are weights are created by slicing the weights of
1315+
// the convolution in the specified manner and output of convolutions are
1316+
// stiched together to get the final output. If the case where original
1317+
// weights cannot be sliced into conv weights directly, they are padded to
1318+
// make them compatible with the slicing. and subsequently the extra ofm
1319+
// generated by the padded weights are removed. Below shows the high level
1320+
// view of the decomposition.
12991321
// clang-format off
13001322
/*
13011323
* +---------------+ +-----------+
@@ -1406,8 +1428,8 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
14061428
}
14071429

14081430
onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> create(rewriter, loc);
1409-
// If the convTranspose kernel is 3x3, then the weights needs to be padded to
1410-
// 4x4
1431+
// If the convTranspose kernel is 3x3, then the weights needs to be padded
1432+
// to 4x4
14111433
bool needWeightsPadding = (kernelShape[0] == 3 && stridesShape[0] == 2);
14121434
if (needWeightsPadding) {
14131435
std::array<int64_t, 8> weightsPadValue = {0, 0, 0, 0, 0, 0, 0, 0};
@@ -1613,10 +1635,10 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
16131635
outputShapeLevel1Concat[outputShapeLevel1Concat.size() - 1] = 2;
16141636
auto level1ConcatOutputType =
16151637
RankedTensorType::get(outputShapeLevel1Concat, elementType);
1616-
// for the case where convtranspose kernel is [4, 4] and with pads [1, 1, 1,
1617-
// 1] The phased convs output are to be concatenated in the reverse order.
1618-
// This is observed by looking at the phased conv outputs with respect to
1619-
// convtranspose output.
1638+
// for the case where convtranspose kernel is [4, 4] and with pads [1, 1,
1639+
// 1, 1] The phased convs output are to be concatenated in the reverse
1640+
// order. This is observed by looking at the phased conv outputs with
1641+
// respect to convtranspose output.
16201642
bool reverseConcatOrder = (needWeightsPadding || (kernelShape[0] == 4));
16211643
// Below concats result will have the innermost dim as 2.
16221644
auto firstConcat =

0 commit comments

Comments
 (0)