Skip to content

Commit 7ff6bf6

Browse files
committed
Env-based reduce API
1 parent 01a0b2a commit 7ff6bf6

File tree

9 files changed

+1267
-2
lines changed

9 files changed

+1267
-2
lines changed

cub/cub/detail/launcher/cuda_runtime.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ struct TripleChevronFactory
6969
}
7070
};
7171

72+
#ifndef CUB_DETAIL_DEFAULT_KERNEL_LAUNCHER
73+
# define CUB_DETAIL_DEFAULT_KERNEL_LAUNCHER detail::TripleChevronFactory
74+
#endif
75+
7276
} // namespace detail
7377

7478
CUB_NAMESPACE_END

cub/cub/device/device_reduce.cuh

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,18 @@
4545
#include <cub/detail/choose_offset.cuh>
4646
#include <cub/device/dispatch/dispatch_reduce.cuh>
4747
#include <cub/device/dispatch/dispatch_reduce_by_key.cuh>
48+
#include <cub/device/dispatch/dispatch_reduce_deterministic.cuh>
4849
#include <cub/device/dispatch/dispatch_streaming_reduce.cuh>
4950
#include <cub/util_type.cuh>
5051

5152
#include <thrust/iterator/tabulate_output_iterator.h>
5253

54+
#include <cuda/__execution/determinism.h>
55+
#include <cuda/__execution/require.h>
56+
#include <cuda/__execution/tune.h>
57+
#include <cuda/__memory_resource/get_memory_resource.h>
58+
#include <cuda/__stream/get_stream.h>
59+
#include <cuda/std/__execution/env.h>
5360
#include <cuda/std/limits>
5461

5562
CUB_NAMESPACE_BEGIN
@@ -58,6 +65,31 @@ namespace detail
5865
{
5966
namespace reduce
6067
{
68+
69+
struct get_reduce_tuning_query_t
70+
{};
71+
72+
template <class Derived>
73+
struct tuning
74+
{
75+
[[nodiscard]] _CCCL_TRIVIAL_API constexpr auto query(const get_reduce_tuning_query_t&) const noexcept -> Derived
76+
{
77+
return static_cast<const Derived&>(*this);
78+
}
79+
};
80+
81+
struct default_tuning : tuning<default_tuning>
82+
{
83+
template <class AccumT, class Offset, class OpT>
84+
using fn = policy_hub<AccumT, Offset, OpT>;
85+
};
86+
87+
struct default_rfa_tuning : tuning<default_tuning>
88+
{
89+
template <class AccumT, class Offset, class OpT>
90+
using fn = detail::rfa::policy_hub<AccumT, Offset, OpT>;
91+
};
92+
6193
template <typename ExtremumOutIteratorT, typename IndexOutIteratorT>
6294
struct unzip_and_write_arg_extremum_op
6395
{
@@ -72,6 +104,41 @@ struct unzip_and_write_arg_extremum_op
72104
}
73105
};
74106
} // namespace reduce
107+
108+
// TODO(gevtushenko): move cudax `device_memory_resource` to `cuda::__device_memory_resource` and use it here
109+
struct device_memory_resource
110+
{
111+
void* allocate(size_t bytes, size_t /* alignment */)
112+
{
113+
void* ptr{nullptr};
114+
_CCCL_TRY_CUDA_API(::cudaMalloc, "allocate failed to allocate with cudaMalloc", &ptr, bytes);
115+
return ptr;
116+
}
117+
118+
void deallocate(void* ptr, size_t /* bytes */)
119+
{
120+
_CCCL_ASSERT_CUDA_API(::cudaFree, "deallocate failed", ptr);
121+
}
122+
123+
void* allocate_async(size_t bytes, size_t /* alignment */, ::cuda::stream_ref stream)
124+
{
125+
return allocate_async(bytes, stream);
126+
}
127+
128+
void* allocate_async(size_t bytes, ::cuda::stream_ref stream)
129+
{
130+
void* ptr{nullptr};
131+
_CCCL_TRY_CUDA_API(
132+
::cudaMallocAsync, "allocate_async failed to allocate with cudaMallocAsync", &ptr, bytes, stream.get());
133+
return ptr;
134+
}
135+
136+
void deallocate_async(void* ptr, size_t /* bytes */, const ::cuda::stream_ref stream)
137+
{
138+
_CCCL_ASSERT_CUDA_API(::cudaFreeAsync, "deallocate_async failed", ptr, stream.get());
139+
}
140+
};
141+
75142
} // namespace detail
76143

