Skip to content

Commit 04e2bda

Browse files
committed
Add test
Signed-off-by: Nghia Truong <[email protected]>
1 parent bba150c commit 04e2bda

File tree

2 files changed

+112
-38
lines changed

2 files changed

+112
-38
lines changed

cpp/tests/CMakeLists.txt

+1-38
Original file line numberDiff line numberDiff line change
@@ -122,44 +122,7 @@ ConfigureTest(TIMESTAMPS_TEST wrappers/timestamps_test.cu)
122122
# * groupby tests ---------------------------------------------------------------------------------
123123
ConfigureTest(
124124
GROUPBY_TEST
125-
groupby/argmin_tests.cpp
126-
groupby/argmax_tests.cpp
127-
groupby/collect_list_tests.cpp
128-
groupby/collect_set_tests.cpp
129-
groupby/correlation_tests.cpp
130-
groupby/count_scan_tests.cpp
131-
groupby/count_tests.cpp
132-
groupby/covariance_tests.cpp
133-
groupby/groupby_test_util.cpp
134-
groupby/groups_tests.cpp
135-
groupby/histogram_tests.cpp
136-
groupby/keys_tests.cpp
137-
groupby/lists_tests.cpp
138-
groupby/m2_tests.cpp
139-
groupby/min_tests.cpp
140-
groupby/max_scan_tests.cpp
141-
groupby/max_tests.cpp
142-
groupby/mean_tests.cpp
143-
groupby/median_tests.cpp
144-
groupby/merge_m2_tests.cpp
145-
groupby/merge_lists_tests.cpp
146-
groupby/merge_sets_tests.cpp
147-
groupby/min_scan_tests.cpp
148-
groupby/nth_element_tests.cpp
149-
groupby/nunique_tests.cpp
150-
groupby/product_scan_tests.cpp
151-
groupby/product_tests.cpp
152-
groupby/quantile_tests.cpp
153-
groupby/rank_scan_tests.cpp
154-
groupby/replace_nulls_tests.cpp
155-
groupby/shift_tests.cpp
156-
groupby/std_tests.cpp
157-
groupby/structs_tests.cpp
158-
groupby/sum_of_squares_tests.cpp
159-
groupby/sum_scan_tests.cpp
160-
groupby/sum_tests.cpp
161-
groupby/tdigest_tests.cu
162-
groupby/var_tests.cpp
125+
groupby/host_udf_tests.cu
163126
GPUS 1
164127
PERCENT 100
165128
)

cpp/tests/groupby/host_udf_tests.cu

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <tests/groupby/groupby_test_util.hpp>
18+
19+
#include <cudf_test/base_fixture.hpp>
20+
#include <cudf_test/column_wrapper.hpp>
21+
#include <cudf_test/debug_utilities.hpp>
22+
#include <cudf_test/iterator_utilities.hpp>
23+
#include <cudf_test/type_lists.hpp>
24+
25+
#include <cudf/column/column_factories.hpp>
26+
#include <cudf/detail/aggregation/aggregation.hpp>
27+
28+
#include <rmm/exec_policy.hpp>
29+
30+
#include <thrust/iterator/counting_iterator.h>
31+
#include <thrust/transform.h>
32+
33+
using namespace cudf::test::iterators;
34+
35+
struct test : public cudf::test::BaseFixture {};
36+
37+
std::unique_ptr<cudf::column> double_sqr(cudf::column_view const& values,
38+
cudf::device_span<cudf::size_type const> group_offsets,
39+
cudf::device_span<cudf::size_type const> group_labels,
40+
cudf::size_type num_groups,
41+
rmm::cuda_stream_view stream,
42+
rmm::device_async_resource_ref mr)
43+
{
44+
auto output = cudf::make_numeric_column(
45+
cudf::data_type{cudf::type_id::INT32}, values.size(), cudf::mask_state::UNALLOCATED, stream);
46+
thrust::transform(rmm::exec_policy(stream),
47+
thrust::make_counting_iterator(0),
48+
thrust::make_counting_iterator(values.size()),
49+
output->mutable_view().begin<int>(),
50+
[values = values.begin<int>()] __device__(int idx) -> int {
51+
return 2 * values[idx] * values[idx];
52+
});
53+
return output;
54+
}
55+
56+
std::unique_ptr<cudf::column> triple_sqr(cudf::column_view const& values,
57+
cudf::device_span<cudf::size_type const> group_offsets,
58+
cudf::device_span<cudf::size_type const> group_labels,
59+
cudf::size_type num_groups,
60+
rmm::cuda_stream_view stream,
61+
rmm::device_async_resource_ref mr)
62+
{
63+
auto output = cudf::make_numeric_column(
64+
cudf::data_type{cudf::type_id::INT32}, values.size(), cudf::mask_state::UNALLOCATED, stream);
65+
thrust::transform(rmm::exec_policy(stream),
66+
thrust::make_counting_iterator(0),
67+
thrust::make_counting_iterator(values.size()),
68+
output->mutable_view().begin<int>(),
69+
[values = values.begin<int>()] __device__(int idx) -> int {
70+
return 3 * values[idx] * values[idx];
71+
});
72+
return output;
73+
}
74+
75+
TEST_F(test, double_sqr)
76+
{
77+
cudf::test::fixed_width_column_wrapper<int> keys{1, 1, 1, 1, 1};
78+
cudf::test::fixed_width_column_wrapper<int> vals{0, 1, 2, 3, 4};
79+
80+
auto agg = cudf::make_host_udf_aggregation<cudf::groupby_aggregation>(double_sqr);
81+
std::vector<cudf::groupby::aggregation_request> requests;
82+
requests.emplace_back();
83+
requests[0].values = vals;
84+
requests[0].aggregations.push_back(std::move(agg));
85+
cudf::groupby::groupby gb_obj(
86+
cudf::table_view({keys}), cudf::null_policy::INCLUDE, cudf::sorted::NO, {}, {});
87+
88+
auto result = gb_obj.aggregate(requests, cudf::test::get_default_stream());
89+
90+
// Got output: 0,2,8,18,32
91+
cudf::test::print(*result.second[0].results[0]);
92+
}
93+
94+
TEST_F(test, triple_sqr)
95+
{
96+
cudf::test::fixed_width_column_wrapper<int> keys{1, 1, 1, 1, 1};
97+
cudf::test::fixed_width_column_wrapper<int> vals{0, 1, 2, 3, 4};
98+
99+
auto agg = cudf::make_host_udf_aggregation<cudf::groupby_aggregation>(triple_sqr);
100+
std::vector<cudf::groupby::aggregation_request> requests;
101+
requests.emplace_back();
102+
requests[0].values = vals;
103+
requests[0].aggregations.push_back(std::move(agg));
104+
cudf::groupby::groupby gb_obj(
105+
cudf::table_view({keys}), cudf::null_policy::INCLUDE, cudf::sorted::NO, {}, {});
106+
107+
auto result = gb_obj.aggregate(requests, cudf::test::get_default_stream());
108+
109+
// Got output: 0,3,12,27,48
110+
cudf::test::print(*result.second[0].results[0]);
111+
}

0 commit comments

Comments
 (0)