Skip to content

Commit 7dd6719

Browse files
authored
Second version of linalg_ext.scan (#8135)
This patch modifies the linalg_ext.scan op to reuse the initial passed identity tensor to contain the most recent result of the scan operation. This is to enable a more composable transform during tiling. Furthermore, this op no longer accepts scalars and hence scalars must be expressed as rank-0 tensors or memrefs. TEST: Added tests in convert_to_loops.mlir, tiling.mlir and e2e tests in scan.mlir.
1 parent d7dbccc commit 7dd6719

File tree

5 files changed

+181
-87
lines changed

5 files changed

+181
-87
lines changed

iree/test/e2e/linalg_ext_ops/scan.mlir

Lines changed: 56 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,65 +2,80 @@ func @scan_1d_dim0_inclusive_sum() {
22
%input = util.unfoldable_constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]> : tensor<6xf32>
33

44
%init = linalg.init_tensor [6] : tensor<6xf32>
5-
%c0 = arith.constant 0.0 : f32
6-
%0 = iree_linalg_ext.scan
5+
%t0 = util.unfoldable_constant dense<0.0> : tensor<f32>
6+
%0:2 = iree_linalg_ext.scan
77
dimension(0) inclusive(true)
8-
ins(%input, %c0 : tensor<6xf32>, f32)
9-
outs(%init : tensor<6xf32>) {
8+
ins(%input : tensor<6xf32>)
9+
outs(%init, %t0 : tensor<6xf32>, tensor<f32>) {
1010
^bb0(%arg0 : f32, %arg1 : f32):
1111
%sum = arith.addf %arg0, %arg1 : f32
1212
iree_linalg_ext.yield %sum : f32
13-
} -> tensor<6xf32>
13+
} -> tensor<6xf32>, tensor<f32>
1414

1515
check.expect_almost_eq_const(
16-
%0,
16+
%0#0,
1717
dense<[1.0, 3.0, 6.0, 10.0, 15.0, 21.0]> : tensor<6xf32>
1818
) : tensor<6xf32>
1919

20+
check.expect_almost_eq_const(
21+
%0#1,
22+
dense<21.0> : tensor<f32>
23+
) : tensor<f32>
24+
2025
return
2126
}
2227

2328
func @scan_1d_dim0_exclusive_sum() {
2429
%input = util.unfoldable_constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]> : tensor<6xf32>
2530

2631
%init = linalg.init_tensor [6] : tensor<6xf32>
27-
%c0 = arith.constant 0.0 : f32
28-
%0 = iree_linalg_ext.scan
32+
%t0 = util.unfoldable_constant dense<10.0> : tensor<f32>
33+
%0:2 = iree_linalg_ext.scan
2934
dimension(0) inclusive(false)
30-
ins(%input, %c0 : tensor<6xf32>, f32)
31-
outs(%init : tensor<6xf32>) {
35+
ins(%input : tensor<6xf32>)
36+
outs(%init, %t0 : tensor<6xf32>, tensor<f32>) {
3237
^bb0(%arg0 : f32, %arg1 : f32):
3338
%sum = arith.addf %arg0, %arg1 : f32
3439
iree_linalg_ext.yield %sum : f32
35-
} -> tensor<6xf32>
40+
} -> tensor<6xf32>, tensor<f32>
3641

3742
check.expect_almost_eq_const(
38-
%0,
39-
dense<[0.0, 1.0, 3.0, 6.0, 10.0, 15.0]> : tensor<6xf32>
43+
%0#0,
44+
dense<[10.0, 11.0, 13.0, 16.0, 20.0, 25.0]> : tensor<6xf32>
4045
) : tensor<6xf32>
4146

47+
check.expect_almost_eq_const(
48+
%0#1,
49+
dense<25.0> : tensor<f32>
50+
) : tensor<f32>
51+
4252
return
4353
}
4454

