-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sample argmax kernel for a single subgroup #47
Conversation
Can't seem to add a reviewer, so @raikonenfnu |
@qedawkins @antiagainst ISA for the subgroup ops from llpc: Command:
llvm IR for BB4: ._crit_edge: ; preds = %.lr.ph, %.entry
%laneMax.0.lcssa = phi float [ %10, %.entry ], [ %17, %.lr.ph ]
%laneResult.0.lcssa = phi i32 [ 0, %.entry ], [ %19, %.lr.ph ]
%21 = bitcast float %laneMax.0.lcssa to i32
%22 = call i32 @llvm.amdgcn.set.inactive.i32(i32 %21, i32 -8388608)
%23 = bitcast i32 %22 to float
%24 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %22, i32 177, i32 15, i32 15, i1 true)
%25 = bitcast i32 %24 to float
%26 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %23, float %25)
%27 = bitcast float %26 to i32
%28 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %27, i32 78, i32 15, i32 15, i1 true)
%29 = bitcast i32 %28 to float
%30 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %26, float %29)
%31 = bitcast float %30 to i32
%32 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %31, i32 321, i32 15, i32 15, i1 true)
%33 = bitcast i32 %32 to float
%34 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %30, float %33)
%35 = bitcast float %34 to i32
%36 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %35, i32 320, i32 15, i32 15, i1 true)
%37 = bitcast i32 %36 to float
%38 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %34, float %37)
%39 = bitcast float %38 to i32
%40 = call i32 @llvm.amdgcn.permlanex16(i32 undef, i32 %39, i32 -1, i32 -1, i1 true, i1 false)
%41 = bitcast i32 %40 to float
%42 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %38, float %41)
%43 = bitcast float %42 to i32
%44 = call i32 @llvm.amdgcn.permlane64(i32 %43)
%45 = bitcast i32 %44 to float
%46 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %42, float %45)
%47 = bitcast float %46 to i32
%48 = call i32 @llvm.amdgcn.wwm.i32(i32 %47)
%49 = bitcast i32 %48 to float
%50 = fcmp oeq float %laneMax.0.lcssa, %49
%51 = call i64 @llvm.amdgcn.ballot.i64(i1 %50)
%52 = call i64 @llvm.cttz.i64(i64 %51, i1 true), !range !8
%.fr1 = freeze i64 %52
%53 = trunc i64 %.fr1 to i32
%54 = icmp eq i32 %8, %53
br i1 %54, label %55, label %56 |
ISA for gfx90:
llvm IR for BB4: ._crit_edge: ; preds = %.lr.ph, %.entry
%laneMax.0.lcssa = phi float [ %9, %.entry ], [ %16, %.lr.ph ]
%laneResult.0.lcssa = phi i32 [ 0, %.entry ], [ %18, %.lr.ph ]
%20 = bitcast float %laneMax.0.lcssa to i32
%21 = call i32 @llvm.amdgcn.set.inactive.i32(i32 %20, i32 -8388608)
%22 = bitcast i32 %21 to float
%23 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %21, i32 177, i32 15, i32 15, i1 true)
%24 = bitcast i32 %23 to float
%25 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %22, float %24)
%26 = bitcast float %25 to i32
%27 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %26, i32 78, i32 15, i32 15, i1 true)
%28 = bitcast i32 %27 to float
%29 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %25, float %28)
%30 = bitcast float %29 to i32
%31 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %30, i32 321, i32 15, i32 15, i1 true)
%32 = bitcast i32 %31 to float
%33 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %29, float %32)
%34 = bitcast float %33 to i32
%35 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %34, i32 320, i32 15, i32 15, i1 true)
%36 = bitcast i32 %35 to float
%37 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %33, float %36)
%38 = bitcast float %37 to i32
%39 = call i32 @llvm.amdgcn.update.dpp.i32(i32 -8388608, i32 %38, i32 322, i32 10, i32 15, i1 true)
%40 = bitcast i32 %39 to float
%41 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %37, float %40)
%42 = bitcast float %41 to i32
%43 = call i32 @llvm.amdgcn.update.dpp.i32(i32 -8388608, i32 %42, i32 323, i32 8, i32 15, i1 true)
%44 = bitcast i32 %43 to float
%45 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %41, float %44)
%46 = bitcast float %45 to i32
%47 = call i32 @llvm.amdgcn.readlane(i32 %46, i32 63)
%48 = call i32 @llvm.amdgcn.wwm.i32(i32 %47)
%49 = bitcast i32 %48 to float
%50 = fcmp oeq float %laneMax.0.lcssa, %49
%51 = call i64 @llvm.amdgcn.ballot.i64(i1 %50)
%52 = call i64 @llvm.cttz.i64(i64 %51, i1 true), !range !8
%.fr1 = freeze i64 %52
%53 = trunc i64 %.fr1 to i32
%54 = icmp eq i32 %LocalInvocationId.i0, %53
br i1 %54, label %55, label %56 |
Awesome, thanks Quinn and Jakub! We can get these intrinsics via HIP (courtesy from ChatGPT): For GLSL __inline__ __device__ float warpMax(float val) {
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
val = max(val, __shfl_down(val, offset));
}
return val;
} For GLSL For GLSL |
uint laneCount = gl_WorkGroupSize.x; | ||
|
||
float16_t laneMax = Input.data[laneID]; | ||
uint laneResult = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not initialize the laneResult with laneID?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch, this should be laneID
// Final reduction with one subgroup | ||
float16_t wgMax = subgroupMax(laneMax); | ||
|
||
bool eq = wgMax == laneMax; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a reason why we cannot just use if (wgMax == laneMax) Output.data = uvec2(laneResult, upper32bits);
directly rather than checking wgMax == laneMax
for the ballot and then checking result from ballot?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The semantics of argmax require the smallest index to be returned in case there are multiple maximum elements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahhh, that makes sense! :)
Closing due to inactivity |
This PR is based on #47. I opened a new one because the old one got stale.
(not expecting a review yet, this is still a draft)