Skip to content

Commit 41215e2

Browse files
authored
Support reading bloom filters from Parquet files and filter row groups using them (#17289)
This PR adds support to read bloom filters from Parquet files and use them to filter row groups based on `col == literal` like predicate(s), if provided. Related to #17164 Authors: - Muhammad Haseeb (https://github.com/mhaseeb123) Approvers: - Yunsong Wang (https://github.com/PointKernel) - Vukasin Milovanovic (https://github.com/vuule) - Karthikeyan (https://github.com/karthikeyann) - Bradley Dice (https://github.com/bdice) URL: #17289
1 parent fe75cb8 commit 41215e2

16 files changed

+1098
-67
lines changed

cpp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,7 @@ add_library(
516516
src/datetime/timezone.cpp
517517
src/io/orc/writer_impl.cu
518518
src/io/parquet/arrow_schema_writer.cpp
519+
src/io/parquet/bloom_filter_reader.cu
519520
src/io/parquet/compact_protocol_reader.cpp
520521
src/io/parquet/compact_protocol_writer.cpp
521522
src/io/parquet/decode_preprocess.cu

cpp/src/io/parquet/bloom_filter_reader.cu

Lines changed: 683 additions & 0 deletions
Large diffs are not rendered by default.

cpp/src/io/parquet/compact_protocol_reader.cpp

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2018-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2018-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -658,14 +658,43 @@ void CompactProtocolReader::read(ColumnChunk* c)
658658
function_builder(this, op);
659659
}
660660

661+
void CompactProtocolReader::read(BloomFilterAlgorithm* alg)
662+
{
663+
auto op = std::make_tuple(parquet_field_union_enumerator(1, alg->algorithm));
664+
function_builder(this, op);
665+
}
666+
667+
void CompactProtocolReader::read(BloomFilterHash* hash)
668+
{
669+
auto op = std::make_tuple(parquet_field_union_enumerator(1, hash->hash));
670+
function_builder(this, op);
671+
}
672+
673+
void CompactProtocolReader::read(BloomFilterCompression* comp)
674+
{
675+
auto op = std::make_tuple(parquet_field_union_enumerator(1, comp->compression));
676+
function_builder(this, op);
677+
}
678+
679+
void CompactProtocolReader::read(BloomFilterHeader* bf)
680+
{
681+
auto op = std::make_tuple(parquet_field_int32(1, bf->num_bytes),
682+
parquet_field_struct(2, bf->algorithm),
683+
parquet_field_struct(3, bf->hash),
684+
parquet_field_struct(4, bf->compression));
685+
function_builder(this, op);
686+
}
687+
661688
void CompactProtocolReader::read(ColumnChunkMetaData* c)
662689
{
663690
using optional_size_statistics =
664691
parquet_field_optional<SizeStatistics, parquet_field_struct<SizeStatistics>>;
665692
using optional_list_enc_stats =
666693
parquet_field_optional<std::vector<PageEncodingStats>,
667694
parquet_field_struct_list<PageEncodingStats>>;
668-
auto op = std::make_tuple(parquet_field_enum<Type>(1, c->type),
695+
using optional_i64 = parquet_field_optional<int64_t, parquet_field_int64>;
696+
using optional_i32 = parquet_field_optional<int32_t, parquet_field_int32>;
697+
auto op = std::make_tuple(parquet_field_enum<Type>(1, c->type),
669698
parquet_field_enum_list(2, c->encodings),
670699
parquet_field_string_list(3, c->path_in_schema),
671700
parquet_field_enum<Compression>(4, c->codec),
@@ -677,6 +706,8 @@ void CompactProtocolReader::read(ColumnChunkMetaData* c)
677706
parquet_field_int64(11, c->dictionary_page_offset),
678707
parquet_field_struct(12, c->statistics),
679708
optional_list_enc_stats(13, c->encoding_stats),
709+
optional_i64(14, c->bloom_filter_offset),
710+
optional_i32(15, c->bloom_filter_length),
680711
optional_size_statistics(16, c->size_statistics));
681712
function_builder(this, op);
682713
}

cpp/src/io/parquet/compact_protocol_reader.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2018-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2018-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -108,6 +108,10 @@ class CompactProtocolReader {
108108
void read(IntType* t);
109109
void read(RowGroup* r);
110110
void read(ColumnChunk* c);
111+
void read(BloomFilterAlgorithm* bf);
112+
void read(BloomFilterHash* bf);
113+
void read(BloomFilterCompression* bf);
114+
void read(BloomFilterHeader* bf);
111115
void read(ColumnChunkMetaData* c);
112116
void read(PageHeader* p);
113117
void read(DataPageHeader* d);

cpp/src/io/parquet/parquet.hpp

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2018-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2018-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -382,12 +382,62 @@ struct ColumnChunkMetaData {
382382
// Set of all encodings used for pages in this column chunk. This information can be used to
383383
// determine if all data pages are dictionary encoded for example.
384384
std::optional<std::vector<PageEncodingStats>> encoding_stats;
385+
// Byte offset from beginning of file to Bloom filter data.
386+
std::optional<int64_t> bloom_filter_offset;
387+
// Size of Bloom filter data including the serialized header, in bytes. Added in 2.10 so readers
388+
// may not read this field from old files and it can be obtained after the BloomFilterHeader has
389+
// been deserialized. Writers should write this field so readers can read the bloom filter in a
390+
// single I/O.
391+
std::optional<int32_t> bloom_filter_length;
385392
// Optional statistics to help estimate total memory when converted to in-memory representations.
386393
// The histograms contained in these statistics can also be useful in some cases for more
387394
// fine-grained nullability/list length filter pushdown.
388395
std::optional<SizeStatistics> size_statistics;
389396
};
390397

398+
/**
399+
* @brief The algorithm used in bloom filter
400+
*/
401+
struct BloomFilterAlgorithm {
402+
// Block-based Bloom filter.
403+
enum class Algorithm { UNDEFINED, SPLIT_BLOCK };
404+
Algorithm algorithm{Algorithm::SPLIT_BLOCK};
405+
};
406+
407+
/**
408+
* @brief The hash function used in Bloom filter
409+
*/
410+
struct BloomFilterHash {
411+
// xxHash_64
412+
enum class Hash { UNDEFINED, XXHASH };
413+
Hash hash{Hash::XXHASH};
414+
};
415+
416+
/**
417+
* @brief The compression used in the bloom filter
418+
*/
419+
struct BloomFilterCompression {
420+
enum class Compression { UNDEFINED, UNCOMPRESSED };
421+
Compression compression{Compression::UNCOMPRESSED};
422+
};
423+
424+
/**
425+
* @brief Bloom filter header struct
426+
*
427+
* The bloom filter data of a column chunk stores this header at the beginning
428+
* following by the filter bitset.
429+
*/
430+
struct BloomFilterHeader {
431+
// The size of bitset in bytes
432+
int32_t num_bytes;
433+
// The algorithm for setting bits
434+
BloomFilterAlgorithm algorithm;
435+
// The hash function used for bloom filter
436+
BloomFilterHash hash;
437+
// The compression used in the bloom filter
438+
BloomFilterCompression compression;
439+
};
440+
391441
/**
392442
* @brief Thrift-derived struct describing a chunk of data for a particular
393443
* column

cpp/src/io/parquet/predicate_pushdown.cpp

Lines changed: 88 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2023-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -32,6 +32,7 @@
3232
#include <thrust/iterator/counting_iterator.h>
3333

3434
#include <algorithm>
35+
#include <limits>
3536
#include <numeric>
3637
#include <optional>
3738
#include <unordered_set>
@@ -388,6 +389,7 @@ class stats_expression_converter : public ast::detail::expression_transformer {
388389
} // namespace
389390

390391
std::optional<std::vector<std::vector<size_type>>> aggregate_reader_metadata::filter_row_groups(
392+
host_span<std::unique_ptr<datasource> const> sources,
391393
host_span<std::vector<size_type> const> row_group_indices,
392394
host_span<data_type const> output_dtypes,
393395
host_span<int const> output_column_schemas,
@@ -396,7 +398,6 @@ std::optional<std::vector<std::vector<size_type>>> aggregate_reader_metadata::fi
396398
{
397399
auto mr = cudf::get_current_device_resource_ref();
398400
// Create row group indices.
399-
std::vector<std::vector<size_type>> filtered_row_group_indices;
400401
std::vector<std::vector<size_type>> all_row_group_indices;
401402
host_span<std::vector<size_type> const> input_row_group_indices;
402403
if (row_group_indices.empty()) {
@@ -412,18 +413,22 @@ std::optional<std::vector<std::vector<size_type>>> aggregate_reader_metadata::fi
412413
} else {
413414
input_row_group_indices = row_group_indices;
414415
}
415-
auto const total_row_groups = std::accumulate(input_row_group_indices.begin(),
416-
input_row_group_indices.end(),
417-
0,
418-
[](size_type sum, auto const& per_file_row_groups) {
419-
return sum + per_file_row_groups.size();
420-
});
416+
auto const total_row_groups = std::accumulate(
417+
input_row_group_indices.begin(),
418+
input_row_group_indices.end(),
419+
size_t{0},
420+
[](size_t sum, auto const& per_file_row_groups) { return sum + per_file_row_groups.size(); });
421+
422+
// Check if we have less than 2B total row groups.
423+
CUDF_EXPECTS(total_row_groups <= std::numeric_limits<cudf::size_type>::max(),
424+
"Total number of row groups exceed the size_type's limit");
421425

422426
// Converts Column chunk statistics to a table
423427
// where min(col[i]) = columns[i*2], max(col[i])=columns[i*2+1]
424428
// For each column, it contains #sources * #column_chunks_per_src rows.
425429
std::vector<std::unique_ptr<column>> columns;
426-
stats_caster const stats_col{total_row_groups, per_file_metadata, input_row_group_indices};
430+
stats_caster const stats_col{
431+
static_cast<size_type>(total_row_groups), per_file_metadata, input_row_group_indices};
427432
for (size_t col_idx = 0; col_idx < output_dtypes.size(); col_idx++) {
428433
auto const schema_idx = output_column_schemas[col_idx];
429434
auto const& dtype = output_dtypes[col_idx];
@@ -452,44 +457,23 @@ std::optional<std::vector<std::vector<size_type>>> aggregate_reader_metadata::fi
452457
CUDF_EXPECTS(predicate.type().id() == cudf::type_id::BOOL8,
453458
"Filter expression must return a boolean column");
454459

455-
auto const host_bitmask = [&] {
456-
auto const num_bitmasks = num_bitmask_words(predicate.size());
457-
if (predicate.nullable()) {
458-
return cudf::detail::make_host_vector_sync(
459-
device_span<bitmask_type const>(predicate.null_mask(), num_bitmasks), stream);
460-
} else {
461-
auto bitmask = cudf::detail::make_host_vector<bitmask_type>(num_bitmasks, stream);
462-
std::fill(bitmask.begin(), bitmask.end(), ~bitmask_type{0});
463-
return bitmask;
464-
}
465-
}();
460+
// Filter stats table with StatsAST expression and collect filtered row group indices
461+
auto const filtered_row_group_indices = collect_filtered_row_group_indices(
462+
stats_table, stats_expr.get_stats_expr(), input_row_group_indices, stream);
466463

467-
auto validity_it = cudf::detail::make_counting_transform_iterator(
468-
0, [bitmask = host_bitmask.data()](auto bit_index) { return bit_is_set(bitmask, bit_index); });
464+
// Span of row groups to apply bloom filtering on.
465+
auto const bloom_filter_input_row_groups =
466+
filtered_row_group_indices.has_value()
467+
? host_span<std::vector<size_type> const>(filtered_row_group_indices.value())
468+
: input_row_group_indices;
469469

470-
auto const is_row_group_required = cudf::detail::make_host_vector_sync(
471-
device_span<uint8_t const>(predicate.data<uint8_t>(), predicate.size()), stream);
470+
// Apply bloom filtering on the bloom filter input row groups
471+
auto const bloom_filtered_row_groups = apply_bloom_filters(
472+
sources, bloom_filter_input_row_groups, output_dtypes, output_column_schemas, filter, stream);
472473

473-
// Return only filtered row groups based on predicate
474-
// if all are required or all are nulls, return.
475-
if (std::all_of(is_row_group_required.cbegin(),
476-
is_row_group_required.cend(),
477-
[](auto i) { return bool(i); }) or
478-
predicate.null_count() == predicate.size()) {
479-
return std::nullopt;
480-
}
481-
size_type is_required_idx = 0;
482-
for (auto const& input_row_group_index : input_row_group_indices) {
483-
std::vector<size_type> filtered_row_groups;
484-
for (auto const rg_idx : input_row_group_index) {
485-
if ((!validity_it[is_required_idx]) || is_row_group_required[is_required_idx]) {
486-
filtered_row_groups.push_back(rg_idx);
487-
}
488-
++is_required_idx;
489-
}
490-
filtered_row_group_indices.push_back(std::move(filtered_row_groups));
491-
}
492-
return {std::move(filtered_row_group_indices)};
474+
// Return bloom filtered row group indices iff collected
475+
return bloom_filtered_row_groups.has_value() ? bloom_filtered_row_groups
476+
: filtered_row_group_indices;
493477
}
494478

495479
// convert column named expression to column index reference expression
@@ -510,14 +494,14 @@ named_to_reference_converter::named_to_reference_converter(
510494
std::reference_wrapper<ast::expression const> named_to_reference_converter::visit(
511495
ast::literal const& expr)
512496
{
513-
_stats_expr = std::reference_wrapper<ast::expression const>(expr);
497+
_converted_expr = std::reference_wrapper<ast::expression const>(expr);
514498
return expr;
515499
}
516500

517501
std::reference_wrapper<ast::expression const> named_to_reference_converter::visit(
518502
ast::column_reference const& expr)
519503
{
520-
_stats_expr = std::reference_wrapper<ast::expression const>(expr);
504+
_converted_expr = std::reference_wrapper<ast::expression const>(expr);
521505
return expr;
522506
}
523507

@@ -531,7 +515,7 @@ std::reference_wrapper<ast::expression const> named_to_reference_converter::visi
531515
}
532516
auto col_index = col_index_it->second;
533517
_col_ref.emplace_back(col_index);
534-
_stats_expr = std::reference_wrapper<ast::expression const>(_col_ref.back());
518+
_converted_expr = std::reference_wrapper<ast::expression const>(_col_ref.back());
535519
return std::reference_wrapper<ast::expression const>(_col_ref.back());
536520
}
537521

@@ -546,7 +530,7 @@ std::reference_wrapper<ast::expression const> named_to_reference_converter::visi
546530
} else if (cudf::ast::detail::ast_operator_arity(op) == 1) {
547531
_operators.emplace_back(op, new_operands.front());
548532
}
549-
_stats_expr = std::reference_wrapper<ast::expression const>(_operators.back());
533+
_converted_expr = std::reference_wrapper<ast::expression const>(_operators.back());
550534
return std::reference_wrapper<ast::expression const>(_operators.back());
551535
}
552536

@@ -640,4 +624,60 @@ class names_from_expression : public ast::detail::expression_transformer {
640624
return names_from_expression(expr, skip_names).to_vector();
641625
}
642626

627+
std::optional<std::vector<std::vector<size_type>>> collect_filtered_row_group_indices(
628+
cudf::table_view table,
629+
std::reference_wrapper<ast::expression const> ast_expr,
630+
host_span<std::vector<size_type> const> input_row_group_indices,
631+
rmm::cuda_stream_view stream)
632+
{
633+
// Filter the input table using AST expression
634+
auto predicate_col = cudf::detail::compute_column(
635+
table, ast_expr.get(), stream, cudf::get_current_device_resource_ref());
636+
auto predicate = predicate_col->view();
637+
CUDF_EXPECTS(predicate.type().id() == cudf::type_id::BOOL8,
638+
"Filter expression must return a boolean column");
639+
640+
auto const host_bitmask = [&] {
641+
auto const num_bitmasks = num_bitmask_words(predicate.size());
642+
if (predicate.nullable()) {
643+
return cudf::detail::make_host_vector_sync(
644+
device_span<bitmask_type const>(predicate.null_mask(), num_bitmasks), stream);
645+
} else {
646+
auto bitmask = cudf::detail::make_host_vector<bitmask_type>(num_bitmasks, stream);
647+
std::fill(bitmask.begin(), bitmask.end(), ~bitmask_type{0});
648+
return bitmask;
649+
}
650+
}();
651+
652+
auto validity_it = cudf::detail::make_counting_transform_iterator(
653+
0, [bitmask = host_bitmask.data()](auto bit_index) { return bit_is_set(bitmask, bit_index); });
654+
655+
// Return only filtered row groups based on predicate
656+
auto const is_row_group_required = cudf::detail::make_host_vector_sync(
657+
device_span<uint8_t const>(predicate.data<uint8_t>(), predicate.size()), stream);
658+
659+
// Return if all are required, or all are nulls.
660+
if (predicate.null_count() == predicate.size() or std::all_of(is_row_group_required.cbegin(),
661+
is_row_group_required.cend(),
662+
[](auto i) { return bool(i); })) {
663+
return std::nullopt;
664+
}
665+
666+
// Collect indices of the filtered row groups
667+
size_type is_required_idx = 0;
668+
std::vector<std::vector<size_type>> filtered_row_group_indices;
669+
for (auto const& input_row_group_index : input_row_group_indices) {
670+
std::vector<size_type> filtered_row_groups;
671+
for (auto const rg_idx : input_row_group_index) {
672+
if ((!validity_it[is_required_idx]) || is_row_group_required[is_required_idx]) {
673+
filtered_row_groups.push_back(rg_idx);
674+
}
675+
++is_required_idx;
676+
}
677+
filtered_row_group_indices.push_back(std::move(filtered_row_groups));
678+
}
679+
680+
return {filtered_row_group_indices};
681+
}
682+
643683
} // namespace cudf::io::parquet::detail

cpp/src/io/parquet/reader_impl_helpers.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2022-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -1030,6 +1030,7 @@ std::vector<std::string> aggregate_reader_metadata::get_pandas_index_names() con
10301030

10311031
std::tuple<int64_t, size_type, std::vector<row_group_info>, std::vector<size_t>>
10321032
aggregate_reader_metadata::select_row_groups(
1033+
host_span<std::unique_ptr<datasource> const> sources,
10331034
host_span<std::vector<size_type> const> row_group_indices,
10341035
int64_t skip_rows_opt,
10351036
std::optional<size_type> const& num_rows_opt,
@@ -1042,7 +1043,7 @@ aggregate_reader_metadata::select_row_groups(
10421043
// if filter is not empty, then gather row groups to read after predicate pushdown
10431044
if (filter.has_value()) {
10441045
filtered_row_group_indices = filter_row_groups(
1045-
row_group_indices, output_dtypes, output_column_schemas, filter.value(), stream);
1046+
sources, row_group_indices, output_dtypes, output_column_schemas, filter.value(), stream);
10461047
if (filtered_row_group_indices.has_value()) {
10471048
row_group_indices =
10481049
host_span<std::vector<size_type> const>(filtered_row_group_indices.value());

0 commit comments

Comments
 (0)