|
14 | 14 | namespace sycl {
|
15 | 15 | __SYCL_INLINE_VER_NAMESPACE(_V1) {
|
16 | 16 | namespace ext::oneapi::experimental {
|
17 |
| - |
18 |
| -// ---- reduce_over_group |
| 17 | +namespace detail { |
19 | 18 | template <typename GroupHelper, typename T, typename BinaryOperation>
|
20 |
| -sycl::detail::enable_if_t<(is_group_helper_v<GroupHelper>), T> |
21 |
| -reduce_over_group(GroupHelper group_helper, T x, BinaryOperation binary_op) { |
22 |
| - if constexpr (sycl::detail::is_native_op<T, BinaryOperation>::value) { |
23 |
| - return sycl::reduce_over_group(group_helper.get_group(), x, binary_op); |
24 |
| - } |
| 19 | +T reduce_over_group_impl(GroupHelper group_helper, T x, size_t num_elements, |
| 20 | + BinaryOperation binary_op) { |
25 | 21 | #ifdef __SYCL_DEVICE_ONLY__
|
26 | 22 | T *Memory = reinterpret_cast<T *>(group_helper.get_memory().data());
|
27 | 23 | auto g = group_helper.get_group();
|
28 | 24 | Memory[g.get_local_linear_id()] = x;
|
29 | 25 | group_barrier(g);
|
30 | 26 | T result = Memory[0];
|
31 | 27 | if (g.leader()) {
|
32 |
| - for (int i = 1; i < g.get_local_linear_range(); i++) { |
| 28 | + for (int i = 1; i < num_elements; i++) { |
33 | 29 | result = binary_op(result, Memory[i]);
|
34 | 30 | }
|
35 | 31 | }
|
36 | 32 | group_barrier(g);
|
37 | 33 | return group_broadcast(g, result);
|
38 | 34 | #else
|
39 | 35 | std::ignore = group_helper;
|
| 36 | + std::ignore = x; |
| 37 | + std::ignore = num_elements; |
| 38 | + std::ignore = binary_op; |
| 39 | + throw runtime_error("Group algorithms are not supported on host.", |
| 40 | + PI_ERROR_INVALID_DEVICE); |
| 41 | +#endif |
| 42 | +} |
| 43 | +} // namespace detail |
| 44 | + |
| 45 | +// ---- reduce_over_group |
| 46 | +template <typename GroupHelper, typename T, typename BinaryOperation> |
| 47 | +sycl::detail::enable_if_t<(is_group_helper_v<GroupHelper>), T> |
| 48 | +reduce_over_group(GroupHelper group_helper, T x, BinaryOperation binary_op) { |
| 49 | + if constexpr (sycl::detail::is_native_op<T, BinaryOperation>::value) { |
| 50 | + return sycl::reduce_over_group(group_helper.get_group(), x, binary_op); |
| 51 | + } |
| 52 | +#ifdef __SYCL_DEVICE_ONLY__ |
| 53 | + return detail::reduce_over_group_impl( |
| 54 | + group_helper, x, group_helper.get_group().get_local_linear_range(), |
| 55 | + binary_op); |
| 56 | +#else |
40 | 57 | throw runtime_error("Group algorithms are not supported on host.",
|
41 | 58 | PI_ERROR_INVALID_DEVICE);
|
42 | 59 | #endif
|
@@ -84,7 +101,10 @@ joint_reduce(GroupHelper group_helper, Ptr first, Ptr last,
|
84 | 101 | sycl::detail::for_each(g, second, last,
|
85 | 102 | [&](const T &x) { partial = binary_op(partial, x); });
|
86 | 103 | group_barrier(g);
|
87 |
| - return reduce_over_group(group_helper, partial, binary_op); |
| 104 | + size_t num_elements = last - first; |
| 105 | + num_elements = std::min(num_elements, g.get_local_linear_range()); |
| 106 | + return detail::reduce_over_group_impl(group_helper, partial, num_elements, |
| 107 | + binary_op); |
88 | 108 | #else
|
89 | 109 | std::ignore = group_helper;
|
90 | 110 | std::ignore = first;
|
|
0 commit comments