Skip to content

Commit bba150c

Browse files
committed
Implement host udf aggregation
Signed-off-by: Nghia Truong <[email protected]>
1 parent 9d5041c commit bba150c

File tree

5 files changed

+142
-37
lines changed

5 files changed

+142
-37
lines changed

cpp/include/cudf/aggregation.hpp

Lines changed: 64 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818

1919
#include <cudf/types.hpp>
2020
#include <cudf/utilities/export.hpp>
21+
#include <cudf/utilities/span.hpp>
22+
23+
#include <rmm/cuda_stream_view.hpp>
2124

2225
#include <functional>
2326
#include <memory>
@@ -84,43 +87,44 @@ class aggregation {
8487
* @brief Possible aggregation operations
8588
*/
8689
enum Kind {
87-
SUM, ///< sum reduction
88-
PRODUCT, ///< product reduction
89-
MIN, ///< min reduction
90-
MAX, ///< max reduction
91-
COUNT_VALID, ///< count number of valid elements
92-
COUNT_ALL, ///< count number of elements
93-
ANY, ///< any reduction
94-
ALL, ///< all reduction
95-
SUM_OF_SQUARES, ///< sum of squares reduction
96-
MEAN, ///< arithmetic mean reduction
97-
M2, ///< sum of squares of differences from the mean
98-
VARIANCE, ///< variance
99-
STD, ///< standard deviation
100-
MEDIAN, ///< median reduction
101-
QUANTILE, ///< compute specified quantile(s)
102-
ARGMAX, ///< Index of max element
103-
ARGMIN, ///< Index of min element
104-
NUNIQUE, ///< count number of unique elements
105-
NTH_ELEMENT, ///< get the nth element
106-
ROW_NUMBER, ///< get row-number of current index (relative to rolling window)
107-
EWMA, ///< get exponential weighted moving average at current index
108-
RANK, ///< get rank of current index
109-
COLLECT_LIST, ///< collect values into a list
110-
COLLECT_SET, ///< collect values into a list without duplicate entries
111-
LEAD, ///< window function, accesses row at specified offset following current row
112-
LAG, ///< window function, accesses row at specified offset preceding current row
113-
PTX, ///< PTX UDF based reduction
114-
CUDA, ///< CUDA UDF based reduction
115-
MERGE_LISTS, ///< merge multiple lists values into one list
116-
MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries
117-
MERGE_M2, ///< merge partial values of M2 aggregation,
118-
COVARIANCE, ///< covariance between two sets of elements
119-
CORRELATION, ///< correlation between two sets of elements
120-
TDIGEST, ///< create a tdigest from a set of input values
121-
MERGE_TDIGEST, ///< create a tdigest by merging multiple tdigests together
122-
HISTOGRAM, ///< compute frequency of each element
123-
MERGE_HISTOGRAM ///< merge partial values of HISTOGRAM aggregation,
90+
SUM, ///< sum reduction
91+
PRODUCT, ///< product reduction
92+
MIN, ///< min reduction
93+
MAX, ///< max reduction
94+
COUNT_VALID, ///< count number of valid elements
95+
COUNT_ALL, ///< count number of elements
96+
ANY, ///< any reduction
97+
ALL, ///< all reduction
98+
SUM_OF_SQUARES, ///< sum of squares reduction
99+
MEAN, ///< arithmetic mean reduction
100+
M2, ///< sum of squares of differences from the mean
101+
VARIANCE, ///< variance
102+
STD, ///< standard deviation
103+
MEDIAN, ///< median reduction
104+
QUANTILE, ///< compute specified quantile(s)
105+
ARGMAX, ///< Index of max element
106+
ARGMIN, ///< Index of min element
107+
NUNIQUE, ///< count number of unique elements
108+
NTH_ELEMENT, ///< get the nth element
109+
ROW_NUMBER, ///< get row-number of current index (relative to rolling window)
110+
EWMA, ///< get exponential weighted moving average at current index
111+
RANK, ///< get rank of current index
112+
COLLECT_LIST, ///< collect values into a list
113+
COLLECT_SET, ///< collect values into a list without duplicate entries
114+
LEAD, ///< window function, accesses row at specified offset following current row
115+
LAG, ///< window function, accesses row at specified offset preceding current row
116+
PTX, ///< PTX UDF based reduction
117+
CUDA, ///< CUDA UDF based reduction
118+
MERGE_LISTS, ///< merge multiple lists values into one list
119+
MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries
120+
MERGE_M2, ///< merge partial values of M2 aggregation,
121+
COVARIANCE, ///< covariance between two sets of elements
122+
CORRELATION, ///< correlation between two sets of elements
123+
TDIGEST, ///< create a tdigest from a set of input values
124+
MERGE_TDIGEST, ///< create a tdigest by merging multiple tdigests together
125+
HISTOGRAM, ///< compute frequency of each element
126+
MERGE_HISTOGRAM, ///< merge partial values of HISTOGRAM aggregation
127+
HOST_UDF ///< host side UDF aggregation
124128
};
125129

126130
aggregation() = delete;
@@ -770,5 +774,28 @@ std::unique_ptr<Base> make_tdigest_aggregation(int max_centroids = 1000);
770774
template <typename Base>
771775
std::unique_ptr<Base> make_merge_tdigest_aggregation(int max_centroids = 1000);
772776

777+
// We should pass as many parameters as possible to this function pointer,
778+
// thus the UDF can have anything it needs to perform its operations.
779+
// Currently (modify if needed):
780+
// column_view const& input,
781+
// cudf::device_span<size_type const> group_offsets,
782+
// cudf::device_span<size_type const> group_labels,
783+
// size_type num_groups,
784+
// int max_centroids,
785+
// rmm::cuda_stream_view stream,
786+
// rmm::device_async_resource_ref mr
787+
using host_udf_func_type = std::function<std::unique_ptr<column>(column_view const&,
788+
device_span<size_type const>,
789+
device_span<size_type const>,
790+
size_type,
791+
rmm::cuda_stream_view,
792+
rmm::device_async_resource_ref)>;
793+
/**
794+
* @brief make_host_udf_aggregation
795+
* @return
796+
*/
797+
template <typename Base>
798+
std::unique_ptr<Base> make_host_udf_aggregation(host_udf_func_type udf_func_);
799+
773800
/** @} */ // end of group
774801
} // namespace CUDF_EXPORT cudf

