Skip to content

Commit 2054aca

Browse files
committed
apacheGH-45216: [C++][Compute] Refactor Rank implementation
1 parent 2b5f56c commit 2054aca

File tree

4 files changed

+143
-198
lines changed

4 files changed

+143
-198
lines changed

cpp/src/arrow/compute/kernels/vector_rank.cc

+71-84
Original file line numberDiff line numberDiff line change
@@ -28,114 +28,95 @@ namespace {
2828
// ----------------------------------------------------------------------
2929
// Rank implementation
3030

31-
template <typename ValueSelector,
32-
typename T = std::decay_t<std::invoke_result_t<ValueSelector, int64_t>>>
31+
constexpr uint64_t kDuplicateMask = 1ULL << 63;
32+
33+
bool NeedsDuplicates(RankOptions::Tiebreaker tiebreaker) {
34+
return tiebreaker != RankOptions::First;
35+
}
36+
37+
template <typename ValueSelector>
38+
void MarkDuplicates(const NullPartitionResult& sorted, ValueSelector&& value_selector) {
39+
using T = decltype(value_selector(int64_t{}));
40+
41+
// Process non-nulls
42+
if (sorted.non_nulls_end != sorted.non_nulls_begin) {
43+
auto it = sorted.non_nulls_begin;
44+
T prev_value = value_selector(*it);
45+
while (++it < sorted.non_nulls_end) {
46+
T curr_value = value_selector(*it);
47+
if (curr_value == prev_value) {
48+
*it |= kDuplicateMask;
49+
}
50+
prev_value = curr_value;
51+
}
52+
}
53+
54+
// Process nulls
55+
if (sorted.nulls_end != sorted.nulls_begin) {
56+
// TODO this should be able to distinguish between NaNs and real nulls (GH-45193)
57+
auto it = sorted.nulls_begin;
58+
while (++it < sorted.nulls_end) {
59+
*it |= kDuplicateMask;
60+
}
61+
}
62+
}
63+
3364
Result<Datum> CreateRankings(ExecContext* ctx, const NullPartitionResult& sorted,
3465
const NullPlacement null_placement,
35-
const RankOptions::Tiebreaker tiebreaker,
36-
ValueSelector&& value_selector) {
66+
const RankOptions::Tiebreaker tiebreaker) {
3767
auto length = sorted.overall_end() - sorted.overall_begin();
3868
ARROW_ASSIGN_OR_RAISE(auto rankings,
3969
MakeMutableUInt64Array(length, ctx->memory_pool()));
4070
auto out_begin = rankings->GetMutableValues<uint64_t>(1);
4171
uint64_t rank;
4272

73+
auto is_duplicate = [](uint64_t index) { return (index & kDuplicateMask) != 0; };
74+
auto original_index = [](uint64_t index) { return index & ~kDuplicateMask; };
75+
4376
switch (tiebreaker) {
4477
case RankOptions::Dense: {
45-
T curr_value, prev_value{};
4678
rank = 0;
47-
48-
if (null_placement == NullPlacement::AtStart && sorted.null_count() > 0) {
49-
rank++;
50-
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
51-
out_begin[*it] = rank;
52-
}
53-
}
54-
55-
for (auto it = sorted.non_nulls_begin; it < sorted.non_nulls_end; it++) {
56-
curr_value = value_selector(*it);
57-
if (it == sorted.non_nulls_begin || curr_value != prev_value) {
58-
rank++;
59-
}
60-
61-
out_begin[*it] = rank;
62-
prev_value = curr_value;
63-
}
64-
65-
if (null_placement == NullPlacement::AtEnd) {
66-
rank++;
67-
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
68-
out_begin[*it] = rank;
79+
for (auto it = sorted.overall_begin(); it < sorted.overall_end(); ++it) {
80+
if (!is_duplicate(*it)) {
81+
++rank;
6982
}
83+
out_begin[original_index(*it)] = rank;
7084
}
7185
break;
7286
}
7387

7488
case RankOptions::First: {
7589
rank = 0;
7690
for (auto it = sorted.overall_begin(); it < sorted.overall_end(); it++) {
91+
// No duplicate marks expected for RankOptions::First
92+
DCHECK(!is_duplicate(*it));
7793
out_begin[*it] = ++rank;
7894
}
7995
break;
8096
}
8197

8298
case RankOptions::Min: {
83-
T curr_value, prev_value{};
8499
rank = 0;
85-
86-
if (null_placement == NullPlacement::AtStart) {
87-
rank++;
88-
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
89-
out_begin[*it] = rank;
90-
}
91-
}
92-
93-
for (auto it = sorted.non_nulls_begin; it < sorted.non_nulls_end; it++) {
94-
curr_value = value_selector(*it);
95-
if (it == sorted.non_nulls_begin || curr_value != prev_value) {
100+
for (auto it = sorted.overall_begin(); it < sorted.overall_end(); ++it) {
101+
if (!is_duplicate(*it)) {
96102
rank = (it - sorted.overall_begin()) + 1;
97103
}
98-
out_begin[*it] = rank;
99-
prev_value = curr_value;
100-
}
101-
102-
if (null_placement == NullPlacement::AtEnd) {
103-
rank = sorted.non_null_count() + 1;
104-
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
105-
out_begin[*it] = rank;
106-
}
104+
out_begin[original_index(*it)] = rank;
107105
}
108106
break;
109107
}
110108

111109
case RankOptions::Max: {
112-
// The algorithm for Max is just like Min, but in reverse order.
113-
T curr_value, prev_value{};
114110
rank = length;
115-
116-
if (null_placement == NullPlacement::AtEnd) {
117-
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
118-
out_begin[*it] = rank;
119-
}
120-
}
121-
122-
for (auto it = sorted.non_nulls_end - 1; it >= sorted.non_nulls_begin; it--) {
123-
curr_value = value_selector(*it);
124-
125-
if (it == sorted.non_nulls_end - 1 || curr_value != prev_value) {
126-
rank = (it - sorted.overall_begin()) + 1;
127-
}
128-
out_begin[*it] = rank;
129-
prev_value = curr_value;
130-
}
131-
132-
if (null_placement == NullPlacement::AtStart) {
133-
rank = sorted.null_count();
134-
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
135-
out_begin[*it] = rank;
111+
for (auto it = sorted.overall_end() - 1; it >= sorted.overall_begin(); --it) {
112+
out_begin[original_index(*it)] = rank;
113+
// If the current index isn't marked as duplicate, then it's the last
114+
// tie in a row (since we iterate in reverse order), so update rank
115+
// for the next row of ties.
116+
if (!is_duplicate(*it)) {
117+
rank = it - sorted.overall_begin();
136118
}
137119
}
138-
139120
break;
140121
}
141122
}
@@ -209,11 +190,14 @@ class Ranker<Array> : public RankerMixin<Array, Ranker<Array>> {
209190
array_sorter(indices_begin_, indices_end_, array, 0,
210191
ArraySortOptions(order_, null_placement_), ctx_));
211192

212-
auto value_selector = [&array](int64_t index) {
213-
return GetView::LogicalValue(array.GetView(index));
214-
};
215-
ARROW_ASSIGN_OR_RAISE(*output_, CreateRankings(ctx_, sorted, null_placement_,
216-
tiebreaker_, value_selector));
193+
if (NeedsDuplicates(tiebreaker_)) {
194+
auto value_selector = [&array](int64_t index) {
195+
return GetView::LogicalValue(array.GetView(index));
196+
};
197+
MarkDuplicates(sorted, value_selector);
198+
}
199+
ARROW_ASSIGN_OR_RAISE(*output_,
200+
CreateRankings(ctx_, sorted, null_placement_, tiebreaker_));
217201

218202
return Status::OK();
219203
}
@@ -238,13 +222,16 @@ class Ranker<ChunkedArray> : public RankerMixin<ChunkedArray, Ranker<ChunkedArra
238222
SortChunkedArray(ctx_, indices_begin_, indices_end_, physical_type_,
239223
physical_chunks_, order_, null_placement_));
240224