77144
//! @rst
@@ -102,6 +169,85 @@ struct unzip_and_write_arg_extremum_op
102169
//! @endrst
103170
struct DeviceReduce
104171
{
172+
private:
173+
// TODO(gevtushenko): dispatch to atomic reduce once merged
174+
template <typename TuningEnvT,
175+
typename InputIteratorT,
176+
typename OutputIteratorT,
177+
typename ReductionOpT,
178+
typename T,
179+
typename NumItemsT,
180+
::cuda::execution::determinism::__determinism_t Determinism>
181+
CUB_RUNTIME_FUNCTION static cudaError_t reduce_impl(
182+
void* d_temp_storage,
183+
size_t& temp_storage_bytes,
184+
InputIteratorT d_in,
185+
OutputIteratorT d_out,
186+
NumItemsT num_items,
187+
ReductionOpT reduction_op,
188+
T init,
189+
::cuda::execution::determinism::__determinism_holder_t<Determinism>,
190+
cudaStream_t stream)
191+
{
192+
using offset_t = detail::choose_offset_t<NumItemsT>;
193+
using accum_t = ::cuda::std::__accumulator_t<ReductionOpT, detail::it_value_t<InputIteratorT>, T>;
194+
using transform_t = ::cuda::std::identity;
195+
using reduce_tuning_t = ::cuda::std::execution::
196+
__query_result_or_t<TuningEnvT, detail::reduce::get_reduce_tuning_query_t, detail::reduce::default_tuning>;
197+
using policy_t = typename reduce_tuning_t::template fn<accum_t, offset_t, ReductionOpT>;
198+
using dispatch_t =
199+
DispatchReduce<InputIteratorT, OutputIteratorT, offset_t, ReductionOpT, T, accum_t, transform_t, policy_t>;
200+
201+
return dispatch_t::Dispatch(
202+
d_temp_storage, temp_storage_bytes, d_in, d_out, static_cast<offset_t>(num_items), reduction_op, init, stream);
203+
}
204+
205+
template <typename TuningEnvT,
206+
typename InputIteratorT,
207+
typename OutputIteratorT,
208+
typename ReductionOpT,
209+
typename T,
210+
typename NumItemsT>
211+
CUB_RUNTIME_FUNCTION static cudaError_t reduce_impl(
212+
void* d_temp_storage,
213+
size_t& temp_storage_bytes,
214+
InputIteratorT d_in,
215+
OutputIteratorT d_out,
216+
NumItemsT num_items,
217+
ReductionOpT,
218+
T init,
219+
::cuda::execution::determinism::gpu_to_gpu_t,
220+
cudaStream_t stream)
221+
{
222+
using offset_t = detail::choose_offset_t<NumItemsT>;
223+
using accum_t = ::cuda::std::__accumulator_t<ReductionOpT, detail::it_value_t<InputIteratorT>, T>;
224+
225+
// RFA is only supported for float and double accumulators
226+
constexpr bool is_float_or_double = _CUDA_VSTD::is_same_v<accum_t, float> || _CUDA_VSTD::is_same_v<accum_t, double>;
227+
constexpr bool is_sum = _CUDA_VSTD::is_same_v<ReductionOpT, ::cuda::std::plus<>>;
228+
constexpr bool is_supported = is_float_or_double && is_sum;
229+
230+
static_assert(is_supported, "gpu-to-gpu deterministic reduction supports only float and double sum.");
231+
232+
if constexpr (is_supported)
233+
{
234+
using transform_t = ::cuda::std::identity;
235+
using reduce_tuning_t = ::cuda::std::execution::
236+
__query_result_or_t<TuningEnvT, detail::reduce::get_reduce_tuning_query_t, detail::reduce::default_rfa_tuning>;
237+
using policy_t = typename reduce_tuning_t::template fn<accum_t, offset_t, ReductionOpT>;
238+
using dispatch_t =
239+
detail::DispatchReduceDeterministic<InputIteratorT, OutputIteratorT, offset_t, T, accum_t, transform_t, policy_t>;
240+
241+
return dispatch_t::Dispatch(
242+
d_temp_storage, temp_storage_bytes, d_in, d_out, static_cast<offset_t>(num_items), init, stream);
243+
}
244+
else
245+
{
246+
return cudaErrorNotSupported;
247+
}
248+
}
249+
250+
public:
105251
//! @rst
106252
//! Computes a device-wide reduction using the specified binary ``reduction_op`` functor and initial value ``init``.
107253
//!
@@ -225,6 +371,134 @@ struct DeviceReduce
225371
d_temp_storage, temp_storage_bytes, d_in, d_out, static_cast<OffsetT>(num_items), reduction_op, init, stream);
226372
}
227373

