Skip to content

Commit 7cd7344

Browse files
committed
apacheGH-39740: [C++] Fix filter kernel for month_day_nano intervals
1 parent 2e8bd8d commit 7cd7344

5 files changed

+270
-139
lines changed

cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc

+96-71
Original file line numberDiff line numberDiff line change
@@ -146,36 +146,40 @@ class DropNullCounter {
146146

147147
/// \brief The Filter implementation for primitive (fixed-width) types does not
148148
/// use the logical Arrow type but rather the physical C type. This way we only
149-
/// generate one take function for each byte width. We use the same
150-
/// implementation here for boolean and fixed-byte-size inputs with some
151-
/// template specialization.
152-
template <typename ArrowType>
149+
/// generate one take function for each byte width.
150+
///
151+
/// We use compile-time specialization for two variations:
152+
/// - operating on boolean data (using kIsBoolean = true)
153+
/// - operating on fixed-width data of arbitrary width (using kByteWidth = -1),
154+
/// with the actual width only known at runtime
155+
template <int32_t kByteWidth, bool kIsBoolean = false>
153156
class PrimitiveFilterImpl {
154157
public:
155-
using T = typename std::conditional<std::is_same<ArrowType, BooleanType>::value,
156-
uint8_t, typename ArrowType::c_type>::type;
157-
158158
PrimitiveFilterImpl(const ArraySpan& values, const ArraySpan& filter,
159159
FilterOptions::NullSelectionBehavior null_selection,
160160
ArrayData* out_arr)
161-
: values_is_valid_(values.buffers[0].data),
162-
values_data_(reinterpret_cast<const T*>(values.buffers[1].data)),
161+
: byte_width_(values.type->byte_width()),
162+
values_is_valid_(values.buffers[0].data),
163+
values_data_(values.buffers[1].data),
163164
values_null_count_(values.null_count),
164165
values_offset_(values.offset),
165166
values_length_(values.length),
166167
filter_(filter),
167168
null_selection_(null_selection) {
168-
if (values.type->id() != Type::BOOL) {
169+
if (kByteWidth >= 0 && !kIsBoolean) {
170+
DCHECK_EQ(kByteWidth, byte_width_);
171+
}
172+
if (!kIsBoolean) {
169173
// No offset applied for boolean because it's a bitmap
170-
values_data_ += values.offset;
174+
values_data_ += values.offset * byte_width();
171175
}
172176

173177
if (out_arr->buffers[0] != nullptr) {
174178
// May be unallocated if neither filter nor values contain nulls
175179
out_is_valid_ = out_arr->buffers[0]->mutable_data();
176180
}
177-
out_data_ = reinterpret_cast<T*>(out_arr->buffers[1]->mutable_data());
178-
out_offset_ = out_arr->offset;
181+
out_data_ = out_arr->buffers[1]->mutable_data();
182+
DCHECK_EQ(out_arr->offset, 0);
179183
out_length_ = out_arr->length;
180184
out_position_ = 0;
181185
}
@@ -201,14 +205,11 @@ class PrimitiveFilterImpl {
201205
[&](int64_t position, int64_t segment_length, bool filter_valid) {
202206
if (filter_valid) {
203207
CopyBitmap(values_is_valid_, values_offset_ + position, segment_length,
204-
out_is_valid_, out_offset_ + out_position_);
208+
out_is_valid_, out_position_);
205209
WriteValueSegment(position, segment_length);
206210
} else {
207-
bit_util::SetBitsTo(out_is_valid_, out_offset_ + out_position_,
208-
segment_length, false);
209-
memset(out_data_ + out_offset_ + out_position_, 0,
210-
segment_length * sizeof(T));
211-
out_position_ += segment_length;
211+
bit_util::SetBitsTo(out_is_valid_, out_position_, segment_length, false);
212+
WriteNullSegment(segment_length);
212213
}
213214
return true;
214215
});
@@ -218,19 +219,16 @@ class PrimitiveFilterImpl {
218219
if (out_is_valid_) {
219220
// Set all to valid, so only if nulls are produced by EMIT_NULL, we need
220221
// to set out_is_valid[i] to false.
221-
bit_util::SetBitsTo(out_is_valid_, out_offset_, out_length_, true);
222+
bit_util::SetBitsTo(out_is_valid_, 0, out_length_, true);
222223
}
223224
return VisitPlainxREEFilterOutputSegments(
224225
filter_, /*filter_may_have_nulls=*/true, null_selection_,
225226
[&](int64_t position, int64_t segment_length, bool filter_valid) {
226227
if (filter_valid) {
227228
WriteValueSegment(position, segment_length);
228229
} else {
229-
bit_util::SetBitsTo(out_is_valid_, out_offset_ + out_position_,
230-
segment_length, false);
231-
memset(out_data_ + out_offset_ + out_position_, 0,
232-
segment_length * sizeof(T));
233-
out_position_ += segment_length;
230+
bit_util::SetBitsTo(out_is_valid_, out_position_, segment_length, false);
231+
WriteNullSegment(segment_length);
234232
}
235233
return true;
236234
});
@@ -260,13 +258,13 @@ class PrimitiveFilterImpl {
260258
values_length_);
261259

262260
auto WriteNotNull = [&](int64_t index) {
263-
bit_util::SetBit(out_is_valid_, out_offset_ + out_position_);
261+
bit_util::SetBit(out_is_valid_, out_position_);
264262
// Increments out_position_
265263
WriteValue(index);
266264
};
267265

268266
auto WriteMaybeNull = [&](int64_t index) {
269-
bit_util::SetBitTo(out_is_valid_, out_offset_ + out_position_,
267+
bit_util::SetBitTo(out_is_valid_, out_position_,
270268
bit_util::GetBit(values_is_valid_, values_offset_ + index));
271269
// Increments out_position_
272270
WriteValue(index);
@@ -279,15 +277,14 @@ class PrimitiveFilterImpl {
279277
BitBlockCount data_block = data_counter.NextWord();
280278
if (filter_block.AllSet() && data_block.AllSet()) {
281279
// Fastest path: all values in block are included and not null
282-
bit_util::SetBitsTo(out_is_valid_, out_offset_ + out_position_,
283-
filter_block.length, true);
280+
bit_util::SetBitsTo(out_is_valid_, out_position_, filter_block.length, true);
284281
WriteValueSegment(in_position, filter_block.length);
285282
in_position += filter_block.length;
286283
} else if (filter_block.AllSet()) {
287284
// Faster: all values are selected, but some values are null
288285
// Batch copy bits from values validity bitmap to output validity bitmap
289286
CopyBitmap(values_is_valid_, values_offset_ + in_position, filter_block.length,
290-
out_is_valid_, out_offset_ + out_position_);
287+
out_is_valid_, out_position_);
291288
WriteValueSegment(in_position, filter_block.length);
292289
in_position += filter_block.length;
293290
} else if (filter_block.NoneSet() && null_selection_ == FilterOptions::DROP) {
@@ -326,7 +323,7 @@ class PrimitiveFilterImpl {
326323
WriteNotNull(in_position);
327324
} else if (!is_valid) {
328325
// Filter slot is null, so we have a null in the output
329-
bit_util::ClearBit(out_is_valid_, out_offset_ + out_position_);
326+
bit_util::ClearBit(out_is_valid_, out_position_);
330327
WriteNull();
331328
}
332329
++in_position;
@@ -362,7 +359,7 @@ class PrimitiveFilterImpl {
362359
WriteMaybeNull(in_position);
363360
} else if (!is_valid) {
364361
// Filter slot is null, so we have a null in the output
365-
bit_util::ClearBit(out_is_valid_, out_offset_ + out_position_);
362+
bit_util::ClearBit(out_is_valid_, out_position_);
366363
WriteNull();
367364
}
368365
++in_position;
@@ -376,54 +373,72 @@ class PrimitiveFilterImpl {
376373
// Write the next out_position given the selected in_position for the input
377374
// data and advance out_position
378375
void WriteValue(int64_t in_position) {
379-
out_data_[out_offset_ + out_position_++] = values_data_[in_position];
376+
if constexpr (kIsBoolean) {
377+
bit_util::SetBitTo(out_data_, out_position_,
378+
bit_util::GetBit(values_data_, values_offset_ + in_position));
379+
} else {
380+
memcpy(out_data_ + out_position_ * byte_width(),
381+
values_data_ + in_position * byte_width(), byte_width());
382+
}
383+
++out_position_;
380384
}
381385

382386
void WriteValueSegment(int64_t in_start, int64_t length) {
383-
std::memcpy(out_data_ + out_position_, values_data_ + in_start, length * sizeof(T));
387+
if constexpr (kIsBoolean) {
388+
CopyBitmap(values_data_, values_offset_ + in_start, length, out_data_,
389+
out_position_);
390+
} else {
391+
memcpy(out_data_ + out_position_ * byte_width(),
392+
values_data_ + in_start * byte_width(), length * byte_width());
393+
}
384394
out_position_ += length;
385395
}
386396

387397
void WriteNull() {
388-
// Zero the memory
389-
out_data_[out_offset_ + out_position_++] = T{};
398+
if constexpr (kIsBoolean) {
399+
// Zero the bit
400+
bit_util::ClearBit(out_data_, out_position_);
401+
} else {
402+
// Zero the memory
403+
memset(out_data_ + out_position_ * byte_width(), 0, byte_width());
404+
}
405+
++out_position_;
406+
}
407+
408+
void WriteNullSegment(int64_t length) {
409+
if constexpr (kIsBoolean) {
410+
// Zero the bits
411+
bit_util::SetBitsTo(out_data_, out_position_, length, false);
412+
} else {
413+
// Zero the memory
414+
memset(out_data_ + out_position_ * byte_width(), 0, length * byte_width());
415+
}
416+
out_position_ += length;
417+
}
418+
419+
constexpr int32_t byte_width() const {
420+
if constexpr (kByteWidth >= 0) {
421+
return kByteWidth;
422+
} else {
423+
return byte_width_;
424+
}
390425
}
391426

392427
private:
428+
int32_t byte_width_;
393429
const uint8_t* values_is_valid_;
394-
const T* values_data_;
430+
const uint8_t* values_data_;
395431
int64_t values_null_count_;
396432
int64_t values_offset_;
397433
int64_t values_length_;
398434
const ArraySpan& filter_;
399435
FilterOptions::NullSelectionBehavior null_selection_;
400436
uint8_t* out_is_valid_ = NULLPTR;
401-
T* out_data_;
402-
int64_t out_offset_;
437+
uint8_t* out_data_;
403438
int64_t out_length_;
404439
int64_t out_position_;
405440
};
406441

407-
template <>
408-
inline void PrimitiveFilterImpl<BooleanType>::WriteValue(int64_t in_position) {
409-
bit_util::SetBitTo(out_data_, out_offset_ + out_position_++,
410-
bit_util::GetBit(values_data_, values_offset_ + in_position));
411-
}
412-
413-
template <>
414-
inline void PrimitiveFilterImpl<BooleanType>::WriteValueSegment(int64_t in_start,
415-
int64_t length) {
416-
CopyBitmap(values_data_, values_offset_ + in_start, length, out_data_,
417-
out_offset_ + out_position_);
418-
out_position_ += length;
419-
}
420-
421-
template <>
422-
inline void PrimitiveFilterImpl<BooleanType>::WriteNull() {
423-
// Zero the bit
424-
bit_util::ClearBit(out_data_, out_offset_ + out_position_++);
425-
}
426-
427442
Status PrimitiveFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
428443
const ArraySpan& values = batch[0].array;
429444
const ArraySpan& filter = batch[1].array;
@@ -459,22 +474,32 @@ Status PrimitiveFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult
459474

460475
switch (bit_width) {
461476
case 1:
462-
PrimitiveFilterImpl<BooleanType>(values, filter, null_selection, out_arr).Exec();
477+
PrimitiveFilterImpl<1, /*kIsBoolean=*/true>(values, filter, null_selection, out_arr)
478+
.Exec();
463479
break;
464480
case 8:
465-
PrimitiveFilterImpl<UInt8Type>(values, filter, null_selection, out_arr).Exec();
481+
PrimitiveFilterImpl<1>(values, filter, null_selection, out_arr).Exec();
466482
break;
467483
case 16:
468-
PrimitiveFilterImpl<UInt16Type>(values, filter, null_selection, out_arr).Exec();
484+
PrimitiveFilterImpl<2>(values, filter, null_selection, out_arr).Exec();
469485
break;
470486
case 32:
471-
PrimitiveFilterImpl<UInt32Type>(values, filter, null_selection, out_arr).Exec();
487+
PrimitiveFilterImpl<4>(values, filter, null_selection, out_arr).Exec();
472488
break;
473489
case 64:
474-
PrimitiveFilterImpl<UInt64Type>(values, filter, null_selection, out_arr).Exec();
490+
PrimitiveFilterImpl<8>(values, filter, null_selection, out_arr).Exec();
491+
break;
492+
case 128:
493+
// For INTERVAL_MONTH_DAY_NANO, DECIMAL128
494+
PrimitiveFilterImpl<16>(values, filter, null_selection, out_arr).Exec();
495+
break;
496+
case 256:
497+
// For DECIMAL256
498+
PrimitiveFilterImpl<32>(values, filter, null_selection, out_arr).Exec();
475499
break;
476500
default:
477-
DCHECK(false) << "Invalid values bit width";
501+
// Non-specializing on byte width
502+
PrimitiveFilterImpl<-1>(values, filter, null_selection, out_arr).Exec();
478503
break;
479504
}
480505
return Status::OK();
@@ -1050,10 +1075,10 @@ void PopulateFilterKernels(std::vector<SelectionKernelData>* out) {
10501075
{InputType(match::Primitive()), plain_filter, PrimitiveFilterExec},
10511076
{InputType(match::BinaryLike()), plain_filter, BinaryFilterExec},
10521077
{InputType(match::LargeBinaryLike()), plain_filter, BinaryFilterExec},
1053-
{InputType(Type::FIXED_SIZE_BINARY), plain_filter, FSBFilterExec},
10541078
{InputType(null()), plain_filter, NullFilterExec},
1055-
{InputType(Type::DECIMAL128), plain_filter, FSBFilterExec},
1056-
{InputType(Type::DECIMAL256), plain_filter, FSBFilterExec},
1079+
{InputType(Type::FIXED_SIZE_BINARY), plain_filter, PrimitiveFilterExec},
1080+
{InputType(Type::DECIMAL128), plain_filter, PrimitiveFilterExec},
1081+
{InputType(Type::DECIMAL256), plain_filter, PrimitiveFilterExec},
10571082
{InputType(Type::DICTIONARY), plain_filter, DictionaryFilterExec},
10581083
{InputType(Type::EXTENSION), plain_filter, ExtensionFilterExec},
10591084
{InputType(Type::LIST), plain_filter, ListFilterExec},
@@ -1068,10 +1093,10 @@ void PopulateFilterKernels(std::vector<SelectionKernelData>* out) {
10681093
{InputType(match::Primitive()), ree_filter, PrimitiveFilterExec},
10691094
{InputType(match::BinaryLike()), ree_filter, BinaryFilterExec},
10701095
{InputType(match::LargeBinaryLike()), ree_filter, BinaryFilterExec},
1071-
{InputType(Type::FIXED_SIZE_BINARY), ree_filter, FSBFilterExec},
10721096
{InputType(null()), ree_filter, NullFilterExec},
1073-
{InputType(Type::DECIMAL128), ree_filter, FSBFilterExec},
1074-
{InputType(Type::DECIMAL256), ree_filter, FSBFilterExec},
1097+
{InputType(Type::FIXED_SIZE_BINARY), ree_filter, PrimitiveFilterExec},
1098+
{InputType(Type::DECIMAL128), ree_filter, PrimitiveFilterExec},
1099+
{InputType(Type::DECIMAL256), ree_filter, PrimitiveFilterExec},
10751100
{InputType(Type::DICTIONARY), ree_filter, DictionaryFilterExec},
10761101
{InputType(Type::EXTENSION), ree_filter, ExtensionFilterExec},
10771102
{InputType(Type::LIST), ree_filter, ListFilterExec},

cpp/src/arrow/compute/kernels/vector_selection_internal.cc

+16-6
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ Status PreallocatePrimitiveArrayData(KernelContext* ctx, int64_t length, int bit
7777
if (bit_width == 1) {
7878
ARROW_ASSIGN_OR_RAISE(out->buffers[1], ctx->AllocateBitmap(length));
7979
} else {
80-
ARROW_ASSIGN_OR_RAISE(out->buffers[1], ctx->Allocate(length * bit_width / 8));
80+
ARROW_ASSIGN_OR_RAISE(out->buffers[1],
81+
ctx->Allocate(bit_util::BytesForBits(length * bit_width)));
8182
}
8283
return Status::OK();
8384
}
@@ -899,10 +900,6 @@ Status FilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
899900

900901
} // namespace
901902

902-
Status FSBFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
903-
return FilterExec<FSBSelectionImpl>(ctx, batch, out);
904-
}
905-
906903
Status ListFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
907904
return FilterExec<ListSelectionImpl<ListType>>(ctx, batch, out);
908905
}
@@ -946,7 +943,20 @@ Status LargeVarBinaryTakeExec(KernelContext* ctx, const ExecSpan& batch,
946943
}
947944

948945
Status FSBTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
949-
return TakeExec<FSBSelectionImpl>(ctx, batch, out);
946+
const ArraySpan& values = batch[0].array;
947+
const auto byte_width = values.type->byte_width();
948+
// Use primitive Take implementation (presumably faster) for some byte widths
949+
switch (byte_width) {
950+
case 1:
951+
case 2:
952+
case 4:
953+
case 8:
954+
case 16:
955+
case 32:
956+
return PrimitiveTakeExec(ctx, batch, out);
957+
default:
958+
return TakeExec<FSBSelectionImpl>(ctx, batch, out);
959+
}
950960
}
951961

952962
Status ListTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {

cpp/src/arrow/compute/kernels/vector_selection_internal.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ void VisitPlainxREEFilterOutputSegments(
7070
FilterOptions::NullSelectionBehavior null_selection,
7171
const EmitREEFilterSegment& emit_segment);
7272

73-
Status FSBFilterExec(KernelContext*, const ExecSpan&, ExecResult*);
7473
Status ListFilterExec(KernelContext*, const ExecSpan&, ExecResult*);
7574
Status LargeListFilterExec(KernelContext*, const ExecSpan&, ExecResult*);
7675
Status FSLFilterExec(KernelContext*, const ExecSpan&, ExecResult*);
@@ -79,6 +78,7 @@ Status MapFilterExec(KernelContext*, const ExecSpan&, ExecResult*);
7978

8079
Status VarBinaryTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
8180
Status LargeVarBinaryTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
81+
Status PrimitiveTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
8282
Status FSBTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
8383
Status ListTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
8484
Status LargeListTakeExec(KernelContext*, const ExecSpan&, ExecResult*);

0 commit comments

Comments
 (0)