241-
const auto arrays = GetArrayPointers(physical_chunks_);
242-
auto value_selector = [resolver = ChunkedArrayResolver(span(arrays))](int64_t index) {
243-
return resolver.Resolve(index).Value<InType>();
244-
};
245-
ARROW_ASSIGN_OR_RAISE(*output_, CreateRankings(ctx_, sorted, null_placement_,
246-
tiebreaker_, value_selector));
247-
225+
if (NeedsDuplicates(tiebreaker_)) {
226+
const auto arrays = GetArrayPointers(physical_chunks_);
227+
auto value_selector = [resolver =
228+
ChunkedArrayResolver(span(arrays))](int64_t index) {
229+
return resolver.Resolve(index).Value<InType>();
230+
};
231+
MarkDuplicates(sorted, value_selector);
232+
}
233+
ARROW_ASSIGN_OR_RAISE(*output_,
234+
CreateRankings(ctx_, sorted, null_placement_, tiebreaker_));
248235
return Status::OK();
249236
}
250237

cpp/src/arrow/compute/kernels/vector_sort.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class ChunkedArraySorter : public TypeVisitor {
121121
CompressedChunkLocation* nulls_middle,
122122
CompressedChunkLocation* nulls_end,
123123
CompressedChunkLocation* temp_indices, int64_t null_count) {
124-
if (has_null_like_values<typename ArrayType::TypeClass>::value) {
124+
if (has_null_like_values<typename ArrayType::TypeClass>()) {
125125
PartitionNullsOnly<StablePartitioner>(nulls_begin, nulls_end, arrays,
126126
null_count, null_placement_);
127127
}
@@ -781,7 +781,7 @@ class TableSorter {
781781
CompressedChunkLocation* nulls_middle,
782782
CompressedChunkLocation* nulls_end,
783783
CompressedChunkLocation* temp_indices, int64_t null_count) {
784-
if constexpr (has_null_like_values<ArrowType>::value) {
784+
if constexpr (has_null_like_values<ArrowType>()) {
785785
// Merge rows with a null or a null-like in the first sort key
786786
auto& comparator = comparator_;
787787
const auto& first_sort_key = sort_keys_[0];

0 commit comments

Comments
 (0)