Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC: Implement HOST_UDF aggregations #17249

Draft
wants to merge 3 commits into
base: branch-24.12
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 64 additions & 37 deletions cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

#include <cudf/types.hpp>
#include <cudf/utilities/export.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>

#include <functional>
#include <memory>
Expand Down Expand Up @@ -84,43 +87,44 @@ class aggregation {
* @brief Possible aggregation operations
*/
enum Kind {
SUM, ///< sum reduction
PRODUCT, ///< product reduction
MIN, ///< min reduction
MAX, ///< max reduction
COUNT_VALID, ///< count number of valid elements
COUNT_ALL, ///< count number of elements
ANY, ///< any reduction
ALL, ///< all reduction
SUM_OF_SQUARES, ///< sum of squares reduction
MEAN, ///< arithmetic mean reduction
M2, ///< sum of squares of differences from the mean
VARIANCE, ///< variance
STD, ///< standard deviation
MEDIAN, ///< median reduction
QUANTILE, ///< compute specified quantile(s)
ARGMAX, ///< Index of max element
ARGMIN, ///< Index of min element
NUNIQUE, ///< count number of unique elements
NTH_ELEMENT, ///< get the nth element
ROW_NUMBER, ///< get row-number of current index (relative to rolling window)
EWMA, ///< get exponential weighted moving average at current index
RANK, ///< get rank of current index
COLLECT_LIST, ///< collect values into a list
COLLECT_SET, ///< collect values into a list without duplicate entries
LEAD, ///< window function, accesses row at specified offset following current row
LAG, ///< window function, accesses row at specified offset preceding current row
PTX, ///< PTX UDF based reduction
CUDA, ///< CUDA UDF based reduction
MERGE_LISTS, ///< merge multiple lists values into one list
MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries
MERGE_M2, ///< merge partial values of M2 aggregation,
COVARIANCE, ///< covariance between two sets of elements
CORRELATION, ///< correlation between two sets of elements
TDIGEST, ///< create a tdigest from a set of input values
MERGE_TDIGEST, ///< create a tdigest by merging multiple tdigests together
HISTOGRAM, ///< compute frequency of each element
MERGE_HISTOGRAM ///< merge partial values of HISTOGRAM aggregation,
SUM, ///< sum reduction
PRODUCT, ///< product reduction
MIN, ///< min reduction
MAX, ///< max reduction
COUNT_VALID, ///< count number of valid elements
COUNT_ALL, ///< count number of elements
ANY, ///< any reduction
ALL, ///< all reduction
SUM_OF_SQUARES, ///< sum of squares reduction
MEAN, ///< arithmetic mean reduction
M2, ///< sum of squares of differences from the mean
VARIANCE, ///< variance
STD, ///< standard deviation
MEDIAN, ///< median reduction
QUANTILE, ///< compute specified quantile(s)
ARGMAX, ///< Index of max element
ARGMIN, ///< Index of min element
NUNIQUE, ///< count number of unique elements
NTH_ELEMENT, ///< get the nth element
ROW_NUMBER, ///< get row-number of current index (relative to rolling window)
EWMA, ///< get exponential weighted moving average at current index
RANK, ///< get rank of current index
COLLECT_LIST, ///< collect values into a list
COLLECT_SET, ///< collect values into a list without duplicate entries
LEAD, ///< window function, accesses row at specified offset following current row
LAG, ///< window function, accesses row at specified offset preceding current row
PTX, ///< PTX UDF based reduction
CUDA, ///< CUDA UDF based reduction
MERGE_LISTS, ///< merge multiple lists values into one list
MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries
MERGE_M2, ///< merge partial values of M2 aggregation,
COVARIANCE, ///< covariance between two sets of elements
CORRELATION, ///< correlation between two sets of elements
TDIGEST, ///< create a tdigest from a set of input values
MERGE_TDIGEST, ///< create a tdigest by merging multiple tdigests together
HISTOGRAM, ///< compute frequency of each element
MERGE_HISTOGRAM, ///< merge partial values of HISTOGRAM aggregation
HOST_UDF ///< host side UDF aggregation
};

aggregation() = delete;
Expand Down Expand Up @@ -770,5 +774,28 @@ std::unique_ptr<Base> make_tdigest_aggregation(int max_centroids = 1000);
template <typename Base>
std::unique_ptr<Base> make_merge_tdigest_aggregation(int max_centroids = 1000);

// We should pass as many parameters as possible to this function pointer,
// thus the UDF can have anything it needs to perform its operations.
// Currently (modify if needed):
// column_view const& input,
// cudf::device_span<size_type const> group_offsets,
// cudf::device_span<size_type const> group_labels,
// size_type num_groups,
// int max_centroids,
// rmm::cuda_stream_view stream,
// rmm::device_async_resource_ref mr
using host_udf_func_type = std::function<std::unique_ptr<column>(column_view const&,
device_span<size_type const>,
device_span<size_type const>,
size_type,
rmm::cuda_stream_view,
rmm::device_async_resource_ref)>;
Comment on lines +787 to +792
Copy link
Contributor Author

@ttnghia ttnghia Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only passes group values (the first column parameter). It seems that we should better pass the group keys too, to have all the group information needed for generic computation.

/**
* @brief make_host_udf_aggregation
* @return
*/
template <typename Base>
std::unique_ptr<Base> make_host_udf_aggregation(host_udf_func_type udf_func_);

/** @} */ // end of group
} // namespace CUDF_EXPORT cudf
35 changes: 35 additions & 0 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <cudf/detail/utilities/assert.cuh>
#include <cudf/types.hpp>
#include <cudf/utilities/error.hpp>
#include <cudf/utilities/span.hpp>
#include <cudf/utilities/traits.hpp>

