@@ -1087,55 +1087,28 @@ struct get_device_info_impl<
1087
1087
return static_cast <size_t >((std::numeric_limits<int >::max)());
1088
1088
}
1089
1089
};
1090
- template <>
1090
+ template <int Dims >
1091
1091
struct get_device_info_impl <
1092
- id<1 >, ext::oneapi::experimental::info::device::max_work_groups<1 >> {
1093
- static id<1 > get (const DeviceImplPtr &Dev) {
1094
- size_t result[3 ];
1092
+ id<Dims>, ext::oneapi::experimental::info::device::max_work_groups<Dims>> {
1093
+ static id<Dims> get (const DeviceImplPtr &Dev) {
1095
1094
size_t Limit =
1096
1095
get_device_info_impl<size_t , ext::oneapi::experimental::info::device::
1097
1096
max_global_work_groups>::get (Dev);
1098
- Dev->getAdapter ()->call <UrApiKind::urDeviceGetInfo>(
1099
- Dev->getHandleRef (),
1100
- UrInfoCode<
1101
- ext::oneapi::experimental::info::device::max_work_groups<3 >>::value,
1102
- sizeof (result), &result, nullptr );
1103
- return id<1 >(std::min (Limit, result[0 ]));
1104
- }
1105
- };
1106
1097
1107
- template <>
1108
- struct get_device_info_impl <
1109
- id<2 >, ext::oneapi::experimental::info::device::max_work_groups<2 >> {
1110
- static id<2 > get (const DeviceImplPtr &Dev) {
1111
1098
size_t result[3 ];
1112
- size_t Limit =
1113
- get_device_info_impl<size_t , ext::oneapi::experimental::info::device::
1114
- max_global_work_groups>::get (Dev);
1115
- Dev->getAdapter ()->call <UrApiKind::urDeviceGetInfo>(
1116
- Dev->getHandleRef (),
1117
- UrInfoCode<
1118
- ext::oneapi::experimental::info::device::max_work_groups<3 >>::value,
1119
- sizeof (result), &result, nullptr );
1120
- return id<2 >(std::min (Limit, result[1 ]), std::min (Limit, result[0 ]));
1121
- }
1122
- };
1123
-
1124
- template <>
1125
- struct get_device_info_impl <
1126
- id<3 >, ext::oneapi::experimental::info::device::max_work_groups<3 >> {
1127
- static id<3 > get (const DeviceImplPtr &Dev) {
1128
- size_t result[3 ];
1129
- size_t Limit =
1130
- get_device_info_impl<size_t , ext::oneapi::experimental::info::device::
1131
- max_global_work_groups>::get (Dev);
1132
1099
Dev->getAdapter ()->call <UrApiKind::urDeviceGetInfo>(
1133
1100
Dev->getHandleRef (),
1134
1101
UrInfoCode<
1135
1102
ext::oneapi::experimental::info::device::max_work_groups<3 >>::value,
1136
1103
sizeof (result), &result, nullptr );
1137
- return id<3 >(std::min (Limit, result[2 ]), std::min (Limit, result[1 ]),
1138
- std::min (Limit, result[0 ]));
1104
+ static_assert (1 <= Dims && Dims <= 3 );
1105
+ if constexpr (Dims == 1 )
1106
+ return id<1 >(std::min (Limit, result[0 ]));
1107
+ else if constexpr (Dims == 2 )
1108
+ return id<2 >(std::min (Limit, result[1 ]), std::min (Limit, result[0 ]));
1109
+ else
1110
+ return id<3 >(std::min (Limit, result[2 ]), std::min (Limit, result[1 ]),
1111
+ std::min (Limit, result[0 ]));
1139
1112
}
1140
1113
};
1141
1114
0 commit comments