Skip to content

Commit

Permalink
Second version of linalg_ext.scan (#8135)
Browse files Browse the repository at this point in the history
This patch modifies the linalg_ext.scan op to
reuse the initial passed identity tensor to contain
the most recent result of the scan operation. This is
to enable a more composable transform during tiling.

Furthermore, this op no longer accepts scalars and
hence scalars must be expressed as rank-0 tensors
or memrefs.

TEST: Added tests in convert_to_loops.mlir, tiling.mlir
and e2e tests in scan.mlir.
  • Loading branch information
harsh-nod authored Jan 21, 2022
1 parent d7dbccc commit 7dd6719
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 87 deletions.
87 changes: 56 additions & 31 deletions iree/test/e2e/linalg_ext_ops/scan.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,65 +2,80 @@ func @scan_1d_dim0_inclusive_sum() {
%input = util.unfoldable_constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]> : tensor<6xf32>

%init = linalg.init_tensor [6] : tensor<6xf32>
%c0 = arith.constant 0.0 : f32
%0 = iree_linalg_ext.scan
%t0 = util.unfoldable_constant dense<0.0> : tensor<f32>
%0:2 = iree_linalg_ext.scan
dimension(0) inclusive(true)
ins(%input, %c0 : tensor<6xf32>, f32)
outs(%init : tensor<6xf32>) {
ins(%input : tensor<6xf32>)
outs(%init, %t0 : tensor<6xf32>, tensor<f32>) {
^bb0(%arg0 : f32, %arg1 : f32):
%sum = arith.addf %arg0, %arg1 : f32
iree_linalg_ext.yield %sum : f32
} -> tensor<6xf32>
} -> tensor<6xf32>, tensor<f32>

check.expect_almost_eq_const(
%0,
%0#0,
dense<[1.0, 3.0, 6.0, 10.0, 15.0, 21.0]> : tensor<6xf32>
) : tensor<6xf32>

check.expect_almost_eq_const(
%0#1,
dense<21.0> : tensor<f32>
) : tensor<f32>

return
}

func @scan_1d_dim0_exclusive_sum() {
%input = util.unfoldable_constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]> : tensor<6xf32>

%init = linalg.init_tensor [6] : tensor<6xf32>
%c0 = arith.constant 0.0 : f32
%0 = iree_linalg_ext.scan
%t0 = util.unfoldable_constant dense<10.0> : tensor<f32>
%0:2 = iree_linalg_ext.scan
dimension(0) inclusive(false)
ins(%input, %c0 : tensor<6xf32>, f32)
outs(%init : tensor<6xf32>) {
ins(%input : tensor<6xf32>)
outs(%init, %t0 : tensor<6xf32>, tensor<f32>) {
^bb0(%arg0 : f32, %arg1 : f32):
%sum = arith.addf %arg0, %arg1 : f32
iree_linalg_ext.yield %sum : f32
} -> tensor<6xf32>
} -> tensor<6xf32>, tensor<f32>

check.expect_almost_eq_const(
%0,
dense<[0.0, 1.0, 3.0, 6.0, 10.0, 15.0]> : tensor<6xf32>
%0#0,
dense<[10.0, 11.0, 13.0, 16.0, 20.0, 25.0]> : tensor<6xf32>
) : tensor<6xf32>

check.expect_almost_eq_const(
%0#1,
dense<25.0> : tensor<f32>
) : tensor<f32>

return
}

func @scan_1d_dim0_inclusive_mul() {
%input = util.unfoldable_constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>

%init = linalg.init_tensor [6] : tensor<6xi32>
%c0 = arith.constant 1 : i32
%0 = iree_linalg_ext.scan
%t0 = util.unfoldable_constant dense<1> : tensor<i32>
%0:2 = iree_linalg_ext.scan
dimension(0) inclusive(true)
ins(%input, %c0 : tensor<6xi32>, i32)
outs(%init : tensor<6xi32>) {
ins(%input : tensor<6xi32>)
outs(%init, %t0 : tensor<6xi32>, tensor<i32>) {
^bb0(%arg0 : i32, %arg1 : i32):
%sum = arith.muli %arg0, %arg1 : i32
iree_linalg_ext.yield %sum : i32
} -> tensor<6xi32>
} -> tensor<6xi32>, tensor<i32>

check.expect_eq_const(
%0,
%0#0,
dense<[1, 2, 6, 24, 120, 720]> : tensor<6xi32>
) : tensor<6xi32>

check.expect_eq_const(
%0#1,
dense<720> : tensor<i32>
) : tensor<i32>

return
}

Expand All @@ -69,21 +84,26 @@ func @scan_2d_dim0_inclusive_sum() {
[4, 5, 6]]> : tensor<2x3xi32>