cpp/include/cudf/detail/aggregation/aggregation.hpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <cudf/detail/utilities/assert.cuh>
2121
#include <cudf/types.hpp>
2222
#include <cudf/utilities/error.hpp>
23+
#include <cudf/utilities/span.hpp>
2324
#include <cudf/utilities/traits.hpp>
2425

2526
#include <functional>
@@ -104,6 +105,8 @@ class simple_aggregations_collector { // Declares the interface for the simple
104105
class tdigest_aggregation const& agg);
105106
virtual std::vector<std::unique_ptr<aggregation>> visit(
106107
data_type col_type, class merge_tdigest_aggregation const& agg);
108+
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
109+
class host_udf_aggregation const& agg);
107110
};
108111

109112
class aggregation_finalizer { // Declares the interface for the finalizer
@@ -144,6 +147,7 @@ class aggregation_finalizer { // Declares the interface for the finalizer
144147
virtual void visit(class tdigest_aggregation const& agg);
145148
virtual void visit(class merge_tdigest_aggregation const& agg);
146149
virtual void visit(class ewma_aggregation const& agg);
150+
virtual void visit(class host_udf_aggregation const& agg);
147151
};
148152

149153
/**
@@ -1186,6 +1190,30 @@ class merge_tdigest_aggregation final : public groupby_aggregation, public reduc
11861190
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
11871191
};
11881192

1193+
/**
1194+
* @brief
1195+
*/
1196+
class host_udf_aggregation final : public groupby_aggregation, public reduce_aggregation {
1197+
public:
1198+
host_udf_func_type host_udf_ptr;
1199+
1200+
explicit host_udf_aggregation(host_udf_func_type host_udf_ptr_)
1201+
: aggregation{HOST_UDF}, host_udf_ptr{std::move(host_udf_ptr_)}
1202+
{
1203+
}
1204+
1205+
[[nodiscard]] std::unique_ptr<aggregation> clone() const override
1206+
{
1207+
return std::make_unique<host_udf_aggregation>(*this);
1208+
}
1209+
std::vector<std::unique_ptr<aggregation>> get_simple_aggregations(
1210+
data_type col_type, simple_aggregations_collector& collector) const override
1211+
{
1212+
return collector.visit(col_type, *this);
1213+
}
1214+
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
1215+
};
1216+
11891217
/**
11901218
* @brief Sentinel value used for `ARGMAX` aggregation.
11911219
*
@@ -1462,6 +1490,11 @@ struct target_type_impl<Source,
14621490
using type = struct_view;
14631491
};
14641492

1493+
template <typename SourceType>
1494+
struct target_type_impl<SourceType, aggregation::HOST_UDF> {
1495+
using type = struct_view;
1496+
};
1497+
14651498
/**
14661499
* @brief Helper alias to get the accumulator type for performing aggregation
14671500
* `k` on elements of type `Source`
@@ -1579,6 +1612,8 @@ CUDF_HOST_DEVICE inline decltype(auto) aggregation_dispatcher(aggregation::Kind
15791612
return f.template operator()<aggregation::MERGE_TDIGEST>(std::forward<Ts>(args)...);
15801613
case aggregation::EWMA:
15811614
return f.template operator()<aggregation::EWMA>(std::forward<Ts>(args)...);
1615+
case aggregation::HOST_UDF:
1616+
return f.template operator()<aggregation::HOST_UDF>(std::forward<Ts>(args)...);
15821617
default: {
15831618
#ifndef __CUDA_ARCH__
15841619
CUDF_FAIL("Unsupported aggregation.");

cpp/src/aggregation/aggregation.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,12 @@ std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
237237
return visit(col_type, static_cast<aggregation const&>(agg));
238238
}
239239

240+
std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
241+
data_type col_type, host_udf_aggregation const& agg)
242+
{
243+
return visit(col_type, static_cast<aggregation const&>(agg));
244+
}
245+
240246
// aggregation_finalizer ----------------------------------------
241247

242248
void aggregation_finalizer::visit(aggregation const& agg) {}
@@ -410,6 +416,11 @@ void aggregation_finalizer::visit(merge_tdigest_aggregation const& agg)
410416
visit(static_cast<aggregation const&>(agg));
411417
}
412418

419+
void aggregation_finalizer::visit(host_udf_aggregation const& agg)
420+
{
421+
visit(static_cast<aggregation const&>(agg));
422+
}
423+
413424
} // namespace detail
414425

415426
std::vector<std::unique_ptr<aggregation>> aggregation::get_simple_aggregations(
@@ -917,6 +928,18 @@ make_merge_tdigest_aggregation<groupby_aggregation>(int max_centroids);
917928
template CUDF_EXPORT std::unique_ptr<reduce_aggregation>
918929
make_merge_tdigest_aggregation<reduce_aggregation>(int max_centroids);
919930

931+
template <typename Base>
932+
std::unique_ptr<Base> make_host_udf_aggregation(host_udf_func_type udf_func_)
933+
{
934+
return std::make_unique<detail::host_udf_aggregation>(udf_func_);
935+
}
936+
template CUDF_EXPORT std::unique_ptr<aggregation> make_host_udf_aggregation<aggregation>(
937+
host_udf_func_type);
938+
template CUDF_EXPORT std::unique_ptr<groupby_aggregation>
939+
make_host_udf_aggregation<groupby_aggregation>(host_udf_func_type);
940+
template CUDF_EXPORT std::unique_ptr<reduce_aggregation>
941+
make_host_udf_aggregation<reduce_aggregation>(host_udf_func_type);
942+
920943
namespace detail {
921944
namespace {
922945
struct target_type_functor {

cpp/src/groupby/sort/aggregate.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,23 @@ void aggregate_result_functor::operator()<aggregation::MERGE_TDIGEST>(aggregatio
791791
mr));
792792
}
793793

794+
template <>
795+
void aggregate_result_functor::operator()<aggregation::HOST_UDF>(aggregation const& agg)
796+
{
797+
// TODO: Add a name string to the aggregation so that we can look up different host UDFs.
798+
if (cache.has_result(values, agg)) { return; }
799+
auto const udf_ptr = dynamic_cast<cudf::detail::host_udf_aggregation const&>(agg).host_udf_ptr;
800+
CUDF_EXPECTS(udf_ptr != nullptr, "errrrrrrrrr");
801+
cache.add_result(values,
802+
agg,
803+
udf_ptr(get_grouped_values(),
804+
helper.group_offsets(stream),
805+
helper.group_labels(stream),
806+
helper.num_groups(stream),
807+
stream,
808+
mr));
809+
}
810+
794811
} // namespace detail
795812

796813
// Sort-based groupby

cpp/src/reductions/reductions.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ struct reduce_dispatch_functor {
144144
auto td_agg = static_cast<cudf::detail::merge_tdigest_aggregation const&>(agg);
145145
return tdigest::detail::reduce_merge_tdigest(col, td_agg.max_centroids, stream, mr);
146146
}
147+
case aggregation::HOST_UDF: {
148+
CUDF_FAIL("Host UDF aggregation is not implemented in `reduction`");
149+
}
147150
default: CUDF_FAIL("Unsupported reduction operator");
148151
}
149152
}

0 commit comments

Comments
 (0)