4555
func @scan_1d_dim0_inclusive_mul() {
4656
%input = util.unfoldable_constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
4757

4858
%init = linalg.init_tensor [6] : tensor<6xi32>
49-
%c0 = arith.constant 1 : i32
50-
%0 = iree_linalg_ext.scan
59+
%t0 = util.unfoldable_constant dense<1> : tensor<i32>
60+
%0:2 = iree_linalg_ext.scan
5161
dimension(0) inclusive(true)
52-
ins(%input, %c0 : tensor<6xi32>, i32)
53-
outs(%init : tensor<6xi32>) {
62+
ins(%input : tensor<6xi32>)
63+
outs(%init, %t0 : tensor<6xi32>, tensor<i32>) {
5464
^bb0(%arg0 : i32, %arg1 : i32):
5565
%sum = arith.muli %arg0, %arg1 : i32
5666
iree_linalg_ext.yield %sum : i32
57-
} -> tensor<6xi32>
67+
} -> tensor<6xi32>, tensor<i32>
5868

5969
check.expect_eq_const(
60-
%0,
70+
%0#0,
6171
dense<[1, 2, 6, 24, 120, 720]> : tensor<6xi32>
6272
) : tensor<6xi32>
6373

74+
check.expect_eq_const(
75+
%0#1,
76+
dense<720> : tensor<i32>
77+
) : tensor<i32>
78+
6479
return
6580
}
6681

@@ -69,21 +84,26 @@ func @scan_2d_dim0_inclusive_sum() {
6984
[4, 5, 6]]> : tensor<2x3xi32>
7085

7186
%init = linalg.init_tensor [2, 3] : tensor<2x3xi32>
72-
%c0 = arith.constant 0 : i32
73-
%0 = iree_linalg_ext.scan
87+
%t0 = util.unfoldable_constant dense<[0, 0, 0]> : tensor<3xi32>
88+
%0:2 = iree_linalg_ext.scan
7489
dimension(0) inclusive(true)
75-
ins(%input, %c0 : tensor<2x3xi32>, i32)
76-
outs(%init : tensor<2x3xi32>) {
90+
ins(%input : tensor<2x3xi32>)
91+
outs(%init, %t0 : tensor<2x3xi32>, tensor<3xi32>) {
7792
^bb0(%arg0 : i32, %arg1 : i32):
7893
%sum = arith.addi %arg0, %arg1 : i32
7994
iree_linalg_ext.yield %sum : i32
80-
} -> tensor<2x3xi32>
95+
} -> tensor<2x3xi32>, tensor<3xi32>
8196

8297
check.expect_eq_const(
83-
%0,
98+
%0#0,
8499
dense<[[1, 2, 3], [5, 7, 9]]> : tensor<2x3xi32>
85100
) : tensor<2x3xi32>
86101

102+
check.expect_eq_const(
103+
%0#1,
104+
dense<[5, 7, 9]> : tensor<3xi32>
105+
) : tensor<3xi32>
106+
87107
return
88108
}
89109

@@ -92,20 +112,25 @@ func @scan_2d_dim1_inclusive_sum() {
92112
[4, 5, 6]]> : tensor<2x3xi32>
93113

94114
%init = linalg.init_tensor [2, 3] : tensor<2x3xi32>
95-
%c0 = arith.constant 0 : i32
96-
%0 = iree_linalg_ext.scan
115+
%t0 = util.unfoldable_constant dense<[0, 0]> : tensor<2xi32>
116+
%0:2 = iree_linalg_ext.scan
97117
dimension(1) inclusive(true)
98-
ins(%input, %c0 : tensor<2x3xi32>, i32)
99-
outs(%init : tensor<2x3xi32>) {
118+
ins(%input : tensor<2x3xi32>)
119+
outs(%init, %t0 : tensor<2x3xi32>, tensor<2xi32>) {
100120
^bb0(%arg0 : i32, %arg1 : i32):
101121
%sum = arith.addi %arg0, %arg1 : i32
102122
iree_linalg_ext.yield %sum : i32
103-
} -> tensor<2x3xi32>
123+
} -> tensor<2x3xi32>, tensor<2xi32>
104124

105125
check.expect_eq_const(
106-
%0,
126+
%0#0,
107127
dense<[[1, 3, 6], [4, 9, 15]]> : tensor<2x3xi32>
108128
) : tensor<2x3xi32>
109129

130+
check.expect_eq_const(
131+
%0#1,
132+
dense<[6, 15]> : tensor<2xi32>
133+
) : tensor<2xi32>
134+
110135
return
111136
}

llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def IREELinalgExt_ScanOp : IREELinalgExt_Op<"scan",
256256
Computes the inclusive/exclusive scan along a given dimension.
257257
}];
258258

