|
14 | 14 | namespace sycl { |
15 | 15 | __SYCL_INLINE_VER_NAMESPACE(_V1) { |
16 | 16 | namespace ext::oneapi::experimental { |
17 | | - |
18 | | -// ---- reduce_over_group |
| 17 | +namespace detail { |
19 | 18 | template <typename GroupHelper, typename T, typename BinaryOperation> |
20 | | -sycl::detail::enable_if_t<(is_group_helper_v<GroupHelper>), T> |
21 | | -reduce_over_group(GroupHelper group_helper, T x, BinaryOperation binary_op) { |
22 | | - if constexpr (sycl::detail::is_native_op<T, BinaryOperation>::value) { |
23 | | - return sycl::reduce_over_group(group_helper.get_group(), x, binary_op); |
24 | | - } |
| 19 | +T reduce_over_group_impl(GroupHelper group_helper, T x, size_t num_elements, |
| 20 | + BinaryOperation binary_op) { |
25 | 21 | #ifdef __SYCL_DEVICE_ONLY__ |
26 | 22 | T *Memory = reinterpret_cast<T *>(group_helper.get_memory().data()); |
27 | 23 | auto g = group_helper.get_group(); |
28 | 24 | Memory[g.get_local_linear_id()] = x; |
29 | 25 | group_barrier(g); |
30 | 26 | T result = Memory[0]; |
31 | 27 | if (g.leader()) { |
32 | | - for (int i = 1; i < g.get_local_linear_range(); i++) { |
| 28 | + for (int i = 1; i < num_elements; i++) { |
33 | 29 | result = binary_op(result, Memory[i]); |
34 | 30 | } |
35 | 31 | } |
36 | 32 | group_barrier(g); |
37 | 33 | return group_broadcast(g, result); |
38 | 34 | #else |
39 | 35 | std::ignore = group_helper; |
| 36 | + std::ignore = x; |
| 37 | + std::ignore = num_elements; |
| 38 | + std::ignore = binary_op; |
| 39 | + throw runtime_error("Group algorithms are not supported on host.", |
| 40 | + PI_ERROR_INVALID_DEVICE); |
| 41 | +#endif |
| 42 | +} |
| 43 | +} // namespace detail |
| 44 | + |
| 45 | +// ---- reduce_over_group |
| 46 | +template <typename GroupHelper, typename T, typename BinaryOperation> |
| 47 | +sycl::detail::enable_if_t<(is_group_helper_v<GroupHelper>), T> |
| 48 | +reduce_over_group(GroupHelper group_helper, T x, BinaryOperation binary_op) { |
| 49 | + if constexpr (sycl::detail::is_native_op<T, BinaryOperation>::value) { |
| 50 | + return sycl::reduce_over_group(group_helper.get_group(), x, binary_op); |
| 51 | + } |
| 52 | +#ifdef __SYCL_DEVICE_ONLY__ |
| 53 | + return detail::reduce_over_group_impl( |
| 54 | + group_helper, x, group_helper.get_group().get_local_linear_range(), |
| 55 | + binary_op); |
| 56 | +#else |
40 | 57 | throw runtime_error("Group algorithms are not supported on host.", |
41 | 58 | PI_ERROR_INVALID_DEVICE); |
42 | 59 | #endif |
@@ -84,7 +101,10 @@ joint_reduce(GroupHelper group_helper, Ptr first, Ptr last, |
84 | 101 | sycl::detail::for_each(g, second, last, |
85 | 102 | [&](const T &x) { partial = binary_op(partial, x); }); |
86 | 103 | group_barrier(g); |
87 | | - return reduce_over_group(group_helper, partial, binary_op); |
| 104 | + size_t num_elements = last - first; |
| 105 | + num_elements = std::min(num_elements, g.get_local_linear_range()); |
| 106 | + return detail::reduce_over_group_impl(group_helper, partial, num_elements, |
| 107 | + binary_op); |
88 | 108 | #else |
89 | 109 | std::ignore = group_helper; |
90 | 110 | std::ignore = first; |
|
0 commit comments