Skip to content

Commit ff89b77

Browse files
committed
Rewrite test
Signed-off-by: Nghia Truong <[email protected]>
1 parent 9a9f738 commit ff89b77

File tree

1 file changed

+56
-29
lines changed

1 file changed

+56
-29
lines changed

cpp/tests/groupby/host_udf_tests.cpp

+56-29
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,15 @@ namespace {
2828
/**
2929
* @brief A host-based UDF implementation used for unit tests.
3030
*/
31-
template <typename cudf_aggregation, int test_location_line>
32-
struct host_udf_test : cudf::host_udf_base {
33-
static_assert(std::is_same_v<cudf_aggregation, cudf::groupby_aggregation>);
34-
35-
bool* const test_run; // to check if the test is accidentally skipped
31+
struct host_udf_test_base : cudf::host_udf_base {
32+
int const test_location_line; // the location where testing is called
33+
bool* const test_run; // to check if the test is accidentally skipped
3634
data_attributes_set_t const input_attrs;
37-
host_udf_test(bool* test_run_, data_attributes_set_t input_attrs_ = {})
38-
: test_run{test_run_}, input_attrs(std::move(input_attrs_))
35+
36+
host_udf_test_base(int test_location_line_, bool* test_run_, data_attributes_set_t input_attrs_)
37+
: test_location_line{test_location_line_},
38+
test_run{test_run_},
39+
input_attrs(std::move(input_attrs_))
3940
{
4041
}
4142

@@ -47,8 +48,53 @@ struct host_udf_test : cudf::host_udf_base {
4748
rmm::cuda_stream_view stream,
4849
rmm::device_async_resource_ref mr) const override
4950
{
50-
SCOPED_TRACE("Original line of failure: " + std::to_string(test_location_line));
51+
SCOPED_TRACE("Test instance created at line: " + std::to_string(test_location_line));
52+
53+
test_data_attributes(input, stream, mr);
54+
55+
*test_run = true; // test is run successfully
56+
return get_empty_output(std::nullopt, stream, mr);
57+
}
58+
59+
[[nodiscard]] output_t get_empty_output(
60+
[[maybe_unused]] std::optional<cudf::data_type> output_dtype,
61+
[[maybe_unused]] rmm::cuda_stream_view stream,
62+
[[maybe_unused]] rmm::device_async_resource_ref mr) const override
63+
{
64+
// Unused function - dummy output.
65+
return cudf::make_empty_column(cudf::data_type{cudf::type_id::INT32});
66+
}
67+
68+
[[nodiscard]] std::size_t do_hash() const override { return 0; }
69+
[[nodiscard]] bool is_equal(host_udf_base const& other) const override { return true; }
70+
71+
// The main test function, which must be implemented for each kind of aggregations
72+
// (groupby/reduction/segmented_reduction).
73+
virtual void test_data_attributes(host_udf_input const& input,
74+
rmm::cuda_stream_view stream,
75+
rmm::device_async_resource_ref mr) const = 0;
76+
};
77+
78+
/**
79+
* @brief A host-based UDF implementation used for unit tests for groupby aggregation.
80+
*/
81+
struct host_udf_groupby_test : host_udf_test_base {
82+
host_udf_groupby_test(int test_location_line_,
83+
bool* test_run_,
84+
data_attributes_set_t input_attrs_ = {})
85+
: host_udf_test_base(test_location_line_, test_run_, std::move(input_attrs_))
86+
{
87+
}
88+
89+
[[nodiscard]] std::unique_ptr<host_udf_base> clone() const override
90+
{
91+
return std::make_unique<host_udf_groupby_test>(test_location_line, test_run, input_attrs);
92+
}
5193

94+
void test_data_attributes(host_udf_input const& input,
95+
rmm::cuda_stream_view stream,
96+
rmm::device_async_resource_ref mr) const override
97+
{
5298
data_attributes_set_t check_attrs = input_attrs;
5399
if (check_attrs.empty()) {
54100
check_attrs = data_attributes_set_t{groupby_data_attribute::INPUT_VALUES,
@@ -91,24 +137,6 @@ struct host_udf_test : cudf::host_udf_base {
91137
EXPECT_TRUE(std::holds_alternative<cudf::column_view>(input.at(attr)));
92138
}
93139
}
94-
95-
*test_run = true; // test is run successfully
96-
return get_empty_output(std::nullopt, stream, mr);
97-
}
98-
99-
[[nodiscard]] output_t get_empty_output(
100-
[[maybe_unused]] std::optional<cudf::data_type> output_dtype,
101-
[[maybe_unused]] rmm::cuda_stream_view stream,
102-
[[maybe_unused]] rmm::device_async_resource_ref mr) const override
103-
{
104-
return cudf::make_empty_column(cudf::data_type{cudf::type_id::INT32});
105-
}
106-
107-
[[nodiscard]] std::size_t do_hash() const override { return 0; }
108-
[[nodiscard]] bool is_equal(host_udf_base const& other) const override { return true; }
109-
[[nodiscard]] std::unique_ptr<host_udf_base> clone() const override
110-
{
111-
return std::make_unique<host_udf_test>(test_run, input_attrs);
112140
}
113141
};
114142

@@ -168,7 +196,7 @@ TEST_F(HostUDFTest, GroupbyAllInput)
168196
auto const keys = int32s_col{0, 1, 2};
169197
auto const vals = int32s_col{0, 1, 2};
170198
auto agg = cudf::make_host_udf_aggregation<cudf::groupby_aggregation>(
171-
std::make_unique<host_udf_test<cudf::groupby_aggregation, __LINE__>>(&test_run));
199+
std::make_unique<host_udf_groupby_test>(__LINE__, &test_run));
172200

173201
std::vector<cudf::groupby::aggregation_request> requests;
174202
requests.emplace_back();
@@ -197,8 +225,7 @@ TEST_F(HostUDFTest, GroupbySomeInput)
197225
auto input_attrs = get_subset(all_attrs);
198226
input_attrs.insert(get_random_agg());
199227
auto agg = cudf::make_host_udf_aggregation<cudf::groupby_aggregation>(
200-
std::make_unique<host_udf_test<cudf::groupby_aggregation, __LINE__>>(&test_run,
201-
std::move(input_attrs)));
228+
std::make_unique<host_udf_groupby_test>(__LINE__, &test_run, std::move(input_attrs)));
202229

203230
std::vector<cudf::groupby::aggregation_request> requests;
204231
requests.emplace_back();

0 commit comments

Comments
 (0)