Skip to content

Reduce-then-scan path for SYCL scan-by-segment implementations #2315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from

Conversation

mmichel11
Copy link
Contributor

@mmichel11 mmichel11 commented Jun 18, 2025

This PR adds a path for inclusive_scan_by_segment and exclusive_scan_by_segment using our reduce-then-scan infrastructure to improve performance. Additionally, some reorganization of existing "fallback" implementations is required

Summary of changes:

  • __parallel_scan_by_segment_reduce_then_scan and associated reduce-then-scan building blocks are added to implement scan-by-segment. The implementation is similar but unique from reduce-by-segment.
  • The high level scan-by-segment fallback has been re-implemented at the parallel pattern level to consolidate the location of SYCL backend implementations.
  • __segmented_scan_fun and __replace_if_fun have been moved to our general utilities as they are shared between implementations in pstl/hetero/dpcpp and internal directories.

@mmichel11 mmichel11 requested a review from Copilot June 18, 2025 15:38
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This draft adds a reduce-then-scan execution path and supporting functors for SYCL-based segmented scans, updates device‐copyable trait specializations, and adapts internal APIs to use the new functors.

  • Introduced __segmented_scan_fun and __replace_if_fun in utils.h and removed legacy duplicates.
  • Extended SYCL traits and tests for device‐copyable checks of the new functors and segment‐scan inputs/writes.
  • Added reduce-then-scan and fallback implementations in the SYCL backends and updated the high-level __pattern_scan_by_segment API.

Reviewed Changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
test/general/implementation_details/device_copyable.pass.cpp Added device-copyable assertions for new functors and inputs.
include/oneapi/dpl/pstl/utils.h Added __segmented_scan_fun & __replace_if_fun, included tuple_impl.
include/oneapi/dpl/pstl/tuple_impl.h Removed circular include of utils.h.
include/oneapi/dpl/pstl/hetero/dpcpp/sycl_traits.h Forward-declared new functors and added SYCL trait specializations.
include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_scan_by_segment.h Split __parallel_scan_by_segment into reduce-then-scan and fallback paths.
include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce_then_scan.h Implemented write functors for inclusive/exclusive segment scans.
include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h Updated __pattern_scan_by_segment to use new init wrappers.
include/oneapi/dpl/internal/reduce_by_segment_impl.h Switched to __segmented_scan_fun.
include/oneapi/dpl/internal/inclusive_scan_by_segment_impl.h Switched to __segmented_scan_fun.
include/oneapi/dpl/internal/exclusive_scan_by_segment_impl.h Switched to __replace_if_fun and __segmented_scan_fun.
include/oneapi/dpl/internal/function.h Removed legacy replace_if_fun & segmented_scan_fun.
Comments suppressed due to low confidence (2)

include/oneapi/dpl/pstl/utils.h:985

  • [nitpick] Add a brief doc comment above __segmented_scan_fun to explain its role and tuple layout, improving maintainability.
template <typename _ValueType, typename _FlagType, typename _BinaryOp>

