26
26
namespace platform_api__ {
27
27
using namespace sycl_cts ;
28
28
29
+ // Compares two devices by their hash value.
30
+ struct DeviceHashLessT {
31
+ bool operator ()(const sycl::device& lDevice, const sycl::device& rDevice) {
32
+ std::hash<sycl::device> hasher;
33
+ return hasher (lDevice) < hasher (rDevice);
34
+ }
35
+ };
36
+
37
+ // Checks that all devices in a vector are unique.
38
+ inline bool AllDevicesUnique (const std::vector<sycl::device>& devices) {
39
+ std::vector<sycl::device> devicesCopy = devices;
40
+ std::sort (devicesCopy.begin (), devicesCopy.end (), DeviceHashLessT{});
41
+ return std::unique (devicesCopy.begin (), devicesCopy.end ()) ==
42
+ devicesCopy.end ();
43
+ }
44
+
45
+ // Checks that all devices are in the list devices returned by the platform.
46
+ inline bool AllDevicesAreInPlatform (const std::vector<sycl::device>& devices,
47
+ const sycl::platform& platform) {
48
+ std::vector<sycl::device> devicesCopy = devices;
49
+ std::vector<sycl::device> allDevices = platform.get_devices ();
50
+ std::sort (devicesCopy.begin (), devicesCopy.end (), DeviceHashLessT{});
51
+ std::sort (allDevices.begin (), allDevices.end (), DeviceHashLessT{});
52
+ return std::includes (allDevices.begin (), allDevices.end (),
53
+ devicesCopy.begin (), devicesCopy.end (),
54
+ DeviceHashLessT{});
55
+ }
56
+
57
+ // Checks that all devices return the specified device_type when queried.
58
+ inline bool AllDevicesHaveType (const std::vector<sycl::device>& devices,
59
+ sycl::info::device_type devType) {
60
+ return std::all_of (
61
+ devices.begin (), devices.end (), [devType](const sycl::device& device) {
62
+ return device.get_info <sycl::info::device::device_type>() == devType;
63
+ });
64
+ }
65
+
66
+ // Returns the number of devices in the platform with the specified device type.
67
+ inline size_t CountPlatformDevicesWithType (const sycl::platform& platform,
68
+ sycl::info::device_type devType) {
69
+ std::vector<sycl::device> allDevices = platform.get_devices ();
70
+ return std::count_if (
71
+ allDevices.begin (), allDevices.end (),
72
+ [devType](const sycl::device& device) {
73
+ return device.get_info <sycl::info::device::device_type>() == devType;
74
+ });
75
+ }
76
+
29
77
/* * tests the api for sycl::platform
30
78
*/
31
79
class TEST_NAME : public util ::test_base {
@@ -43,23 +91,71 @@ class TEST_NAME : public util::test_base {
43
91
/* * check get_devices() member function
44
92
*/
45
93
{
94
+ INFO (" Checking platform::get_devices()" );
46
95
auto plt = util::get_cts_object::platform (cts_selector);
47
96
auto devs = plt.get_devices ();
48
97
check_return_type<std::vector<sycl::device>>(log, devs,
49
98
" platform::get_devices()" );
99
+ CHECK (AllDevicesUnique (devs));
50
100
}
51
101
52
102
/* * check get_devices(info::device_type::all) member function
53
103
*/
54
104
{
105
+ INFO (" Checking platform::get_devices(info::device_type::all)" );
55
106
auto plt = util::get_cts_object::platform (cts_selector);
56
107
auto devs = plt.get_devices (sycl::info::device_type::all);
57
- if (devs.size () != 0 ) {
58
- check_return_type<std::vector<sycl::device>>(
59
- log, devs, " platform::get_devices(info::device_type::all)" );
108
+ check_return_type<std::vector<sycl::device>>(
109
+ log, devs, " platform::get_devices(info::device_type::all)" );
110
+ CHECK (AllDevicesUnique (devs));
111
+ }
112
+
113
+ /* * check get_devices(info::device_type::automatic) member function
114
+ */
115
+ {
116
+ INFO (" Checking platform::get_devices(info::device_type::automatic)" );
117
+ auto plt = util::get_cts_object::platform (cts_selector);
118
+ auto devs = plt.get_devices (sycl::info::device_type::automatic);
119
+ check_return_type<std::vector<sycl::device>>(
120
+ log, devs, " platform::get_devices(info::device_type::automatic)" );
121
+ if (devs.size () == 0 ) {
122
+ CHECK (plt.get_devices ().size () == 0 );
123
+ } else {
124
+ CHECK (AllDevicesAreInPlatform (devs, plt));
60
125
}
61
126
}
62
127
128
+ /* * check get_devices(info::device_type::<cpu|gpu|accelerator|custom>)
129
+ * member function
130
+ */
131
+ for (sycl::info::device_type devType :
132
+ {sycl::info::device_type::cpu, sycl::info::device_type::gpu,
133
+ sycl::info::device_type::accelerator,
134
+ sycl::info::device_type::custom}) {
135
+ std::string devTypeName = [devType]() {
136
+ switch (devType) {
137
+ case sycl::info::device_type::cpu:
138
+ return " sycl::info::device_type::cpu" ;
139
+ case sycl::info::device_type::gpu:
140
+ return " sycl::info::device_type::gpu" ;
141
+ case sycl::info::device_type::accelerator:
142
+ return " sycl::info::device_type::accelerator" ;
143
+ case sycl::info::device_type::custom:
144
+ return " sycl::info::device_type::custom" ;
145
+ default :
146
+ assert (false && " Missing enumeration!" );
147
+ }
148
+ }();
149
+ INFO (" Checking platform::get_devices(" + devTypeName + " )" );
150
+ auto plt = util::get_cts_object::platform (cts_selector);
151
+ auto devs = plt.get_devices (devType);
152
+ check_return_type<std::vector<sycl::device>>(
153
+ log, devs, " platform::get_devices(" + devTypeName + " )" );
154
+ CHECK (AllDevicesAreInPlatform (devs, plt));
155
+ CHECK (AllDevicesHaveType (devs, devType));
156
+ CHECK (devs.size () == CountPlatformDevicesWithType (plt, devType));
157
+ }
158
+
63
159
/* * check has() member function
64
160
*/
65
161
{
0 commit comments