Skip to content

Commit 71ebbc1

Browse files
authored
Merge pull request #1134 from steffenlarsen/steffen/add_platform_get_devices_checks
Add tests for platform::get_devices()
2 parents 6985423 + f05dd9d commit 71ebbc1

File tree

1 file changed

+99
-3
lines changed

1 file changed

+99
-3
lines changed

tests/platform/platform_api.cpp

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,54 @@
2626
namespace platform_api__ {
2727
using namespace sycl_cts;
2828

29+
// Compares two devices by their hash value.
30+
struct DeviceHashLessT {
31+
bool operator()(const sycl::device& lDevice, const sycl::device& rDevice) {
32+
std::hash<sycl::device> hasher;
33+
return hasher(lDevice) < hasher(rDevice);
34+
}
35+
};
36+
37+
// Checks that all devices in a vector are unique.
38+
inline bool AllDevicesUnique(const std::vector<sycl::device>& devices) {
39+
std::vector<sycl::device> devicesCopy = devices;
40+
std::sort(devicesCopy.begin(), devicesCopy.end(), DeviceHashLessT{});
41+
return std::unique(devicesCopy.begin(), devicesCopy.end()) ==
42+
devicesCopy.end();
43+
}
44+
45+
// Checks that all devices are in the list devices returned by the platform.
46+
inline bool AllDevicesAreInPlatform(const std::vector<sycl::device>& devices,
47+
const sycl::platform& platform) {
48+
std::vector<sycl::device> devicesCopy = devices;
49+
std::vector<sycl::device> allDevices = platform.get_devices();
50+
std::sort(devicesCopy.begin(), devicesCopy.end(), DeviceHashLessT{});
51+
std::sort(allDevices.begin(), allDevices.end(), DeviceHashLessT{});
52+
return std::includes(allDevices.begin(), allDevices.end(),
53+
devicesCopy.begin(), devicesCopy.end(),
54+
DeviceHashLessT{});
55+
}
56+
57+
// Checks that all devices return the specified device_type when queried.
58+
inline bool AllDevicesHaveType(const std::vector<sycl::device>& devices,
59+
sycl::info::device_type devType) {
60+
return std::all_of(
61+
devices.begin(), devices.end(), [devType](const sycl::device& device) {
62+
return device.get_info<sycl::info::device::device_type>() == devType;
63+
});
64+
}
65+
66+
// Returns the number of devices in the platform with the specified device type.
67+
inline size_t CountPlatformDevicesWithType(const sycl::platform& platform,
68+
sycl::info::device_type devType) {
69+
std::vector<sycl::device> allDevices = platform.get_devices();
70+
return std::count_if(
71+
allDevices.begin(), allDevices.end(),
72+
[devType](const sycl::device& device) {
73+
return device.get_info<sycl::info::device::device_type>() == devType;
74+
});
75+
}
76+
2977
/** tests the api for sycl::platform
3078
*/
3179
class TEST_NAME : public util::test_base {
@@ -43,23 +91,71 @@ class TEST_NAME : public util::test_base {
4391
/** check get_devices() member function
4492
*/
4593
{
94+
INFO("Checking platform::get_devices()");
4695
auto plt = util::get_cts_object::platform(cts_selector);
4796
auto devs = plt.get_devices();
4897
check_return_type<std::vector<sycl::device>>(log, devs,
4998
"platform::get_devices()");
99+
CHECK(AllDevicesUnique(devs));
50100
}
51101

52102
/** check get_devices(info::device_type::all) member function
53103
*/
54104
{
105+
INFO("Checking platform::get_devices(info::device_type::all)");
55106
auto plt = util::get_cts_object::platform(cts_selector);
56107
auto devs = plt.get_devices(sycl::info::device_type::all);
57-
if (devs.size() != 0) {
58-
check_return_type<std::vector<sycl::device>>(
59-
log, devs, "platform::get_devices(info::device_type::all)");
108+
check_return_type<std::vector<sycl::device>>(
109+
log, devs, "platform::get_devices(info::device_type::all)");
110+
CHECK(AllDevicesUnique(devs));
111+
}
112+
113+
/** check get_devices(info::device_type::automatic) member function
114+
*/
115+
{
116+
INFO("Checking platform::get_devices(info::device_type::automatic)");
117+
auto plt = util::get_cts_object::platform(cts_selector);
118+
auto devs = plt.get_devices(sycl::info::device_type::automatic);
119+
check_return_type<std::vector<sycl::device>>(
120+
log, devs, "platform::get_devices(info::device_type::automatic)");
121+
if (devs.size() == 0) {
122+
CHECK(plt.get_devices().size() == 0);
123+
} else {
124+
CHECK(AllDevicesAreInPlatform(devs, plt));
60125
}
61126
}
62127

128+
/** check get_devices(info::device_type::<cpu|gpu|accelerator|custom>)
129+
* member function
130+
*/
131+
for (sycl::info::device_type devType :
132+
{sycl::info::device_type::cpu, sycl::info::device_type::gpu,
133+
sycl::info::device_type::accelerator,
134+
sycl::info::device_type::custom}) {
135+
std::string devTypeName = [devType]() {
136+
switch (devType) {
137+
case sycl::info::device_type::cpu:
138+
return "sycl::info::device_type::cpu";
139+
case sycl::info::device_type::gpu:
140+
return "sycl::info::device_type::gpu";
141+
case sycl::info::device_type::accelerator:
142+
return "sycl::info::device_type::accelerator";
143+
case sycl::info::device_type::custom:
144+
return "sycl::info::device_type::custom";
145+
default:
146+
assert(false && "Missing enumeration!");
147+
}
148+
}();
149+
INFO("Checking platform::get_devices(" + devTypeName + ")");
150+
auto plt = util::get_cts_object::platform(cts_selector);
151+
auto devs = plt.get_devices(devType);
152+
check_return_type<std::vector<sycl::device>>(
153+
log, devs, "platform::get_devices("+ devTypeName + ")");
154+
CHECK(AllDevicesAreInPlatform(devs, plt));
155+
CHECK(AllDevicesHaveType(devs, devType));
156+
CHECK(devs.size() == CountPlatformDevicesWithType(plt, devType));
157+
}
158+
63159
/** check has() member function
64160
*/
65161
{

0 commit comments

Comments
 (0)