Skip to content

Commit

Permalink
More IREEGPUAttrs.cpp cleanups (#19142)
Browse files Browse the repository at this point in the history
Two things in this PR:
1. Make a big switch statement more concise.
2. Currently, `DataTileMMAAttr::buildMmaOperation` creates a `MMAAttr`
just to call `buildMmaOperation` on it, to reuse that implementation. In
addition to being roundabout, this required a comment explaining why we
discarded the error status, as `MMAAttr::buildMmaOperation` is fallible
but here we were already past validation and mutating IR. This PR
refactors that to let both call a shared, infallible helper.

Signed-off-by: Benoit Jacob <[email protected]>
  • Loading branch information
bjacob authored Nov 14, 2024
1 parent f828914 commit 2a2bd06
Showing 1 changed file with 66 additions and 99 deletions.
165 changes: 66 additions & 99 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,88 +220,45 @@ static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
Type i8 = IntegerType::get(context, 8);
Type i32 = IntegerType::get(context, 32);
switch (intrinsic) {
case MMAIntrinsic::MFMA_F64_16x16x4_F64: {
case MMAIntrinsic::MFMA_F64_16x16x4_F64:
return {f64, f64, f64};
}
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
return {f32, f32, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
return {f16, f16, f32};
}
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::NV_WMMA_F32_16x16x16_F16:
return {f16, f16, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x8_BF16: {
return {bf16, bf16, f32};
}
case MMAIntrinsic::MFMA_F32_32x32x4_BF16: {
return {bf16, bf16, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
return {bf16, bf16, f32};
}
case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
case MMAIntrinsic::NV_WMMA_F16_16x16x16_F16:
return {f16, f16, f16};
case MMAIntrinsic::MFMA_F32_16x16x8_BF16:
case MMAIntrinsic::MFMA_F32_32x32x4_BF16:
case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
case MMAIntrinsic::WMMA_F32_16x16x16_BF16:
return {bf16, bf16, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: {
return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: {
return {f8E5M2FNUZ, f8E5M2FNUZ, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ: {
return {f8E4M3FNUZ, f8E5M2FNUZ, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ: {
return {f8E5M2FNUZ, f8E4M3FNUZ, f32};
}
case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ: {
case MMAIntrinsic::WMMA_BF16_16x16x16_BF16:
return {bf16, bf16, bf16};
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ:
return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
}
case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ: {
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ:
return {f8E5M2FNUZ, f8E5M2FNUZ, f32};
}
case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ: {
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ:
case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ:
return {f8E4M3FNUZ, f8E5M2FNUZ, f32};
}
case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ: {
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ:
return {f8E5M2FNUZ, f8E4M3FNUZ, f32};
}
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
return {i8, i8, i32};
}
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
return {i8, i8, i32};
}
case MMAIntrinsic::MFMA_I32_32x32x8_I8: {
return {i8, i8, i32};
}
case MMAIntrinsic::MFMA_I32_16x16x16_I8: {
return {i8, i8, i32};
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16: {
return {f16, f16, f32};
}
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
return {f16, f16, f16};
}
case MMAIntrinsic::WMMA_F32_16x16x16_BF16: {
return {bf16, bf16, f32};
}
case MMAIntrinsic::WMMA_BF16_16x16x16_BF16: {
return {bf16, bf16, bf16};
}
case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
case MMAIntrinsic::MFMA_I32_16x16x16_I8:
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::MFMA_I32_32x32x16_I8:
case MMAIntrinsic::WMMA_I32_16x16x16_I8:
return {i8, i8, i32};
}
case MMAIntrinsic::NV_WMMA_F16_16x16x16_F16: {
return {f16, f16, f16};
}
case MMAIntrinsic::NV_WMMA_F32_16x16x16_F16: {
return {f16, f16, f32};
}
}
assert(false && "unexpected enum value");
return {};
}
Expand Down Expand Up @@ -498,11 +455,15 @@ MMAAttr::getContractionLayout(vector::ContractionOp contract) const {
return IREE::GPU::getContractionLayout(contract, layout);
}

int64_t MMAAttr::getBlockSize() const {
static int getBlockSize(MMAIntrinsic /*intrinsic*/) {
// Not supporting any block size other than 1 at the moment.
return 1;
}

int64_t MMAAttr::getBlockSize() const {
return IREE::GPU::getBlockSize(getIntrinsic().getValue());
}

static uint32_t getArchID(MMAIntrinsic intrinsic) {
return static_cast<int>(intrinsic) & 0xFF00;
}
Expand Down Expand Up @@ -704,6 +665,31 @@ SmallVector<VirtualMMAIntrinsic> MMAAttr::getVirtualIntrinsics() const {
}
}

static Value createMmaOp(OpBuilder &builder, Location loc,
MMAIntrinsic intrinsic, Type resultType, Value lhs,
Value rhs, Value acc) {
auto getVecOrSingleElem = [&](Value vec) -> Value {
bool one = llvm::cast<VectorType>(vec.getType()).getNumElements() == 1;
return one ? builder.create<vector::ExtractOp>(loc, vec, 0) : vec;
};
auto layout = getOpaqueMMALayout(builder.getContext(), intrinsic);
if (is_AMD_MFMA(intrinsic)) {
// MFMA intrinsics want single-element operands of element type, not vector.
lhs = getVecOrSingleElem(lhs);
rhs = getVecOrSingleElem(rhs);
return builder
.create<amdgpu::MFMAOp>(loc, resultType, layout.mSize, layout.nSize,
layout.kSize, getBlockSize(intrinsic), lhs, rhs,
acc)
.getResult();
}
if (is_AMD_WMMA(intrinsic)) {
return builder.create<amdgpu::WMMAOp>(loc, resultType, lhs, rhs, acc)
.getResult();
}
return {};
}

// Generates amdgpu.mfma/wmma operation on the given inputs for this attribute
// type.
FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
Expand All @@ -719,23 +705,9 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
if (cType != resultType) {
return failure();
}
auto getVecOrSingleElem = [&](Value vec) -> Value {
bool one = llvm::cast<VectorType>(vec.getType()).getNumElements() == 1;
return one ? builder.create<vector::ExtractOp>(loc, vec, 0) : vec;
};
MMAIntrinsic intrinsic = getIntrinsic().getValue();
if (is_AMD_MFMA(intrinsic)) {
// MFMA intrinsics want single-element operands of element type, not vector.
lhs = getVecOrSingleElem(lhs);
rhs = getVecOrSingleElem(rhs);
auto [m, n, k] = getMNKShape();
return builder
.create<amdgpu::MFMAOp>(loc, resultType, m, n, k, getBlockSize(), lhs,
rhs, acc)
.getResult();
} else if (is_AMD_WMMA(intrinsic)) {
return builder.create<amdgpu::WMMAOp>(loc, resultType, lhs, rhs, acc)
.getResult();
if (Value value = createMmaOp(builder, loc, getIntrinsic().getValue(),
resultType, lhs, rhs, acc)) {
return value;
}
return failure();
}
Expand Down Expand Up @@ -1168,23 +1140,18 @@ FailureOr<Value> DataTiledMMAAttr::buildMmaOperation(OpBuilder &builder,
SmallVector<Value> intrinsicsAcc =
distributeMmaFragmentToIntrinsics(builder, loc, acc, accSwizzle);

// Get a MMAAttr for the intrinsic itself, to reuse MMAAttr::buildMmaOperation
// to create the target intrinsics.
auto intrinsicMma = MMAAttr::get(getContext(), getIntrinsic().getValue());
auto [intrinsicAType, intrinsicBType, intrinsicCType] =
intrinsicMma.getABCVectorTypes();
MMAIntrinsic intrinsic = getIntrinsic().getValue();
VectorType intrinCType =
getVectorType(builder.getContext(), intrinsic, MMAFragment::Acc);

// Loop over the 3 unroll_{m,n,k} dimensions to create the intrinsics.
for (int mu = 0; mu < getUnrollM(); ++mu) {
for (int nu = 0; nu < getUnrollN(); ++nu) {
for (int ku = 0; ku < getUnrollK(); ++ku) {
// Assume intrinsicMma.buildMmaOperation() success: validation should be
// completed prior to mutating IR.
Value lhs = intrinsicsLhs[mu * getUnrollK() + ku];
Value rhs = intrinsicsRhs[nu * getUnrollK() + ku];
Value &acc = intrinsicsAcc[mu * getUnrollN() + nu];
acc = *intrinsicMma.buildMmaOperation(builder, loc, intrinsicCType, lhs,
rhs, acc);
acc = createMmaOp(builder, loc, intrinsic, intrinCType, lhs, rhs, acc);
}
}
}
Expand Down

0 comments on commit 2a2bd06

Please sign in to comment.