#include <functional>
Expand Down Expand Up @@ -104,6 +105,8 @@ class simple_aggregations_collector { // Declares the interface for the simple
class tdigest_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(
data_type col_type, class merge_tdigest_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class host_udf_aggregation const& agg);
};

class aggregation_finalizer { // Declares the interface for the finalizer
Expand Down Expand Up @@ -144,6 +147,7 @@ class aggregation_finalizer { // Declares the interface for the finalizer
virtual void visit(class tdigest_aggregation const& agg);
virtual void visit(class merge_tdigest_aggregation const& agg);
virtual void visit(class ewma_aggregation const& agg);
virtual void visit(class host_udf_aggregation const& agg);
};

/**
Expand Down Expand Up @@ -1186,6 +1190,30 @@ class merge_tdigest_aggregation final : public groupby_aggregation, public reduc
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};

/**
* @brief
*/
class host_udf_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
host_udf_func_type host_udf_ptr;

explicit host_udf_aggregation(host_udf_func_type host_udf_ptr_)
: aggregation{HOST_UDF}, host_udf_ptr{std::move(host_udf_ptr_)}
{
}

[[nodiscard]] std::unique_ptr<aggregation> clone() const override
{
return std::make_unique<host_udf_aggregation>(*this);
}
std::vector<std::unique_ptr<aggregation>> get_simple_aggregations(
data_type col_type, simple_aggregations_collector& collector) const override
{
return collector.visit(col_type, *this);
}
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};

