-
Notifications
You must be signed in to change notification settings - Fork 113
Closed
Description
Describe the bug
def test_sort_with_inf():
y = torch.tensor([[-5.4792, -8.0438, float("inf"), -float("inf"), -float("inf"), -float("inf")]],
dtype=torch.float16, device='cpu')
ref_value, ref_index = torch.sort(y, dim=1, descending=True)
print("PyTorch Sorted Values:", ref_value)
print("PyTorch Sorted Indices:", ref_index)
with flag_gems.use_gems():
res_value, res_index = torch.sort(y, dim=1, descending=True)
print("FlagGems Sorted Values:", res_value)
print("FlagGems Sorted Indices:", res_index)
result
PyTorch Sorted Values: tensor([[ inf, -5.4805, -8.0469, -inf, -inf, -inf]],
dtype=torch.float16)
PyTorch Sorted Indices: tensor([[2, 0, 1, 3, 4, 5]])
FlagGems Sorted Values: tensor([[-5.4805, -8.0469, nan, nan, nan, nan]],
dtype=torch.float16)
FlagGems Sorted Indices: tensor([[4, 5, 7, 6, 2, 3]])
Metadata
Metadata
Assignees
Labels
No labels