|
18 | 18 |
|
19 | 19 | #include <cudf/types.hpp> |
20 | 20 | #include <cudf/utilities/export.hpp> |
| 21 | +#include <cudf/utilities/span.hpp> |
| 22 | + |
| 23 | +#include <rmm/cuda_stream_view.hpp> |
21 | 24 |
|
22 | 25 | #include <functional> |
23 | 26 | #include <memory> |
@@ -84,43 +87,44 @@ class aggregation { |
84 | 87 | * @brief Possible aggregation operations |
85 | 88 | */ |
86 | 89 | enum Kind { |
87 | | - SUM, ///< sum reduction |
88 | | - PRODUCT, ///< product reduction |
89 | | - MIN, ///< min reduction |
90 | | - MAX, ///< max reduction |
91 | | - COUNT_VALID, ///< count number of valid elements |
92 | | - COUNT_ALL, ///< count number of elements |
93 | | - ANY, ///< any reduction |
94 | | - ALL, ///< all reduction |
95 | | - SUM_OF_SQUARES, ///< sum of squares reduction |
96 | | - MEAN, ///< arithmetic mean reduction |
97 | | - M2, ///< sum of squares of differences from the mean |
98 | | - VARIANCE, ///< variance |
99 | | - STD, ///< standard deviation |
100 | | - MEDIAN, ///< median reduction |
101 | | - QUANTILE, ///< compute specified quantile(s) |
102 | | - ARGMAX, ///< Index of max element |
103 | | - ARGMIN, ///< Index of min element |
104 | | - NUNIQUE, ///< count number of unique elements |
105 | | - NTH_ELEMENT, ///< get the nth element |
106 | | - ROW_NUMBER, ///< get row-number of current index (relative to rolling window) |
107 | | - EWMA, ///< get exponential weighted moving average at current index |
108 | | - RANK, ///< get rank of current index |
109 | | - COLLECT_LIST, ///< collect values into a list |
110 | | - COLLECT_SET, ///< collect values into a list without duplicate entries |
111 | | - LEAD, ///< window function, accesses row at specified offset following current row |
112 | | - LAG, ///< window function, accesses row at specified offset preceding current row |
113 | | - PTX, ///< PTX UDF based reduction |
114 | | - CUDA, ///< CUDA UDF based reduction |
115 | | - MERGE_LISTS, ///< merge multiple lists values into one list |
116 | | - MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries |
117 | | - MERGE_M2, ///< merge partial values of M2 aggregation, |
118 | | - COVARIANCE, ///< covariance between two sets of elements |
119 | | - CORRELATION, ///< correlation between two sets of elements |
120 | | - TDIGEST, ///< create a tdigest from a set of input values |
121 | | - MERGE_TDIGEST, ///< create a tdigest by merging multiple tdigests together |
122 | | - HISTOGRAM, ///< compute frequency of each element |
123 | | - MERGE_HISTOGRAM ///< merge partial values of HISTOGRAM aggregation, |
| 90 | + SUM, ///< sum reduction |
| 91 | + PRODUCT, ///< product reduction |
| 92 | + MIN, ///< min reduction |
| 93 | + MAX, ///< max reduction |
| 94 | + COUNT_VALID, ///< count number of valid elements |
| 95 | + COUNT_ALL, ///< count number of elements |
| 96 | + ANY, ///< any reduction |
| 97 | + ALL, ///< all reduction |
| 98 | + SUM_OF_SQUARES, ///< sum of squares reduction |
| 99 | + MEAN, ///< arithmetic mean reduction |
| 100 | + M2, ///< sum of squares of differences from the mean |
| 101 | + VARIANCE, ///< variance |
| 102 | + STD, ///< standard deviation |
| 103 | + MEDIAN, ///< median reduction |
| 104 | + QUANTILE, ///< compute specified quantile(s) |
| 105 | + ARGMAX, ///< Index of max element |
| 106 | + ARGMIN, ///< Index of min element |
| 107 | + NUNIQUE, ///< count number of unique elements |
| 108 | + NTH_ELEMENT, ///< get the nth element |
| 109 | + ROW_NUMBER, ///< get row-number of current index (relative to rolling window) |
| 110 | + EWMA, ///< get exponential weighted moving average at current index |
| 111 | + RANK, ///< get rank of current index |
| 112 | + COLLECT_LIST, ///< collect values into a list |
| 113 | + COLLECT_SET, ///< collect values into a list without duplicate entries |
| 114 | + LEAD, ///< window function, accesses row at specified offset following current row |
| 115 | + LAG, ///< window function, accesses row at specified offset preceding current row |
| 116 | + PTX, ///< PTX UDF based reduction |
| 117 | + CUDA, ///< CUDA UDF based reduction |
| 118 | + MERGE_LISTS, ///< merge multiple lists values into one list |
| 119 | + MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries |
| 120 | + MERGE_M2, ///< merge partial values of M2 aggregation, |
| 121 | + COVARIANCE, ///< covariance between two sets of elements |
| 122 | + CORRELATION, ///< correlation between two sets of elements |
| 123 | + TDIGEST, ///< create a tdigest from a set of input values |
| 124 | + MERGE_TDIGEST, ///< create a tdigest by merging multiple tdigests together |
| 125 | + HISTOGRAM, ///< compute frequency of each element |
| 126 | + MERGE_HISTOGRAM, ///< merge partial values of HISTOGRAM aggregation |
| 127 | + HOST_UDF ///< host side UDF aggregation |
124 | 128 | }; |
125 | 129 |
|
126 | 130 | aggregation() = delete; |
@@ -770,5 +774,28 @@ std::unique_ptr<Base> make_tdigest_aggregation(int max_centroids = 1000); |
770 | 774 | template <typename Base> |
771 | 775 | std::unique_ptr<Base> make_merge_tdigest_aggregation(int max_centroids = 1000); |
772 | 776 |
|
| 777 | +// We should pass as many parameters as possible to this function pointer, |
| 778 | +// thus the UDF can have anything it needs to perform its operations. |
| 779 | +// Currently (modify if needed): |
| 780 | +// column_view const& input, |
| 781 | +// cudf::device_span<size_type const> group_offsets, |
| 782 | +// cudf::device_span<size_type const> group_labels, |
| 783 | +// size_type num_groups, |
| 784 | +// int max_centroids, |
| 785 | +// rmm::cuda_stream_view stream, |
| 786 | +// rmm::device_async_resource_ref mr |
| 787 | +using host_udf_func_type = std::function<std::unique_ptr<column>(column_view const&, |
| 788 | + device_span<size_type const>, |
| 789 | + device_span<size_type const>, |
| 790 | + size_type, |
| 791 | + rmm::cuda_stream_view, |
| 792 | + rmm::device_async_resource_ref)>; |
| 793 | +/** |
| 794 | + * @brief make_host_udf_aggregation |
| 795 | + * @return |
| 796 | + */ |
| 797 | +template <typename Base> |
| 798 | +std::unique_ptr<Base> make_host_udf_aggregation(host_udf_func_type udf_func_); |
| 799 | + |
773 | 800 | /** @} */ // end of group |
774 | 801 | } // namespace CUDF_EXPORT cudf |
0 commit comments