/**
* @brief Sentinel value used for `ARGMAX` aggregation.
*
Expand Down Expand Up @@ -1462,6 +1490,11 @@ struct target_type_impl<Source,
using type = struct_view;
};

template <typename SourceType>
struct target_type_impl<SourceType, aggregation::HOST_UDF> {
using type = struct_view;
};

/**
* @brief Helper alias to get the accumulator type for performing aggregation
* `k` on elements of type `Source`
Expand Down Expand Up @@ -1579,6 +1612,8 @@ CUDF_HOST_DEVICE inline decltype(auto) aggregation_dispatcher(aggregation::Kind
return f.template operator()<aggregation::MERGE_TDIGEST>(std::forward<Ts>(args)...);
case aggregation::EWMA:
return f.template operator()<aggregation::EWMA>(std::forward<Ts>(args)...);
case aggregation::HOST_UDF:
return f.template operator()<aggregation::HOST_UDF>(std::forward<Ts>(args)...);
default: {
#ifndef __CUDA_ARCH__
CUDF_FAIL("Unsupported aggregation.");
Expand Down
23 changes: 23 additions & 0 deletions cpp/src/aggregation/aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,12 @@ std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
return visit(col_type, static_cast<aggregation const&>(agg));
}

std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
data_type col_type, host_udf_aggregation const& agg)
{
return visit(col_type, static_cast<aggregation const&>(agg));
}

// aggregation_finalizer ----------------------------------------

void aggregation_finalizer::visit(aggregation const& agg) {}
Expand Down Expand Up @@ -410,6 +416,11 @@ void aggregation_finalizer::visit(merge_tdigest_aggregation const& agg)
visit(static_cast<aggregation const&>(agg));
}

void aggregation_finalizer::visit(host_udf_aggregation const& agg)
{
visit(static_cast<aggregation const&>(agg));
}

} // namespace detail

std::vector<std::unique_ptr<aggregation>> aggregation::get_simple_aggregations(
Expand Down Expand Up @@ -917,6 +928,18 @@ make_merge_tdigest_aggregation<groupby_aggregation>(int max_centroids);
template CUDF_EXPORT std::unique_ptr<reduce_aggregation>
make_merge_tdigest_aggregation<reduce_aggregation>(int max_centroids);

template <typename Base>
std::unique_ptr<Base> make_host_udf_aggregation(host_udf_func_type udf_func_)
{
return std::make_unique<detail::host_udf_aggregation>(udf_func_);
}
template CUDF_EXPORT std::unique_ptr<aggregation> make_host_udf_aggregation<aggregation>(
host_udf_func_type);
template CUDF_EXPORT std::unique_ptr<groupby_aggregation>
make_host_udf_aggregation<groupby_aggregation>(host_udf_func_type);
template CUDF_EXPORT std::unique_ptr<reduce_aggregation>
make_host_udf_aggregation<reduce_aggregation>(host_udf_func_type);

namespace detail {
namespace {
struct target_type_functor {
Expand Down
17 changes: 17 additions & 0 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,23 @@ void aggregate_result_functor::operator()<aggregation::MERGE_TDIGEST>(aggregatio
mr));
}

template <>
void aggregate_result_functor::operator()<aggregation::HOST_UDF>(aggregation const& agg)
{
// TODO: Add a name string to the aggregation so that we can look up different host UDFs.
if (cache.has_result(values, agg)) { return; }
Comment on lines +797 to +798
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Why not ask the implementer of the host udf to provide hash and equality, like the other aggregations have?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Providing a name string should be enough for hashing agg here as we will hash a pair {aggregation::kind, udf_name_str}. That will be much simpler than providing a hash and equality functor.

auto const udf_ptr = dynamic_cast<cudf::detail::host_udf_aggregation const&>(agg).host_udf_ptr;
CUDF_EXPECTS(udf_ptr != nullptr, "errrrrrrrrr");
cache.add_result(values,
agg,
udf_ptr(get_grouped_values(),
helper.group_offsets(stream),
helper.group_labels(stream),
helper.num_groups(stream),
stream,
mr));
}

} // namespace detail

// Sort-based groupby
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/reductions/reductions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ struct reduce_dispatch_functor {
auto td_agg = static_cast<cudf::detail::merge_tdigest_aggregation const&>(agg);
return tdigest::detail::reduce_merge_tdigest(col, td_agg.max_centroids, stream, mr);
}
case aggregation::HOST_UDF: {
CUDF_FAIL("Host UDF aggregation is not implemented in `reduction`");
}
default: CUDF_FAIL("Unsupported reduction operator");
}
}
Expand Down
39 changes: 1 addition & 38 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,44 +122,7 @@ ConfigureTest(TIMESTAMPS_TEST wrappers/timestamps_test.cu)
# * groupby tests ---------------------------------------------------------------------------------
ConfigureTest(
GROUPBY_TEST
groupby/argmin_tests.cpp
groupby/argmax_tests.cpp
groupby/collect_list_tests.cpp
groupby/collect_set_tests.cpp
groupby/correlation_tests.cpp
groupby/count_scan_tests.cpp
groupby/count_tests.cpp
groupby/covariance_tests.cpp
groupby/groupby_test_util.cpp
groupby/groups_tests.cpp
groupby/histogram_tests.cpp
groupby/keys_tests.cpp
groupby/lists_tests.cpp
groupby/m2_tests.cpp
groupby/min_tests.cpp
groupby/max_scan_tests.cpp
groupby/max_tests.cpp
groupby/mean_tests.cpp
groupby/median_tests.cpp
groupby/merge_m2_tests.cpp
groupby/merge_lists_tests.cpp
groupby/merge_sets_tests.cpp
groupby/min_scan_tests.cpp
groupby/nth_element_tests.cpp
groupby/nunique_tests.cpp
groupby/product_scan_tests.cpp
groupby/product_tests.cpp
groupby/quantile_tests.cpp
groupby/rank_scan_tests.cpp
groupby/replace_nulls_tests.cpp
groupby/shift_tests.cpp
groupby/std_tests.cpp
groupby/structs_tests.cpp
groupby/sum_of_squares_tests.cpp
groupby/sum_scan_tests.cpp
groupby/sum_tests.cpp
groupby/tdigest_tests.cu
groupby/var_tests.cpp
groupby/host_udf_tests.cu
GPUS 1
PERCENT 100
)
Expand Down
Loading
Loading