1
- // Copyright (C) 2021 Igor Baratta
1
+ // Copyright (C) 2021-2025 Igor Baratta and Paul T. Kühner
2
2
//
3
3
// This file is part of DOLFINx (https://www.fenicsproject.org)
4
4
//
7
7
#pragma once
8
8
9
9
#include < algorithm>
10
+ #include < bit>
10
11
#include < cassert>
11
12
#include < concepts>
12
13
#include < cstdint>
13
14
#include < functional>
14
15
#include < iterator>
16
+ #include < limits>
15
17
#include < numeric>
16
18
#include < span>
17
19
#include < type_traits>
20
22
21
23
namespace dolfinx
22
24
{
25
+
26
+ struct __unsigned_projection
27
+ {
28
+ // Transforms the projected value to an unsigned int (if signed), while
29
+ // maintaining relative order by
30
+ // x ↦ x + |std::numeric_limits<I>::min()|
31
+ template <std::signed_integral T>
32
+ constexpr std::make_unsigned_t <T> operator ()(T e) const noexcept
33
+ {
34
+ using uT = std::make_unsigned_t <T>;
35
+
36
+ // Assert binary structure for bit shift
37
+ static_assert (static_cast <uT>(std::numeric_limits<T>::min ())
38
+ + static_cast <uT>(std::numeric_limits<T>::max ())
39
+ == static_cast <uT>(T (-1 )));
40
+ static_assert (std::numeric_limits<uT>::digits
41
+ == std::numeric_limits<T>::digits + 1 );
42
+ static_assert (std::bit_cast<uT>(std::numeric_limits<T>::min ())
43
+ == (uT (1 ) << (sizeof (T) * 8 - 1 )));
44
+
45
+ return std::bit_cast<uT>(std::forward<T>(e))
46
+ ^ (uT (1 ) << (sizeof (T) * 8 - 1 ));
47
+ }
48
+ };
49
+
50
+ // / Projection from signed to signed int
51
+ inline constexpr __unsigned_projection unsigned_projection{};
52
+
23
53
struct __radix_sort
24
54
{
25
55
// / @brief Sort a range with radix sorting algorithm. The bucket size
@@ -46,10 +76,11 @@ struct __radix_sort
46
76
// / @tparam BITS The number of bits to sort at a time.
47
77
// / @param[in, out] range The range to sort.
48
78
// / @param[in] P Element projection.
49
- template <
50
- std::ranges::random_access_range R, typename P = std::identity,
51
- std::remove_cvref_t <std::invoke_result_t <P, std::iter_value_t <R>>> BITS
52
- = 8 >
79
+ template <std::ranges::random_access_range R, typename P = std::identity,
80
+ std::make_unsigned_t <std::remove_cvref_t <
81
+ std::invoke_result_t <P, std::iter_value_t <R>>>>
82
+ BITS
83
+ = 8 >
53
84
requires std::integral<decltype (BITS)>
54
85
constexpr void operator ()(R&& range, P proj = {}) const
55
86
{
@@ -58,19 +89,36 @@ struct __radix_sort
58
89
59
90
// index type (if no projection is provided it holds I == T)
60
91
using I = std::remove_cvref_t <std::invoke_result_t <P, T>>;
92
+ using uI = std::make_unsigned_t <I>;
93
+
94
+ if constexpr (!std::is_same_v<uI, I>)
95
+ {
96
+ __radix_sort ()(std::forward<R>(range), [&](const T& e) -> uI
97
+ { return unsigned_projection (proj (e)); });
98
+ return ;
99
+ }
61
100
62
101
if (range.size () <= 1 )
63
102
return ;
64
103
65
- T max_value = proj (*std::ranges::max_element (range, std::less{}, proj));
104
+ uI max_value = proj (*std::ranges::max_element (range, std::less{}, proj));
66
105
67
106
// Sort N bits at a time
68
- constexpr I bucket_size = 1 << BITS;
69
- T mask = (T (1 ) << BITS) - 1 ;
107
+ constexpr uI bucket_size = 1 << BITS;
108
+ uI mask = (uI (1 ) << BITS) - 1 ;
70
109
71
110
// Compute number of iterations, most significant digit (N bits) of
72
111
// maxvalue
73
112
I its = 0 ;
113
+
114
+ // optimize for case where all first bits are set - then order will not
115
+ // depend on it
116
+ bool all_first_bit = std::ranges::all_of (
117
+ range, [&](const auto & e)
118
+ { return proj (e) & (uI (1 ) << (sizeof (uI) * 8 - 1 )); });
119
+ if (all_first_bit)
120
+ max_value = max_value & ~(uI (1 ) << (sizeof (uI) * 8 - 1 ));
121
+
74
122
while (max_value)
75
123
{
76
124
max_value >>= BITS;
@@ -81,7 +129,7 @@ struct __radix_sort
81
129
std::array<I, bucket_size> counter;
82
130
std::array<I, bucket_size + 1 > offset;
83
131
84
- I mask_offset = 0 ;
132
+ uI mask_offset = 0 ;
85
133
std::vector<T> buffer (range.size ());
86
134
std::span<T> current_perm = range;
87
135
std::span<T> next_perm = buffer;
@@ -100,8 +148,8 @@ struct __radix_sort
100
148
std::next (offset.begin ()));
101
149
for (const auto & c : current_perm)
102
150
{
103
- I bucket = (proj (c) & mask) >> mask_offset;
104
- I new_pos = offset[bucket + 1 ] - counter[bucket];
151
+ uI bucket = (proj (c) & mask) >> mask_offset;
152
+ uI new_pos = offset[bucket + 1 ] - counter[bucket];
105
153
next_perm[new_pos] = c;
106
154
counter[bucket]--;
107
155
}
0 commit comments