-
Notifications
You must be signed in to change notification settings - Fork 116
Reduce-then-scan path for SYCL scan-by-segment implementations #2315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This draft adds a reduce-then-scan execution path and supporting functors for SYCL-based segmented scans, updates device‐copyable trait specializations, and adapts internal APIs to use the new functors.
- Introduced
__segmented_scan_fun
and__replace_if_fun
inutils.h
and removed legacy duplicates. - Extended SYCL traits and tests for device‐copyable checks of the new functors and segment‐scan inputs/writes.
- Added reduce-then-scan and fallback implementations in the SYCL backends and updated the high-level
__pattern_scan_by_segment
API.
Reviewed Changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 3 comments.
Show a summary per file
File | Description |
---|---|
test/general/implementation_details/device_copyable.pass.cpp | Added device-copyable assertions for new functors and inputs. |
include/oneapi/dpl/pstl/utils.h | Added __segmented_scan_fun & __replace_if_fun , included tuple_impl. |
include/oneapi/dpl/pstl/tuple_impl.h | Removed circular include of utils.h. |
include/oneapi/dpl/pstl/hetero/dpcpp/sycl_traits.h | Forward-declared new functors and added SYCL trait specializations. |
include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_scan_by_segment.h | Split __parallel_scan_by_segment into reduce-then-scan and fallback paths. |
include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce_then_scan.h | Implemented write functors for inclusive/exclusive segment scans. |
include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h | Updated __pattern_scan_by_segment to use new init wrappers. |
include/oneapi/dpl/internal/reduce_by_segment_impl.h | Switched to __segmented_scan_fun . |
include/oneapi/dpl/internal/inclusive_scan_by_segment_impl.h | Switched to __segmented_scan_fun . |
include/oneapi/dpl/internal/exclusive_scan_by_segment_impl.h | Switched to __replace_if_fun and __segmented_scan_fun . |
include/oneapi/dpl/internal/function.h | Removed legacy replace_if_fun & segmented_scan_fun . |
Comments suppressed due to low confidence (2)
include/oneapi/dpl/pstl/utils.h:985
- [nitpick] Add a brief doc comment above
__segmented_scan_fun
to explain its role and tuple layout, improving maintainability.
template <typename _ValueType, typename _FlagType, typename _BinaryOp>
include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h:2206
- The static_assert requiring a known identity was removed. Confirm that all binary operators without a known identity are still handled correctly or consider reintroducing a guard to prevent misuse.
__bknd::__parallel_scan_by_segment<_Inclusive::value>(
include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_scan_by_segment.h
Show resolved
Hide resolved
include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce_then_scan.h
Outdated
Show resolved
Hide resolved
e584025
to
5578d5b
Compare
{ | ||
oneapi::dpl::unseq_backend::__no_init_value<oneapi::dpl::__internal::tuple<_FlagType, _ValueType>> | ||
__wrapped_init; | ||
using _WriteOp = __write_scan_by_seg<__is_inclusive, decltype(__wrapped_init), _BinaryOperator>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This type alias duplicates in two branches.
May be make sense to move it up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a little more complicated as we need a std::conditional_t
, but I agree and made this change.
@mmichel11 could you please find in the code |
// This is because init handling must occur on a per-segment basis and functions differently than the typical scan init. | ||
oneapi::dpl::unseq_backend::__init_value<oneapi::dpl::__internal::tuple<_FlagType, _ValueType>> __wrapped_init{ | ||
{0, __init.__value}}; | ||
using _WriteOp = __write_scan_by_seg<__is_inclusive, decltype(__wrapped_init), _BinaryOperator>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a preference thing, but it might be cleaner to define a type alias for oneapi::dpl::unseq_backend::__init_value<oneapi::dpl::__internal::tuple<_FlagType, _ValueType>>
and use that both in the declaration above this line and in this template.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I combined this with Sergey's suggestion to move it outside of the branch and added an alias for the tuple type since it is used elsewhere as well.
Signed-off-by: Matthew Michel <[email protected]>
Signed-off-by: Matthew Michel <[email protected]>
* Re-implements the high level fallback with parallel patterns and moves implementation to dpcpp backend * Move all of the path selection logic to the dpcpp backend * Add policy wrappers where needed to prevent duplicate kernel names Signed-off-by: Matthew Michel <[email protected]>
Signed-off-by: Matthew Michel <[email protected]>
Signed-off-by: Matthew Michel <[email protected]>
Signed-off-by: Matthew Michel <[email protected]>
Signed-off-by: Matthew Michel <[email protected]>
Signed-off-by: Matthew Michel <[email protected]>
Signed-off-by: Matthew Michel <[email protected]>
Signed-off-by: Matthew Michel <[email protected]>
Signed-off-by: Matthew Michel <[email protected]>
Signed-off-by: Matthew Michel <[email protected]>
Signed-off-by: Matthew Michel <[email protected]>
Signed-off-by: Matthew Michel <[email protected]>
Signed-off-by: Matthew Michel <[email protected]>
Signed-off-by: Matthew Michel <[email protected]>
…ug fix Signed-off-by: Matthew Michel <[email protected]>
c209173
to
a121fe2
Compare
include/oneapi/dpl/pstl/utils.h
Outdated
|
||
template <typename _T1, typename _T2> | ||
_T | ||
operator()(_T1&& __a, _T2&& __s) const |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
operator()(_T1&& __a, _T2&& __s) const | |
operator()(const _T1& __a, _T2&& __s) const |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just relocated code, but I don't have a problem changing it.
From my perspective, operator()(_T1&& __a, const _T2& __s) const
makes the most sense. __s
is just passed to a unary predicate, so a const ref is all we need here. __a
may be returned, so I think a forwarding reference makes sense for this as it can be moved in and out of the function object.
include/oneapi/dpl/pstl/utils.h
Outdated
_T | ||
operator()(_T1&& __a, _T2&& __s) const | ||
{ | ||
return __pred(__s) ? __new_value : __a; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return __pred(__s) ? __new_value : __a; | |
return __pred(std::forward<_T2>(__s)) ? __new_value : __a; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the above comment for the change I made here.
template <typename _CustomName, bool __is_inclusive, typename _Range1, typename _Range2, typename _Range3, | ||
typename _BinaryPredicate, typename _BinaryOperator, typename _InitType> | ||
__future<sycl::event, __result_and_scratch_storage< | ||
oneapi::dpl::__internal::tuple<std::uint32_t, oneapi::dpl::__internal::__value_t<_Range2>>>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure about this return type?
For example, __parallel_transform_reduce_then_scan
return type is __future<sycl::event, __result_and_scratch_storage<typename _InitType::__value_type>>
So what is it in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, because we are using the _WrappedInitType
defined within this function to call transform reduce-then-scan. It ends up being a tuple of the flag and value type that is passed as init to __parallel_transform_reduce_then_scan
which is reflected in the return type.
using _TempData = __noop_temp_data; | ||
_InitType __init_value; | ||
_BinaryOp __binary_op; | ||
template <typename _OutRng, typename _ValueType> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
template <typename _OutRng, typename _ValueType> | |
template <typename _OutRng, typename _ValueType> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
if (get<1>(__v)) | ||
__out_rng[__id] = static_cast<_ConvertedTupleType>(get<1>(__init_value.__value)); | ||
else | ||
{ | ||
__out_rng[__id] = | ||
static_cast<_ConvertedTupleType>(__binary_op(get<1>(__init_value.__value), get<1>(get<0>(__v)))); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about ternary operator here?
if (get<1>(__v)) | |
__out_rng[__id] = static_cast<_ConvertedTupleType>(get<1>(__init_value.__value)); | |
else | |
{ | |
__out_rng[__id] = | |
static_cast<_ConvertedTupleType>(__binary_op(get<1>(__init_value.__value), get<1>(get<0>(__v)))); | |
} | |
__out_rng[__id] = get<1>(__v) | |
? static_cast<_ConvertedTupleType>(get<1>(__init_value.__value)) | |
: static_cast<_ConvertedTupleType>(__binary_op(get<1>(__init_value.__value), get<1>(get<0>(__v)))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree it's cleaner, done.
Signed-off-by: Matthew Michel <[email protected]>
…ment Signed-off-by: Matthew Michel <[email protected]>
This PR adds a path for
inclusive_scan_by_segment
andexclusive_scan_by_segment
using our reduce-then-scan infrastructure to improve performance. Additionally, some reorganization of existing "fallback" implementations is requiredSummary of changes:
__parallel_scan_by_segment_reduce_then_scan
and associated reduce-then-scan building blocks are added to implement scan-by-segment. The implementation is similar but unique from reduce-by-segment.__segmented_scan_fun
and__replace_if_fun
have been moved to our general utilities as they are shared between implementations inpstl/hetero/dpcpp
andinternal
directories.