Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
pitrou committed Jan 9, 2025
1 parent 2b5f56c commit 50c3682
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 198 deletions.
155 changes: 71 additions & 84 deletions cpp/src/arrow/compute/kernels/vector_rank.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,114 +28,95 @@ namespace {
// ----------------------------------------------------------------------
// Rank implementation

template <typename ValueSelector,
typename T = std::decay_t<std::invoke_result_t<ValueSelector, int64_t>>>
constexpr uint64_t kDuplicateMask = 1ULL << 63;

bool NeedsDuplicates(RankOptions::Tiebreaker tiebreaker) {
return tiebreaker != RankOptions::First;
}

template <typename ValueSelector>
void MarkDuplicates(const NullPartitionResult& sorted, ValueSelector&& value_selector) {
using T = decltype(value_selector(int64_t{}));

// Process non-nulls
if (sorted.non_nulls_end != sorted.non_nulls_begin) {
auto it = sorted.non_nulls_begin;
T prev_value = value_selector(*it);
while (++it < sorted.non_nulls_end) {
T curr_value = value_selector(*it);
if (curr_value == prev_value) {
*it |= kDuplicateMask;
}
prev_value = curr_value;
}
}

// Process nulls
if (sorted.nulls_end != sorted.nulls_begin) {
// TODO this should be able to distinguish between NaNs and real nulls (GH-45193)
auto it = sorted.nulls_begin;
while (++it < sorted.nulls_end) {
*it |= kDuplicateMask;
}
}
}

Result<Datum> CreateRankings(ExecContext* ctx, const NullPartitionResult& sorted,
const NullPlacement null_placement,
const RankOptions::Tiebreaker tiebreaker,
ValueSelector&& value_selector) {
const RankOptions::Tiebreaker tiebreaker) {
auto length = sorted.overall_end() - sorted.overall_begin();
ARROW_ASSIGN_OR_RAISE(auto rankings,
MakeMutableUInt64Array(length, ctx->memory_pool()));
auto out_begin = rankings->GetMutableValues<uint64_t>(1);
uint64_t rank;

auto is_duplicate = [](uint64_t index) { return (index & kDuplicateMask) != 0; };
auto original_index = [](uint64_t index) { return index & ~kDuplicateMask; };

switch (tiebreaker) {
case RankOptions::Dense: {
T curr_value, prev_value{};
rank = 0;

if (null_placement == NullPlacement::AtStart && sorted.null_count() > 0) {
rank++;
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
out_begin[*it] = rank;
}
}

for (auto it = sorted.non_nulls_begin; it < sorted.non_nulls_end; it++) {
curr_value = value_selector(*it);
if (it == sorted.non_nulls_begin || curr_value != prev_value) {
rank++;
}

out_begin[*it] = rank;
prev_value = curr_value;
}

if (null_placement == NullPlacement::AtEnd) {
rank++;
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
out_begin[*it] = rank;
for (auto it = sorted.overall_begin(); it < sorted.overall_end(); ++it) {
if (!is_duplicate(*it)) {
++rank;
}
out_begin[original_index(*it)] = rank;
}
break;
}

case RankOptions::First: {
rank = 0;
for (auto it = sorted.overall_begin(); it < sorted.overall_end(); it++) {
// No duplicate marks expected for RankOptions::First
DCHECK(!is_duplicate(*it));
out_begin[*it] = ++rank;
}
break;
}

case RankOptions::Min: {
T curr_value, prev_value{};
rank = 0;

if (null_placement == NullPlacement::AtStart) {
rank++;
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
out_begin[*it] = rank;
}
}

for (auto it = sorted.non_nulls_begin; it < sorted.non_nulls_end; it++) {
curr_value = value_selector(*it);
if (it == sorted.non_nulls_begin || curr_value != prev_value) {
for (auto it = sorted.overall_begin(); it < sorted.overall_end(); ++it) {
if (!is_duplicate(*it)) {
rank = (it - sorted.overall_begin()) + 1;
}
out_begin[*it] = rank;
prev_value = curr_value;
}

if (null_placement == NullPlacement::AtEnd) {
rank = sorted.non_null_count() + 1;
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
out_begin[*it] = rank;
}
out_begin[original_index(*it)] = rank;
}
break;
}

case RankOptions::Max: {
// The algorithm for Max is just like Min, but in reverse order.
T curr_value, prev_value{};
rank = length;

if (null_placement == NullPlacement::AtEnd) {
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
out_begin[*it] = rank;
}
}

for (auto it = sorted.non_nulls_end - 1; it >= sorted.non_nulls_begin; it--) {
curr_value = value_selector(*it);

if (it == sorted.non_nulls_end - 1 || curr_value != prev_value) {
rank = (it - sorted.overall_begin()) + 1;
}
out_begin[*it] = rank;
prev_value = curr_value;
}

if (null_placement == NullPlacement::AtStart) {
rank = sorted.null_count();
for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
out_begin[*it] = rank;
for (auto it = sorted.overall_end() - 1; it >= sorted.overall_begin(); --it) {
out_begin[original_index(*it)] = rank;
// If the current index isn't marked as duplicate, then it's the last
// tie in a row (since we iterate in reverse order), so update rank
// for the next row of ties.
if (!is_duplicate(*it)) {
rank = it - sorted.overall_begin();
}
}

break;
}
}
Expand Down Expand Up @@ -209,11 +190,14 @@ class Ranker<Array> : public RankerMixin<Array, Ranker<Array>> {
array_sorter(indices_begin_, indices_end_, array, 0,
ArraySortOptions(order_, null_placement_), ctx_));

auto value_selector = [&array](int64_t index) {
return GetView::LogicalValue(array.GetView(index));
};
ARROW_ASSIGN_OR_RAISE(*output_, CreateRankings(ctx_, sorted, null_placement_,
tiebreaker_, value_selector));
if (NeedsDuplicates(tiebreaker_)) {
auto value_selector = [&array](int64_t index) {
return GetView::LogicalValue(array.GetView(index));
};
MarkDuplicates(sorted, value_selector);
}
ARROW_ASSIGN_OR_RAISE(*output_,
CreateRankings(ctx_, sorted, null_placement_, tiebreaker_));

return Status::OK();
}
Expand All @@ -238,13 +222,16 @@ class Ranker<ChunkedArray> : public RankerMixin<ChunkedArray, Ranker<ChunkedArra
SortChunkedArray(ctx_, indices_begin_, indices_end_, physical_type_,
physical_chunks_, order_, null_placement_));

const auto arrays = GetArrayPointers(physical_chunks_);
auto value_selector = [resolver = ChunkedArrayResolver(span(arrays))](int64_t index) {
return resolver.Resolve(index).Value<InType>();
};
ARROW_ASSIGN_OR_RAISE(*output_, CreateRankings(ctx_, sorted, null_placement_,
tiebreaker_, value_selector));

if (NeedsDuplicates(tiebreaker_)) {
const auto arrays = GetArrayPointers(physical_chunks_);
auto value_selector = [resolver =
ChunkedArrayResolver(span(arrays))](int64_t index) {
return resolver.Resolve(index).Value<InType>();
};
MarkDuplicates(sorted, value_selector);
}
ARROW_ASSIGN_OR_RAISE(*output_,
CreateRankings(ctx_, sorted, null_placement_, tiebreaker_));
return Status::OK();
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/compute/kernels/vector_sort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class ChunkedArraySorter : public TypeVisitor {
CompressedChunkLocation* nulls_middle,
CompressedChunkLocation* nulls_end,
CompressedChunkLocation* temp_indices, int64_t null_count) {
if (has_null_like_values<typename ArrayType::TypeClass>::value) {
if (has_null_like_values<typename ArrayType::TypeClass>()) {
PartitionNullsOnly<StablePartitioner>(nulls_begin, nulls_end, arrays,
null_count, null_placement_);
}
Expand Down Expand Up @@ -781,7 +781,7 @@ class TableSorter {
CompressedChunkLocation* nulls_middle,
CompressedChunkLocation* nulls_end,
CompressedChunkLocation* temp_indices, int64_t null_count) {
if constexpr (has_null_like_values<ArrowType>::value) {
if constexpr (has_null_like_values<ArrowType>()) {
// Merge rows with a null or a null-like in the first sort key
auto& comparator = comparator_;
const auto& first_sort_key = sort_keys_[0];
Expand Down
Loading

0 comments on commit 50c3682

Please sign in to comment.