Skip to content

Commit 1473fe4

Browse files
Fix signed version of radix_sort (#3724)
* Fix signed version of radix_sort * Wrong operation * abs * All signs? * wrong test... * |INT_MIN| > |INT_MAX| fix * Finalize * Switch to bit operation * Tidy code path * Move projection out of funciton and rely on compile time recursion * Optimize for first bit --------- Co-authored-by: Jørgen Schartum Dokken <[email protected]>
1 parent 92e1890 commit 1473fe4

File tree

2 files changed

+94
-16
lines changed

2 files changed

+94
-16
lines changed

cpp/dolfinx/common/sort.h

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (C) 2021 Igor Baratta
1+
// Copyright (C) 2021-2025 Igor Baratta and Paul T. Kühner
22
//
33
// This file is part of DOLFINx (https://www.fenicsproject.org)
44
//
@@ -7,11 +7,13 @@
77
#pragma once
88

99
#include <algorithm>
10+
#include <bit>
1011
#include <cassert>
1112
#include <concepts>
1213
#include <cstdint>
1314
#include <functional>
1415
#include <iterator>
16+
#include <limits>
1517
#include <numeric>
1618
#include <span>
1719
#include <type_traits>
@@ -20,6 +22,34 @@
2022

2123
namespace dolfinx
2224
{
25+
26+
struct __unsigned_projection
27+
{
28+
// Transforms the projected value to an unsigned int (if signed), while
29+
// maintaining relative order by
30+
// x ↦ x + |std::numeric_limits<I>::min()|
31+
template <std::signed_integral T>
32+
constexpr std::make_unsigned_t<T> operator()(T e) const noexcept
33+
{
34+
using uT = std::make_unsigned_t<T>;
35+
36+
// Assert binary structure for bit shift
37+
static_assert(static_cast<uT>(std::numeric_limits<T>::min())
38+
+ static_cast<uT>(std::numeric_limits<T>::max())
39+
== static_cast<uT>(T(-1)));
40+
static_assert(std::numeric_limits<uT>::digits
41+
== std::numeric_limits<T>::digits + 1);
42+
static_assert(std::bit_cast<uT>(std::numeric_limits<T>::min())
43+
== (uT(1) << (sizeof(T) * 8 - 1)));
44+
45+
return std::bit_cast<uT>(std::forward<T>(e))
46+
^ (uT(1) << (sizeof(T) * 8 - 1));
47+
}
48+
};
49+
50+
/// Projection from signed to signed int
51+
inline constexpr __unsigned_projection unsigned_projection{};
52+
2353
struct __radix_sort
2454
{
2555
/// @brief Sort a range with radix sorting algorithm. The bucket size
@@ -46,10 +76,11 @@ struct __radix_sort
4676
/// @tparam BITS The number of bits to sort at a time.
4777
/// @param[in, out] range The range to sort.
4878
/// @param[in] P Element projection.
49-
template <
50-
std::ranges::random_access_range R, typename P = std::identity,
51-
std::remove_cvref_t<std::invoke_result_t<P, std::iter_value_t<R>>> BITS
52-
= 8>
79+
template <std::ranges::random_access_range R, typename P = std::identity,
80+
std::make_unsigned_t<std::remove_cvref_t<
81+
std::invoke_result_t<P, std::iter_value_t<R>>>>
82+
BITS
83+
= 8>
5384
requires std::integral<decltype(BITS)>
5485
constexpr void operator()(R&& range, P proj = {}) const
5586
{
@@ -58,19 +89,36 @@ struct __radix_sort
5889

5990
// index type (if no projection is provided it holds I == T)
6091
using I = std::remove_cvref_t<std::invoke_result_t<P, T>>;
92+
using uI = std::make_unsigned_t<I>;
93+
94+
if constexpr (!std::is_same_v<uI, I>)
95+
{
96+
__radix_sort()(std::forward<R>(range), [&](const T& e) -> uI
97+
{ return unsigned_projection(proj(e)); });
98+
return;
99+
}
61100

62101
if (range.size() <= 1)
63102
return;
64103

65-
T max_value = proj(*std::ranges::max_element(range, std::less{}, proj));
104+
uI max_value = proj(*std::ranges::max_element(range, std::less{}, proj));
66105

67106
// Sort N bits at a time
68-
constexpr I bucket_size = 1 << BITS;
69-
T mask = (T(1) << BITS) - 1;
107+
constexpr uI bucket_size = 1 << BITS;
108+
uI mask = (uI(1) << BITS) - 1;
70109

71110
// Compute number of iterations, most significant digit (N bits) of
72111
// maxvalue
73112
I its = 0;
113+
114+
// optimize for case where all first bits are set - then order will not
115+
// depend on it
116+
bool all_first_bit = std::ranges::all_of(
117+
range, [&](const auto& e)
118+
{ return proj(e) & (uI(1) << (sizeof(uI) * 8 - 1)); });
119+
if (all_first_bit)
120+
max_value = max_value & ~(uI(1) << (sizeof(uI) * 8 - 1));
121+
74122
while (max_value)
75123
{
76124
max_value >>= BITS;
@@ -81,7 +129,7 @@ struct __radix_sort
81129
std::array<I, bucket_size> counter;
82130
std::array<I, bucket_size + 1> offset;
83131

84-
I mask_offset = 0;
132+
uI mask_offset = 0;
85133
std::vector<T> buffer(range.size());
86134
std::span<T> current_perm = range;
87135
std::span<T> next_perm = buffer;
@@ -100,8 +148,8 @@ struct __radix_sort
100148
std::next(offset.begin()));
101149
for (const auto& c : current_perm)
102150
{
103-
I bucket = (proj(c) & mask) >> mask_offset;
104-
I new_pos = offset[bucket + 1] - counter[bucket];
151+
uI bucket = (proj(c) & mask) >> mask_offset;
152+
uI new_pos = offset[bucket + 1] - counter[bucket];
105153
next_perm[new_pos] = c;
106154
counter[bucket]--;
107155
}

cpp/test/common/sort.cpp

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,35 @@
1-
// Copyright (C) 2021 Igor A. Baratta
1+
// Copyright (C) 2021-2025 Igor A. Baratta and Paul T. Kühner
22
//
33
// This file is part of DOLFINx (https://www.fenicsproject.org)
44
//
55
// SPDX-License-Identifier: LGPL-3.0-or-later
66

77
#include <algorithm>
88
#include <array>
9+
#include <bitset>
910
#include <catch2/catch_template_test_macros.hpp>
1011
#include <catch2/catch_test_macros.hpp>
1112
#include <catch2/generators/catch_generators.hpp>
13+
#include <cstdint>
1214
#include <dolfinx/common/sort.h>
1315
#include <functional>
16+
#include <iostream>
17+
#include <limits>
1418
#include <numeric>
1519
#include <random>
20+
#include <type_traits>
1621
#include <vector>
1722

18-
TEMPLATE_TEST_CASE("Test radix sort", "[vector][template]", std::int32_t,
19-
std::int64_t)
23+
TEMPLATE_TEST_CASE("Test radix sort", "[vector][template]", std::int16_t,
24+
std::int32_t, std::int64_t, std::uint16_t, std::uint32_t,
25+
std::uint64_t)
2026
{
2127
auto vec_size = GENERATE(100, 1000, 10000);
2228
std::vector<TestType> vec;
2329
vec.reserve(vec_size);
2430

2531
// Generate a vector of ints with a Uniform Int distribution
26-
std::uniform_int_distribution<TestType> distribution(0, 10000);
32+
std::uniform_int_distribution<TestType> distribution(-10000, 10000);
2733
std::mt19937 engine;
2834
auto generator = std::bind(distribution, engine);
2935
std::generate_n(std::back_inserter(vec), vec_size, generator);
@@ -35,12 +41,31 @@ TEMPLATE_TEST_CASE("Test radix sort", "[vector][template]", std::int32_t,
3541
REQUIRE(std::ranges::is_sorted(vec));
3642
}
3743

44+
TEMPLATE_TEST_CASE("Test radix sort (limits)", "[vector][template]",
45+
std::int16_t, std::int32_t, std::int64_t, std::uint16_t,
46+
std::uint32_t, std::uint64_t)
47+
{
48+
std::vector<TestType> vec{0, std::numeric_limits<TestType>::max(),
49+
std::numeric_limits<TestType>::min()};
50+
dolfinx::radix_sort(vec);
51+
REQUIRE(std::ranges::is_sorted(vec));
52+
REQUIRE(std::ranges::equal(
53+
vec, std::vector<TestType>{std::numeric_limits<TestType>::min(), 0,
54+
std::numeric_limits<TestType>::max()}));
55+
}
56+
3857
TEMPLATE_TEST_CASE("Test radix sort (projection)", "[radix]", std::int16_t,
39-
std::int32_t, std::int64_t)
58+
std::int32_t, std::int64_t, std::uint16_t, std::uint32_t,
59+
std::uint64_t)
4060
{
4161
// Check projection into same type array
4262
{
4363
std::vector<TestType> vec = {3, 6, 2, 1, 5, 4, 0};
64+
if constexpr (std::is_signed_v<TestType>)
65+
{
66+
vec[1] *= -1;
67+
vec[4] *= -1;
68+
}
4469
std::vector<TestType> indices(vec.size());
4570
std::iota(indices.begin(), indices.end(), 0);
4671

@@ -53,6 +78,11 @@ TEMPLATE_TEST_CASE("Test radix sort (projection)", "[radix]", std::int16_t,
5378
{
5479
std::vector<std::array<TestType, 1>> vec_array{{3}, {6}, {2}, {1},
5580
{5}, {4}, {0}};
81+
if constexpr (std::is_signed_v<TestType>)
82+
{
83+
vec_array[1][0] *= -1;
84+
vec_array[4][0] *= -1;
85+
}
5686
std::vector<TestType> indices(vec_array.size());
5787
std::iota(indices.begin(), indices.end(), 0);
5888

0 commit comments

Comments
 (0)