Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More IREEGPUAttrs.cpp cleanups #19142

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 66 additions & 99 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,88 +220,45 @@ static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
Type i8 = IntegerType::get(context, 8);
Type i32 = IntegerType::get(context, 32);
switch (intrinsic) {
case MMAIntrinsic::MFMA_F64_16x16x4_F64: {
case MMAIntrinsic::MFMA_F64_16x16x4_F64:
return {f64, f64, f64};
}
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
return {f32, f32, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
return {f16, f16, f32};
}
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::NV_WMMA_F32_16x16x16_F16:
return {f16, f16, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x8_BF16: {
return {bf16, bf16, f32};
}
case MMAIntrinsic::MFMA_F32_32x32x4_BF16: {
return {bf16, bf16, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
return {bf16, bf16, f32};
}
case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
case MMAIntrinsic::NV_WMMA_F16_16x16x16_F16:
return {f16, f16, f16};
case MMAIntrinsic::MFMA_F32_16x16x8_BF16:
case MMAIntrinsic::MFMA_F32_32x32x4_BF16:
case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
case MMAIntrinsic::WMMA_F32_16x16x16_BF16:
return {bf16, bf16, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: {
return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: {
return {f8E5M2FNUZ, f8E5M2FNUZ, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ: {
return {f8E4M3FNUZ, f8E5M2FNUZ, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ: {
return {f8E5M2FNUZ, f8E4M3FNUZ, f32};
}
case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ: {
case MMAIntrinsic::WMMA_BF16_16x16x16_BF16:
return {bf16, bf16, bf16};
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ:
return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
}
case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ: {
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ:
return {f8E5M2FNUZ, f8E5M2FNUZ, f32};
}
case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ: {
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ:
case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ:
return {f8E4M3FNUZ, f8E5M2FNUZ, f32};
}
case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ: {
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ:
return {f8E5M2FNUZ, f8E4M3FNUZ, f32};
}
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
return {i8, i8, i32};
}
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
return {i8, i8, i32};
}
case MMAIntrinsic::MFMA_I32_32x32x8_I8: {
return {i8, i8, i32};
}
case MMAIntrinsic::MFMA_I32_16x16x16_I8: {
return {i8, i8, i32};
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16: {
return {f16, f16, f32};
}
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
return {f16, f16, f16};
}
case MMAIntrinsic::WMMA_F32_16x16x16_BF16: {
return {bf16, bf16, f32};
}
case MMAIntrinsic::WMMA_BF16_16x16x16_BF16: {
return {bf16, bf16, bf16};
}
case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
case MMAIntrinsic::MFMA_I32_16x16x16_I8:
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::MFMA_I32_32x32x16_I8:
case MMAIntrinsic::WMMA_I32_16x16x16_I8:
return {i8, i8, i32};
}
case MMAIntrinsic::NV_WMMA_F16_16x16x16_F16: {
return {f16, f16, f16};
}
case MMAIntrinsic::NV_WMMA_F32_16x16x16_F16: {
return {f16, f16, f32};
}
}
assert(false && "unexpected enum value");
return {};
}
Expand Down Expand Up @@ -498,11 +455,15 @@ MMAAttr::getContractionLayout(vector::ContractionOp contract) const {
return IREE::GPU::getContractionLayout(contract, layout);
}

int64_t MMAAttr::getBlockSize() const {
static int getBlockSize(MMAIntrinsic /*intrinsic*/) {
// Not supporting any block size other than 1 at the moment.
return 1;
}

int64_t MMAAttr::getBlockSize() const {
return IREE::GPU::getBlockSize(getIntrinsic().getValue());
}

static uint32_t getArchID(MMAIntrinsic intrinsic) {
return static_cast<int>(intrinsic) & 0xFF00;
}
Expand Down Expand Up @@ -704,6 +665,31 @@ SmallVector<VirtualMMAIntrinsic> MMAAttr::getVirtualIntrinsics() const {
}
}

static Value createMmaOp(OpBuilder &builder, Location loc,
MMAIntrinsic intrinsic, Type resultType, Value lhs,
Value rhs, Value acc) {
auto getVecOrSingleElem = [&](Value vec) -> Value {
bool one = llvm::cast<VectorType>(vec.getType()).getNumElements() == 1;
return one ? builder.create<vector::ExtractOp>(loc, vec, 0) : vec;
};
auto layout = getOpaqueMMALayout(builder.getContext(), intrinsic);
if (is_AMD_MFMA(intrinsic)) {
// MFMA intrinsics want single-element operands of element type, not vector.
lhs = getVecOrSingleElem(lhs);
rhs = getVecOrSingleElem(rhs);
return builder
.create<amdgpu::MFMAOp>(loc, resultType, layout.mSize, layout.nSize,
layout.kSize, getBlockSize(intrinsic), lhs, rhs,
acc)
.getResult();
}
if (is_AMD_WMMA(intrinsic)) {
return builder.create<amdgpu::WMMAOp>(loc, resultType, lhs, rhs, acc)
.getResult();
}
return {};
}

// Generates amdgpu.mfma/wmma operation on the given inputs for this attribute
// type.
FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
Expand All @@ -719,23 +705,9 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
if (cType != resultType) {
return failure();
}
auto getVecOrSingleElem = [&](Value vec) -> Value {
bool one = llvm::cast<VectorType>(vec.getType()).getNumElements() == 1;
return one ? builder.create<vector::ExtractOp>(loc, vec, 0) : vec;
};
MMAIntrinsic intrinsic = getIntrinsic().getValue();
if (is_AMD_MFMA(intrinsic)) {
// MFMA intrinsics want single-element operands of element type, not vector.
lhs = getVecOrSingleElem(lhs);
rhs = getVecOrSingleElem(rhs);
auto [m, n, k] = getMNKShape();
return builder
.create<amdgpu::MFMAOp>(loc, resultType, m, n, k, getBlockSize(), lhs,
rhs, acc)
.getResult();
} else if (is_AMD_WMMA(intrinsic)) {
return builder.create<amdgpu::WMMAOp>(loc, resultType, lhs, rhs, acc)
.getResult();
if (Value value = createMmaOp(builder, loc, getIntrinsic().getValue(),
resultType, lhs, rhs, acc)) {
return value;
}
return failure();
}
Expand Down Expand Up @@ -1168,23 +1140,18 @@ FailureOr<Value> DataTiledMMAAttr::buildMmaOperation(OpBuilder &builder,
SmallVector<Value> intrinsicsAcc =
distributeMmaFragmentToIntrinsics(builder, loc, acc, accSwizzle);

// Get a MMAAttr for the intrinsic itself, to reuse MMAAttr::buildMmaOperation
// to create the target intrinsics.
auto intrinsicMma = MMAAttr::get(getContext(), getIntrinsic().getValue());
auto [intrinsicAType, intrinsicBType, intrinsicCType] =
intrinsicMma.getABCVectorTypes();
MMAIntrinsic intrinsic = getIntrinsic().getValue();
VectorType intrinCType =
getVectorType(builder.getContext(), intrinsic, MMAFragment::Acc);
Comment on lines +1144 to +1145
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not required, but any thoughts on moving this into the createMmaOp itself?


// Loop over the 3 unroll_{m,n,k} dimensions to create the intrinsics.
for (int mu = 0; mu < getUnrollM(); ++mu) {
for (int nu = 0; nu < getUnrollN(); ++nu) {
for (int ku = 0; ku < getUnrollK(); ++ku) {
// Assume intrinsicMma.buildMmaOperation() success: validation should be
// completed prior to mutating IR.
Value lhs = intrinsicsLhs[mu * getUnrollK() + ku];
Value rhs = intrinsicsRhs[nu * getUnrollK() + ku];
Value &acc = intrinsicsAcc[mu * getUnrollN() + nu];
acc = *intrinsicMma.buildMmaOperation(builder, loc, intrinsicCType, lhs,
rhs, acc);
acc = createMmaOp(builder, loc, intrinsic, intrinCType, lhs, rhs, acc);
}
}
}
Expand Down
Loading