From 118830c1e53b317f9a4a8bd117bab6ff10aad928 Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Mon, 12 May 2025 13:23:06 +0200 Subject: [PATCH 1/2] [SYCL] fix asserts after logical operation changes --- sycl/include/sycl/group_algorithm.hpp | 140 ++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) diff --git a/sycl/include/sycl/group_algorithm.hpp b/sycl/include/sycl/group_algorithm.hpp index 9547039d45b69..c475329340dc3 100644 --- a/sycl/include/sycl/group_algorithm.hpp +++ b/sycl/include/sycl/group_algorithm.hpp @@ -215,9 +215,19 @@ std::enable_if_t<(is_group_v> && detail::is_native_op::value), T> reduce_over_group(Group g, T x, BinaryOperation binary_op) { + +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert((std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else static_assert( std::is_same_v, "Result type of binary_op must match reduction accumulation type."); +#endif + #ifdef __SYCL_DEVICE_ONLY__ #if defined(__NVPTX__) if constexpr (ext::oneapi::experimental::is_user_constructed_group_v) { @@ -291,9 +301,18 @@ std::enable_if_t< std::is_convertible_v), T> reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) { +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert((std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else static_assert( std::is_same_v, "Result type of binary_op must match reduction accumulation type."); +#endif + #ifdef __SYCL_DEVICE_ONLY__ return binary_op(init, reduce_over_group(g, T(x), binary_op)); #else @@ -341,9 +360,18 @@ std::enable_if_t< detail::is_native_op::value), T> joint_reduce(Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op) { +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert((std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else static_assert( std::is_same_v, "Result type of binary_op must match reduction accumulation type."); +#endif + #ifdef __SYCL_DEVICE_ONLY__ T partial = detail::identity_for_ga_op(); sycl::detail::for_each( @@ -679,8 +707,16 @@ std::enable_if_t<(is_group_v> && detail::is_native_op::value), T> exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) { +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert((std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else static_assert(std::is_same_v, "Result type of binary_op must match scan accumulation type."); +#endif #ifdef __SYCL_DEVICE_ONLY__ #if defined(__NVPTX__) if constexpr (ext::oneapi::experimental::is_user_constructed_group_v) { @@ -752,8 +788,16 @@ std::enable_if_t<(is_group_v> && detail::is_native_op::value), T> exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) { +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert((std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else static_assert(std::is_same_v, "Result type of binary_op must match scan accumulation type."); +#endif T result; typename detail::get_scalar_binary_op::type scalar_binary_op{}; @@ -775,8 +819,17 @@ std::enable_if_t< std::is_convertible_v), T> exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op) { +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert((std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else static_assert(std::is_same_v, "Result type of binary_op must match scan accumulation type."); +#endif + #ifdef __SYCL_DEVICE_ONLY__ typename Group::linear_id_type local_linear_id = sycl::detail::get_local_linear_id(g); @@ -831,8 +884,17 @@ std::enable_if_t< OutPtr> joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, T init, BinaryOperation binary_op) { +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert((std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else static_assert(std::is_same_v, "Result type of binary_op must match scan accumulation type."); +#endif + #ifdef __SYCL_DEVICE_ONLY__ ptrdiff_t offset = sycl::detail::get_local_linear_id(g); ptrdiff_t stride = sycl::detail::get_local_linear_range(g); @@ -883,9 +945,33 @@ std::enable_if_t< OutPtr> joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, BinaryOperation binary_op) { +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert( + (std::is_same_v>>> || + std::is_same_v>>>) + ? std::is_same_v>(), + std::remove_cv_t>())), + bool> + : std::is_same_v< + decltype(binary_op( + std::remove_cv_t< + std::remove_reference_t>(), + std::remove_cv_t< + std::remove_reference_t>())), + std::remove_cv_t>>, + "Result type of binary_op must match scan accumulation type."); +#else static_assert(std::is_same_v::type>, "Result type of binary_op must match scan accumulation type."); +#endif using T = typename detail::remove_pointer::type; T init = detail::identity_for_ga_op(); return joint_exclusive_scan(g, first, last, result, init, binary_op); @@ -903,8 +989,19 @@ std::enable_if_t<(is_group_v> && detail::is_native_op::value), T> inclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) { + +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert((std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else + static_assert(std::is_same_v, "Result type of binary_op must match scan accumulation type."); +#endif + #ifdef __SYCL_DEVICE_ONLY__ #if defined(__NVPTX__) if constexpr (ext::oneapi::experimental::is_user_constructed_group_v) { @@ -972,8 +1069,18 @@ std::enable_if_t< std::is_convertible_v), T> inclusive_scan_over_group(Group g, V x, BinaryOperation binary_op, T init) { + +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert((std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else static_assert(std::is_same_v, "Result type of binary_op must match scan accumulation type."); +#endif + #ifdef __SYCL_DEVICE_ONLY__ T y = x; if (sycl::detail::get_local_linear_id(g) == 0) { @@ -1022,8 +1129,17 @@ std::enable_if_t< OutPtr> joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, BinaryOperation binary_op, T init) { +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert((std::is_same_v> || + std::is_same_v>) + ? std::is_same_v + : std::is_same_v, + "Result type of binary_op must match scan accumulation type."); +#else static_assert(std::is_same_v, "Result type of binary_op must match scan accumulation type."); +#endif + #ifdef __SYCL_DEVICE_ONLY__ ptrdiff_t offset = sycl::detail::get_local_linear_id(g); ptrdiff_t stride = sycl::detail::get_local_linear_range(g); @@ -1071,9 +1187,33 @@ std::enable_if_t< OutPtr> joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, BinaryOperation binary_op) { +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + static_assert( + (std::is_same_v>>> || + std::is_same_v>>>) + ? std::is_same_v>(), + std::remove_cv_t>())), + bool> + : std::is_same_v< + decltype(binary_op( + std::remove_cv_t< + std::remove_reference_t>(), + std::remove_cv_t< + std::remove_reference_t>())), + std::remove_cv_t>>, + "Result type of binary_op must match scan accumulation type."); +#else static_assert(std::is_same_v::type>, "Result type of binary_op must match scan accumulation type."); +#endif using T = typename detail::remove_pointer::type; T init = detail::identity_for_ga_op(); From 3f34653efd74a0c78e1333946efc3003a34893eb Mon Sep 17 00:00:00 2001 From: "Klochkov, Denis" Date: Tue, 13 May 2025 11:59:59 +0200 Subject: [PATCH 2/2] [SYCL] use alias for readability --- sycl/include/sycl/group_algorithm.hpp | 50 ++++++++------------------- 1 file changed, 14 insertions(+), 36 deletions(-) diff --git a/sycl/include/sycl/group_algorithm.hpp b/sycl/include/sycl/group_algorithm.hpp index c475329340dc3..808421e0540be 100644 --- a/sycl/include/sycl/group_algorithm.hpp +++ b/sycl/include/sycl/group_algorithm.hpp @@ -946,26 +946,15 @@ std::enable_if_t< joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, BinaryOperation binary_op) { #ifdef __INTEL_PREVIEW_BREAKING_CHANGES + using binary_op_t = + std::remove_cv_t>; static_assert( - (std::is_same_v>>> || - std::is_same_v>>>) - ? std::is_same_v>(), - std::remove_cv_t>())), + (std::is_same_v> || + std::is_same_v>) + ? std::is_same_v - : std::is_same_v< - decltype(binary_op( - std::remove_cv_t< - std::remove_reference_t>(), - std::remove_cv_t< - std::remove_reference_t>())), - std::remove_cv_t>>, + : std::is_same_v, "Result type of binary_op must match scan accumulation type."); #else static_assert(std::is_same_v>; static_assert( - (std::is_same_v>>> || - std::is_same_v>>>) - ? std::is_same_v>(), - std::remove_cv_t>())), + (std::is_same_v> || + std::is_same_v>) + ? std::is_same_v - : std::is_same_v< - decltype(binary_op( - std::remove_cv_t< - std::remove_reference_t>(), - std::remove_cv_t< - std::remove_reference_t>())), - std::remove_cv_t>>, + : std::is_same_v, "Result type of binary_op must match scan accumulation type."); #else static_assert(std::is_same_v