Skip to content

Commit 3bb751d

Browse files
committed
Thread stream and mr through empty output construction
Since we call make_lists_column in some cases, we need a stream and mr around.
1 parent 4591523 commit 3bb751d

File tree

4 files changed

+25
-10
lines changed

4 files changed

+25
-10
lines changed

cpp/src/rolling/detail/rolling.cuh

+15-6
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151

5252
#include <rmm/cuda_stream_view.hpp>
5353
#include <rmm/exec_policy.hpp>
54+
#include <rmm/resource_ref.hpp>
5455

5556
#include <cuda/std/climits>
5657
#include <cuda/std/limits>
@@ -453,7 +454,10 @@ struct DeviceRollingRowNumber {
453454

454455
struct agg_specific_empty_output {
455456
template <typename InputType, aggregation::Kind op>
456-
std::unique_ptr<column> operator()(column_view const& input, rolling_aggregation const&) const
457+
std::unique_ptr<column> operator()(column_view const& input,
458+
rolling_aggregation const&,
459+
rmm::cuda_stream_view stream,
460+
rmm::device_async_resource_ref mr) const
457461
{
458462
using target_type = cudf::detail::target_type_t<InputType, op>;
459463

@@ -467,15 +471,18 @@ struct agg_specific_empty_output {
467471

468472
if constexpr (op == aggregation::COLLECT_LIST) {
469473
return cudf::make_lists_column(
470-
0, make_empty_column(type_to_id<size_type>()), empty_like(input), 0, {});
474+
0, make_empty_column(type_to_id<size_type>()), empty_like(input), 0, {}, stream, mr);
471475
}
472476

473477
return empty_like(input);
474478
}
475479
};
476480

477-
static std::unique_ptr<column> empty_output_for_rolling_aggregation(column_view const& input,
478-
rolling_aggregation const& agg)
481+
static std::unique_ptr<column> empty_output_for_rolling_aggregation(
482+
column_view const& input,
483+
rolling_aggregation const& agg,
484+
rmm::cuda_stream_view stream,
485+
rmm::device_async_resource_ref mr)
479486
{
480487
// TODO:
481488
// Ideally, for UDF aggregations, the returned column would match
@@ -490,7 +497,7 @@ static std::unique_ptr<column> empty_output_for_rolling_aggregation(column_view
490497
return agg.kind == aggregation::CUDA || agg.kind == aggregation::PTX
491498
? empty_like(input)
492499
: cudf::detail::dispatch_type_and_aggregation(
493-
input.type(), agg.kind, agg_specific_empty_output{}, input, agg);
500+
input.type(), agg.kind, agg_specific_empty_output{}, input, agg, stream, mr);
494501
}
495502

496503
/**
@@ -1326,7 +1333,9 @@ std::unique_ptr<column> rolling_window(column_view const& input,
13261333
static_assert(warp_size == cudf::detail::size_in_bits<cudf::bitmask_type>(),
13271334
"bitmask_type size does not match CUDA warp size");
13281335

1329-
if (input.is_empty()) { return cudf::detail::empty_output_for_rolling_aggregation(input, agg); }
1336+
if (input.is_empty()) {
1337+
return cudf::detail::empty_output_for_rolling_aggregation(input, agg, stream, mr);
1338+
}
13301339

13311340
if (cudf::is_dictionary(input.type())) {
13321341
CUDF_EXPECTS(agg.kind == aggregation::COUNT_ALL || agg.kind == aggregation::COUNT_VALID ||

cpp/src/rolling/detail/rolling_fixed_window.cu

+3-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ std::unique_ptr<column> rolling_window(column_view const& input,
4040
{
4141
CUDF_FUNC_RANGE();
4242

43-
if (input.is_empty()) { return cudf::detail::empty_output_for_rolling_aggregation(input, agg); }
43+
if (input.is_empty()) {
44+
return cudf::detail::empty_output_for_rolling_aggregation(input, agg, stream, mr);
45+
}
4446

4547
CUDF_EXPECTS((min_periods >= 0), "min_periods must be non-negative");
4648

cpp/src/rolling/detail/rolling_variable_window.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ std::unique_ptr<column> rolling_window(column_view const& input,
3939
CUDF_FUNC_RANGE();
4040

4141
if (preceding_window.is_empty() || following_window.is_empty() || input.is_empty()) {
42-
return cudf::detail::empty_output_for_rolling_aggregation(input, agg);
42+
return cudf::detail::empty_output_for_rolling_aggregation(input, agg, stream, mr);
4343
}
4444

4545
CUDF_EXPECTS(preceding_window.type().id() == type_id::INT32 &&

cpp/src/rolling/grouped_rolling.cu

+6-2
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,9 @@ std::unique_ptr<column> grouped_rolling_window(table_view const& group_keys,
158158
{
159159
CUDF_FUNC_RANGE();
160160

161-
if (input.is_empty()) { return cudf::detail::empty_output_for_rolling_aggregation(input, aggr); }
161+
if (input.is_empty()) {
162+
return cudf::detail::empty_output_for_rolling_aggregation(input, aggr, stream, mr);
163+
}
162164

163165
CUDF_EXPECTS((group_keys.num_columns() == 0 || group_keys.num_rows() == input.size()),
164166
"Size mismatch between group_keys and input vector.");
@@ -1152,7 +1154,9 @@ std::unique_ptr<column> grouped_range_rolling_window(table_view const& group_key
11521154
{
11531155
CUDF_FUNC_RANGE();
11541156

1155-
if (input.is_empty()) { return cudf::detail::empty_output_for_rolling_aggregation(input, aggr); }
1157+
if (input.is_empty()) {
1158+
return cudf::detail::empty_output_for_rolling_aggregation(input, aggr, stream, mr);
1159+
}
11561160

11571161
CUDF_EXPECTS((group_keys.num_columns() == 0 || group_keys.num_rows() == input.size()),
11581162
"Size mismatch between group_keys and input vector.");

0 commit comments

Comments
 (0)