Skip to content

Commit 23b68db

Browse files
[SYCL] Fix ext::oneapi::experimental::joint_reduce (#7781)
This patch fixes ext::oneapi::experimental::joint_reduce, it missed the case when WG size is bigger than size of input data. Test: intel/llvm-test-suite#1452
1 parent 751acf3 commit 23b68db

File tree

1 file changed

+29
-9
lines changed

1 file changed

+29
-9
lines changed

sycl/include/sycl/ext/oneapi/experimental/user_defined_reductions.hpp

+29-9
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,46 @@
1414
namespace sycl {
1515
__SYCL_INLINE_VER_NAMESPACE(_V1) {
1616
namespace ext::oneapi::experimental {
17-
18-
// ---- reduce_over_group
17+
namespace detail {
1918
template <typename GroupHelper, typename T, typename BinaryOperation>
20-
sycl::detail::enable_if_t<(is_group_helper_v<GroupHelper>), T>
21-
reduce_over_group(GroupHelper group_helper, T x, BinaryOperation binary_op) {
22-
if constexpr (sycl::detail::is_native_op<T, BinaryOperation>::value) {
23-
return sycl::reduce_over_group(group_helper.get_group(), x, binary_op);
24-
}
19+
T reduce_over_group_impl(GroupHelper group_helper, T x, size_t num_elements,
20+
BinaryOperation binary_op) {
2521
#ifdef __SYCL_DEVICE_ONLY__
2622
T *Memory = reinterpret_cast<T *>(group_helper.get_memory().data());
2723
auto g = group_helper.get_group();
2824
Memory[g.get_local_linear_id()] = x;
2925
group_barrier(g);
3026
T result = Memory[0];
3127
if (g.leader()) {
32-
for (int i = 1; i < g.get_local_linear_range(); i++) {
28+
for (int i = 1; i < num_elements; i++) {
3329
result = binary_op(result, Memory[i]);
3430
}
3531
}
3632
group_barrier(g);
3733
return group_broadcast(g, result);
3834
#else
3935
std::ignore = group_helper;
36+
std::ignore = x;
37+
std::ignore = num_elements;
38+
std::ignore = binary_op;
39+
throw runtime_error("Group algorithms are not supported on host.",
40+
PI_ERROR_INVALID_DEVICE);
41+
#endif
42+
}
43+
} // namespace detail
44+
45+
// ---- reduce_over_group
46+
template <typename GroupHelper, typename T, typename BinaryOperation>
47+
sycl::detail::enable_if_t<(is_group_helper_v<GroupHelper>), T>
48+
reduce_over_group(GroupHelper group_helper, T x, BinaryOperation binary_op) {
49+
if constexpr (sycl::detail::is_native_op<T, BinaryOperation>::value) {
50+
return sycl::reduce_over_group(group_helper.get_group(), x, binary_op);
51+
}
52+
#ifdef __SYCL_DEVICE_ONLY__
53+
return detail::reduce_over_group_impl(
54+
group_helper, x, group_helper.get_group().get_local_linear_range(),
55+
binary_op);
56+
#else
4057
throw runtime_error("Group algorithms are not supported on host.",
4158
PI_ERROR_INVALID_DEVICE);
4259
#endif
@@ -84,7 +101,10 @@ joint_reduce(GroupHelper group_helper, Ptr first, Ptr last,
84101
sycl::detail::for_each(g, second, last,
85102
[&](const T &x) { partial = binary_op(partial, x); });
86103
group_barrier(g);
87-
return reduce_over_group(group_helper, partial, binary_op);
104+
size_t num_elements = last - first;
105+
num_elements = std::min(num_elements, g.get_local_linear_range());
106+
return detail::reduce_over_group_impl(group_helper, partial, num_elements,
107+
binary_op);
88108
#else
89109
std::ignore = group_helper;
90110
std::ignore = first;

0 commit comments

Comments
 (0)