Skip to content

Commit 34a9a86

Browse files
committed
[C++][Compute] Refactor rank function implementation
1 parent e12bc56 commit 34a9a86

File tree

2 files changed

+64
-82
lines changed

2 files changed

+64
-82
lines changed

cpp/src/arrow/compute/kernels/vector_rank.cc

+62-76
Original file line numberDiff line numberDiff line change
@@ -28,45 +28,60 @@ namespace {
2828
// ----------------------------------------------------------------------
2929
// Rank implementation
3030

31-
template <typename ValueSelector,
32-
typename T = std::decay_t<std::invoke_result_t<ValueSelector, int64_t>>>
31+
constexpr uint64_t kDuplicateMask = 1ULL << 63;
32+
33+
template <typename ValueSelector>
34+
void MarkDuplicates(const NullPartitionResult& sorted, ValueSelector&& value_selector) {
35+
using T = std::decay_t<decltype(value_selector(uint64_t(0)))>;
36+
37+
// Process non-nulls
38+
if (sorted.non_nulls_end != sorted.non_nulls_begin) {
39+
auto it = sorted.non_nulls_begin;
40+
T prev_value = value_selector(*it);
41+
T curr_value{};
42+
while (++it < sorted.non_nulls_end) {
43+
curr_value = value_selector(*it);
44+
if (curr_value == prev_value) {
45+
// Mark as duplicate
46+
*it |= kDuplicateMask;
47+
}
48+
prev_value = curr_value;
49+
}
50+
}
51+
// Process nulls
52+
if (sorted.nulls_end != sorted.nulls_begin) {
53+
auto it = sorted.nulls_begin;
54+
// Mark all other nulles as duplicate
55+
while (++it < sorted.nulls_end) {
56+
*it |= kDuplicateMask;
57+
}
58+
}
59+
}
60+
61+
bool NeedsDuplicates(RankOptions::Tiebreaker tiebreaker) {
62+
return tiebreaker != RankOptions::First;
63+
}
64+
3365
Result<Datum> CreateRankings(ExecContext* ctx, const NullPartitionResult& sorted,
3466
const NullPlacement null_placement,
35-
const RankOptions::Tiebreaker tiebreaker,
36-
ValueSelector&& value_selector) {
67+
const RankOptions::Tiebreaker tiebreaker) {
3768
auto length = sorted.overall_end() - sorted.overall_begin();
3869
ARROW_ASSIGN_OR_RAISE(auto rankings,
3970
MakeMutableUInt64Array(length, ctx->memory_pool()));
4071
auto out_begin = rankings->GetMutableValues<uint64_t>(1);
4172
uint64_t rank;
4273

74+
auto is_duplicate = [](uint64_t index) { return (index & kDuplicateMask) != 0; };
75+
auto original_index = [](uint64_t index) { return index & ~kDuplicateMask; };
76+
4377
switch (tiebreaker) {
4478
case RankOptions::Dense: {
45-
T curr_value, prev_value{};
4679
rank = 0;
47-
48-
if (null_placement == NullPlacement::AtStart && sorted.null_count() > 0) {
49-
rank++;
50-
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
51-
out_begin[*it] = rank;
52-
}
53-
}
54-
55-
for (auto it = sorted.non_nulls_begin; it < sorted.non_nulls_end; it++) {
56-
curr_value = value_selector(*it);
57-
if (it == sorted.non_nulls_begin || curr_value != prev_value) {
58-
rank++;
59-
}
60-
61-
out_begin[*it] = rank;
62-
prev_value = curr_value;
63-
}
64-
65-
if (null_placement == NullPlacement::AtEnd) {
66-
rank++;
67-
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
68-
out_begin[*it] = rank;
80+
for (auto it = sorted.overall_begin(); it < sorted.overall_end(); ++it) {
81+
if (!is_duplicate(*it)) {
82+
++rank;
6983
}
84+
out_begin[original_index(*it)] = rank;
7085
}
7186
break;
7287
}
@@ -80,62 +95,27 @@ Result<Datum> CreateRankings(ExecContext* ctx, const NullPartitionResult& sorted
8095
}
8196

8297
case RankOptions::Min: {
83-
T curr_value, prev_value{};
8498
rank = 0;
85-
86-
if (null_placement == NullPlacement::AtStart) {
87-
rank++;
88-
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
89-
out_begin[*it] = rank;
90-
}
91-
}
92-
93-
for (auto it = sorted.non_nulls_begin; it < sorted.non_nulls_end; it++) {
94-
curr_value = value_selector(*it);
95-
if (it == sorted.non_nulls_begin || curr_value != prev_value) {
99+
for (auto it = sorted.overall_begin(); it < sorted.overall_end(); ++it) {
100+
if (!is_duplicate(*it)) {
96101
rank = (it - sorted.overall_begin()) + 1;
97102
}
98-
out_begin[*it] = rank;
99-
prev_value = curr_value;
100-
}
101-
102-
if (null_placement == NullPlacement::AtEnd) {
103-
rank = sorted.non_null_count() + 1;
104-
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
105-
out_begin[*it] = rank;
106-
}
103+
out_begin[original_index(*it)] = rank;
107104
}
108105
break;
109106
}
110107

111108
case RankOptions::Max: {
112-
// The algorithm for Max is just like Min, but in reverse order.
113-
T curr_value, prev_value{};
114109
rank = length;
115-
116-
if (null_placement == NullPlacement::AtEnd) {
117-
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
118-
out_begin[*it] = rank;
119-
}
120-
}
121-
122-
for (auto it = sorted.non_nulls_end - 1; it >= sorted.non_nulls_begin; it--) {
123-
curr_value = value_selector(*it);
124-
125-
if (it == sorted.non_nulls_end - 1 || curr_value != prev_value) {
126-
rank = (it - sorted.overall_begin()) + 1;
110+
for (auto it = sorted.overall_end() - 1; it >= sorted.overall_begin(); --it) {
111+
out_begin[original_index(*it)] = rank;
112+
// If the current index isn't marked as duplicate, then it's the last
113+
// tie in a row (since we iterate in reverse order), so update rank
114+
// for the next row of ties.
115+
if (!is_duplicate(*it)) {
116+
rank = it - sorted.overall_begin();
127117
}
128-
out_begin[*it] = rank;
129-
prev_value = curr_value;
130118
}
131-
132-
if (null_placement == NullPlacement::AtStart) {
133-
rank = sorted.null_count();
134-
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
135-
out_begin[*it] = rank;
136-
}
137-
}
138-
139119
break;
140120
}
141121
}
@@ -212,8 +192,11 @@ class Ranker<Array> : public RankerMixin<Array, Ranker<Array>> {
212192
auto value_selector = [&array](int64_t index) {
213193
return GetView::LogicalValue(array.GetView(index));
214194
};
215-
ARROW_ASSIGN_OR_RAISE(*output_, CreateRankings(ctx_, sorted, null_placement_,
216-
tiebreaker_, value_selector));
195+
if (NeedsDuplicates(tiebreaker_)) {
196+
MarkDuplicates(sorted, value_selector);
197+
}
198+
ARROW_ASSIGN_OR_RAISE(*output_,
199+
CreateRankings(ctx_, sorted, null_placement_, tiebreaker_));
217200

218201
return Status::OK();
219202
}
@@ -242,8 +225,11 @@ class Ranker<ChunkedArray> : public RankerMixin<ChunkedArray, Ranker<ChunkedArra
242225
auto value_selector = [resolver = ChunkedArrayResolver(span(arrays))](int64_t index) {
243226
return resolver.Resolve(index).Value<InType>();
244227
};
245-
ARROW_ASSIGN_OR_RAISE(*output_, CreateRankings(ctx_, sorted, null_placement_,
246-
tiebreaker_, value_selector));
228+
if (NeedsDuplicates(tiebreaker_)) {
229+
MarkDuplicates(sorted, value_selector);
230+
}
231+
ARROW_ASSIGN_OR_RAISE(*output_,
232+
CreateRankings(ctx_, sorted, null_placement_, tiebreaker_));
247233

248234
return Status::OK();
249235
}

cpp/src/arrow/compute/kernels/vector_sort_internal.h

+2-6
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@
2929
#include "arrow/type.h"
3030
#include "arrow/type_traits.h"
3131

32-
namespace arrow {
33-
namespace compute {
34-
namespace internal {
32+
namespace arrow::compute::internal {
3533

3634
// Visit all physical types for which sorting is implemented.
3735
#define VISIT_SORTABLE_PHYSICAL_TYPES(VISIT) \
@@ -853,6 +851,4 @@ inline Result<std::shared_ptr<ArrayData>> MakeMutableUInt64Array(
853851
return ArrayData::Make(uint64(), length, {nullptr, std::move(data)}, /*null_count=*/0);
854852
}
855853

856-
} // namespace internal
857-
} // namespace compute
858-
} // namespace arrow
854+
} // namespace arrow::compute::internal

0 commit comments

Comments
 (0)