Skip to content

Commit cfb508a

Browse files
[NFCI][SYCL] Unify Dims impls for sycl_exp::info::device::max_work_groups (#18114)
1 parent 79c26f3 commit cfb508a

File tree

1 file changed

+11
-38
lines changed

1 file changed

+11
-38
lines changed

sycl/source/detail/device_info.hpp

+11-38
Original file line numberDiff line numberDiff line change
@@ -1087,55 +1087,28 @@ struct get_device_info_impl<
10871087
return static_cast<size_t>((std::numeric_limits<int>::max)());
10881088
}
10891089
};
1090-
template <>
1090+
template <int Dims>
10911091
struct get_device_info_impl<
1092-
id<1>, ext::oneapi::experimental::info::device::max_work_groups<1>> {
1093-
static id<1> get(const DeviceImplPtr &Dev) {
1094-
size_t result[3];
1092+
id<Dims>, ext::oneapi::experimental::info::device::max_work_groups<Dims>> {
1093+
static id<Dims> get(const DeviceImplPtr &Dev) {
10951094
size_t Limit =
10961095
get_device_info_impl<size_t, ext::oneapi::experimental::info::device::
10971096
max_global_work_groups>::get(Dev);
1098-
Dev->getAdapter()->call<UrApiKind::urDeviceGetInfo>(
1099-
Dev->getHandleRef(),
1100-
UrInfoCode<
1101-
ext::oneapi::experimental::info::device::max_work_groups<3>>::value,
1102-
sizeof(result), &result, nullptr);
1103-
return id<1>(std::min(Limit, result[0]));
1104-
}
1105-
};
11061097

1107-
template <>
1108-
struct get_device_info_impl<
1109-
id<2>, ext::oneapi::experimental::info::device::max_work_groups<2>> {
1110-
static id<2> get(const DeviceImplPtr &Dev) {
11111098
size_t result[3];
1112-
size_t Limit =
1113-
get_device_info_impl<size_t, ext::oneapi::experimental::info::device::
1114-
max_global_work_groups>::get(Dev);
1115-
Dev->getAdapter()->call<UrApiKind::urDeviceGetInfo>(
1116-
Dev->getHandleRef(),
1117-
UrInfoCode<
1118-
ext::oneapi::experimental::info::device::max_work_groups<3>>::value,
1119-
sizeof(result), &result, nullptr);
1120-
return id<2>(std::min(Limit, result[1]), std::min(Limit, result[0]));
1121-
}
1122-
};
1123-
1124-
template <>
1125-
struct get_device_info_impl<
1126-
id<3>, ext::oneapi::experimental::info::device::max_work_groups<3>> {
1127-
static id<3> get(const DeviceImplPtr &Dev) {
1128-
size_t result[3];
1129-
size_t Limit =
1130-
get_device_info_impl<size_t, ext::oneapi::experimental::info::device::
1131-
max_global_work_groups>::get(Dev);
11321099
Dev->getAdapter()->call<UrApiKind::urDeviceGetInfo>(
11331100
Dev->getHandleRef(),
11341101
UrInfoCode<
11351102
ext::oneapi::experimental::info::device::max_work_groups<3>>::value,
11361103
sizeof(result), &result, nullptr);
1137-
return id<3>(std::min(Limit, result[2]), std::min(Limit, result[1]),
1138-
std::min(Limit, result[0]));
1104+
static_assert(1 <= Dims && Dims <= 3);
1105+
if constexpr (Dims == 1)
1106+
return id<1>(std::min(Limit, result[0]));
1107+
else if constexpr (Dims == 2)
1108+
return id<2>(std::min(Limit, result[1]), std::min(Limit, result[0]));
1109+
else
1110+
return id<3>(std::min(Limit, result[2]), std::min(Limit, result[1]),
1111+
std::min(Limit, result[0]));
11391112
}
11401113
};
11411114

0 commit comments

Comments
 (0)