``` >>> import torch >>> topk_ids = torch.tensor([[16, 17, 18, 19], ... [16, 17, 18, 19], ... [16, 17, 18, 19], ... [16, 17, 18, 19], ... [16, 17, 18, 19], ... [16, 17, 18, 19]], device='cuda:0') >>> >>> topk_ids.view(-1).argsort() tensor([16, 20, 8, 12, 4, 0, 21, 17, 13, 9, 1, 5, 18, 22, 10, 14, 6, 2, 3, 7, 15, 11, 23, 19], device='cuda:0') >>> import flag_gems >>> flag_gems.enable() >>> topk_ids.view(-1).argsort() tensor([24, 25, 27, 26, 30, 31, 29, 28, 0, 4, 12, 8, 20, 16, 9, 13, 17, 21, 5, 1, 2, 6, 22, 18], device='cuda:0') ```