Skip to content

[flang][fir] Add FIR structured control flow ops to SCF dialect pass. #140374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

NexMing
Copy link
Contributor

@NexMing NexMing commented May 17, 2025

This patch only supports the conversion from fir.do_loop to scf.for. The current pass is still under development, and future work will focus on gradually improving this conversion pass.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels May 17, 2025
@llvmbot
Copy link
Member

llvmbot commented May 17, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: MingYan (NexMing)

Changes

Convert FIR structured control flow ops to SCF dialect.


Full diff: https://github.com/llvm/llvm-project/pull/140374.diff

6 Files Affected:

  • (modified) flang/include/flang/Optimizer/Support/InitFIR.h (+2)
  • (modified) flang/include/flang/Optimizer/Transforms/Passes.h (+1)
  • (modified) flang/include/flang/Optimizer/Transforms/Passes.td (+11)
  • (modified) flang/lib/Optimizer/Transforms/CMakeLists.txt (+1)
  • (added) flang/lib/Optimizer/Transforms/FIRToSCF.cpp (+103)
  • (added) flang/test/Fir/FirToSCF/do-loop.fir (+147)
diff --git a/flang/include/flang/Optimizer/Support/InitFIR.h b/flang/include/flang/Optimizer/Support/InitFIR.h
index 1868fbb201970..fa7c430ed631c 100644
--- a/flang/include/flang/Optimizer/Support/InitFIR.h
+++ b/flang/include/flang/Optimizer/Support/InitFIR.h
@@ -30,6 +30,7 @@
 #include "mlir/Pass/PassRegistry.h"
 #include "mlir/Transforms/LocationSnapshot.h"
 #include "mlir/Transforms/Passes.h"