%init = linalg.init_tensor [2, 3] : tensor<2x3xi32>
%c0 = arith.constant 0 : i32
%0 = iree_linalg_ext.scan
%t0 = util.unfoldable_constant dense<[0, 0, 0]> : tensor<3xi32>
%0:2 = iree_linalg_ext.scan
dimension(0) inclusive(true)
ins(%input, %c0 : tensor<2x3xi32>, i32)
outs(%init : tensor<2x3xi32>) {
ins(%input : tensor<2x3xi32>)
outs(%init, %t0 : tensor<2x3xi32>, tensor<3xi32>) {
^bb0(%arg0 : i32, %arg1 : i32):
%sum = arith.addi %arg0, %arg1 : i32
iree_linalg_ext.yield %sum : i32
} -> tensor<2x3xi32>
} -> tensor<2x3xi32>, tensor<3xi32>

check.expect_eq_const(
%0,
%0#0,
dense<[[1, 2, 3], [5, 7, 9]]> : tensor<2x3xi32>
) : tensor<2x3xi32>

check.expect_eq_const(
%0#1,
dense<[5, 7, 9]> : tensor<3xi32>
) : tensor<3xi32>

return
}

Expand All @@ -92,20 +112,25 @@ func @scan_2d_dim1_inclusive_sum() {
[4, 5, 6]]> : tensor<2x3xi32>

%init = linalg.init_tensor [2, 3] : tensor<2x3xi32>
%c0 = arith.constant 0 : i32
%0 = iree_linalg_ext.scan
%t0 = util.unfoldable_constant dense<[0, 0]> : tensor<2xi32>
%0:2 = iree_linalg_ext.scan
dimension(1) inclusive(true)
ins(%input, %c0 : tensor<2x3xi32>, i32)
outs(%init : tensor<2x3xi32>) {
ins(%input : tensor<2x3xi32>)
outs(%init, %t0 : tensor<2x3xi32>, tensor<2xi32>) {
^bb0(%arg0 : i32, %arg1 : i32):
%sum = arith.addi %arg0, %arg1 : i32
iree_linalg_ext.yield %sum : i32
} -> tensor<2x3xi32>
} -> tensor<2x3xi32>, tensor<2xi32>

check.expect_eq_const(
%0,
%0#0,
dense<[[1, 3, 6], [4, 9, 15]]> : tensor<2x3xi32>
) : tensor<2x3xi32>

check.expect_eq_const(
%0#1,
dense<[6, 15]> : tensor<2xi32>
) : tensor<2xi32>

return
}
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def IREELinalgExt_ScanOp : IREELinalgExt_Op<"scan",
Computes the inclusive/exclusive scan along a given dimension.
}];