374+
//! @rst
375+
//! Computes a device-wide reduction using the specified binary ``reduction_op`` functor and initial value ``init``.
376+
//!
377+
//! - Does not support binary reduction operators that are non-commutative.
378+
//! - By default, provides "run-to-run" determinism for pseudo-associative reduction
379+
//! (e.g., addition of floating point types) on the same GPU device.
380+
//! However, results for pseudo-associative reduction may be inconsistent
381+
//! from one device to a another device of a different compute-capability
382+
//! because CUB can employ different tile-sizing for different architectures.
383+
//! To request "gpu-to-gpu" determinism, pass `cuda::execution::require(cuda::execution::determinism::gpu_to_gpu)`
384+
//! as the `env` parameter.
385+
//! - The range ``[d_in, d_in + num_items)`` shall not overlap ``d_out``.
386+
//!
387+
//! Snippet
388+
//! +++++++++++++++++++++++++++++++++++++++++++++
389+
//!
390+
//! The code snippet below illustrates a user-defined min-reduction of a
391+
//! device vector of ``int`` data elements.
392+
//!
393+
//! .. literalinclude:: ../../../cub/test/catch2_test_device_reduce_env_api.cu
394+
//! :language: c++
395+
//! :dedent:
396+
//! :start-after: example-begin reduce-env-determinism
397+
//! :end-before: example-end reduce-env-determinism
398+
//!
399+
//! @endrst
400+
//!
401+
//! @tparam InputIteratorT
402+
//! **[inferred]** Random-access input iterator type for reading input items @iterator
403+
//!
404+
//! @tparam OutputIteratorT
405+
//! **[inferred]** Output iterator type for recording the reduced aggregate @iterator
406+
//!
407+
//! @tparam ReductionOpT
408+
//! **[inferred]** Binary reduction functor type having member `T operator()(const T &a, const T &b)`
409+
//!
410+
//! @tparam T
411+
//! **[inferred]** Data element type that is convertible to the `value` type of `InputIteratorT`
412+
//!
413+
//! @tparam NumItemsT
414+
//! **[inferred]** Type of num_items
415+
//!
416+
//! @tparam EnvT
417+
//! **[inferred]** Execution environment type. Default is `cuda::std::execution::env<>`.
418+
//!
419+
//! @param[in] d_in
420+
//! Pointer to the input sequence of data items
421+
//!
422+
//! @param[out] d_out
423+
//! Pointer to the output aggregate
424+
//!
425+
//! @param[in] num_items
426+
//! Total number of input items (i.e., length of `d_in`)
427+
//!
428+
//! @param[in] reduction_op
429+
//! Binary reduction functor
430+
//!
431+
//! @param[in] init
432+
//! Initial value of the reduction
433+
//!
434+
//! @param[in] env
435+
//! @rst
436+
//! **[optional]** Execution environment. Default is `cuda::std::execution::env{}`.
437+
//! @endrst
438+
template <typename InputIteratorT,
439+
typename OutputIteratorT,
440+
typename ReductionOpT,
441+
typename T,
442+
typename NumItemsT,
443+
typename EnvT = ::cuda::std::execution::env<>>
444+
CUB_RUNTIME_FUNCTION static cudaError_t Reduce(
445+
InputIteratorT d_in, OutputIteratorT d_out, NumItemsT num_items, ReductionOpT reduction_op, T init, EnvT env = {})
446+
{
447+
_CCCL_NVTX_RANGE_SCOPE("cub::DeviceReduce::Reduce");
448+
449+
static_assert(!_CUDA_STD_EXEC::__queryable_with<EnvT, _CUDA_EXEC::determinism::__get_determinism_t>,
450+
"Determinism should be used inside requires to have an effect.");
451+
using requirements_t =
452+
_CUDA_STD_EXEC::__query_result_or_t<EnvT, _CUDA_EXEC::__get_requirements_t, _CUDA_STD_EXEC::env<>>;
453+
using determinism_t =
454+
_CUDA_STD_EXEC::__query_result_or_t<requirements_t, //
455+
_CUDA_EXEC::determinism::__get_determinism_t,
456+
_CUDA_EXEC::determinism::run_to_run_t>;
457+
458+
// Query relevant properties from the environment
459+
auto stream = _CUDA_STD_EXEC::__query_or(env, ::cuda::get_stream, ::cuda::stream_ref{});
460+
auto mr = _CUDA_STD_EXEC::__query_or(env, ::cuda::mr::__get_memory_resource, detail::device_memory_resource{});
461+
462+
void* d_temp_storage = nullptr;
463+
size_t temp_storage_bytes = 0;
464+
465+
using tuning_t = _CUDA_STD_EXEC::__query_result_or_t<EnvT, _CUDA_EXEC::__get_tuning_t, _CUDA_STD_EXEC::env<>>;
466+
467+
// Query the required temporary storage size
468+
cudaError_t error = reduce_impl<tuning_t>(
469+
d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, reduction_op, init, determinism_t{}, stream.get());
470+
if (error != cudaSuccess)
471+
{
472+
return error;
473+
}
474+
475+
NV_IF_ELSE_TARGET(
476+
NV_IS_HOST,
477+
(
478+
try { d_temp_storage = mr.allocate_async(temp_storage_bytes, stream); } catch (...) {
479+
return cudaErrorMemoryAllocation;
480+
}),
481+
(d_temp_storage = mr.allocate_async(temp_storage_bytes, stream);));
482+
483+
// Run the algorithm
484+
error = reduce_impl<tuning_t>(
485+
d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, reduction_op, init, determinism_t{}, stream.get());
486+
if (error != cudaSuccess)
487+
{
488+
return error;
489+
}
490+
491+
NV_IF_ELSE_TARGET(
492+
NV_IS_HOST,
493+
(
494+
try { mr.deallocate_async(d_temp_storage, temp_storage_bytes, stream); } catch (...) {
495+
return cudaErrorMemoryAllocation;
496+
}),
497+
(mr.deallocate_async(d_temp_storage, temp_storage_bytes, stream);));
498+
499+
return cudaSuccess;
500+
}
501+
228502
//! @rst
229503
//! Computes a device-wide sum using the addition (``+``) operator.
230504
//!

0 commit comments

Comments
 (0)