@@ -28,114 +28,95 @@ namespace {
28
28
// ----------------------------------------------------------------------
29
29
// Rank implementation
30
30
31
- template <typename ValueSelector,
32
- typename T = std::decay_t <std::invoke_result_t <ValueSelector, int64_t >>>
31
+ constexpr uint64_t kDuplicateMask = 1ULL << 63 ;
32
+
33
+ bool NeedsDuplicates (RankOptions::Tiebreaker tiebreaker) {
34
+ return tiebreaker != RankOptions::First;
35
+ }
36
+
37
+ template <typename ValueSelector>
38
+ void MarkDuplicates (const NullPartitionResult& sorted, ValueSelector&& value_selector) {
39
+ using T = decltype (value_selector (int64_t {}));
40
+
41
+ // Process non-nulls
42
+ if (sorted.non_nulls_end != sorted.non_nulls_begin ) {
43
+ auto it = sorted.non_nulls_begin ;
44
+ T prev_value = value_selector (*it);
45
+ while (++it < sorted.non_nulls_end ) {
46
+ T curr_value = value_selector (*it);
47
+ if (curr_value == prev_value) {
48
+ *it |= kDuplicateMask ;
49
+ }
50
+ prev_value = curr_value;
51
+ }
52
+ }
53
+
54
+ // Process nulls
55
+ if (sorted.nulls_end != sorted.nulls_begin ) {
56
+ // TODO this should be able to distinguish between NaNs and real nulls (GH-45193)
57
+ auto it = sorted.nulls_begin ;
58
+ while (++it < sorted.nulls_end ) {
59
+ *it |= kDuplicateMask ;
60
+ }
61
+ }
62
+ }
63
+
33
64
Result<Datum> CreateRankings (ExecContext* ctx, const NullPartitionResult& sorted,
34
65
const NullPlacement null_placement,
35
- const RankOptions::Tiebreaker tiebreaker,
36
- ValueSelector&& value_selector) {
66
+ const RankOptions::Tiebreaker tiebreaker) {
37
67
auto length = sorted.overall_end () - sorted.overall_begin ();
38
68
ARROW_ASSIGN_OR_RAISE (auto rankings,
39
69
MakeMutableUInt64Array (length, ctx->memory_pool ()));
40
70
auto out_begin = rankings->GetMutableValues <uint64_t >(1 );
41
71
uint64_t rank;
42
72
73
+ auto is_duplicate = [](uint64_t index ) { return (index & kDuplicateMask ) != 0 ; };
74
+ auto original_index = [](uint64_t index ) { return index & ~kDuplicateMask ; };
75
+
43
76
switch (tiebreaker) {
44
77
case RankOptions::Dense: {
45
- T curr_value, prev_value{};
46
78
rank = 0 ;
47
-
48
- if (null_placement == NullPlacement::AtStart && sorted.null_count () > 0 ) {
49
- rank++;
50
- for (auto it = sorted.nulls_begin ; it < sorted.nulls_end ; it++) {
51
- out_begin[*it] = rank;
52
- }
53
- }
54
-
55
- for (auto it = sorted.non_nulls_begin ; it < sorted.non_nulls_end ; it++) {
56
- curr_value = value_selector (*it);
57
- if (it == sorted.non_nulls_begin || curr_value != prev_value) {
58
- rank++;
59
- }
60
-
61
- out_begin[*it] = rank;
62
- prev_value = curr_value;
63
- }
64
-
65
- if (null_placement == NullPlacement::AtEnd) {
66
- rank++;
67
- for (auto it = sorted.nulls_begin ; it < sorted.nulls_end ; it++) {
68
- out_begin[*it] = rank;
79
+ for (auto it = sorted.overall_begin (); it < sorted.overall_end (); ++it) {
80
+ if (!is_duplicate (*it)) {
81
+ ++rank;
69
82
}
83
+ out_begin[original_index (*it)] = rank;
70
84
}
71
85
break ;
72
86
}
73
87
74
88
case RankOptions::First: {
75
89
rank = 0 ;
76
90
for (auto it = sorted.overall_begin (); it < sorted.overall_end (); it++) {
91
+ // No duplicate marks expected for RankOptions::First
92
+ DCHECK (!is_duplicate (*it));
77
93
out_begin[*it] = ++rank;
78
94
}
79
95
break ;
80
96
}
81
97
82
98
case RankOptions::Min: {
83
- T curr_value, prev_value{};
84
99
rank = 0 ;
85
-
86
- if (null_placement == NullPlacement::AtStart) {
87
- rank++;
88
- for (auto it = sorted.nulls_begin ; it < sorted.nulls_end ; it++) {
89
- out_begin[*it] = rank;
90
- }
91
- }
92
-
93
- for (auto it = sorted.non_nulls_begin ; it < sorted.non_nulls_end ; it++) {
94
- curr_value = value_selector (*it);
95
- if (it == sorted.non_nulls_begin || curr_value != prev_value) {
100
+ for (auto it = sorted.overall_begin (); it < sorted.overall_end (); ++it) {
101
+ if (!is_duplicate (*it)) {
96
102
rank = (it - sorted.overall_begin ()) + 1 ;
97
103
}
98
- out_begin[*it] = rank;
99
- prev_value = curr_value;
100
- }
101
-
102
- if (null_placement == NullPlacement::AtEnd) {
103
- rank = sorted.non_null_count () + 1 ;
104
- for (auto it = sorted.nulls_begin ; it < sorted.nulls_end ; it++) {
105
- out_begin[*it] = rank;
106
- }
104
+ out_begin[original_index (*it)] = rank;
107
105
}
108
106
break ;
109
107
}
110
108
111
109
case RankOptions::Max: {
112
- // The algorithm for Max is just like Min, but in reverse order.
113
- T curr_value, prev_value{};
114
110
rank = length;
115
-
116
- if (null_placement == NullPlacement::AtEnd) {
117
- for (auto it = sorted.nulls_begin ; it < sorted.nulls_end ; it++) {
118
- out_begin[*it] = rank;
119
- }
120
- }
121
-
122
- for (auto it = sorted.non_nulls_end - 1 ; it >= sorted.non_nulls_begin ; it--) {
123
- curr_value = value_selector (*it);
124
-
125
- if (it == sorted.non_nulls_end - 1 || curr_value != prev_value) {
126
- rank = (it - sorted.overall_begin ()) + 1 ;
127
- }
128
- out_begin[*it] = rank;
129
- prev_value = curr_value;
130
- }
131
-
132
- if (null_placement == NullPlacement::AtStart) {
133
- rank = sorted.null_count ();
134
- for (auto it = sorted.nulls_begin ; it < sorted.nulls_end ; it++) {
135
- out_begin[*it] = rank;
111
+ for (auto it = sorted.overall_end () - 1 ; it >= sorted.overall_begin (); --it) {
112
+ out_begin[original_index (*it)] = rank;
113
+ // If the current index isn't marked as duplicate, then it's the last
114
+ // tie in a row (since we iterate in reverse order), so update rank
115
+ // for the next row of ties.
116
+ if (!is_duplicate (*it)) {
117
+ rank = it - sorted.overall_begin ();
136
118
}
137
119
}
138
-
139
120
break ;
140
121
}
141
122
}
@@ -209,11 +190,14 @@ class Ranker<Array> : public RankerMixin<Array, Ranker<Array>> {
209
190
array_sorter (indices_begin_, indices_end_, array, 0 ,
210
191
ArraySortOptions (order_, null_placement_), ctx_));
211
192
212
- auto value_selector = [&array](int64_t index ) {
213
- return GetView::LogicalValue (array.GetView (index ));
214
- };
215
- ARROW_ASSIGN_OR_RAISE (*output_, CreateRankings (ctx_, sorted, null_placement_,
216
- tiebreaker_, value_selector));
193
+ if (NeedsDuplicates (tiebreaker_)) {
194
+ auto value_selector = [&array](int64_t index ) {
195
+ return GetView::LogicalValue (array.GetView (index ));
196
+ };
197
+ MarkDuplicates (sorted, value_selector);
198
+ }
199
+ ARROW_ASSIGN_OR_RAISE (*output_,
200
+ CreateRankings (ctx_, sorted, null_placement_, tiebreaker_));
217
201
218
202
return Status::OK ();
219
203
}
@@ -238,13 +222,16 @@ class Ranker<ChunkedArray> : public RankerMixin<ChunkedArray, Ranker<ChunkedArra
238
222
SortChunkedArray (ctx_, indices_begin_, indices_end_, physical_type_,
239
223
physical_chunks_, order_, null_placement_));
240
224
241
- const auto arrays = GetArrayPointers (physical_chunks_);
242
- auto value_selector = [resolver = ChunkedArrayResolver (span (arrays))](int64_t index ) {
243
- return resolver.Resolve (index ).Value <InType>();
244
- };
245
- ARROW_ASSIGN_OR_RAISE (*output_, CreateRankings (ctx_, sorted, null_placement_,
246
- tiebreaker_, value_selector));
247
-
225
+ if (NeedsDuplicates (tiebreaker_)) {
226
+ const auto arrays = GetArrayPointers (physical_chunks_);
227
+ auto value_selector = [resolver =
228
+ ChunkedArrayResolver (span (arrays))](int64_t index ) {
229
+ return resolver.Resolve (index ).Value <InType>();
230
+ };
231
+ MarkDuplicates (sorted, value_selector);
232
+ }
233
+ ARROW_ASSIGN_OR_RAISE (*output_,
234
+ CreateRankings (ctx_, sorted, null_placement_, tiebreaker_));
248
235
return Status::OK ();
249
236
}
250
237
0 commit comments