Skip to content

Fix argsort datatype bug #519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 20, 2025
Merged

Fix argsort datatype bug #519

merged 4 commits into from
Jun 20, 2025

Conversation

MARD1NO
Copy link
Collaborator

@MARD1NO MARD1NO commented Mar 27, 2025

对于padding的值我们会根据 desecend 参数给pad 当前输入数据类型的最大值 or 最小值。

#505

在原issue的case输入是int64,而后续kernel内部转成int32,导致问题

修复后结果正常:

image

@MARD1NO MARD1NO requested a review from StrongSpoon March 27, 2025 06:15
@Rugu7
Copy link
Contributor

Rugu7 commented Mar 31, 2025

int 8类型也有类似错误。使用下述补丁修复,无问题:

# git diff src/flag_gems/ops/topk.py
diff --git a/src/flag_gems/ops/topk.py b/src/flag_gems/ops/topk.py
index afbdd07..06ae3f6 100644
--- a/src/flag_gems/ops/topk.py
+++ b/src/flag_gems/ops/topk.py
@@ -23,6 +23,8 @@ _MIN_FLOAT16_VAL: tl.constexpr = torch.finfo(torch.float16).min
 _MAX_FLOAT16_VAL: tl.constexpr = torch.finfo(torch.float16).max
 _MIN_BFLOAT16_VAL: tl.constexpr = torch.finfo(torch.bfloat16).min
 _MAX_BFLOAT16_VAL: tl.constexpr = torch.finfo(torch.bfloat16).max
+_MIN_INT8_VAL: tl.constexpr = torch.iinfo(torch.int8).min
+_MAX_INT8_VAL: tl.constexpr = torch.iinfo(torch.int8).max
 _MIN_INT16_VAL: tl.constexpr = torch.iinfo(torch.int16).min
 _MAX_INT16_VAL: tl.constexpr = torch.iinfo(torch.int16).max
 _MIN_INT32_VAL: tl.constexpr = torch.iinfo(torch.int32).min
@@ -58,7 +60,12 @@ def _get_iinfo_val(
     dtype,
     return_max,
 ):
-    if dtype is tl.int16:
+    if dtype is tl.int8:
+        if return_max:
+            return _MAX_INT8_VAL
+        else:
+            return _MIN_INT8_VAL
+    elif dtype is tl.int16:
         if return_max:
             return _MAX_INT16_VAL
         else:

@StrongSpoon
Copy link
Collaborator

StrongSpoon commented Mar 31, 2025

Feiyu implemented functions get_dtype_max and get_dtype_min in file limits.py and I think you could consider reusing them.

@MARD1NO
Copy link
Collaborator Author

MARD1NO commented Mar 31, 2025

Feiyu implemented functions get_dtype_max and get_dtype_min in file limits.py and I think you could consider reusing them.

yes, I reuse it to simplify get_iinfo_val function,

StrongSpoon
StrongSpoon previously approved these changes Apr 28, 2025
Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@StrongSpoon StrongSpoon self-assigned this Jun 20, 2025
Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@StrongSpoon StrongSpoon merged commit 18347b3 into master Jun 20, 2025
9 of 14 checks passed
@StrongSpoon StrongSpoon deleted the fix_argsort_bug branch June 20, 2025 09:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants