Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tuner]: Add a utility function to query supported MMA intrinsics #19124

Merged
merged 6 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ iree_compiler_cc_library(
"ROCDLKernelConfig.cpp",
"ROCDLLowerExecutableTarget.cpp",
"ROCDLSelectLoweringStrategy.cpp",
"TestLLVMGPUQueryMMAPass.cpp",
"Verifiers.cpp",
],
hdrs = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ iree_cc_library(
"ROCDLKernelConfig.cpp"
"ROCDLLowerExecutableTarget.cpp"
"ROCDLSelectLoweringStrategy.cpp"
"TestLLVMGPUQueryMMAPass.cpp"
"Verifiers.cpp"
DEPS
::PassHeaders
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,9 @@ def TestLLVMGPUScalarizeMathOpPass :
let summary = "Test pass for several legalization patterns.";
}

def TestLLVMGPUQueryMMAPass :
Pass<"iree-test-llvmgpu-query-mma", "ModuleOp"> {
let summary = "Test pass for querying the supported mma intrinsic instructions.";
}

#endif // IREE_CODEGEN_LLVMGPU_PASSES
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"

#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "iree-test-llvmgpu-query-mma"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_TESTLLVMGPUQUERYMMAPASS
#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"

namespace {

struct TestLLVMGPUQueryMMAPass final
: impl::TestLLVMGPUQueryMMAPassBase<TestLLVMGPUQueryMMAPass> {
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
llvm::SmallDenseMap<Operation *, SmallVector<IREE::GPU::MMAIntrinsic>>
mmaMap;
queryMMAIntrinsics(moduleOp, mmaMap);
for (const auto &entry : mmaMap) {
Operation *op = entry.first;
const SmallVector<IREE::GPU::MMAIntrinsic> &mmaAttrs = entry.second;
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved
if (auto variantOp = llvm::dyn_cast<IREE::HAL::ExecutableVariantOp>(op)) {
llvm::outs() << "Executable Variant Name: " << variantOp.getName()
<< "\n";
} else {
llvm::outs() << "Executable Variant Name: " << "Unnamed Operation"
<< "\n";
}
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved
llvm::outs() << "MMA Intrinsics: ";
for (const auto &mma : mmaAttrs) {
llvm::outs() << mma << " ";
}
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved
llvm::outs() << "\n";
}
}
};
} // namespace
} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ iree_lit_test_suite(
"promote_matmul_to_fit_mma.mlir",
"tensor_pad.mlir",
"tensorcore_vectorization.mlir",
"test_query_mma.mlir",
"transform_dialect_bufferize.mlir",
"transform_dialect_eliminate_gpu_barriers.mlir",
"transform_dialect_hoist_allocs.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ iree_lit_test_suite(
"rocdl_pipeline_test.mlir"
"tensor_pad.mlir"
"tensorcore_vectorization.mlir"
"test_query_mma.mlir"
"transform_dialect_bufferize.mlir"
"transform_dialect_eliminate_gpu_barriers.mlir"
"transform_dialect_hoist_allocs.mlir"
Expand Down
104 changes: 104 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/test/test_query_mma.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// RUN: iree-opt --split-input-file --iree-test-llvmgpu-query-mma %s | FileCheck %s

#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
{abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "",
wgp = <compute = int32, storage = b32,
subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>],
subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647]>>, waves_per_eu = 2 : i64}>
#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>
module {
hal.executable private @main {
hal.executable.variant public @main target(#executable_target_rocm_hsaco_fb) {
hal.executable.export public @entry_point layout(#pipeline_layout)
builtin.module {
func.func @fn() {
return
}
}
}
}
}

// CHECK: Executable Variant Name
// CHECK-SAME: main
// CHECK: MMA Intrinsics
// CHECK-SAME: MFMA_F32_16x16x4_F32
// CHECK-SAME: MFMA_F32_16x16x16_F16
// CHECK-LABEL: func.func @fn

// -----

#executable_target_rocm_hsaco_fb0 = #hal.executable.target<"rocm", "rocm-hsaco-fb",
{abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "",
wgp = <compute = int32, storage = b32,
subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>],
subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved
max_workgroup_counts = [2147483647]>>, waves_per_eu = 2 : i64}>
#executable_target_rocm_hsaco_fb1 = #hal.executable.target<"rocm", "rocm-hsaco-fb",
{abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "",
wgp = <compute = int32, storage = b32,
subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
mma = [<MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x16_BF16>],
subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647]>>, waves_per_eu = 2 : i64}>
#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>
module {
hal.executable private @main_0 {
hal.executable.variant public @main_0 target(#executable_target_rocm_hsaco_fb0) {
hal.executable.export public @entry_point_0 layout(#pipeline_layout)
builtin.module {
func.func @fn_0() {
return
}
}
}
}
hal.executable private @main_1 {
hal.executable.variant public @main_1 target(#executable_target_rocm_hsaco_fb1) {
hal.executable.export public @entry_point layout(#pipeline_layout)
builtin.module {
func.func @fn_1() {
return
}
}
}
}
}

// CHECK: Executable Variant Name
// CHECK-SAME: main_0
// CHECK: MMA Intrinsics
// CHECK-SAME: MFMA_F32_16x16x4_F32
// CHECK-SAME: MFMA_F32_16x16x16_F16
// CHECK: Executable Variant Name
// CHECK-SAME: main_1
// CHECK: MMA Intrinsics
// CHECK-SAME: MFMA_F32_32x32x8_F16
// CHECK-SAME: MFMA_F32_16x16x16_BF16

// -----

#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb">
#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>
module {
hal.executable private @main {
hal.executable.variant public @main target(#executable_target_rocm_hsaco_fb) {
hal.executable.export public @entry_point layout(#pipeline_layout)
builtin.module {
func.func @fn_empty() {
return
}
}
}
}
}

// CHECK-NOT: Executable Variant Name
// CHECK-NOT: MMA Intrinsics
// CHECK-LABEL: func.func @fn
19 changes: 19 additions & 0 deletions compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
Expand Down Expand Up @@ -1028,4 +1029,22 @@ std::optional<int> getGPUSubgroupSize(mlir::FunctionOpInterface func) {
return std::nullopt;
}

void queryMMAIntrinsics(
mlir::ModuleOp moduleOp,
llvm::SmallDenseMap<Operation *, SmallVector<IREE::GPU::MMAIntrinsic>>
&mmaAttributesMap) {
moduleOp.walk([&](IREE::HAL::ExecutableVariantOp executableOp) {
if (IREE::GPU::TargetAttr target = getGPUTargetAttr(executableOp)) {
SmallVector<IREE::GPU::MMAIntrinsic> mmaIntrinsics;
llvm::append_range(
mmaIntrinsics,
llvm::map_range(target.getWgp().getMma(),
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved
[](IREE::GPU::MMAAttr attr) {
return attr.getIntrinsic().getValue();
}));
mmaAttributesMap[executableOp] = mmaIntrinsics;
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved
}
});
}

} // namespace mlir::iree_compiler
7 changes: 7 additions & 0 deletions compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,13 @@ IREE::GPU::TargetAttr getGPUTargetAttr(Operation *op);
/// Returns std::nullopt if none found.
std::optional<int> getGPUSubgroupSize(mlir::FunctionOpInterface func);

/// Returns supported MMA intrinsic instructions based on the GPU target
/// description stored in `moduleOp` and populates them in `MMAIntrinsic`.
void queryMMAIntrinsics(
mlir::ModuleOp moduleOp,
llvm::SmallDenseMap<Operation *, SmallVector<IREE::GPU::MMAIntrinsic>>
&mmaAttributesMap);
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved

} // namespace mlir::iree_compiler

#endif // IREE_COMPILER_CODEGEN_UTILS_GPUUTILS_H_
Loading