include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h:2206

  • The static_assert requiring a known identity was removed. Confirm that all binary operators without a known identity are still handled correctly or consider reintroducing a guard to prevent misuse.
    __bknd::__parallel_scan_by_segment<_Inclusive::value>(

@mmichel11 mmichel11 marked this pull request as ready for review June 18, 2025 20:05
@mmichel11 mmichel11 changed the title [Draft] Reduce-then-scan path for SYCL scan-by-segment implementations Reduce-then-scan path for SYCL scan-by-segment implementations Jun 18, 2025
@mmichel11 mmichel11 force-pushed the dev/mmichel11/rts_scan_by_segment branch from e584025 to 5578d5b Compare June 30, 2025 14:27
{
oneapi::dpl::unseq_backend::__no_init_value<oneapi::dpl::__internal::tuple<_FlagType, _ValueType>>
__wrapped_init;
using _WriteOp = __write_scan_by_seg<__is_inclusive, decltype(__wrapped_init), _BinaryOperator>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type alias duplicates in two branches.
May be make sense to move it up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little more complicated as we need a std::conditional_t, but I agree and made this change.

@SergeyKopienko
Copy link
Contributor

@mmichel11 could you please find in the code overruning (in comments) and replace (to overrunning) ?
It's not from this PR but we have good chance to fix this spelling check error.

// This is because init handling must occur on a per-segment basis and functions differently than the typical scan init.
oneapi::dpl::unseq_backend::__init_value<oneapi::dpl::__internal::tuple<_FlagType, _ValueType>> __wrapped_init{
{0, __init.__value}};
using _WriteOp = __write_scan_by_seg<__is_inclusive, decltype(__wrapped_init), _BinaryOperator>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a preference thing, but it might be cleaner to define a type alias for oneapi::dpl::unseq_backend::__init_value<oneapi::dpl::__internal::tuple<_FlagType, _ValueType>> and use that both in the declaration above this line and in this template.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I combined this with Sergey's suggestion to move it outside of the branch and added an alias for the tuple type since it is used elsewhere as well.

@mmichel11 mmichel11 added this to the 2022.10.0 milestone Jul 1, 2025
mmichel11 added 11 commits July 13, 2025 20:34
* Re-implements the high level fallback with parallel patterns and moves
  implementation to dpcpp backend
* Move all of the path selection logic to the dpcpp backend
* Add policy wrappers where needed to prevent duplicate kernel names

Signed-off-by: Matthew Michel <[email protected]>
Signed-off-by: Matthew Michel <[email protected]>
@mmichel11 mmichel11 force-pushed the dev/mmichel11/rts_scan_by_segment branch from c209173 to a121fe2 Compare July 14, 2025 01:35

template <typename _T1, typename _T2>
_T
operator()(_T1&& __a, _T2&& __s) const
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
operator()(_T1&& __a, _T2&& __s) const
operator()(const _T1& __a, _T2&& __s) const

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just relocated code, but I don't have a problem changing it.

From my perspective, operator()(_T1&& __a, const _T2& __s) const makes the most sense. __s is just passed to a unary predicate, so a const ref is all we need here. __a may be returned, so I think a forwarding reference makes sense for this as it can be moved in and out of the function object.

_T
operator()(_T1&& __a, _T2&& __s) const
{
return __pred(__s) ? __new_value : __a;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return __pred(__s) ? __new_value : __a;
return __pred(std::forward<_T2>(__s)) ? __new_value : __a;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the above comment for the change I made here.

template <typename _CustomName, bool __is_inclusive, typename _Range1, typename _Range2, typename _Range3,
typename _BinaryPredicate, typename _BinaryOperator, typename _InitType>
__future<sycl::event, __result_and_scratch_storage<
oneapi::dpl::__internal::tuple<std::uint32_t, oneapi::dpl::__internal::__value_t<_Range2>>>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure about this return type?
For example, __parallel_transform_reduce_then_scan return type is __future<sycl::event, __result_and_scratch_storage<typename _InitType::__value_type>>
So what is it in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because we are using the _WrappedInitType defined within this function to call transform reduce-then-scan. It ends up being a tuple of the flag and value type that is passed as init to __parallel_transform_reduce_then_scan which is reflected in the return type.

using _TempData = __noop_temp_data;
_InitType __init_value;
_BinaryOp __binary_op;
template <typename _OutRng, typename _ValueType>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
template <typename _OutRng, typename _ValueType>
template <typename _OutRng, typename _ValueType>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 241 to 247
if (get<1>(__v))
__out_rng[__id] = static_cast<_ConvertedTupleType>(get<1>(__init_value.__value));
else
{
__out_rng[__id] =
static_cast<_ConvertedTupleType>(__binary_op(get<1>(__init_value.__value), get<1>(get<0>(__v))));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about ternary operator here?

Suggested change
if (get<1>(__v))
__out_rng[__id] = static_cast<_ConvertedTupleType>(get<1>(__init_value.__value));
else
{
__out_rng[__id] =
static_cast<_ConvertedTupleType>(__binary_op(get<1>(__init_value.__value), get<1>(get<0>(__v))));
}
__out_rng[__id] = get<1>(__v)
? static_cast<_ConvertedTupleType>(get<1>(__init_value.__value))
: static_cast<_ConvertedTupleType>(__binary_op(get<1>(__init_value.__value), get<1>(get<0>(__v))));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it's cleaner, done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants