Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sample argmax kernel for a single subgroup #47

Closed
wants to merge 1 commit into from

Conversation

qedawkins
Copy link
Collaborator

@qedawkins qedawkins commented Nov 30, 2023

(not expecting a review yet, this is still a draft)

@qedawkins
Copy link
Collaborator Author

Can't seem to add a reviewer, so @raikonenfnu

@kuhar
Copy link
Collaborator

kuhar commented Nov 30, 2023

@qedawkins @antiagainst ISA for the subgroup ops from llpc:

Command: amdllpc one_workgroup_argmax_subgroup_f32.comp -o /dev/null --gfxip=11.0 -v
Assembly:

.AMDGPU.disasm (size = 6467 bytes)
_amdgpu_cs_main:
BB0_0:
	s_getpc_b64 s[4:5]                                                                   ; BE844700
	s_mov_b32 s0, s1                                                                     ; BE800001
	s_mov_b32 s1, s5                                                                     ; BE810005
	v_and_b32_e32 v3, 0x3ff, v0                                                          ; 360600FF 000003FF
	s_load_b256 s[4:11], s[0:1], 0x0                                                     ; F40C0100 F8000000
	v_mov_b32_e32 v0, 0                                                                  ; 7E000280
	s_cmpk_lt_u32 s2, 0x80                                                               ; B6820080
	s_delay_alu instid0(VALU_DEP_2)                                                      ; BF870002
	v_lshlrev_b32_e32 v5, 2, v3                                                          ; 300A0682
	s_waitcnt lgkmcnt(0)                                                                 ; BF89FC07
	buffer_load_b32 v4, v5, s[4:7], 0 offen                                              ; E0500000 80410405
	s_cbranch_scc1 .LBB0_3                                                               ; BFA20000
	v_add_nc_u32_e32 v5, 0x100, v5                                                       ; 4A0A0AFF 00000100
	v_add_nc_u32_e32 v6, 64, v3                                                          ; 4A0C06C0
	v_mov_b32_e32 v0, 0                                                                  ; 7E000280
	s_lshr_b32 s0, s2, 6                                                                 ; 85008602
	s_delay_alu instid0(SALU_CYCLE_1)                                                    ; BF870009
	s_add_i32 s0, s0, -1                                                                 ; 8100C100
BB0_2:
	buffer_load_b32 v7, v5, s[4:7], 0 offen                                              ; E0500000 80410705
	s_waitcnt vmcnt(1)                                                                   ; BF8907F7
	v_mov_b32_e32 v8, v4                                                                 ; 7E100304
	v_add_nc_u32_e32 v5, 0x100, v5                                                       ; 4A0A0AFF 00000100
	s_add_i32 s0, s0, -1                                                                 ; 8100C100
	s_delay_alu instid0(SALU_CYCLE_1)                                                    ; BF870009
	s_cmp_lg_u32 s0, 0                                                                   ; BF078000
	s_waitcnt vmcnt(0)                                                                   ; BF8903F7
	v_cmp_lt_f32_e32 vcc, v8, v7                                                         ; 7C220F08
	v_max_f32_e32 v4, v8, v7                                                             ; 20080F08
	v_cndmask_b32_e32 v0, v0, v6, vcc                                                    ; 02000D00
	v_add_nc_u32_e32 v6, 64, v6                                                          ; 4A0C0CC0
	s_cbranch_scc1 .LBB0_2                                                               ; BFA20000
BB0_3:
	s_waitcnt vmcnt(0)                                                                   ; BF8903F7
	v_mov_b32_e32 v1, v4                                                                 ; 7E020304
	s_not_b64 exec, exec                                                                 ; BEFE1F7E
	v_mov_b32_e32 v1, 0xff800000                                                         ; 7E0202FF FF800000
	s_not_b64 exec, exec                                                                 ; BEFE1F7E
	s_or_saveexec_b64 s[0:1], -1                                                         ; BE8023C1
	s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)               ; BF870091
	v_max_f32_dpp v1, v1, v1 quad_perm:[1,0,3,2] row_mask:0xf bank_mask:0xf bound_ctrl:1 ; 200202FA FF08B101
	v_max_f32_dpp v1, v1, v1 quad_perm:[2,3,0,1] row_mask:0xf bank_mask:0xf bound_ctrl:1 ; 200202FA FF084E01
	s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)               ; BF870091
	v_max_f32_dpp v1, v1, v1 row_half_mirror row_mask:0xf bank_mask:0xf bound_ctrl:1     ; 200202FA FF094101
	v_max_f32_dpp v1, v1, v1 row_mirror row_mask:0xf bank_mask:0xf bound_ctrl:1          ; 200202FA FF094001
	s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)               ; BF870091
	v_permlanex16_b32 v2, v1, -1, -1 op_sel:[1,0]                                        ; D65C0802 03058301
	v_max_f32_e32 v1, v1, v2                                                             ; 20020501
	s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)               ; BF870091
	v_permlane64_b32 v2, v1                                                              ; 7E04CF01
	v_max_f32_e32 v1, v1, v2                                                             ; 20020501
	s_mov_b64 exec, s[0:1]                                                               ; BEFE0100
	s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)               ; BF870091
	v_mov_b32_e32 v5, v1                                                                 ; 7E0A0301
	v_cmp_eq_f32_e32 vcc, v4, v5                                                         ; 7C240B04
	s_ctz_i32_b32 s0, vcc_hi                                                             ; BE80086B
	s_ctz_i32_b32 s1, vcc_lo                                                             ; BE81086A
	s_add_i32 s0, s0, 32                                                                 ; 8100A000
	s_delay_alu instid0(SALU_CYCLE_1) | instskip(NEXT) | instid1(SALU_CYCLE_1)           ; BF870499
	s_min_u32 s0, s1, s0                                                                 ; 89800001
	v_cmp_eq_u32_e32 vcc, s0, v3                                                         ; 7C940600
	s_and_saveexec_b64 s[0:1], vcc                                                       ; BE80216A
	s_cbranch_execz .LBB0_5                                                              ; BFA50000
	buffer_store_b32 v0, off, s[8:11], 0                                                 ; E0680000 80020000
BB0_5:
	s_nop 0                                                                              ; BF800000
	s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)                                                 ; BFB60003
	s_endpgm                                                                             ; BFB00000

llvm IR for BB4:

._crit_edge:                                      ; preds = %.lr.ph, %.entry
  %laneMax.0.lcssa = phi float [ %10, %.entry ], [ %17, %.lr.ph ]
  %laneResult.0.lcssa = phi i32 [ 0, %.entry ], [ %19, %.lr.ph ]
  %21 = bitcast float %laneMax.0.lcssa to i32
  %22 = call i32 @llvm.amdgcn.set.inactive.i32(i32 %21, i32 -8388608)
  %23 = bitcast i32 %22 to float
  %24 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %22, i32 177, i32 15, i32 15, i1 true)
  %25 = bitcast i32 %24 to float
  %26 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %23, float %25)
  %27 = bitcast float %26 to i32
  %28 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %27, i32 78, i32 15, i32 15, i1 true)
  %29 = bitcast i32 %28 to float
  %30 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %26, float %29)
  %31 = bitcast float %30 to i32
  %32 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %31, i32 321, i32 15, i32 15, i1 true)
  %33 = bitcast i32 %32 to float
  %34 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %30, float %33)
  %35 = bitcast float %34 to i32
  %36 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %35, i32 320, i32 15, i32 15, i1 true)
  %37 = bitcast i32 %36 to float
  %38 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %34, float %37)
  %39 = bitcast float %38 to i32
  %40 = call i32 @llvm.amdgcn.permlanex16(i32 undef, i32 %39, i32 -1, i32 -1, i1 true, i1 false)
  %41 = bitcast i32 %40 to float
  %42 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %38, float %41)
  %43 = bitcast float %42 to i32
  %44 = call i32 @llvm.amdgcn.permlane64(i32 %43)
  %45 = bitcast i32 %44 to float
  %46 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %42, float %45)
  %47 = bitcast float %46 to i32
  %48 = call i32 @llvm.amdgcn.wwm.i32(i32 %47)
  %49 = bitcast i32 %48 to float
  %50 = fcmp oeq float %laneMax.0.lcssa, %49
  %51 = call i64 @llvm.amdgcn.ballot.i64(i1 %50)
  %52 = call i64 @llvm.cttz.i64(i64 %51, i1 true), !range !8
  %.fr1 = freeze i64 %52
  %53 = trunc i64 %.fr1 to i32
  %54 = icmp eq i32 %8, %53
  br i1 %54, label %55, label %56

@kuhar
Copy link
Collaborator

kuhar commented Nov 30, 2023

ISA for gfx90:

_amdgpu_cs_main:
BB0_0:
	s_getpc_b64 s[4:5]                                                                   ; BE841C00
	s_mov_b32 s0, s1                                                                     ; BE800001
	s_mov_b32 s1, s5                                                                     ; BE810005
	s_load_dwordx8 s[4:11], s[0:1], 0x0                                                  ; C00E0100 00000000
	v_lshlrev_b32_e32 v6, 2, v0                                                          ; 240C0082
	s_cmpk_lt_u32 s2, 0x80                                                               ; B6020080
	v_mov_b32_e32 v4, 0                                                                  ; 7E080280
	s_waitcnt lgkmcnt(0)                                                                 ; BF8CC07F
	buffer_load_dword v5, v6, s[4:7], 0 offen                                            ; E0501000 80010506
	s_cbranch_scc1 .LBB0_3                                                               ; BF850000
	s_lshr_b32 s0, s2, 6                                                                 ; 8F008602
	s_add_i32 s0, s0, -1                                                                 ; 8100C100
	v_add_u32_e32 v6, 0x100, v6                                                          ; 680C0CFF 00000100
	v_add_u32_e32 v7, 64, v0                                                             ; 680E00C0
	v_mov_b32_e32 v4, 0                                                                  ; 7E080280
BB0_2:
	buffer_load_dword v8, v6, s[4:7], 0 offen                                            ; E0501000 80010806
	s_waitcnt vmcnt(1)                                                                   ; BF8C0F71
	v_mov_b32_e32 v9, v5                                                                 ; 7E120305
	s_add_i32 s0, s0, -1                                                                 ; 8100C100
	v_add_u32_e32 v6, 0x100, v6                                                          ; 680C0CFF 00000100
	s_cmp_lg_u32 s0, 0                                                                   ; BF078000
	s_waitcnt vmcnt(0)                                                                   ; BF8C0F70
	v_cmp_lt_f32_e32 vcc, v9, v8                                                         ; 7C821109
	v_max_f32_e32 v5, v9, v8                                                             ; 160A1109
	v_cndmask_b32_e32 v4, v4, v7, vcc                                                    ; 00080F04
	v_add_u32_e32 v7, 64, v7                                                             ; 680E0EC0
	s_cbranch_scc1 .LBB0_2                                                               ; BF850000
BB0_3:
	s_or_saveexec_b64 s[0:1], -1                                                         ; BE8021C1
	v_mov_b32_e32 v1, 0xff800000                                                         ; 7E0202FF FF800000
	s_mov_b64 exec, s[0:1]                                                               ; BEFE0100
	s_waitcnt vmcnt(0)                                                                   ; BF8C0F70
	v_mov_b32_e32 v2, v5                                                                 ; 7E040305
	s_not_b64 exec, exec                                                                 ; BEFE057E
	v_mov_b32_e32 v2, 0xff800000                                                         ; 7E0402FF FF800000
	s_not_b64 exec, exec                                                                 ; BEFE057E
	s_or_saveexec_b64 s[0:1], -1                                                         ; BE8021C1
	v_max_f32_dpp v2, v2, v2 quad_perm:[1,0,3,2] row_mask:0xf bank_mask:0xf bound_ctrl:1 ; 160404FA FF08B102
	v_mov_b32_e32 v3, 0xff800000                                                         ; 7E0602FF FF800000
	s_nop 0                                                                              ; BF800000
	v_max_f32_dpp v2, v2, v2 quad_perm:[2,3,0,1] row_mask:0xf bank_mask:0xf bound_ctrl:1 ; 160404FA FF084E02
	s_nop 1                                                                              ; BF800001
	v_max_f32_dpp v2, v2, v2 row_half_mirror row_mask:0xf bank_mask:0xf bound_ctrl:1     ; 160404FA FF094102
	s_nop 1                                                                              ; BF800001
	v_max_f32_dpp v2, v2, v2 row_mirror row_mask:0xf bank_mask:0xf bound_ctrl:1          ; 160404FA FF094002
	s_nop 1                                                                              ; BF800001
	v_mov_b32_dpp v3, v2 row_bcast:15 row_mask:0xa bank_mask:0xf bound_ctrl:1            ; 7E0602FA AF094202
	v_max_f32_e32 v2, v2, v3                                                             ; 16040702
	s_nop 1                                                                              ; BF800001
	v_mov_b32_dpp v1, v2 row_bcast:31 row_mask:0x8 bank_mask:0xf bound_ctrl:1            ; 7E0202FA 8F094302
	v_max_f32_e32 v1, v2, v1                                                             ; 16020302
	v_readlane_b32 s2, v1, 63                                                            ; D2890002 00017F01
	s_mov_b64 exec, s[0:1]                                                               ; BEFE0100
	v_cmp_eq_f32_e32 vcc, s2, v5                                                         ; 7C840A02
	s_ff1_i32_b32 s0, vcc_hi                                                             ; BE80106B
	s_add_i32 s0, s0, 32                                                                 ; 8100A000
	s_ff1_i32_b32 s1, vcc_lo                                                             ; BE81106A
	s_min_u32 s0, s1, s0                                                                 ; 83800001
	v_cmp_eq_u32_e32 vcc, s0, v0                                                         ; 7D940000
	s_and_saveexec_b64 s[0:1], vcc                                                       ; BE80206A
	s_cbranch_execz .LBB0_5                                                              ; BF880000
	buffer_store_dword v4, off, s[8:11], 0                                               ; E0700000 80020400