+#include <mlir/Dialect/SCF/Transforms/Passes.h>
 
 namespace fir::support {
 
@@ -103,6 +104,7 @@ inline void registerMLIRPassesForFortranTools() {
   mlir::registerPrintOpStatsPass();
   mlir::registerInlinerPass();
   mlir::registerSCCPPass();
+  mlir::registerSCFPasses();
   mlir::affine::registerAffineScalarReplacementPass();
   mlir::registerSymbolDCEPass();
   mlir::registerLocationSnapshotPass();
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index 6dbabd523f88a..dc8a5b9141ad2 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -72,6 +72,7 @@ std::unique_ptr<mlir::Pass>
 createArrayValueCopyPass(fir::ArrayValueCopyOptions options = {});
 std::unique_ptr<mlir::Pass> createMemDataFlowOptPass();
 std::unique_ptr<mlir::Pass> createPromoteToAffinePass();
+std::unique_ptr<mlir::Pass> createFIRToSCFPass();
 std::unique_ptr<mlir::Pass>
 createAddDebugInfoPass(fir::AddDebugInfoOptions options = {});
 
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index c0d88a8e19f80..da3d9bc751927 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -76,6 +76,17 @@ def AffineDialectDemotion : Pass<"demote-affine", "::mlir::func::FuncOp"> {
   ];
 }
 
+def FIRToSCFPass : Pass<"fir-to-scf"> {
+  let summary = "Convert FIR structured control flow ops to SCF dialect.";
+  let description = [{
+    Convert FIR structured control flow ops to SCF dialect.
+  }];
+  let constructor = "::fir::createFIRToSCFPass()";
+  let dependentDialects = [
+    "fir::FIROpsDialect", "mlir::scf::SCFDialect"
+  ];
+}
+
 def AnnotateConstantOperands : Pass<"annotate-constant"> {
   let summary = "Annotate constant operands to all FIR operations";
   let description = [{
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 170b6e2cca225..846d6c64dbd04 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ add_flang_library(FIRTransforms
   CUFComputeSharedMemoryOffsetsAndSize.cpp
   ArrayValueCopy.cpp
   ExternalNameConversion.cpp
+  FIRToSCF.cpp
   MemoryUtils.cpp
   MemoryAllocation.cpp
   StackArrays.cpp
diff --git a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
new file mode 100644
index 0000000000000..02810f1bdba4e
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
@@ -0,0 +1,103 @@
+//===-- FIRToSCF.cpp ------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace fir {
+#define GEN_PASS_DEF_FIRTOSCFPASS
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace fir;
+using namespace mlir;
+
+namespace {
+class FIRToSCFPass : public fir::impl::FIRToSCFPassBase<FIRToSCFPass> {
+public:
+  void runOnOperation() override;
+};
+} // namespace
+
+struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
+  using OpRewritePattern<fir::DoLoopOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(fir::DoLoopOp doLoopOp,
+                                PatternRewriter &rewriter) const override {
+    auto loc = doLoopOp.getLoc();
+    bool hasFinalValue = doLoopOp.getFinalValue().has_value();
+
+    // Get loop values from the DoLoopOp
+    auto low = doLoopOp.getLowerBound();
+    auto high = doLoopOp.getUpperBound();
+    assert(low && high && "must be a Value");
+    auto step = doLoopOp.getStep();
+    llvm::SmallVector<mlir::Value> iterArgs;
+    if (hasFinalValue)
+      iterArgs.push_back(low);
+    iterArgs.append(doLoopOp.getIterOperands().begin(),
+                    doLoopOp.getIterOperands().end());
+
+    // Caculate the trip count.
+    auto diff = rewriter.create<mlir::arith::SubIOp>(loc, high, low);
+    auto distance = rewriter.create<mlir::arith::AddIOp>(loc, diff, step);
+    auto tripCount = rewriter.create<mlir::arith::DivSIOp>(loc, distance, step);
+    auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
+    auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
+    auto scfForOp =
+        rewriter.create<scf::ForOp>(loc, zero, tripCount, one, iterArgs);
+
+    auto &loopOps = doLoopOp.getBody()->getOperations();
+    auto resultOp = cast<fir::ResultOp>(doLoopOp.getBody()->getTerminator());
+    auto results = resultOp.getOperands();
+    Block *loweredBody = scfForOp.getBody();
+
+    loweredBody->getOperations().splice(loweredBody->begin(), loopOps,
+                                        loopOps.begin(),
+                                        std::prev(loopOps.end()));
+
+    rewriter.setInsertionPointToStart(loweredBody);
+    Value iv =
+        rewriter.create<arith::MulIOp>(loc, scfForOp.getInductionVar(), step);
+    iv = rewriter.create<arith::AddIOp>(loc, low, iv);
+
+    if (!results.empty()) {
+      rewriter.setInsertionPointToEnd(loweredBody);
+      rewriter.create<scf::YieldOp>(resultOp->getLoc(), results);
+    }
+    doLoopOp.getInductionVar().replaceAllUsesWith(iv);
+    rewriter.replaceAllUsesWith(doLoopOp.getRegionIterArgs(),
+                                hasFinalValue
+                                    ? scfForOp.getRegionIterArgs().drop_front()
+                                    : scfForOp.getRegionIterArgs());
+
+    // Copy loop annotations from the do loop to the loop entry condition.
+    if (auto ann = doLoopOp.getLoopAnnotation())
+      scfForOp->setAttr("loop_annotation", *ann);
+
+    rewriter.replaceOp(doLoopOp, scfForOp);
+    return success();
+  }
+};
+
+void FIRToSCFPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  patterns.add<DoLoopConversion>(patterns.getContext());
+  ConversionTarget target(getContext());
+  target.addIllegalOp<fir::DoLoopOp>();
+  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
+  if (failed(
+          applyPartialConversion(getOperation(), target, std::move(patterns))))
+    signalPassFailure();
+}
+
+std::unique_ptr<mlir::Pass> fir::createFIRToSCFPass() {
+  return std::make_unique<FIRToSCFPass>();
+}
diff --git a/flang/test/Fir/FirToSCF/do-loop.fir b/flang/test/Fir/FirToSCF/do-loop.fir
new file mode 100644
index 0000000000000..c3c24ccc1db71
--- /dev/null
+++ b/flang/test/Fir/FirToSCF/do-loop.fir
@@ -0,0 +1,147 @@
+// RUN: fir-opt %s --fir-to-scf | FileCheck %s
+
+// CHECK-LABEL:   func.func @simple_loop(
+// CHECK-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 100 : index
+// CHECK:           %[[VAL_2:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_4:.*]] = arith.subi %[[VAL_1]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_6:.*]] = arith.divsi %[[VAL_5]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_7:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:           scf.for %[[VAL_9:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_8]] {
+// CHECK:             %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %[[VAL_0]] : index
+// CHECK:             %[[VAL_11:.*]] = arith.addi %[[VAL_0]], %[[VAL_10]] : index
+// CHECK:             %[[VAL_12:.*]] = fir.array_coor %[[ARG0]](%[[VAL_2]]) %[[VAL_11]] : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+// CHECK:             fir.store %[[VAL_3]] to %[[VAL_12]] : !fir.ref<i32>
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+func.func @simple_loop(%arg0: !fir.ref<!fir.array<100xi32>>) {
+  %c1 = arith.constant 1 : index
+  %c100 = arith.constant 100 : index
+  %0 = fir.shape %c100 : (index) -> !fir.shape<1>
+  %c1_i32 = arith.constant 1 : i32
+  fir.do_loop %arg1 = %c1 to %c100 step %c1 {
+    %1 = fir.array_coor %arg0(%0) %arg1 : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+    fir.store %c1_i32 to %1 : !fir.ref<i32>
+  }
+  return
+}
+
+// CHECK-LABEL:   func.func @loop_with_negtive_step(
+// CHECK-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 100 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant -1 : index
+// CHECK:           %[[VAL_3:.*]] = fir.shape %[[VAL_0]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_5:.*]] = arith.subi %[[VAL_1]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_2]] : index
+// CHECK:           %[[VAL_7:.*]] = arith.divsi %[[VAL_6]], %[[VAL_2]] : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_9]] {
+// CHECK:             %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_2]] : index
+// CHECK:             %[[VAL_12:.*]] = arith.addi %[[VAL_0]], %[[VAL_11]] : index
+// CHECK:             %[[VAL_13:.*]] = fir.array_coor %[[ARG0]](%[[VAL_3]]) %[[VAL_12]] : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+// CHECK:             fir.store %[[VAL_4]] to %[[VAL_13]] : !fir.ref<i32>
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+func.func @loop_with_negtive_step(%arg0: !fir.ref<!fir.array<100xi32>>) {
+  %c100 = arith.constant 100 : index
+  %c1 = arith.constant 1 : index
+  %c-1 = arith.constant -1 : index
+  %0 = fir.shape %c100 : (index) -> !fir.shape<1>
+  %c1_i32 = arith.constant 1 : i32
+  fir.do_loop %arg1 = %c100 to %c1 step %c-1 {
+    %1 = fir.array_coor %arg0(%0) %arg1 : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+    fir.store %c1_i32 to %1 : !fir.ref<i32>
+  }
+  return
+}
+
+// CHECK-LABEL:   func.func @loop_with_results(
+// CHECK-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>,
+// CHECK-SAME:      %[[ARG1:.*]]: !fir.ref<i32>) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_2:.*]] = arith.constant 100 : index
+// CHECK:           %[[VAL_3:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_4:.*]] = arith.subi %[[VAL_2]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_6:.*]] = arith.divsi %[[VAL_5]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_7:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_9:.*]] = scf.for %[[VAL_10:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_8]] iter_args(%[[VAL_11:.*]] = %[[VAL_1]]) -> (i32) {
+// CHECK:             %[[VAL_12:.*]] = arith.muli %[[VAL_10]], %[[VAL_0]] : index
+// CHECK:             %[[VAL_13:.*]] = arith.addi %[[VAL_0]], %[[VAL_12]] : index
+// CHECK:             %[[VAL_14:.*]] = fir.array_coor %[[ARG0]](%[[VAL_3]]) %[[VAL_13]] : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+// CHECK:             %[[VAL_15:.*]] = fir.load %[[VAL_14]] : !fir.ref<i32>
+// CHECK:             %[[VAL_16:.*]] = arith.addi %[[VAL_11]], %[[VAL_15]] : i32
+// CHECK:             scf.yield %[[VAL_16]] : i32
+// CHECK:           }
+// CHECK:           fir.store %[[VAL_9]] to %[[ARG1]] : !fir.ref<i32>
+// CHECK:           return
+// CHECK:         }
+func.func @loop_with_results(%arg0: !fir.ref<!fir.array<100xi32>>, %arg1: !fir.ref<i32>) {
+  %c1 = arith.constant 1 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c100 = arith.constant 100 : index
+  %0 = fir.shape %c100 : (index) -> !fir.shape<1>
+  %1 = fir.do_loop %arg2 = %c1 to %c100 step %c1 iter_args(%arg3 = %c0_i32) -> (i32) {
+    %2 = fir.array_coor %arg0(%0) %arg2 : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+    %3 = fir.load %2 : !fir.ref<i32>
+    %4 = arith.addi %arg3, %3 : i32
+    fir.result %4 : i32
+  }
+  fir.store %1 to %arg1 : !fir.ref<i32>
+  return
+}
+
+// CHECK-LABEL:   func.func @loop_with_final_value(
+// CHECK-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.array<100xi32>>,
+// CHECK-SAME:      %[[ARG1:.*]]: !fir.ref<i32>) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_2:.*]] = arith.constant 100 : index
+// CHECK:           %[[VAL_3:.*]] = fir.alloca index
+// CHECK:           %[[VAL_4:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_5:.*]] = arith.subi %[[VAL_2]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_7:.*]] = arith.divsi %[[VAL_6]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_10:.*]]:2 = scf.for %[[VAL_11:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_9]] iter_args(%[[VAL_12:.*]] = %[[VAL_0]], %[[VAL_13:.*]] = %[[VAL_1]]) -> (index, i32) {
+// CHECK:             %[[VAL_14:.*]] = arith.muli %[[VAL_11]], %[[VAL_0]] : index
+// CHECK:             %[[VAL_15:.*]] = arith.addi %[[VAL_0]], %[[VAL_14]] : index
+// CHECK:             %[[VAL_16:.*]] = fir.array_coor %[[ARG0]](%[[VAL_4]]) %[[VAL_15]] : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+// CHECK:             %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
+// CHECK:             %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_0]] overflow<nsw> : index
+// CHECK:             %[[VAL_19:.*]] = arith.addi %[[VAL_13]], %[[VAL_17]] overflow<nsw> : i32
+// CHECK:             scf.yield %[[VAL_18]], %[[VAL_19]] : index, i32
+// CHECK:           }
+// CHECK:           fir.store %[[VAL_20:.*]]#0 to %[[VAL_3]] : !fir.ref<index>
+// CHECK:           fir.store %[[VAL_20]]#1 to %[[ARG1]] : !fir.ref<i32>
+// CHECK:           return
+// CHECK:         }
+func.func @loop_with_final_value(%arg0: !fir.ref<!fir.array<100xi32>>, %arg1: !fir.ref<i32>) {
+  %c1 = arith.constant 1 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c100 = arith.constant 100 : index
+  %0 = fir.alloca index
+  %1 = fir.shape %c100 : (index) -> !fir.shape<1>
+  %2:2 = fir.do_loop %arg2 = %c1 to %c100 step %c1 iter_args(%arg3 = %c0_i32) -> (index, i32) {
+    %3 = fir.array_coor %arg0(%1) %arg2 : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+    %4 = fir.load %3 : !fir.ref<i32>
+    %5 = arith.addi %arg2, %c1 overflow<nsw> : index
+    %6 = arith.addi %arg3, %4 overflow<nsw> : i32
+    fir.result %5, %6 : index, i32
+  }
+  fir.store %2#0 to %0 : !fir.ref<index>
+  fir.store %2#1 to %arg1 : !fir.ref<i32>
+  return
+}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants