@@ -146,36 +146,40 @@ class DropNullCounter {
146
146
147
147
// / \brief The Filter implementation for primitive (fixed-width) types does not
148
148
// / use the logical Arrow type but rather the physical C type. This way we only
149
- // / generate one take function for each byte width. We use the same
150
- // / implementation here for boolean and fixed-byte-size inputs with some
151
- // / template specialization.
152
- template <typename ArrowType>
149
+ // / generate one take function for each byte width.
150
+ // /
151
+ // / We use compile-time specialization for two variations:
152
+ // / - operating on boolean data (using kIsBoolean = true)
153
+ // / - operating on fixed-width data of arbitrary width (using kByteWidth = -1),
154
+ // / with the actual width only known at runtime
155
+ template <int32_t kByteWidth , bool kIsBoolean = false >
153
156
class PrimitiveFilterImpl {
154
157
public:
155
- using T = typename std::conditional<std::is_same<ArrowType, BooleanType>::value,
156
- uint8_t , typename ArrowType::c_type>::type;
157
-
158
158
PrimitiveFilterImpl (const ArraySpan& values, const ArraySpan& filter,
159
159
FilterOptions::NullSelectionBehavior null_selection,
160
160
ArrayData* out_arr)
161
- : values_is_valid_(values.buffers[0 ].data),
162
- values_data_ (reinterpret_cast <const T*>(values.buffers[1 ].data)),
161
+ : byte_width_(values.type->byte_width ()),
162
+ values_is_valid_(values.buffers[0 ].data),
163
+ values_data_(values.buffers[1 ].data),
163
164
values_null_count_(values.null_count),
164
165
values_offset_(values.offset),
165
166
values_length_(values.length),
166
167
filter_(filter),
167
168
null_selection_(null_selection) {
168
- if (values.type ->id () != Type::BOOL) {
169
+ if (kByteWidth >= 0 && !kIsBoolean ) {
170
+ DCHECK_EQ (kByteWidth , byte_width_);
171
+ }
172
+ if (!kIsBoolean ) {
169
173
// No offset applied for boolean because it's a bitmap
170
- values_data_ += values.offset ;
174
+ values_data_ += values.offset * byte_width () ;
171
175
}
172
176
173
177
if (out_arr->buffers [0 ] != nullptr ) {
174
178
// May be unallocated if neither filter nor values contain nulls
175
179
out_is_valid_ = out_arr->buffers [0 ]->mutable_data ();
176
180
}
177
- out_data_ = reinterpret_cast <T*>( out_arr->buffers [1 ]->mutable_data () );
178
- out_offset_ = out_arr->offset ;
181
+ out_data_ = out_arr->buffers [1 ]->mutable_data ();
182
+ DCHECK_EQ ( out_arr->offset , 0 ) ;
179
183
out_length_ = out_arr->length ;
180
184
out_position_ = 0 ;
181
185
}
@@ -201,14 +205,11 @@ class PrimitiveFilterImpl {
201
205
[&](int64_t position, int64_t segment_length, bool filter_valid) {
202
206
if (filter_valid) {
203
207
CopyBitmap (values_is_valid_, values_offset_ + position, segment_length,
204
- out_is_valid_, out_offset_ + out_position_);
208
+ out_is_valid_, out_position_);
205
209
WriteValueSegment (position, segment_length);
206
210
} else {
207
- bit_util::SetBitsTo (out_is_valid_, out_offset_ + out_position_,
208
- segment_length, false );
209
- memset (out_data_ + out_offset_ + out_position_, 0 ,
210
- segment_length * sizeof (T));
211
- out_position_ += segment_length;
211
+ bit_util::SetBitsTo (out_is_valid_, out_position_, segment_length, false );
212
+ WriteNullSegment (segment_length);
212
213
}
213
214
return true ;
214
215
});
@@ -218,19 +219,16 @@ class PrimitiveFilterImpl {
218
219
if (out_is_valid_) {
219
220
// Set all to valid, so only if nulls are produced by EMIT_NULL, we need
220
221
// to set out_is_valid[i] to false.
221
- bit_util::SetBitsTo (out_is_valid_, out_offset_ , out_length_, true );
222
+ bit_util::SetBitsTo (out_is_valid_, 0 , out_length_, true );
222
223
}
223
224
return VisitPlainxREEFilterOutputSegments (
224
225
filter_, /* filter_may_have_nulls=*/ true , null_selection_,
225
226
[&](int64_t position, int64_t segment_length, bool filter_valid) {
226
227
if (filter_valid) {
227
228
WriteValueSegment (position, segment_length);
228
229
} else {
229
- bit_util::SetBitsTo (out_is_valid_, out_offset_ + out_position_,
230
- segment_length, false );
231
- memset (out_data_ + out_offset_ + out_position_, 0 ,
232
- segment_length * sizeof (T));
233
- out_position_ += segment_length;
230
+ bit_util::SetBitsTo (out_is_valid_, out_position_, segment_length, false );
231
+ WriteNullSegment (segment_length);
234
232
}
235
233
return true ;
236
234
});
@@ -260,13 +258,13 @@ class PrimitiveFilterImpl {
260
258
values_length_);
261
259
262
260
auto WriteNotNull = [&](int64_t index ) {
263
- bit_util::SetBit (out_is_valid_, out_offset_ + out_position_);
261
+ bit_util::SetBit (out_is_valid_, out_position_);
264
262
// Increments out_position_
265
263
WriteValue (index );
266
264
};
267
265
268
266
auto WriteMaybeNull = [&](int64_t index ) {
269
- bit_util::SetBitTo (out_is_valid_, out_offset_ + out_position_,
267
+ bit_util::SetBitTo (out_is_valid_, out_position_,
270
268
bit_util::GetBit (values_is_valid_, values_offset_ + index ));
271
269
// Increments out_position_
272
270
WriteValue (index );
@@ -279,15 +277,14 @@ class PrimitiveFilterImpl {
279
277
BitBlockCount data_block = data_counter.NextWord ();
280
278
if (filter_block.AllSet () && data_block.AllSet ()) {
281
279
// Fastest path: all values in block are included and not null
282
- bit_util::SetBitsTo (out_is_valid_, out_offset_ + out_position_,
283
- filter_block.length , true );
280
+ bit_util::SetBitsTo (out_is_valid_, out_position_, filter_block.length , true );
284
281
WriteValueSegment (in_position, filter_block.length );
285
282
in_position += filter_block.length ;
286
283
} else if (filter_block.AllSet ()) {
287
284
// Faster: all values are selected, but some values are null
288
285
// Batch copy bits from values validity bitmap to output validity bitmap
289
286
CopyBitmap (values_is_valid_, values_offset_ + in_position, filter_block.length ,
290
- out_is_valid_, out_offset_ + out_position_);
287
+ out_is_valid_, out_position_);
291
288
WriteValueSegment (in_position, filter_block.length );
292
289
in_position += filter_block.length ;
293
290
} else if (filter_block.NoneSet () && null_selection_ == FilterOptions::DROP) {
@@ -326,7 +323,7 @@ class PrimitiveFilterImpl {
326
323
WriteNotNull (in_position);
327
324
} else if (!is_valid) {
328
325
// Filter slot is null, so we have a null in the output
329
- bit_util::ClearBit (out_is_valid_, out_offset_ + out_position_);
326
+ bit_util::ClearBit (out_is_valid_, out_position_);
330
327
WriteNull ();
331
328
}
332
329
++in_position;
@@ -362,7 +359,7 @@ class PrimitiveFilterImpl {
362
359
WriteMaybeNull (in_position);
363
360
} else if (!is_valid) {
364
361
// Filter slot is null, so we have a null in the output
365
- bit_util::ClearBit (out_is_valid_, out_offset_ + out_position_);
362
+ bit_util::ClearBit (out_is_valid_, out_position_);
366
363
WriteNull ();
367
364
}
368
365
++in_position;
@@ -376,54 +373,72 @@ class PrimitiveFilterImpl {
376
373
// Write the next out_position given the selected in_position for the input
377
374
// data and advance out_position
378
375
void WriteValue (int64_t in_position) {
379
- out_data_[out_offset_ + out_position_++] = values_data_[in_position];
376
+ if constexpr (kIsBoolean ) {
377
+ bit_util::SetBitTo (out_data_, out_position_,
378
+ bit_util::GetBit (values_data_, values_offset_ + in_position));
379
+ } else {
380
+ memcpy (out_data_ + out_position_ * byte_width (),
381
+ values_data_ + in_position * byte_width (), byte_width ());
382
+ }
383
+ ++out_position_;
380
384
}
381
385
382
386
void WriteValueSegment (int64_t in_start, int64_t length) {
383
- std::memcpy (out_data_ + out_position_, values_data_ + in_start, length * sizeof (T));
387
+ if constexpr (kIsBoolean ) {
388
+ CopyBitmap (values_data_, values_offset_ + in_start, length, out_data_,
389
+ out_position_);
390
+ } else {
391
+ memcpy (out_data_ + out_position_ * byte_width (),
392
+ values_data_ + in_start * byte_width (), length * byte_width ());
393
+ }
384
394
out_position_ += length;
385
395
}
386
396
387
397
void WriteNull () {
388
- // Zero the memory
389
- out_data_[out_offset_ + out_position_++] = T{};
398
+ if constexpr (kIsBoolean ) {
399
+ // Zero the bit
400
+ bit_util::ClearBit (out_data_, out_position_);
401
+ } else {
402
+ // Zero the memory
403
+ memset (out_data_ + out_position_ * byte_width (), 0 , byte_width ());
404
+ }
405
+ ++out_position_;
406
+ }
407
+
408
+ void WriteNullSegment (int64_t length) {
409
+ if constexpr (kIsBoolean ) {
410
+ // Zero the bits
411
+ bit_util::SetBitsTo (out_data_, out_position_, length, false );
412
+ } else {
413
+ // Zero the memory
414
+ memset (out_data_ + out_position_ * byte_width (), 0 , length * byte_width ());
415
+ }
416
+ out_position_ += length;
417
+ }
418
+
419
+ constexpr int32_t byte_width () const {
420
+ if constexpr (kByteWidth >= 0 ) {
421
+ return kByteWidth ;
422
+ } else {
423
+ return byte_width_;
424
+ }
390
425
}
391
426
392
427
private:
428
+ int32_t byte_width_;
393
429
const uint8_t * values_is_valid_;
394
- const T * values_data_;
430
+ const uint8_t * values_data_;
395
431
int64_t values_null_count_;
396
432
int64_t values_offset_;
397
433
int64_t values_length_;
398
434
const ArraySpan& filter_;
399
435
FilterOptions::NullSelectionBehavior null_selection_;
400
436
uint8_t * out_is_valid_ = NULLPTR;
401
- T* out_data_;
402
- int64_t out_offset_;
437
+ uint8_t * out_data_;
403
438
int64_t out_length_;
404
439
int64_t out_position_;
405
440
};
406
441
407
- template <>
408
- inline void PrimitiveFilterImpl<BooleanType>::WriteValue(int64_t in_position) {
409
- bit_util::SetBitTo (out_data_, out_offset_ + out_position_++,
410
- bit_util::GetBit (values_data_, values_offset_ + in_position));
411
- }
412
-
413
- template <>
414
- inline void PrimitiveFilterImpl<BooleanType>::WriteValueSegment(int64_t in_start,
415
- int64_t length) {
416
- CopyBitmap (values_data_, values_offset_ + in_start, length, out_data_,
417
- out_offset_ + out_position_);
418
- out_position_ += length;
419
- }
420
-
421
- template <>
422
- inline void PrimitiveFilterImpl<BooleanType>::WriteNull() {
423
- // Zero the bit
424
- bit_util::ClearBit (out_data_, out_offset_ + out_position_++);
425
- }
426
-
427
442
Status PrimitiveFilterExec (KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
428
443
const ArraySpan& values = batch[0 ].array ;
429
444
const ArraySpan& filter = batch[1 ].array ;
@@ -459,22 +474,32 @@ Status PrimitiveFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult
459
474
460
475
switch (bit_width) {
461
476
case 1 :
462
- PrimitiveFilterImpl<BooleanType>(values, filter, null_selection, out_arr).Exec ();
477
+ PrimitiveFilterImpl<1 , /* kIsBoolean=*/ true >(values, filter, null_selection, out_arr)
478
+ .Exec ();
463
479
break ;
464
480
case 8 :
465
- PrimitiveFilterImpl<UInt8Type >(values, filter, null_selection, out_arr).Exec ();
481
+ PrimitiveFilterImpl<1 >(values, filter, null_selection, out_arr).Exec ();
466
482
break ;
467
483
case 16 :
468
- PrimitiveFilterImpl<UInt16Type >(values, filter, null_selection, out_arr).Exec ();
484
+ PrimitiveFilterImpl<2 >(values, filter, null_selection, out_arr).Exec ();
469
485
break ;
470
486
case 32 :
471
- PrimitiveFilterImpl<UInt32Type >(values, filter, null_selection, out_arr).Exec ();
487
+ PrimitiveFilterImpl<4 >(values, filter, null_selection, out_arr).Exec ();
472
488
break ;
473
489
case 64 :
474
- PrimitiveFilterImpl<UInt64Type>(values, filter, null_selection, out_arr).Exec ();
490
+ PrimitiveFilterImpl<8 >(values, filter, null_selection, out_arr).Exec ();
491
+ break ;
492
+ case 128 :
493
+ // For INTERVAL_MONTH_DAY_NANO, DECIMAL128
494
+ PrimitiveFilterImpl<16 >(values, filter, null_selection, out_arr).Exec ();
495
+ break ;
496
+ case 256 :
497
+ // For DECIMAL256
498
+ PrimitiveFilterImpl<32 >(values, filter, null_selection, out_arr).Exec ();
475
499
break ;
476
500
default :
477
- DCHECK (false ) << " Invalid values bit width" ;
501
+ // Non-specializing on byte width
502
+ PrimitiveFilterImpl<-1 >(values, filter, null_selection, out_arr).Exec ();
478
503
break ;
479
504
}
480
505
return Status::OK ();
@@ -1050,10 +1075,10 @@ void PopulateFilterKernels(std::vector<SelectionKernelData>* out) {
1050
1075
{InputType (match::Primitive ()), plain_filter, PrimitiveFilterExec},
1051
1076
{InputType (match::BinaryLike ()), plain_filter, BinaryFilterExec},
1052
1077
{InputType (match::LargeBinaryLike ()), plain_filter, BinaryFilterExec},
1053
- {InputType (Type::FIXED_SIZE_BINARY), plain_filter, FSBFilterExec},
1054
1078
{InputType (null ()), plain_filter, NullFilterExec},
1055
- {InputType (Type::DECIMAL128), plain_filter, FSBFilterExec},
1056
- {InputType (Type::DECIMAL256), plain_filter, FSBFilterExec},
1079
+ {InputType (Type::FIXED_SIZE_BINARY), plain_filter, PrimitiveFilterExec},
1080
+ {InputType (Type::DECIMAL128), plain_filter, PrimitiveFilterExec},
1081
+ {InputType (Type::DECIMAL256), plain_filter, PrimitiveFilterExec},
1057
1082
{InputType (Type::DICTIONARY), plain_filter, DictionaryFilterExec},
1058
1083
{InputType (Type::EXTENSION), plain_filter, ExtensionFilterExec},
1059
1084
{InputType (Type::LIST), plain_filter, ListFilterExec},
@@ -1068,10 +1093,10 @@ void PopulateFilterKernels(std::vector<SelectionKernelData>* out) {
1068
1093
{InputType (match::Primitive ()), ree_filter, PrimitiveFilterExec},
1069
1094
{InputType (match::BinaryLike ()), ree_filter, BinaryFilterExec},
1070
1095
{InputType (match::LargeBinaryLike ()), ree_filter, BinaryFilterExec},
1071
- {InputType (Type::FIXED_SIZE_BINARY), ree_filter, FSBFilterExec},
1072
1096
{InputType (null ()), ree_filter, NullFilterExec},
1073
- {InputType (Type::DECIMAL128), ree_filter, FSBFilterExec},
1074
- {InputType (Type::DECIMAL256), ree_filter, FSBFilterExec},
1097
+ {InputType (Type::FIXED_SIZE_BINARY), ree_filter, PrimitiveFilterExec},
1098
+ {InputType (Type::DECIMAL128), ree_filter, PrimitiveFilterExec},
1099
+ {InputType (Type::DECIMAL256), ree_filter, PrimitiveFilterExec},
1075
1100
{InputType (Type::DICTIONARY), ree_filter, DictionaryFilterExec},
1076
1101
{InputType (Type::EXTENSION), ree_filter, ExtensionFilterExec},
1077
1102
{InputType (Type::LIST), ree_filter, ListFilterExec},
0 commit comments