BB0_5:
	s_endpgm                                                                             ; BF810000

llvm IR for BB4:

._crit_edge:                                      ; preds = %.lr.ph, %.entry
  %laneMax.0.lcssa = phi float [ %9, %.entry ], [ %16, %.lr.ph ]
  %laneResult.0.lcssa = phi i32 [ 0, %.entry ], [ %18, %.lr.ph ]
  %20 = bitcast float %laneMax.0.lcssa to i32
  %21 = call i32 @llvm.amdgcn.set.inactive.i32(i32 %20, i32 -8388608)
  %22 = bitcast i32 %21 to float
  %23 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %21, i32 177, i32 15, i32 15, i1 true)
  %24 = bitcast i32 %23 to float
  %25 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %22, float %24)
  %26 = bitcast float %25 to i32
  %27 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %26, i32 78, i32 15, i32 15, i1 true)
  %28 = bitcast i32 %27 to float
  %29 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %25, float %28)
  %30 = bitcast float %29 to i32
  %31 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %30, i32 321, i32 15, i32 15, i1 true)
  %32 = bitcast i32 %31 to float
  %33 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %29, float %32)
  %34 = bitcast float %33 to i32
  %35 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %34, i32 320, i32 15, i32 15, i1 true)
  %36 = bitcast i32 %35 to float
  %37 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %33, float %36)
  %38 = bitcast float %37 to i32
  %39 = call i32 @llvm.amdgcn.update.dpp.i32(i32 -8388608, i32 %38, i32 322, i32 10, i32 15, i1 true)
  %40 = bitcast i32 %39 to float
  %41 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %37, float %40)
  %42 = bitcast float %41 to i32
  %43 = call i32 @llvm.amdgcn.update.dpp.i32(i32 -8388608, i32 %42, i32 323, i32 8, i32 15, i1 true)
  %44 = bitcast i32 %43 to float
  %45 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %41, float %44)
  %46 = bitcast float %45 to i32
  %47 = call i32 @llvm.amdgcn.readlane(i32 %46, i32 63)
  %48 = call i32 @llvm.amdgcn.wwm.i32(i32 %47)
  %49 = bitcast i32 %48 to float
  %50 = fcmp oeq float %laneMax.0.lcssa, %49
  %51 = call i64 @llvm.amdgcn.ballot.i64(i1 %50)
  %52 = call i64 @llvm.cttz.i64(i64 %51, i1 true), !range !8
  %.fr1 = freeze i64 %52
  %53 = trunc i64 %.fr1 to i32
  %54 = icmp eq i32 %LocalInvocationId.i0, %53
  br i1 %54, label %55, label %56

@antiagainst
Copy link
Collaborator

antiagainst commented Nov 30, 2023

Awesome, thanks Quinn and Jakub! We can get these intrinsics via HIP (courtesy from ChatGPT):

For GLSL subgroupMax:

__inline__ __device__ float warpMax(float val) {
    for (int offset = warpSize / 2; offset > 0; offset /= 2) {
        val = max(val, __shfl_down(val, offset));
    }
    return val;
}

For GLSL subgroupBallot, we can use __ballot.

For GLSL subgroupBallotFindLSB, we can use __ffsll I think.

uint laneCount = gl_WorkGroupSize.x;

float16_t laneMax = Input.data[laneID];
uint laneResult = 0;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not initialize the laneResult with laneID?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, this should be laneID

// Final reduction with one subgroup
float16_t wgMax = subgroupMax(laneMax);

bool eq = wgMax == laneMax;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason why we cannot just use if (wgMax == laneMax) Output.data = uvec2(laneResult, upper32bits); directly rather than checking wgMax == laneMax for the ballot and then checking result from ballot?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The semantics of argmax require the smallest index to be returned in case there are multiple maximum elements.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahhh, that makes sense! :)

@kuhar
Copy link
Collaborator

kuhar commented Aug 1, 2024

Closing due to inactivity

@kuhar kuhar closed this Aug 1, 2024
kuhar pushed a commit that referenced this pull request Aug 2, 2024
This PR is based on #47. I opened a new one because the old one got stale.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants