Skip to content

Commit

Permalink
[MLIR][ROCDL] Add GFX940 SMFMAC (2:4 sparsity) instructions to the RO…
Browse files Browse the repository at this point in the history
…CDL dialect (#124435)

# Overview

This PR adds 2:4 structured sparsity (sparse A, dense B) matrix multiply
instructions to ROCDL.

# Testing

I've added tests to Dialect/mlir and Target/mlir
  • Loading branch information
SamGinzburg authored Jan 27, 2025
1 parent 14ffff3 commit 43a50de
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 0 deletions.
18 changes: 18 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,24 @@ def ROCDL_mfma_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x32.i8">;
def ROCDL_mfma_f32_32x32x16_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.f16">;
def ROCDL_mfma_scale_f32_16x16x128_f8f6f4 : ROCDL_Mfma_OO_IntrOp<"mfma.scale.f32.16x16x128.f8f6f4", [0,1]>;
def ROCDL_mfma_scale_f32_32x32x64_f8f6f4 : ROCDL_Mfma_OO_IntrOp<"mfma.scale.f32.32x32x64.f8f6f4", [0,1]>;

// 2:4 Sparsity ops (GFX940)
def ROCDL_smfmac_f32_16x16x32_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x32.f16">;
def ROCDL_smfmac_f32_32x32x16_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x16.f16">;
def ROCDL_smfmac_f32_16x16x32_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x32.bf16">;
def ROCDL_smfmac_f32_32x32x16_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x16.bf16">;
def ROCDL_smfmac_i32_16x16x64_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.16x16x64.i8">;
def ROCDL_smfmac_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.32x32x32.i8">;
def ROCDL_smfmac_f32_16x16x64_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf8.bf8">;
def ROCDL_smfmac_f32_16x16x64_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf8.fp8">;
def ROCDL_smfmac_f32_16x16x64_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.fp8.bf8">;
def ROCDL_smfmac_f32_16x16x64_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.fp8.fp8">;
def ROCDL_smfmac_f32_32x32x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.bf8">;
def ROCDL_smfmac_f32_32x32x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.fp8">;
def ROCDL_smfmac_f32_32x32x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.bf8">;
def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.fp8">;


//===---------------------------------------------------------------------===//
// WMMA intrinsics
class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
Expand Down
87 changes: 87 additions & 0 deletions mlir/test/Dialect/LLVMIR/rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,93 @@ func.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
llvm.return
}


