@@ -1246,26 +1246,48 @@ Value decomposeConvT1dIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
1246
1246
RankedTensorType::get (outputShapeLevel1Concat, elementType);
1247
1247
1248
1248
// Below concats result will have the innermost dim as 2.
1249
- auto convOfmConcat =
1249
+ Value convOfmConcat =
1250
1250
rewriter.create <ONNXConcatOp>(loc, level1ConcatOutputType,
1251
1251
ValueRange{reshapeOutputAddOneDimConv1, reshapeOutputAddOneDimConv2,
1252
1252
reshapeOutputAddOneDimConv3, reshapeOutputAddOneDimConv4,
1253
1253
reshapeOutputAddOneDimConv5},
1254
1254
-1 );
1255
- SmallVector<int64_t > outputShapeForResult (paddedConvOutputShapeValue);
1256
- auto dimValueAtLastIndex =
1257
- paddedConvOutputShapeValue[paddedConvOutputShapeValue.size () - 1 ] * 5 ;
1258
- outputShapeForResult[outputShapeForResult.size () - 1 ] = dimValueAtLastIndex;
1255
+ // Making the dim2 of concat even by padding one at the end.
1256
+ bool isPaddedToMakeEven = false ;
1257
+ if (outputShapeLevel1Concat[2 ] % 2 != 0 ) {
1258
+ SmallVector<int64_t > outputShapePadToEven (outputShapeLevel1Concat);
1259
+ outputShapePadToEven[2 ] = outputShapePadToEven[2 ] + 1 ;
1260
+ auto padToEvenOutputShapedType =
1261
+ RankedTensorType::get (outputShapePadToEven, elementType);
1262
+
1263
+ std::array<int64_t , 8 > padValueToEven = {0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 };
1264
+
1265
+ auto onnxPadsToEvenValueConstant =
1266
+ getONNXConstOpFromVector (rewriter, loc, padValueToEven);
1267
+
1268
+ convOfmConcat = rewriter.create <ONNXPadOp>(loc, padToEvenOutputShapedType,
1269
+ convOfmConcat, onnxPadsToEvenValueConstant, onnxPaddingConstantZero,
1270
+ onnxAxisValueConstantNone, rewriter.getStringAttr (" constant" ));
1271
+ isPaddedToMakeEven = true ;
1272
+ }
1273
+ // This is the shape of the five conv merge. Using [2] as this
1274
+ // is convtranspose 1D.
1275
+ SmallVector<int64_t > reshapeOutputShape (paddedConvOutputShapeValue);
1276
+ reshapeOutputShape[2 ] =
1277
+ (isPaddedToMakeEven ? (paddedConvOutputShapeValue[2 ] + 1 )
1278
+ : paddedConvOutputShapeValue[2 ]) *
1279
+ 5 ;
1280
+ ;
1259
1281
1260
1282
auto onnxConstForLastReshape =
1261
- getONNXConstOpFromVector (rewriter, loc, outputShapeForResult );
1283
+ getONNXConstOpFromVector (rewriter, loc, reshapeOutputShape );
1262
1284
1263
- auto outputTypeBeforeSlice =
1264
- RankedTensorType::get (outputShapeForResult , elementType);
1285
+ auto reshapeResultType =
1286
+ RankedTensorType::get (reshapeOutputShape , elementType);
1265
1287
// Result is reshaped back to match the original convtranspose output
1266
1288
// dimensions
1267
- auto outputBeforeSlice = rewriter.create <ONNXReshapeOp>(
1268
- loc, outputTypeBeforeSlice , convOfmConcat, onnxConstForLastReshape);
1289
+ auto reshapeOutput = rewriter.create <ONNXReshapeOp>(
1290
+ loc, reshapeResultType , convOfmConcat, onnxConstForLastReshape);
1269
1291
1270
1292
SmallVector<int64_t > finalSliceOutputShape (convTransposeOutputShape);
1271
1293
auto finalSliceOutputType = RankedTensorType::get (
@@ -1276,8 +1298,8 @@ Value decomposeConvT1dIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
1276
1298
{finalSliceOutputShape[finalSliceOutputShape.size () - 1 ]});
1277
1299
1278
1300
auto finalSlicedOutput = rewriter.create <ONNXSliceOp>(loc,
1279
- finalSliceOutputType, outputBeforeSlice , startOnnxConstant,
1280
- endOnnxConstant, axisOnnxConstant, stepOnnxConstant);
1301
+ finalSliceOutputType, reshapeOutput , startOnnxConstant, endOnnxConstant ,
1302
+ axisOnnxConstant, stepOnnxConstant);
1281
1303
1282
1304
return finalSlicedOutput;
1283
1305
}
@@ -1289,13 +1311,13 @@ Value decomposeConvT1dIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
1289
1311
// The phased convolutions are then merged to get the final output.
1290
1312
// The number of phases is determined by the strides of the convtranspose op.
1291
1313
// The num of phases = stride_x * stride_y.
1292
- // The phased convolutions are weights are created by slicing the weights of the
1293
- // convolution in the specified manner and output of convolutions are stiched
1294
- // together to get the final output. If the case where original weights cannot
1295
- // be sliced into conv weights directly, they are padded to make them compatible
1296
- // with the slicing. and subsequently the extra ofm generated by the padded
1297
- // weights are removed.
1298
- // Below shows the high level view of the decomposition.
1314
+ // The phased convolutions are weights are created by slicing the weights of
1315
+ // the convolution in the specified manner and output of convolutions are
1316
+ // stiched together to get the final output. If the case where original
1317
+ // weights cannot be sliced into conv weights directly, they are padded to
1318
+ // make them compatible with the slicing. and subsequently the extra ofm
1319
+ // generated by the padded weights are removed. Below shows the high level
1320
+ // view of the decomposition.
1299
1321
// clang-format off
1300
1322
/*
1301
1323
* +---------------+ +-----------+
@@ -1406,8 +1428,8 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
1406
1428
}
1407
1429
1408
1430
onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> create (rewriter, loc);
1409
- // If the convTranspose kernel is 3x3, then the weights needs to be padded to
1410
- // 4x4
1431
+ // If the convTranspose kernel is 3x3, then the weights needs to be padded
1432
+ // to 4x4
1411
1433
bool needWeightsPadding = (kernelShape[0 ] == 3 && stridesShape[0 ] == 2 );
1412
1434
if (needWeightsPadding) {
1413
1435
std::array<int64_t , 8 > weightsPadValue = {0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 };
@@ -1613,10 +1635,10 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
1613
1635
outputShapeLevel1Concat[outputShapeLevel1Concat.size () - 1 ] = 2 ;
1614
1636
auto level1ConcatOutputType =
1615
1637
RankedTensorType::get (outputShapeLevel1Concat, elementType);
1616
- // for the case where convtranspose kernel is [4, 4] and with pads [1, 1, 1,
1617
- // 1] The phased convs output are to be concatenated in the reverse order.
1618
- // This is observed by looking at the phased conv outputs with respect to
1619
- // convtranspose output.
1638
+ // for the case where convtranspose kernel is [4, 4] and with pads [1, 1,
1639
+ // 1, 1 ] The phased convs output are to be concatenated in the reverse
1640
+ // order. This is observed by looking at the phased conv outputs with
1641
+ // respect to convtranspose output.
1620
1642
bool reverseConcatOrder = (needWeightsPadding || (kernelShape[0 ] == 4 ));
1621
1643
// Below concats result will have the innermost dim as 2.
1622
1644
auto firstConcat =
0 commit comments