11/*
2- * Copyright (c) 2023-2024 , NVIDIA CORPORATION.
2+ * Copyright (c) 2023-2025 , NVIDIA CORPORATION.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
3232#include < thrust/iterator/counting_iterator.h>
3333
3434#include < algorithm>
35+ #include < limits>
3536#include < numeric>
3637#include < optional>
3738#include < unordered_set>
@@ -388,6 +389,7 @@ class stats_expression_converter : public ast::detail::expression_transformer {
388389} // namespace
389390
390391std::optional<std::vector<std::vector<size_type>>> aggregate_reader_metadata::filter_row_groups (
392+ host_span<std::unique_ptr<datasource> const > sources,
391393 host_span<std::vector<size_type> const > row_group_indices,
392394 host_span<data_type const > output_dtypes,
393395 host_span<int const > output_column_schemas,
@@ -396,7 +398,6 @@ std::optional<std::vector<std::vector<size_type>>> aggregate_reader_metadata::fi
396398{
397399 auto mr = cudf::get_current_device_resource_ref ();
398400 // Create row group indices.
399- std::vector<std::vector<size_type>> filtered_row_group_indices;
400401 std::vector<std::vector<size_type>> all_row_group_indices;
401402 host_span<std::vector<size_type> const > input_row_group_indices;
402403 if (row_group_indices.empty ()) {
@@ -412,18 +413,22 @@ std::optional<std::vector<std::vector<size_type>>> aggregate_reader_metadata::fi
412413 } else {
413414 input_row_group_indices = row_group_indices;
414415 }
415- auto const total_row_groups = std::accumulate (input_row_group_indices.begin (),
416- input_row_group_indices.end (),
417- 0 ,
418- [](size_type sum, auto const & per_file_row_groups) {
419- return sum + per_file_row_groups.size ();
420- });
416+ auto const total_row_groups = std::accumulate (
417+ input_row_group_indices.begin (),
418+ input_row_group_indices.end (),
419+ size_t {0 },
420+ [](size_t sum, auto const & per_file_row_groups) { return sum + per_file_row_groups.size (); });
421+
422+ // Check if we have less than 2B total row groups.
423+ CUDF_EXPECTS (total_row_groups <= std::numeric_limits<cudf::size_type>::max (),
424+ " Total number of row groups exceed the size_type's limit" );
421425
422426 // Converts Column chunk statistics to a table
423427 // where min(col[i]) = columns[i*2], max(col[i])=columns[i*2+1]
424428 // For each column, it contains #sources * #column_chunks_per_src rows.
425429 std::vector<std::unique_ptr<column>> columns;
426- stats_caster const stats_col{total_row_groups, per_file_metadata, input_row_group_indices};
430+ stats_caster const stats_col{
431+ static_cast <size_type>(total_row_groups), per_file_metadata, input_row_group_indices};
427432 for (size_t col_idx = 0 ; col_idx < output_dtypes.size (); col_idx++) {
428433 auto const schema_idx = output_column_schemas[col_idx];
429434 auto const & dtype = output_dtypes[col_idx];
@@ -452,44 +457,23 @@ std::optional<std::vector<std::vector<size_type>>> aggregate_reader_metadata::fi
452457 CUDF_EXPECTS (predicate.type ().id () == cudf::type_id::BOOL8,
453458 " Filter expression must return a boolean column" );
454459
455- auto const host_bitmask = [&] {
456- auto const num_bitmasks = num_bitmask_words (predicate.size ());
457- if (predicate.nullable ()) {
458- return cudf::detail::make_host_vector_sync (
459- device_span<bitmask_type const >(predicate.null_mask (), num_bitmasks), stream);
460- } else {
461- auto bitmask = cudf::detail::make_host_vector<bitmask_type>(num_bitmasks, stream);
462- std::fill (bitmask.begin (), bitmask.end (), ~bitmask_type{0 });
463- return bitmask;
464- }
465- }();
460+ // Filter stats table with StatsAST expression and collect filtered row group indices
461+ auto const filtered_row_group_indices = collect_filtered_row_group_indices (
462+ stats_table, stats_expr.get_stats_expr (), input_row_group_indices, stream);
466463
467- auto validity_it = cudf::detail::make_counting_transform_iterator (
468- 0 , [bitmask = host_bitmask.data ()](auto bit_index) { return bit_is_set (bitmask, bit_index); });
464+ // Span of row groups to apply bloom filtering on.
465+ auto const bloom_filter_input_row_groups =
466+ filtered_row_group_indices.has_value ()
467+ ? host_span<std::vector<size_type> const >(filtered_row_group_indices.value ())
468+ : input_row_group_indices;
469469
470- auto const is_row_group_required = cudf::detail::make_host_vector_sync (
471- device_span<uint8_t const >(predicate.data <uint8_t >(), predicate.size ()), stream);
470+ // Apply bloom filtering on the bloom filter input row groups
471+ auto const bloom_filtered_row_groups = apply_bloom_filters (
472+ sources, bloom_filter_input_row_groups, output_dtypes, output_column_schemas, filter, stream);
472473
473- // Return only filtered row groups based on predicate
474- // if all are required or all are nulls, return.
475- if (std::all_of (is_row_group_required.cbegin (),
476- is_row_group_required.cend (),
477- [](auto i) { return bool (i); }) or
478- predicate.null_count () == predicate.size ()) {
479- return std::nullopt ;
480- }
481- size_type is_required_idx = 0 ;
482- for (auto const & input_row_group_index : input_row_group_indices) {
483- std::vector<size_type> filtered_row_groups;
484- for (auto const rg_idx : input_row_group_index) {
485- if ((!validity_it[is_required_idx]) || is_row_group_required[is_required_idx]) {
486- filtered_row_groups.push_back (rg_idx);
487- }
488- ++is_required_idx;
489- }
490- filtered_row_group_indices.push_back (std::move (filtered_row_groups));
491- }
492- return {std::move (filtered_row_group_indices)};
474+ // Return bloom filtered row group indices iff collected
475+ return bloom_filtered_row_groups.has_value () ? bloom_filtered_row_groups
476+ : filtered_row_group_indices;
493477}
494478
495479// convert column named expression to column index reference expression
@@ -510,14 +494,14 @@ named_to_reference_converter::named_to_reference_converter(
510494std::reference_wrapper<ast::expression const > named_to_reference_converter::visit (
511495 ast::literal const & expr)
512496{
513- _stats_expr = std::reference_wrapper<ast::expression const >(expr);
497+ _converted_expr = std::reference_wrapper<ast::expression const >(expr);
514498 return expr;
515499}
516500
517501std::reference_wrapper<ast::expression const > named_to_reference_converter::visit (
518502 ast::column_reference const & expr)
519503{
520- _stats_expr = std::reference_wrapper<ast::expression const >(expr);
504+ _converted_expr = std::reference_wrapper<ast::expression const >(expr);
521505 return expr;
522506}
523507
@@ -531,7 +515,7 @@ std::reference_wrapper<ast::expression const> named_to_reference_converter::visi
531515 }
532516 auto col_index = col_index_it->second ;
533517 _col_ref.emplace_back (col_index);
534- _stats_expr = std::reference_wrapper<ast::expression const >(_col_ref.back ());
518+ _converted_expr = std::reference_wrapper<ast::expression const >(_col_ref.back ());
535519 return std::reference_wrapper<ast::expression const >(_col_ref.back ());
536520}
537521
@@ -546,7 +530,7 @@ std::reference_wrapper<ast::expression const> named_to_reference_converter::visi
546530 } else if (cudf::ast::detail::ast_operator_arity (op) == 1 ) {
547531 _operators.emplace_back (op, new_operands.front ());
548532 }
549- _stats_expr = std::reference_wrapper<ast::expression const >(_operators.back ());
533+ _converted_expr = std::reference_wrapper<ast::expression const >(_operators.back ());
550534 return std::reference_wrapper<ast::expression const >(_operators.back ());
551535}
552536
@@ -640,4 +624,60 @@ class names_from_expression : public ast::detail::expression_transformer {
640624 return names_from_expression (expr, skip_names).to_vector ();
641625}
642626
627+ std::optional<std::vector<std::vector<size_type>>> collect_filtered_row_group_indices (
628+ cudf::table_view table,
629+ std::reference_wrapper<ast::expression const > ast_expr,
630+ host_span<std::vector<size_type> const > input_row_group_indices,
631+ rmm::cuda_stream_view stream)
632+ {
633+ // Filter the input table using AST expression
634+ auto predicate_col = cudf::detail::compute_column (
635+ table, ast_expr.get (), stream, cudf::get_current_device_resource_ref ());
636+ auto predicate = predicate_col->view ();
637+ CUDF_EXPECTS (predicate.type ().id () == cudf::type_id::BOOL8,
638+ " Filter expression must return a boolean column" );
639+
640+ auto const host_bitmask = [&] {
641+ auto const num_bitmasks = num_bitmask_words (predicate.size ());
642+ if (predicate.nullable ()) {
643+ return cudf::detail::make_host_vector_sync (
644+ device_span<bitmask_type const >(predicate.null_mask (), num_bitmasks), stream);
645+ } else {
646+ auto bitmask = cudf::detail::make_host_vector<bitmask_type>(num_bitmasks, stream);
647+ std::fill (bitmask.begin (), bitmask.end (), ~bitmask_type{0 });
648+ return bitmask;
649+ }
650+ }();
651+
652+ auto validity_it = cudf::detail::make_counting_transform_iterator (
653+ 0 , [bitmask = host_bitmask.data ()](auto bit_index) { return bit_is_set (bitmask, bit_index); });
654+
655+ // Return only filtered row groups based on predicate
656+ auto const is_row_group_required = cudf::detail::make_host_vector_sync (
657+ device_span<uint8_t const >(predicate.data <uint8_t >(), predicate.size ()), stream);
658+
659+ // Return if all are required, or all are nulls.
660+ if (predicate.null_count () == predicate.size () or std::all_of (is_row_group_required.cbegin (),
661+ is_row_group_required.cend (),
662+ [](auto i) { return bool (i); })) {
663+ return std::nullopt ;
664+ }
665+
666+ // Collect indices of the filtered row groups
667+ size_type is_required_idx = 0 ;
668+ std::vector<std::vector<size_type>> filtered_row_group_indices;
669+ for (auto const & input_row_group_index : input_row_group_indices) {
670+ std::vector<size_type> filtered_row_groups;
671+ for (auto const rg_idx : input_row_group_index) {
672+ if ((!validity_it[is_required_idx]) || is_row_group_required[is_required_idx]) {
673+ filtered_row_groups.push_back (rg_idx);
674+ }
675+ ++is_required_idx;
676+ }
677+ filtered_row_group_indices.push_back (std::move (filtered_row_groups));
678+ }
679+
680+ return {filtered_row_group_indices};
681+ }
682+
643683} // namespace cudf::io::parquet::detail
0 commit comments