llvm.func @rocdl.smfmac(%arg0 : i32,
%arg1 : vector<4 x f16>,
%arg2 : vector<8 x f16>,
%arg3 : vector<4 x f32>,
%arg4 : vector<16 x f32>,
%arg5 : vector<4 x i16>,
%arg6 : vector<8 x i16>,
%arg7 : vector<2xi32>,
%arg8 : vector<4xi32>,
%arg9 : vector<16xi32>) -> vector<4 x f32> {
%csti32 = llvm.mlir.constant(42 : i32) : i32

// CHECK-LABEL: rocdl.smfmac
// CHECK: rocdl.smfmac.f32.16x16x32.f16 %{{.*}} : (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
%r0 = rocdl.smfmac.f32.16x16x32.f16 %arg1, %arg2, %arg3, %csti32, %csti32, %csti32 :
(vector<4xf16>, vector<8xf16>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>

// CHECK: rocdl.smfmac.f32.32x32x16.f16 %{{.*}} : (vector<4xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
%r1 = rocdl.smfmac.f32.32x32x16.f16 %arg1, %arg2, %arg4, %csti32, %csti32, %csti32 :
(vector<4xf16>, vector<8xf16>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>

// CHECK: rocdl.smfmac.f32.16x16x32.bf16 %{{.*}} : (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
%r2 = rocdl.smfmac.f32.16x16x32.bf16 %arg5, %arg6, %arg3, %csti32, %csti32, %csti32 :
(vector<4xi16>, vector<8xi16>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>

// CHECK: rocdl.smfmac.f32.32x32x16.bf16 %{{.*}} : (vector<4xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
%r3 = rocdl.smfmac.f32.32x32x16.bf16 %arg5, %arg6, %arg4, %csti32, %csti32, %csti32 :
(vector<4xi16>, vector<8xi16>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>

// CHECK: rocdl.smfmac.i32.16x16x64.i8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
%r4 = rocdl.smfmac.i32.16x16x64.i8 %arg7, %arg8, %arg8, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<4xi32>,
i32, i32, i32) -> vector<4xi32>

// CHECK: rocdl.smfmac.i32.32x32x32.i8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
%r5 = rocdl.smfmac.i32.32x32x32.i8 %arg7, %arg8, %arg9, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<16xi32>,
i32, i32, i32) -> vector<16xi32>

// CHECK: rocdl.smfmac.f32.16x16x64.bf8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
%r6 = rocdl.smfmac.f32.16x16x64.bf8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>

// CHECK: rocdl.smfmac.f32.16x16x64.bf8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
%r7 = rocdl.smfmac.f32.16x16x64.bf8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>

// CHECK: rocdl.smfmac.f32.16x16x64.fp8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
%r8 = rocdl.smfmac.f32.16x16x64.fp8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>

// CHECK: rocdl.smfmac.f32.16x16x64.fp8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
%r9 = rocdl.smfmac.f32.16x16x64.fp8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>

// CHECK: rocdl.smfmac.f32.32x32x32.bf8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
%r10 = rocdl.smfmac.f32.32x32x32.bf8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>

// CHECK: rocdl.smfmac.f32.32x32x32.bf8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
%r11 = rocdl.smfmac.f32.32x32x32.bf8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>

// CHECK: rocdl.smfmac.f32.32x32x32.fp8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
%r12 = rocdl.smfmac.f32.32x32x32.fp8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>

// CHECK: rocdl.smfmac.f32.32x32x32.fp8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
%r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>

llvm.return %r0 : vector<4 x f32>
}

llvm.func @rocdl.mfma.scale.f32.32x32x64.f8f6f4(%arg0 : i32,
%arg1 : vector<16 x f32>, %arg2 : vector<8xi32>,
%arg3 : vector<6xi32>, %arg4 : vector<4xi32>) {
Expand Down
89 changes: 89 additions & 0 deletions mlir/test/Target/LLVMIR/rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,95 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
llvm.return %r0 : vector<32 x f32>
}

llvm.func @rocdl.smfmac(%arg0 : i32,
%arg1 : vector<4 x f16>,
%arg2 : vector<8 x f16>,
%arg3 : vector<4 x f32>,
%arg4 : vector<16 x f32>,
%arg5 : vector<4 x i16>,
%arg6 : vector<8 x i16>,
%arg7 : vector<2xi32>,
%arg8 : vector<4xi32>,
%arg9 : vector<16xi32>) -> vector<4 x f32> {
%csti32 = llvm.mlir.constant(42 : i32) : i32

// CHECK-LABEL: rocdl.smfmac

// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
%r0 = rocdl.smfmac.f32.16x16x32.f16 %arg1, %arg2, %arg3, %csti32, %csti32, %csti32 :
(vector<4xf16>, vector<8xf16>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>

// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
%r1 = rocdl.smfmac.f32.32x32x16.f16 %arg1, %arg2, %arg4, %csti32, %csti32, %csti32 :
(vector<4xf16>, vector<8xf16>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>

// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.bf16(<4 x i16> %{{.*}}, <8 x i16> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
%r2 = rocdl.smfmac.f32.16x16x32.bf16 %arg5, %arg6, %arg3, %csti32, %csti32, %csti32 :
(vector<4xi16>, vector<8xi16>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>

// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.bf16(<4 x i16> %{{.*}}, <8 x i16> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
%r3 = rocdl.smfmac.f32.32x32x16.bf16 %arg5, %arg6, %arg4, %csti32, %csti32, %csti32 :
(vector<4xi16>, vector<8xi16>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>

// CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x64.i8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 42, i32 42)
%r4 = rocdl.smfmac.i32.16x16x64.i8 %arg7, %arg8, %arg8, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<4xi32>,
i32, i32, i32) -> vector<4xi32>

// CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x32.i8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 42, i32 42)
%r5 = rocdl.smfmac.i32.32x32x32.i8 %arg7, %arg8, %arg9, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<16xi32>,
i32, i32, i32) -> vector<16xi32>

// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
%r6 = rocdl.smfmac.f32.16x16x64.bf8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>

// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
%r7 = rocdl.smfmac.f32.16x16x64.bf8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>

// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.fp8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
%r8 = rocdl.smfmac.f32.16x16x64.fp8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>

// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
%r9 = rocdl.smfmac.f32.16x16x64.fp8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
i32, i32, i32) -> vector<4xf32>

// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
%r10 = rocdl.smfmac.f32.32x32x32.bf8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>

// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
%r11 = rocdl.smfmac.f32.32x32x32.bf8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>

// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
%r12 = rocdl.smfmac.f32.32x32x32.fp8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>


// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
%r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>

llvm.return %r0 : vector<4 x f32>
}


llvm.func @rocdl.mfma.scale.f32.32x32x64.f8f6f4(%arg0 : i32,
%arg1 : vector<16 x f32>, %arg2 : vector<8xi32>,
%arg3 : vector<6xi32>, %arg4 : vector<4xi32>) -> vector<16 x f32> {
Expand Down

0 comments on commit 43a50de

Please sign in to comment.