From 7432caeafc24e18c188d20fbd6ed9c3a1e3ac64a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Sun, 16 Oct 2022 15:30:14 +0000 Subject: [PATCH 1/3] Add AccumulateOp. --- .../Dialect/Iterators/IR/IteratorsOps.td | 57 +++++++++++++++++++ .../lib/Dialect/Iterators/IR/Iterators.cpp | 38 +++++++++++++ .../test/Dialect/Iterators/accumulate.mlir | 29 ++++++++++ 3 files changed, 124 insertions(+) create mode 100644 experimental/iterators/test/Dialect/Iterators/accumulate.mlir diff --git a/experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td b/experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td index e8e47af68533..679ae94fd114 100644 --- a/experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td +++ b/experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td @@ -79,6 +79,63 @@ def Iterators_PrintOp : Iterators_Base_Op<"print", [ // High-level iterators //===----------------------------------------------------------------------===// +def Iterators_AccumulateOp : Iterators_Op<"accumulate", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "Accumulate the elements of a stream into one element"; + let description = [{ + Accumulate the elements of the input stream into a single element, i.e., + compute their generalized sum. This is similar to + [`std::accumulate`](https://en.cppreference.com/w/cpp/algorithm/accumulate) + in C++ and + [`functools.reduce`](https://docs.python.org/3/library/functools.html#functools.reduce) + with *initializer* in Python. The accumulator is initialized with thevalue + provided value (which must be of the same type as the elements of the result + stream); the logic of the accumulation is given by the provided accumulate + function. + + Pseudo-code: + ``` + accumulator = initVal + while (next = upstream->next()): + accumulator = @accumulateFuncRef(accumulator, next->value()) + return accumulator + + Example: + ```mlir + %input = ... + %zero_tuple = ... + %0 = iterators.accumulate(%input, %zero_tuple) with @sum + : (!iterators.stream) -> !iterators.stream> + ``` + }]; + let arguments = (ins + Iterators_Stream:$input, + AnyType:$initVal, + FlatSymbolRefAttr:$accumulateFuncRef + ); + let results = (outs Iterators_Stream:$result); + let assemblyFormat = [{ + `(` $input `,` $initVal `)` `with` $accumulateFuncRef attr-dict `:` + `(` qualified(type($input)) `)` `->` qualified(type($result)) + custom(type($initVal), ref(type($result))) + }]; + let extraClassDeclaration = [{ + /// Return the accumulate function op that the accumulateFuncRef refers to. + func::FuncOp getAccumulateFunc() { + return SymbolTable::lookupNearestSymbolFrom( + *this, getAccumulateFuncRefAttr()); + } + }]; + let extraClassDefinition = [{ + /// Implement OpAsmOpInterface. + void $cppClass::getAsmResultNames( + llvm::function_ref setNameFn) { + setNameFn(getResult(), "accumulated"); + } + }]; +} + def Iterators_ConstantStreamOp : Iterators_Op<"constantstream", [ PredOpTrait<"element type of return type must be tuple with matching types", CPred<[{ diff --git a/experimental/iterators/lib/Dialect/Iterators/IR/Iterators.cpp b/experimental/iterators/lib/Dialect/Iterators/IR/Iterators.cpp index 87a68d6881fa..c67b227a8945 100644 --- a/experimental/iterators/lib/Dialect/Iterators/IR/Iterators.cpp +++ b/experimental/iterators/lib/Dialect/Iterators/IR/Iterators.cpp @@ -62,6 +62,44 @@ void IteratorsDialect::initialize() { // Iterators operations //===----------------------------------------------------------------------===// +/// Implement SymbolUserOpInterface. +LogicalResult +AccumulateOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + MLIRContext *context = getContext(); + Type inputElementType = + getInput().getType().dyn_cast().getElementType(); + Type resultElementType = + getResult().getType().dyn_cast().getElementType(); + + // Verify accumulate function. + func::FuncOp accumulateFunc = getAccumulateFunc(); + if (!accumulateFunc) + return emitOpError() << "uses the symbol '" << getAccumulateFuncRef() + << "', which does not reference a valid function"; + + Type accumulateFuncType = FunctionType::get( + context, {resultElementType, inputElementType}, resultElementType); + if (accumulateFunc.getFunctionType() != accumulateFuncType) { + return emitOpError() << "uses the symbol '" << getAccumulateFuncRef() + << "', which does not refer to a function of type " + << accumulateFuncType; + } + + return success(); +} + +static ParseResult parseAccumulateInitValType(AsmParser & /*parser*/, + Type &initValType, + Type resultType) { + auto resultStreamType = resultType.cast(); + initValType = resultStreamType.getElementType(); + return success(); +} + +static void printAccumulateInitValType(AsmPrinter & /*printer*/, + Operation * /*op*/, Type /*initValType*/, + Type /*resultType*/) {} + static ParseResult parseInsertValueType(AsmParser & /*parser*/, Type &valueType, Type stateType, IntegerAttr indexAttr) { int64_t index = indexAttr.getValue().getSExtValue(); diff --git a/experimental/iterators/test/Dialect/Iterators/accumulate.mlir b/experimental/iterators/test/Dialect/Iterators/accumulate.mlir new file mode 100644 index 000000000000..a13855cde3db --- /dev/null +++ b/experimental/iterators/test/Dialect/Iterators/accumulate.mlir @@ -0,0 +1,29 @@ +// RUN: iterators-opt %s \ +// RUN: | FileCheck %s + +func.func private @accumulate_sum_tuple( + %acc : tuple, %val : tuple) -> tuple { + %acci = tuple.to_elements %acc : tuple + %vali = tuple.to_elements %val : tuple + %i = arith.addi %acci, %vali : i32 + %result = tuple.from_elements %i : tuple + return %result : tuple +} + +// CHECK-LABEL: func.func @main() { +func.func @main() { +// CHECK-NEXT: %[[V0:.*]] = "iterators.constantstream"{{.*}} + %input = "iterators.constantstream"() { value = [] } : + () -> (!iterators.stream>) + + // CHECK: %[[V1:.*]] = tuple.from_elements %{{.*}} : tuple + %hundred = arith.constant 100 : i32 + %init_value = tuple.from_elements %hundred : tuple + + // CHECK: %[[V2:accumulated.*]] = iterators.accumulate(%[[V0]], %[[V1]]) with @accumulate_sum_tuple : (!iterators.stream>) -> !iterators.stream> + %accumulated = iterators.accumulate(%input, %init_value) + with @accumulate_sum_tuple : + (!iterators.stream>) -> + !iterators.stream> + return +} From d49ed91f0bb8ec7f58a78037109b2d4d17e02ce1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Sun, 16 Oct 2022 16:45:15 +0000 Subject: [PATCH 2/3] Implement lowering for AccumulateOp. --- .../IteratorsToLLVM/IteratorAnalysis.cpp | 18 ++ .../IteratorsToLLVM/IteratorsToLLVM.cpp | 242 ++++++++++++++++++ .../IteratorsToLLVM/accumulate.mlir | 71 +++++ .../Dialect/Iterators/CPU/accumulate.mlir | 79 ++++++ 4 files changed, 410 insertions(+) create mode 100644 experimental/iterators/test/Conversion/IteratorsToLLVM/accumulate.mlir create mode 100644 experimental/iterators/test/Integration/Dialect/Iterators/CPU/accumulate.mlir diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp index acf620b37e10..604611fd3f7b 100644 --- a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp @@ -57,6 +57,23 @@ class StateTypeComputer { TypeConverter typeConverter; }; +/// The state of AccumulateOp consists of the state of its upstream iterator, +/// i.e., the state of the iterator that produces its input stream, the initial +/// value of the accumulator, and a Boolean indicating whether the iterator has +/// returned a result already (which is initialized to false and set to true in +/// the first call to next in order to ensure that only a single result is +/// returned). +template <> +StateType +StateTypeComputer::operator()(AccumulateOp op, + llvm::SmallVector upstreamStateTypes) { + MLIRContext *context = op->getContext(); + Type hasReturned = IntegerType::get(context, /*width=*/1); + Type initValType = op.getInitVal().getType(); + return StateType::get(context, + {upstreamStateTypes[0], initValType, hasReturned}); +} + /// The state of ConstantStreamOp consists of a single number that corresponds /// to the index of the next struct returned by the iterator. template <> @@ -180,6 +197,7 @@ mlir::iterators::IteratorAnalysis::IteratorAnalysis( // TODO: Verify that operands do not come from bbArgs. .Case< // clang-format off + AccumulateOp, ConstantStreamOp, FilterOp, MapOp, diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp index 0fef8b6694ea..f610318d4d04 100644 --- a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp @@ -308,6 +308,244 @@ struct PrintOpLowering : public OpConversionPattern { } }; +//===----------------------------------------------------------------------===// +// AccumulateOp. +//===----------------------------------------------------------------------===// + +/// Builds IR that opens the nested upstream iterator and sets `hasReturned` to +/// false. Possible output: +/// +/// %0 = iterators.extractvalue %arg0[0] : +/// -> !upstream_state +/// %1 = call @iterators.upstream.open.0(%0) : +/// (!upstream_state) -> !upstream_state +/// %2 = iterators.insertvalue %arg0[0] (%1 : !upstream_state) : +/// +/// %false = arith.constant false +/// %3 = iterators.insertvalue %false into %2[1] : +/// !iterators.state +static Value buildOpenBody(AccumulateOp op, OpBuilder &builder, + Value initialState, + ArrayRef upstreamInfos) { + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, builder); + + Type upstreamStateType = upstreamInfos[0].stateType; + + // Extract upstream state. + Value initialUpstreamState = b.create( + upstreamStateType, initialState, b.getIndexAttr(0)); + + // Call Open on upstream. + SymbolRefAttr openFunc = upstreamInfos[0].openFunc; + auto openCallOp = + b.create(openFunc, upstreamStateType, initialUpstreamState); + + // Update upstream state. + Value updatedUpstreamState = openCallOp->getResult(0); + Value updatedState = b.create( + initialState, b.getIndexAttr(0), updatedUpstreamState); + + // Reset hasReturned to false. + Value constFalse = b.create(/*value=*/0, /*width=*/1); + updatedState = b.create( + updatedState, b.getIndexAttr(2), constFalse); + + return updatedState; +} + +/// Builds IR that consumes all elements of the upstream iterator and combines +/// them into a single one using the given accumulate function. Pseudo-code: +/// +/// if hasReturned: return {} +/// hasReturned = True +/// accumulator = initVal +/// while (next = upstream->Next()): +/// accumulator = accumulate(accumulator, next) +/// return accumulator +/// +/// Possible output: +/// +/// %upstream_state = iterators.extractvalue %arg0[0] : !state_type +/// %init_val = iterators.extractvalue %arg0[1] : !state_type +/// %has_returned = iterators.extractvalue %arg0[2] : !state_type +/// %2:2 = scf.if %2 -> (!upstream_state, !element_type) { +/// scf.yield %upstream_state, %init_val : !upstream_state, !element_type +/// } else { +/// %5:3 = scf.while (%arg1 = %upsteram_state, %arg2 = %init_val) : +/// (!upstream_state, !element_type) -> +/// (!upstream_state, !element_type, !element_type) { +/// %6:3 = func.call @iterators.upstream.next.0(%arg1) : +/// (!upstream_state) -> (!upstream_state, i1, !element_type) +/// scf.condition(%6#1) %8#0, %arg2, %8#2 : +/// !upstream_state, !element_type, !element_type +//// } do { +/// ^bb0(%arg1: !upstream_state, %arg2: !element_type, %arg3: !element_type): +/// %8 = func.call @accumulate_func(%arg2, %arg3) : +/// (!element_type, !element_type) -> !element_type +/// scf.yield %arg1, %8 : !upstream_state, !element_type +/// } +/// scf.yield %7#0, %7#1 : !upstream_state, !element_type +/// } +/// %true = arith.constant true +/// %4 = arith.xori %true, %1 : i1 +/// %state_0 = iterators.insertvalue %3#0 into %arg0[0] : !state_type +/// %state_1 = iterators.insertvalue %true into %state_0[1] : !state_type +static llvm::SmallVector +buildNextBody(AccumulateOp op, OpBuilder &builder, Value initialState, + ArrayRef upstreamInfos, Type elementType) { + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, builder); + Type i1 = b.getI1Type(); + + // Extract input element type. + StreamType inputStreamType = op.getInput().getType().cast(); + Type inputElementType = inputStreamType.getElementType(); + + // Extract upstream state and init value. + Type upstreamStateType = upstreamInfos[0].stateType; + Value initialUpstreamState = b.create( + upstreamStateType, initialState, b.getIndexAttr(0)); + Value initValue = b.create( + elementType, initialState, b.getIndexAttr(1)); + + // Check if the iterator has returned an element already (since it should + // return one only in the first call to next). + Value hasReturned = + b.create(i1, initialState, b.getIndexAttr(2)); + SmallVector ifReturnTypes{upstreamStateType, elementType}; + auto ifOp = b.create( + hasReturned, + /*thenBuilder=*/ + [&](OpBuilder &builder, Location loc) { + ImplicitLocOpBuilder b(loc, builder); + + // Don't modify state; return init value. + b.create(ValueRange{initialUpstreamState, initValue}); + }, + /*elseBuilder=*/ + [&](OpBuilder &builder, Location loc) { + ImplicitLocOpBuilder b(loc, builder); + + // Create while loop using init value as initial accumulator. + SmallVector whileInputs = {initialUpstreamState, initValue}; + SmallVector whileResultTypes = { + upstreamStateType, // Updated upstream state. + elementType, // Accumulator. + inputElementType // Element from last next call. + }; + scf::WhileOp whileOp = b.create( + whileResultTypes, whileInputs, + /*beforeBuilder=*/ + [&](OpBuilder &builder, Location loc, ValueRange args) { + ImplicitLocOpBuilder b(loc, builder); + + Value upstreamState = args[0]; + Value accumulator = args[1]; + + // Call next function. + SmallVector nextResultTypes = {upstreamStateType, i1, + inputElementType}; + SymbolRefAttr nextFunc = upstreamInfos[0].nextFunc; + auto nextCall = b.create(nextFunc, nextResultTypes, + upstreamState); + + Value updatedUpstreamState = nextCall->getResult(0); + Value hasNext = nextCall->getResult(1); + Value maybeNextElement = nextCall->getResult(2); + b.create( + hasNext, ValueRange{updatedUpstreamState, accumulator, + maybeNextElement}); + }, + /*afterBuilder=*/ + [&](OpBuilder &builder, Location loc, ValueRange args) { + ImplicitLocOpBuilder b(loc, builder); + + Value upstreamState = args[0]; + Value accumulator = args[1]; + Value nextElement = args[2]; + + // Call accumulate function. + auto accumulateCall = + b.create(elementType, op.getAccumulateFuncRef(), + ValueRange{accumulator, nextElement}); + Value newAccumulator = accumulateCall->getResult(0); + + b.create(ValueRange{upstreamState, newAccumulator}); + }); + + Value updatedState = whileOp->getResult(0); + Value accumulator = whileOp->getResult(1); + + b.create(ValueRange{updatedState, accumulator}); + }); + + // Compute hasNext: we have an element iff we have not returned before, i.e., + // iff "not hasReturend". We simulate "not" with "xor true". + Value constTrue = b.create(/*value=*/1, /*width=*/1); + Value hasNext = b.create(constTrue, hasReturned); + + // Update state. + Value finalUpstreamState = ifOp->getResult(0); + Value finalState = b.create( + initialState, b.getIndexAttr(0), finalUpstreamState); // upstreamState + finalState = b.create(finalState, b.getIndexAttr(2), + constTrue); // hasReturned + Value nextElement = ifOp->getResult(1); + + return {finalState, hasNext, nextElement}; +} + +/// Builds IR that closes the nested upstream iterator. Possible output: +/// +/// %0 = iterators.extractvalue %arg0[0] : +/// !iterators.state -> !upstream_state +/// %1 = call @iterators.upstream.close.0(%0) : +/// (!upstream_state) -> !upstream_state +/// %2 = iterators.insertvalue %arg0[0] (%1 : !upstream_state) : +/// !iterators.state +static Value buildCloseBody(AccumulateOp op, OpBuilder &builder, + Value initialState, + ArrayRef upstreamInfos) { + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, builder); + + Type upstreamStateType = upstreamInfos[0].stateType; + + // Extract upstream state. + Value initialUpstreamState = b.create( + upstreamStateType, initialState, b.getIndexAttr(0)); + + // Call Close on upstream. + SymbolRefAttr closeFunc = upstreamInfos[0].closeFunc; + auto closeCallOp = b.create(closeFunc, upstreamStateType, + initialUpstreamState); + + // Update upstream state. + Value updatedUpstreamState = closeCallOp->getResult(0); + return b + .create(initialState, b.getIndexAttr(0), + updatedUpstreamState) + .getResult(); +} + +/// Builds IR that initializes the iterator state with the state of the upstream +/// iterator. Possible output: +/// +/// %0 = ... +/// %1 = arith.constant false +/// %2 = iterators.createstate(%0, %1) : !iterators.state +static Value buildStateCreation(AccumulateOp op, AccumulateOp::Adaptor adaptor, + OpBuilder &builder, StateType stateType) { + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, builder); + Value upstreamState = adaptor.getInput(); + Value initVal = adaptor.getInitVal(); + Value constFalse = b.create(/*value=*/0, /*width=*/1); + return b.create( + stateType, ValueRange{upstreamState, initVal, constFalse}); +} + //===----------------------------------------------------------------------===// // ConstantStreamOp. //===----------------------------------------------------------------------===// @@ -1543,6 +1781,7 @@ static Value buildOpenBody(Operation *op, OpBuilder &builder, return llvm::TypeSwitch(op) .Case< // clang-format off + AccumulateOp, ConstantStreamOp, FilterOp, MapOp, @@ -1563,6 +1802,7 @@ buildNextBody(Operation *op, OpBuilder &builder, Value initialState, return llvm::TypeSwitch>(op) .Case< // clang-format off + AccumulateOp, ConstantStreamOp, FilterOp, MapOp, @@ -1584,6 +1824,7 @@ static Value buildCloseBody(Operation *op, OpBuilder &builder, return llvm::TypeSwitch(op) .Case< // clang-format off + AccumulateOp, ConstantStreamOp, FilterOp, MapOp, @@ -1603,6 +1844,7 @@ static Value buildStateCreation(IteratorOpInterface op, OpBuilder &builder, return llvm::TypeSwitch(op) .Case< // clang-format off + AccumulateOp, ConstantStreamOp, FilterOp, MapOp, diff --git a/experimental/iterators/test/Conversion/IteratorsToLLVM/accumulate.mlir b/experimental/iterators/test/Conversion/IteratorsToLLVM/accumulate.mlir new file mode 100644 index 000000000000..2d9e8f6ae2de --- /dev/null +++ b/experimental/iterators/test/Conversion/IteratorsToLLVM/accumulate.mlir @@ -0,0 +1,71 @@ +// RUN: iterators-opt %s -convert-iterators-to-llvm \ +// RUN: | FileCheck --enable-var-scope %s + +func.func private @sum_tuple( + %acc : tuple, %val : tuple) -> tuple { + %acci = tuple.to_elements %acc : tuple + %vali = tuple.to_elements %val : tuple + %i = arith.addi %acci, %vali : i32 + %result = tuple.from_elements %i : tuple + return %result : tuple +} + +// CHECK-LABEL: func.func private @iterators.accumulate.close.{{[0-9]+}}( +// CHECK-SAME: %[[ARG0:.*]]: !iterators.state<[[upstreamStateType:!iterators.state<[^>]*>]], tuple, i1>) -> +// CHECK-SAME: !iterators.state<[[upstreamStateType]], tuple, i1> { +// CHECK-NEXT: %[[V0:.*]] = iterators.extractvalue %[[ARG0]][0] : !iterators.state<[[upstreamStateType]], tuple, i1> +// CHECK-NEXT: %[[V1:.*]] = call @iterators.{{.*}}.close.{{.*}}(%[[V0]]) : ([[upstreamStateType]]) -> [[upstreamStateType]] +// CHECK-NEXT: %[[V2:.*]] = iterators.insertvalue %[[V1]] into %[[ARG0]][0] : !iterators.state<[[upstreamStateType]], tuple, i1> +// CHECK-NEXT: return %[[V2]] : !iterators.state<[[upstreamStateType]], tuple, i1> + +// CHECK-LABEL: func.func private @iterators.accumulate.next.{{[0-9]+}}( +// CHECK-SAME: %[[ARG0:.*]]: !iterators.state<[[upstreamStateType:!iterators.state<[^>]*>]], tuple, i1>) -> +// CHECK-SAME: (!iterators.state<[[upstreamStateType]], tuple, i1>, i1, tuple) { +// CHECK-NEXT: %[[V1:.*]] = iterators.extractvalue %[[ARG0]][0] : !iterators.state<[[upstreamStateType]], tuple, i1> +// CHECK-NEXT: %[[V2:.*]] = iterators.extractvalue %[[ARG0]][1] : !iterators.state<[[upstreamStateType]], tuple, i1> +// CHECK-NEXT: %[[V3:.*]] = iterators.extractvalue %[[ARG0]][2] : !iterators.state<[[upstreamStateType]], tuple, i1> +// CHECK-NEXT: %[[V4:.*]]:2 = scf.if %[[V3]] -> ([[upstreamStateType]], tuple) { +// CHECK-NEXT: scf.yield %[[V1]], %[[V2]] : [[upstreamStateType]], tuple +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[V5:.*]]:3 = scf.while (%[[arg1:.*]] = %[[V1]], %[[arg2:.*]] = %[[V2]]) : ([[upstreamStateType]], tuple) -> ([[upstreamStateType]], tuple, tuple) { +// CHECK-NEXT: %[[V6:.*]]:3 = func.call @iterators.{{.*}}.next.{{.*}}(%[[arg1]]) : ([[upstreamStateType]]) -> ([[upstreamStateType]], i1, tuple) +// CHECK-NEXT: scf.condition(%[[V6]]#1) %[[V6]]#0, %[[arg2]], %[[V6]]#2 : [[upstreamStateType]], tuple, tuple +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[arg1:.*]]: [[upstreamStateType]], %[[arg2:.*]]: tuple, %[[arg3:.*]]: tuple): +// CHECK-NEXT: %[[V7:.*]] = func.call @sum_tuple(%[[arg2]], %[[arg3]]) : (tuple, tuple) -> tuple +// CHECK-NEXT: scf.yield %[[arg1]], %[[V7]] : [[upstreamStateType]], tuple +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[V5]]#0, %[[V5]]#1 : [[upstreamStateType]], tuple +// CHECK-NEXT: } +// CHECK-NEXT: %[[V8:.*]] = arith.constant true +// CHECK-NEXT: %[[V9:.*]] = arith.xori %[[V8]], %[[V3]] : i1 +// CHECK-NEXT: %[[Va:.*]] = iterators.insertvalue %[[V4]]#0 into %[[ARG0]][0] : !iterators.state<[[upstreamStateType]], tuple, i1> +// CHECK-NEXT: %[[Vb:.*]] = iterators.insertvalue %[[V8]] into %[[Va]][2] : !iterators.state<[[upstreamStateType]], tuple, i1> +// CHECK-NEXT: return %[[Vb]], %[[V9]], %[[V4]]#1 : !iterators.state<[[upstreamStateType]], tuple, i1>, i1, tuple + +// CHECK-LABEL: func.func private @iterators.accumulate.open.{{[0-9]+}}( +// CHECK-SAME: %[[ARG0:.*]]: !iterators.state<[[upstreamStateType:!iterators.state<[^>]*>]], tuple, i1>) -> +// CHECK-SAME: !iterators.state<[[upstreamStateType]], tuple, i1> { +// CHECK-NEXT: %[[V1:.*]] = iterators.extractvalue %[[ARG0]][0] : !iterators.state<[[upstreamStateType]], tuple, i1> +// CHECK-NEXT: %[[V2:.*]] = call @iterators.{{.*}}.open.{{.*}}(%[[V1]]) : ([[upstreamStateType]]) -> [[upstreamStateType]] +// CHECK-NEXT: %[[V3:.*]] = iterators.insertvalue %[[V2]] into %[[ARG0]][0] : !iterators.state<[[upstreamStateType]], tuple, i1> +// CHECK-NEXT: %[[V4:.*]] = arith.constant false +// CHECK-NEXT: %[[V5:.*]] = iterators.insertvalue %[[V4]] into %[[V3]][2] : !iterators.state<[[upstreamStateType]], tuple, i1> +// CHECK-NEXT: return %[[V5]] : !iterators.state<[[upstreamStateType]], tuple, i1> + +// CHECK-LABEL: func.func @main() +func.func @main() { + // CHECK-DAG: %[[V0:.*]] = iterators.createstate{{.*}} : [[upstreamStateType:!iterators.state<[^>]*>]] + %input = "iterators.constantstream"() { value = [] } : () -> (!iterators.stream>) + + // CHECK-DAG: %[[V1:.*]] = tuple.from_elements %{{.*}} : tuple + %hundred = arith.constant 0 : i32 + %init_value = tuple.from_elements %hundred : tuple + + // CHECK-DAG: %[[V2:.*]] = arith.constant false + // CHECK-NEXT: %[[V3:.*]] = iterators.createstate(%[[V0]], %[[V1]], %[[V2]]) : !iterators.state<[[upstreamStateType]], tuple, i1> + %accumulated = iterators.accumulate(%input, %init_value) with @sum_tuple + : (!iterators.stream>) -> !iterators.stream> + return + // CHECK-NEXT: return +} diff --git a/experimental/iterators/test/Integration/Dialect/Iterators/CPU/accumulate.mlir b/experimental/iterators/test/Integration/Dialect/Iterators/CPU/accumulate.mlir new file mode 100644 index 000000000000..7cf0fff5aec6 --- /dev/null +++ b/experimental/iterators/test/Integration/Dialect/Iterators/CPU/accumulate.mlir @@ -0,0 +1,79 @@ +// RUN: iterators-opt %s \ +// RUN: -convert-iterators-to-llvm \ +// RUN: -decompose-iterator-states \ +// RUN: -decompose-tuples \ +// RUN: -convert-func-to-llvm \ +// RUN: -convert-scf-to-cf -convert-cf-to-llvm \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: | FileCheck %s + +func.func private @accumulate_sum_tuple( + %acc : tuple, %val : tuple) -> tuple { + %acci = tuple.to_elements %acc : tuple + %vali = tuple.to_elements %val : tuple + %i = arith.addi %acci, %vali : i32 + %result = tuple.from_elements %i : tuple + return %result : tuple +} + +// CHECK-LABEL: test_accumulate_sum_tuple +// CHECK-NEXT: (160) +// CHECK-NEXT: - +func.func @test_accumulate_sum_tuple() { + iterators.print("test_accumulate_sum_tuple") + %input = "iterators.constantstream"() + { value = [[0 : i32], [10 : i32], [20 : i32], [30 : i32]] } + : () -> (!iterators.stream>) + %hundred = arith.constant 100 : i32 + %init_value = tuple.from_elements %hundred : tuple + %accumulated = iterators.accumulate(%input, %init_value) + with @accumulate_sum_tuple + : (!iterators.stream>) -> !iterators.stream> + "iterators.sink"(%accumulated) : (!iterators.stream>) -> () + return +} + +func.func private @accumulate_avg_tuple( + %acc : tuple, %val : tuple) -> tuple { + %cnt, %sum = tuple.to_elements %acc : tuple + %vali = tuple.to_elements %val : tuple + %one = arith.constant 1 : i32 + %new_cnt = arith.addi %cnt, %one : i32 + %new_sum = arith.addi %sum, %vali : i32 + %result = tuple.from_elements %new_cnt, %new_sum : tuple + return %result : tuple +} + +func.func private @avg(%input : tuple) -> tuple { + %cnt, %sum = tuple.to_elements %input : tuple + %cntf = arith.sitofp %cnt : i32 to f32 + %sumf = arith.sitofp %sum : i32 to f32 + %avg = arith.divf %sumf, %cntf : f32 + %result = tuple.from_elements %avg : tuple + return %result : tuple +} + +// CHECK-LABEL: test_accumulate_avg_tuple +// CHECK-NEXT: (15) +// CHECK-NEXT: - +func.func @test_accumulate_avg_tuple() { + iterators.print("test_accumulate_avg_tuple") + %input = "iterators.constantstream"() + { value = [[0 : i32], [10 : i32], [20 : i32], [30 : i32]] } + : () -> (!iterators.stream>) + %zero = arith.constant 0 : i32 + %init_value = tuple.from_elements %zero, %zero : tuple + %accumulated = iterators.accumulate(%input, %init_value) + with @accumulate_avg_tuple + : (!iterators.stream>) -> !iterators.stream> + %mapped = "iterators.map"(%accumulated) {mapFuncRef = @avg} + : (!iterators.stream>) -> (!iterators.stream>) + "iterators.sink"(%mapped) : (!iterators.stream>) -> () + return +} + +func.func @main() { + call @test_accumulate_sum_tuple() : () -> () + call @test_accumulate_avg_tuple() : () -> () + return +} From 17ed75cb1b15e2463da7f54e70c2101144e6616f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 17 Oct 2022 12:36:02 +0000 Subject: [PATCH 3/3] XXX: Implement histogram as showcase for AccumulateOp. --- .../Dialect/Iterators/CPU/accumulate.mlir | 64 ++++++++++++++++++- 1 file changed, 62 insertions(+), 2 deletions(-) diff --git a/experimental/iterators/test/Integration/Dialect/Iterators/CPU/accumulate.mlir b/experimental/iterators/test/Integration/Dialect/Iterators/CPU/accumulate.mlir index 7cf0fff5aec6..9a54ec02dcff 100644 --- a/experimental/iterators/test/Integration/Dialect/Iterators/CPU/accumulate.mlir +++ b/experimental/iterators/test/Integration/Dialect/Iterators/CPU/accumulate.mlir @@ -2,9 +2,20 @@ // RUN: -convert-iterators-to-llvm \ // RUN: -decompose-iterator-states \ // RUN: -decompose-tuples \ +// RUN: -inline -canonicalize \ +// RUN: -one-shot-bufferize="allow-return-allocs" \ +// RUN: -buffer-hoisting \ +// RUN: -buffer-deallocation \ +// RUN: -convert-bufferization-to-memref \ +// RUN: -expand-strided-metadata \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-scf-to-cf \ // RUN: -convert-func-to-llvm \ -// RUN: -convert-scf-to-cf -convert-cf-to-llvm \ -// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -canonicalize \ +// RUN: -convert-cf-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext \ // RUN: | FileCheck %s func.func private @accumulate_sum_tuple( @@ -72,8 +83,57 @@ func.func @test_accumulate_avg_tuple() { return } +func.func private @unpack_i32(%input : tuple) -> i32 { + %i = tuple.to_elements %input : tuple + return %i : i32 +} + +func.func private @accumulate_histogram( + %hist : tensor<4xi32>, %val : i32) -> tensor<4xi32> { + %idx = arith.index_cast %val : i32 to index + %oldCount = tensor.extract %hist[%idx] : tensor<4xi32> + %one = arith.constant 1 : i32 + %newCount = arith.addi %oldCount, %one : i32 + %newHist = tensor.insert %newCount into %hist[%idx] : tensor<4xi32> + return %newHist : tensor<4xi32> +} + +func.func private @tensor_to_struct(%input : tensor<4xi32>) -> tuple { + %idx0 = arith.constant 0 : index + %idx1 = arith.constant 1 : index + %idx2 = arith.constant 2 : index + %idx3 = arith.constant 3 : index + %i0 = tensor.extract %input[%idx0] : tensor<4xi32> + %i1 = tensor.extract %input[%idx1] : tensor<4xi32> + %i2 = tensor.extract %input[%idx2] : tensor<4xi32> + %i3 = tensor.extract %input[%idx3] : tensor<4xi32> + %tuple = tuple.from_elements %i0, %i1, %i2, %i3 : tuple + return %tuple : tuple +} + +// CHECK-LABEL: test_accumulate_histogram +// CHECK-NEXT: (1, 2, 1, 0) +// CHECK-NEXT: - +func.func @test_accumulate_histogram() { + iterators.print("test_accumulate_histogram") + %input = "iterators.constantstream"() + { value = [[0 : i32], [1 : i32], [1 : i32], [2 : i32]] } + : () -> (!iterators.stream>) + %unpacked = "iterators.map"(%input) {mapFuncRef = @unpack_i32} + : (!iterators.stream>) -> (!iterators.stream) + %init_value = arith.constant dense<[0, 0, 0, 0]> : tensor<4xi32> + %accumulated = iterators.accumulate(%unpacked, %init_value) + with @accumulate_histogram + : (!iterators.stream) -> !iterators.stream> + %transposed = "iterators.map"(%accumulated) {mapFuncRef = @tensor_to_struct} + : (!iterators.stream>) -> (!iterators.stream>) + "iterators.sink"(%transposed) : (!iterators.stream>) -> () + return +} + func.func @main() { call @test_accumulate_sum_tuple() : () -> () call @test_accumulate_avg_tuple() : () -> () + call @test_accumulate_histogram() : () -> () return }