259-
let arguments = (ins Variadic<AnyType>:$inputs,
259+
let arguments = (ins Variadic<AnyShaped>:$inputs,
260260
Variadic<AnyShaped>:$outputs,
261261
I64Attr:$dimension,
262262
BoolAttr:$inclusive
@@ -269,21 +269,22 @@ def IREELinalgExt_ScanOp : IREELinalgExt_Op<"scan",
269269

270270
let results = (outs Variadic<AnyRankedTensor>:$results);
271271
let regions = (region AnyRegion:$region);
272+
let hasFolder = 1;
272273
let assemblyFormat = [{
273274
`dimension` `(` $dimension `)`
274275
`inclusive` `(` $inclusive `)`
275276
attr-dict
276277
`ins` `(` $inputs `:` type($inputs) `)`
277-
(`outs` `(` $outputs^ `:` type($outputs) `)`)?
278+
`outs` `(` $outputs `:` type($outputs) `)`
278279
$region (`->` type($results)^)?
279280
}];
280281

281282
let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
282283
Value input() {
283284
return getInputOperand(0)->get();
284285
}
285-
Value identity() {
286-
return getInputOperand(1)->get();
286+
Value accumulator() {
287+
return getOutputOperand(1)->get();
287288
}
288289
Value output() {
289290
return getOutputOperand(0)->get();

llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp

Lines changed: 75 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -772,33 +772,46 @@ Operation *FftOp::getTiledImplementation(OpBuilder &builder, ValueRange outputs,
772772
//===----------------------------------------------------------------------===//
773773

774774
static LogicalResult verifyScanOp(ScanOp op) {
775-
if (op.getNumInputs() != 2) {
776-
return op.emitOpError("expected two input operands");
775+
if (op.getNumInputs() != 1) {
776+
return op.emitOpError("expected one input operands");
777777
}
778-
if (op.getNumOutputs() != 1) {
779-
return op.emitOpError("expected one output operand");
778+
if (op.getNumOutputs() != 2) {
779+
return op.emitOpError("expected two output operands");
780780
}
781781
if (!op.input().getType().isa<ShapedType>()) {
782782
return op.emitOpError("expected first input element type to be shaped");
783783
}
784-
auto identityElementType = op.identity().getType();
785-
if (!(identityElementType.isa<FloatType>() ||
786-
identityElementType.isa<IntegerType>())) {
787-
return op.emitOpError(
788-
"expected second input element type to be float or integer");
789-
}
784+
auto accumulatorType = op.accumulator().getType().cast<ShapedType>();
790785
auto inputType = op.input().getType().cast<ShapedType>();
791786
auto outputType = op.output().getType().cast<ShapedType>();
792-
if (identityElementType != inputType.getElementType()) {
787+
ArrayRef<int64_t> inputShapes = inputType.getShape();
788+
ArrayRef<int64_t> outputShapes = outputType.getShape();
789+
if (accumulatorType.getElementType() != inputType.getElementType()) {
793790
return op.emitOpError(
794-
"expected input/identity element types to be identical");
791+
"expected input/accumulator element types to be identical");
792+
}
793+
ArrayRef<int64_t> accumulatorShape = accumulatorType.getShape();
794+
int64_t accumulatorRank = accumulatorType.getRank();
795+
if (accumulatorRank != inputType.getRank() - 1) {
796+
return op.emitOpError(
797+
"expected accumulator rank to be equal to input rank - 1");
798+
}
799+
SmallVector<int64_t> expectedAccumulatorShape;
800+
for (int i = 0; i < inputType.getRank(); i++) {
801+
if (i != op.dimension()) expectedAccumulatorShape.push_back(inputShapes[i]);
802+
}
803+
if (llvm::any_of(llvm::zip(expectedAccumulatorShape, accumulatorShape),
804+
[](std::tuple<int64_t, int64_t> s) {
805+
return std::get<0>(s) != ShapedType::kDynamicSize &&
806+
std::get<1>(s) != ShapedType::kDynamicSize &&
807+
std::get<0>(s) != std::get<1>(s);
808+
})) {
809+
return op.emitOpError("incompatible input/accumulator shapes");
795810
}
796811
if (inputType.getElementType() != outputType.getElementType()) {
797812
return op.emitOpError(
798813
"expected input/output element types to be identical");
799814
}
800-
ArrayRef<int64_t> inputShapes = inputType.getShape();
801-
ArrayRef<int64_t> outputShapes = outputType.getShape();
802815
if (inputShapes.size() != outputShapes.size()) {
803816
return op.emitOpError("expected input/output to have identical ranks");
804817
}
@@ -862,14 +875,20 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
862875
auto cond = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
863876
indices[scanDim], zero);
864877
bool isInclusive = inclusive();
878+
SmallVector<Value> accIndices;
879+
for (int i = 0; i < indices.size(); i++) {
880+
if (i != scanDim) accIndices.push_back(indices[i]);
881+
}
882+
865883
auto scfIf = b.create<scf::IfOp>(
866884
loc, TypeRange{}, cond,
867885
[&](OpBuilder &b, Location loc) {
868886
if (isInclusive) {
869887
auto value = b.create<memref::LoadOp>(loc, input(), indices);
870888
b.create<memref::StoreOp>(loc, value, output(), indices);
871889
} else {
872-
b.create<memref::StoreOp>(loc, identity(), output(), indices);
890+
auto value = b.create<memref::LoadOp>(loc, accumulator(), accIndices);
891+
b.create<memref::StoreOp>(loc, value, output(), indices);
873892
}
874893
b.create<scf::YieldOp>(loc);
875894
},
@@ -902,6 +921,9 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
902921
b.create<memref::StoreOp>(
903922
loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)),
904923
output(), indices);
924+
b.create<memref::StoreOp>(
925+
loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)),
926+
accumulator(), accIndices);
905927
b.create<scf::YieldOp>(loc);
906928
}
907929
return success();
@@ -922,25 +944,61 @@ Operation *ScanOp::getTiledImplementation(OpBuilder &builder,
922944
SmallVector<Value> tiledOperands;
923945
tiledOperands.emplace_back(
924946
getSlice(builder, getLoc(), input(), offsets, sizes, strides));
925-
tiledOperands.emplace_back(identity());
926947
tiledOperands.emplace_back(
927-
getSlice(builder, getLoc(), output(), offsets, sizes, strides));
948+
getSlice(builder, getLoc(), outputs[0], offsets, sizes, strides));
949+
SmallVector<OpFoldResult> accumOffsets, accumSizes, accumStrides;
950+
if (rank > 1) {
951+
for (int i = 0; i < rank; i++) {
952+
if (i != dimension()) {
953+
accumOffsets.push_back(offsets[i]);
954+
accumSizes.push_back(sizes[i]);
955+
accumStrides.push_back(strides[i]);
956+
}
957+
}
958+
tiledOperands.emplace_back(getSlice(
959+
builder, getLoc(), outputs[1], accumOffsets, accumSizes, accumStrides));
960+
} else {
961+
tiledOperands.emplace_back(outputs[1]);
962+
}
928963

929964
SmallVector<Type, 4> resultTypes;
930965
if (hasTensorSemantics()) {
966+
resultTypes.push_back(tiledOperands[1].getType());
931967
resultTypes.push_back(tiledOperands[2].getType());
932968
}
933969

934970
Operation *tiledScanOp = cast<LinalgExtOp>(getOperation())
935971
.clone(builder, loc, resultTypes, tiledOperands);
936972
for (auto result : llvm::enumerate(tiledScanOp->getResults())) {
973+
if ((result.index() == resultTypes.size() - 1) && (rank > 1)) {
974+
offsets = accumOffsets;
975+
sizes = accumSizes;
976+
strides = accumStrides;
977+
}
937978
auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
938979
loc, result.value(), outputs[result.index()], offsets, sizes, strides);
939980
results.push_back(insertSliceOp.getResult());
940981
}
941982
return tiledScanOp;
942983
}
943984

985+
static LogicalResult foldMemRefCast(Operation *op) {
986+
bool folded = false;
987+
for (OpOperand &operand : op->getOpOperands()) {
988+
auto castOp = operand.get().getDefiningOp<memref::CastOp>();
989+
if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
990+
operand.set(castOp.getOperand());
991+
folded = true;
992+
}
993+
}
994+
return success(folded);
995+
}
996+
997+
LogicalResult ScanOp::fold(ArrayRef<Attribute>,
998+
SmallVectorImpl<OpFoldResult> &) {
999+
return foldMemRefCast(*this);
1000+
}
1001+
9441002
//===----------------------------------------------------------------------===//
9451003
// ReverseOp
9461004
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)