diff --git a/sycl/include/sycl/group_algorithm.hpp b/sycl/include/sycl/group_algorithm.hpp index 9547039d45b6..808421e0540b 100644 --- a/sycl/include/sycl/group_algorithm.hpp +++ b/sycl/include/sycl/group_algorithm.hpp @@ -215,9 +215,19 @@ std::enable_if_t<(is_group_v> && detail::is_native_op::value), T> reduce_over_group(Group g, T x, BinaryOperation binary_op) { + +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert((std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else static_assert( std::is_same_v, "Result type of binary_op must match reduction accumulation type."); +#endif + #ifdef __SYCL_DEVICE_ONLY__ #if defined(__NVPTX__) if constexpr (ext::oneapi::experimental::is_user_constructed_group_v) { @@ -291,9 +301,18 @@ std::enable_if_t< std::is_convertible_v), T> reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) { +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert((std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else static_assert( std::is_same_v, "Result type of binary_op must match reduction accumulation type."); +#endif + #ifdef __SYCL_DEVICE_ONLY__ return binary_op(init, reduce_over_group(g, T(x), binary_op)); #else @@ -341,9 +360,18 @@ std::enable_if_t< detail::is_native_op::value), T> joint_reduce(Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op) { +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert((std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else static_assert( std::is_same_v, "Result type of binary_op must match reduction accumulation type."); +#endif + #ifdef __SYCL_DEVICE_ONLY__ T partial = detail::identity_for_ga_op(); sycl::detail::for_each( @@ -679,8 +707,16 @@ std::enable_if_t<(is_group_v> && detail::is_native_op::value), T> exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) { +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert((std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else static_assert(std::is_same_v, "Result type of binary_op must match scan accumulation type."); +#endif #ifdef __SYCL_DEVICE_ONLY__ #if defined(__NVPTX__) if constexpr (ext::oneapi::experimental::is_user_constructed_group_v) { @@ -752,8 +788,16 @@ std::enable_if_t<(is_group_v> && detail::is_native_op::value), T> exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) { +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert((std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else static_assert(std::is_same_v, "Result type of binary_op must match scan accumulation type."); +#endif T result; typename detail::get_scalar_binary_op::type scalar_binary_op{}; @@ -775,8 +819,17 @@ std::enable_if_t< std::is_convertible_v), T> exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op) { +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert((std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else static_assert(std::is_same_v, "Result type of binary_op must match scan accumulation type."); +#endif + #ifdef __SYCL_DEVICE_ONLY__ typename Group::linear_id_type local_linear_id = sycl::detail::get_local_linear_id(g); @@ -831,8 +884,17 @@ std::enable_if_t< OutPtr> joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, T init, BinaryOperation binary_op) { +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert((std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else static_assert(std::is_same_v, "Result type of binary_op must match scan accumulation type."); +#endif + #ifdef __SYCL_DEVICE_ONLY__ ptrdiff_t offset = sycl::detail::get_local_linear_id(g); ptrdiff_t stride = sycl::detail::get_local_linear_range(g); @@ -883,9 +945,22 @@ std::enable_if_t< OutPtr> joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, BinaryOperation binary_op) { +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + using binary_op_t = + std::remove_cv_t>; + static_assert( + (std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else static_assert(std::is_same_v::type>, "Result type of binary_op must match scan accumulation type."); +#endif using T = typename detail::remove_pointer::type; T init = detail::identity_for_ga_op(); return joint_exclusive_scan(g, first, last, result, init, binary_op); @@ -903,8 +978,19 @@ std::enable_if_t<(is_group_v> && detail::is_native_op::value), T> inclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) { + +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert((std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else + static_assert(std::is_same_v, "Result type of binary_op must match scan accumulation type."); +#endif + #ifdef __SYCL_DEVICE_ONLY__ #if defined(__NVPTX__) if constexpr (ext::oneapi::experimental::is_user_constructed_group_v) { @@ -972,8 +1058,18 @@ std::enable_if_t< std::is_convertible_v), T> inclusive_scan_over_group(Group g, V x, BinaryOperation binary_op, T init) { + +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert((std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else static_assert(std::is_same_v, "Result type of binary_op must match scan accumulation type."); +#endif + #ifdef __SYCL_DEVICE_ONLY__ T y = x; if (sycl::detail::get_local_linear_id(g) == 0) { @@ -1022,8 +1118,17 @@ std::enable_if_t< OutPtr> joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, BinaryOperation binary_op, T init) { +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert((std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else static_assert(std::is_same_v, "Result type of binary_op must match scan accumulation type."); +#endif + #ifdef __SYCL_DEVICE_ONLY__ ptrdiff_t offset = sycl::detail::get_local_linear_id(g); ptrdiff_t stride = sycl::detail::get_local_linear_range(g); @@ -1071,9 +1176,22 @@ std::enable_if_t< OutPtr> joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, BinaryOperation binary_op) { +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + using binary_op_t = + std::remove_cv_t>; + static_assert( + (std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else static_assert(std::is_same_v::type>, "Result type of binary_op must match scan accumulation type."); +#endif using T = typename detail::remove_pointer::type; T init = detail::identity_for_ga_op();