let arguments = (ins Variadic<AnyType>:$inputs,
let arguments = (ins Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outputs,
I64Attr:$dimension,
BoolAttr:$inclusive
Expand All @@ -269,21 +269,22 @@ def IREELinalgExt_ScanOp : IREELinalgExt_Op<"scan",

let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region AnyRegion:$region);
let hasFolder = 1;
let assemblyFormat = [{
`dimension` `(` $dimension `)`
`inclusive` `(` $inclusive `)`
attr-dict
`ins` `(` $inputs `:` type($inputs) `)`
(`outs` `(` $outputs^ `:` type($outputs) `)`)?
`outs` `(` $outputs `:` type($outputs) `)`
$region (`->` type($results)^)?
}];

let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
Value input() {
return getInputOperand(0)->get();
}
Value identity() {
return getInputOperand(1)->get();
Value accumulator() {
return getOutputOperand(1)->get();
}
Value output() {
return getOutputOperand(0)->get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -772,33 +772,46 @@ Operation *FftOp::getTiledImplementation(OpBuilder &builder, ValueRange outputs,
//===----------------------------------------------------------------------===//

static LogicalResult verifyScanOp(ScanOp op) {
if (op.getNumInputs() != 2) {
return op.emitOpError("expected two input operands");
if (op.getNumInputs() != 1) {
return op.emitOpError("expected one input operands");
}
if (op.getNumOutputs() != 1) {
return op.emitOpError("expected one output operand");
if (op.getNumOutputs() != 2) {
return op.emitOpError("expected two output operands");
}
if (!op.input().getType().isa<ShapedType>()) {
return op.emitOpError("expected first input element type to be shaped");
}
auto identityElementType = op.identity().getType();
if (!(identityElementType.isa<FloatType>() ||
identityElementType.isa<IntegerType>())) {
return op.emitOpError(
"expected second input element type to be float or integer");
}
auto accumulatorType = op.accumulator().getType().cast<ShapedType>();
auto inputType = op.input().getType().cast<ShapedType>();
auto outputType = op.output().getType().cast<ShapedType>();
if (identityElementType != inputType.getElementType()) {
ArrayRef<int64_t> inputShapes = inputType.getShape();
ArrayRef<int64_t> outputShapes = outputType.getShape();
if (accumulatorType.getElementType() != inputType.getElementType()) {
return op.emitOpError(
"expected input/identity element types to be identical");
"expected input/accumulator element types to be identical");
}
ArrayRef<int64_t> accumulatorShape = accumulatorType.getShape();
int64_t accumulatorRank = accumulatorType.getRank();
if (accumulatorRank != inputType.getRank() - 1) {
return op.emitOpError(
"expected accumulator rank to be equal to input rank - 1");
}
SmallVector<int64_t> expectedAccumulatorShape;
for (int i = 0; i < inputType.getRank(); i++) {
if (i != op.dimension()) expectedAccumulatorShape.push_back(inputShapes[i]);
}
if (llvm::any_of(llvm::zip(expectedAccumulatorShape, accumulatorShape),
[](std::tuple<int64_t, int64_t> s) {
return std::get<0>(s) != ShapedType::kDynamicSize &&
std::get<1>(s) != ShapedType::kDynamicSize &&
std::get<0>(s) != std::get<1>(s);
})) {
return op.emitOpError("incompatible input/accumulator shapes");
}
if (inputType.getElementType() != outputType.getElementType()) {
return op.emitOpError(
"expected input/output element types to be identical");
}
ArrayRef<int64_t> inputShapes = inputType.getShape();
ArrayRef<int64_t> outputShapes = outputType.getShape();
if (inputShapes.size() != outputShapes.size()) {
return op.emitOpError("expected input/output to have identical ranks");
}
Expand Down Expand Up @@ -862,14 +875,20 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
auto cond = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
indices[scanDim], zero);
bool isInclusive = inclusive();
SmallVector<Value> accIndices;
for (int i = 0; i < indices.size(); i++) {
if (i != scanDim) accIndices.push_back(indices[i]);
}

auto scfIf = b.create<scf::IfOp>(
loc, TypeRange{}, cond,
[&](OpBuilder &b, Location loc) {
if (isInclusive) {
auto value = b.create<memref::LoadOp>(loc, input(), indices);
b.create<memref::StoreOp>(loc, value, output(), indices);
} else {
b.create<memref::StoreOp>(loc, identity(), output(), indices);
auto value = b.create<memref::LoadOp>(loc, accumulator(), accIndices);
b.create<memref::StoreOp>(loc, value, output(), indices);
}
b.create<scf::YieldOp>(loc);
},
Expand Down Expand Up @@ -902,6 +921,9 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
b.create<memref::StoreOp>(
loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)),
output(), indices);
b.create<memref::StoreOp>(
loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)),
accumulator(), accIndices);
b.create<scf::YieldOp>(loc);
}
return success();
Expand All @@ -922,25 +944,61 @@ Operation *ScanOp::getTiledImplementation(OpBuilder &builder,
SmallVector<Value> tiledOperands;
tiledOperands.emplace_back(
getSlice(builder, getLoc(), input(), offsets, sizes, strides));
tiledOperands.emplace_back(identity());
tiledOperands.emplace_back(
getSlice(builder, getLoc(), output(), offsets, sizes, strides));
getSlice(builder, getLoc(), outputs[0], offsets, sizes, strides));
SmallVector<OpFoldResult> accumOffsets, accumSizes, accumStrides;
if (rank > 1) {
for (int i = 0; i < rank; i++) {
if (i != dimension()) {
accumOffsets.push_back(offsets[i]);
accumSizes.push_back(sizes[i]);
accumStrides.push_back(strides[i]);
}
}
tiledOperands.emplace_back(getSlice(
builder, getLoc(), outputs[1], accumOffsets, accumSizes, accumStrides));
} else {
tiledOperands.emplace_back(outputs[1]);
}

SmallVector<Type, 4> resultTypes;
if (hasTensorSemantics()) {
resultTypes.push_back(tiledOperands[1].getType());
resultTypes.push_back(tiledOperands[2].getType());
}

Operation *tiledScanOp = cast<LinalgExtOp>(getOperation())
.clone(builder, loc, resultTypes, tiledOperands);
for (auto result : llvm::enumerate(tiledScanOp->getResults())) {
if ((result.index() == resultTypes.size() - 1) && (rank > 1)) {
offsets = accumOffsets;
sizes = accumSizes;
strides = accumStrides;
}
auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
loc, result.value(), outputs[result.index()], offsets, sizes, strides);
results.push_back(insertSliceOp.getResult());
}
return tiledScanOp;
}

static LogicalResult foldMemRefCast(Operation *op) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto castOp = operand.get().getDefiningOp<memref::CastOp>();
if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
operand.set(castOp.getOperand());
folded = true;
}
}
return success(folded);
}

LogicalResult ScanOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}

//===----------------------------------------------------------------------===//
// ReverseOp
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 7dd6719

Please sign in to comment.