@@ -28,14 +28,15 @@ namespace {
28
28
/* *
29
29
* @brief A host-based UDF implementation used for unit tests.
30
30
*/
31
- template <typename cudf_aggregation, int test_location_line>
32
- struct host_udf_test : cudf::host_udf_base {
33
- static_assert (std::is_same_v<cudf_aggregation, cudf::groupby_aggregation>);
34
-
35
- bool * const test_run; // to check if the test is accidentally skipped
31
+ struct host_udf_test_base : cudf::host_udf_base {
32
+ int const test_location_line; // the location where testing is called
33
+ bool * const test_run; // to check if the test is accidentally skipped
36
34
data_attributes_set_t const input_attrs;
37
- host_udf_test (bool * test_run_, data_attributes_set_t input_attrs_ = {})
38
- : test_run{test_run_}, input_attrs(std::move(input_attrs_))
35
+
36
+ host_udf_test_base (int test_location_line_, bool * test_run_, data_attributes_set_t input_attrs_)
37
+ : test_location_line{test_location_line_},
38
+ test_run{test_run_},
39
+ input_attrs (std::move(input_attrs_))
39
40
{
40
41
}
41
42
@@ -47,8 +48,53 @@ struct host_udf_test : cudf::host_udf_base {
47
48
rmm::cuda_stream_view stream,
48
49
rmm::device_async_resource_ref mr) const override
49
50
{
50
- SCOPED_TRACE (" Original line of failure: " + std::to_string (test_location_line));
51
+ SCOPED_TRACE (" Test instance created at line: " + std::to_string (test_location_line));
52
+
53
+ test_data_attributes (input, stream, mr);
54
+
55
+ *test_run = true ; // test is run successfully
56
+ return get_empty_output (std::nullopt, stream, mr);
57
+ }
58
+
59
+ [[nodiscard]] output_t get_empty_output (
60
+ [[maybe_unused]] std::optional<cudf::data_type> output_dtype,
61
+ [[maybe_unused]] rmm::cuda_stream_view stream,
62
+ [[maybe_unused]] rmm::device_async_resource_ref mr) const override
63
+ {
64
+ // Unused function - dummy output.
65
+ return cudf::make_empty_column (cudf::data_type{cudf::type_id::INT32});
66
+ }
67
+
68
+ [[nodiscard]] std::size_t do_hash () const override { return 0 ; }
69
+ [[nodiscard]] bool is_equal (host_udf_base const & other) const override { return true ; }
70
+
71
+ // The main test function, which must be implemented for each kind of aggregations
72
+ // (groupby/reduction/segmented_reduction).
73
+ virtual void test_data_attributes (host_udf_input const & input,
74
+ rmm::cuda_stream_view stream,
75
+ rmm::device_async_resource_ref mr) const = 0;
76
+ };
77
+
78
+ /* *
79
+ * @brief A host-based UDF implementation used for unit tests for groupby aggregation.
80
+ */
81
+ struct host_udf_groupby_test : host_udf_test_base {
82
+ host_udf_groupby_test (int test_location_line_,
83
+ bool * test_run_,
84
+ data_attributes_set_t input_attrs_ = {})
85
+ : host_udf_test_base(test_location_line_, test_run_, std::move(input_attrs_))
86
+ {
87
+ }
88
+
89
+ [[nodiscard]] std::unique_ptr<host_udf_base> clone () const override
90
+ {
91
+ return std::make_unique<host_udf_groupby_test>(test_location_line, test_run, input_attrs);
92
+ }
51
93
94
+ void test_data_attributes (host_udf_input const & input,
95
+ rmm::cuda_stream_view stream,
96
+ rmm::device_async_resource_ref mr) const override
97
+ {
52
98
data_attributes_set_t check_attrs = input_attrs;
53
99
if (check_attrs.empty ()) {
54
100
check_attrs = data_attributes_set_t {groupby_data_attribute::INPUT_VALUES,
@@ -91,24 +137,6 @@ struct host_udf_test : cudf::host_udf_base {
91
137
EXPECT_TRUE (std::holds_alternative<cudf::column_view>(input.at (attr)));
92
138
}
93
139
}
94
-
95
- *test_run = true ; // test is run successfully
96
- return get_empty_output (std::nullopt, stream, mr);
97
- }
98
-
99
- [[nodiscard]] output_t get_empty_output (
100
- [[maybe_unused]] std::optional<cudf::data_type> output_dtype,
101
- [[maybe_unused]] rmm::cuda_stream_view stream,
102
- [[maybe_unused]] rmm::device_async_resource_ref mr) const override
103
- {
104
- return cudf::make_empty_column (cudf::data_type{cudf::type_id::INT32});
105
- }
106
-
107
- [[nodiscard]] std::size_t do_hash () const override { return 0 ; }
108
- [[nodiscard]] bool is_equal (host_udf_base const & other) const override { return true ; }
109
- [[nodiscard]] std::unique_ptr<host_udf_base> clone () const override
110
- {
111
- return std::make_unique<host_udf_test>(test_run, input_attrs);
112
140
}
113
141
};
114
142
@@ -168,7 +196,7 @@ TEST_F(HostUDFTest, GroupbyAllInput)
168
196
auto const keys = int32s_col{0 , 1 , 2 };
169
197
auto const vals = int32s_col{0 , 1 , 2 };
170
198
auto agg = cudf::make_host_udf_aggregation<cudf::groupby_aggregation>(
171
- std::make_unique<host_udf_test<cudf::groupby_aggregation, __LINE__>>( &test_run));
199
+ std::make_unique<host_udf_groupby_test>(__LINE__, &test_run));
172
200
173
201
std::vector<cudf::groupby::aggregation_request> requests;
174
202
requests.emplace_back ();
@@ -197,8 +225,7 @@ TEST_F(HostUDFTest, GroupbySomeInput)
197
225
auto input_attrs = get_subset (all_attrs);
198
226
input_attrs.insert (get_random_agg ());
199
227
auto agg = cudf::make_host_udf_aggregation<cudf::groupby_aggregation>(
200
- std::make_unique<host_udf_test<cudf::groupby_aggregation, __LINE__>>(&test_run,
201
- std::move (input_attrs)));
228
+ std::make_unique<host_udf_groupby_test>(__LINE__, &test_run, std::move (input_attrs)));
202
229
203
230
std::vector<cudf::groupby::aggregation_request> requests;
204
231
requests.emplace_back ();
0 commit comments