Skip to content

Commit 0272ec2

Browse files
[SYCL] Fix LastDeviceIds assignment for Platform w/o device (#5695)
When platform has no devices LastDeviceIds[] for the platform will stay 0 that will affect device enumeration for the next platforms. Now we do adjustment if no devices for platform was obtained from backend. Testing: issue is found/partly covered by E2E test in llvm-test_suite/SYCL/Regression/device_num.cpp Signed-off-by: Tikhomirova, Kseniya <[email protected]>
1 parent 271ef40 commit 0272ec2

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

sycl/source/detail/platform_impl.cpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,23 @@ platform_impl::get_devices(info::device_type DeviceType) const {
255255
MPlatform, pi::cast<RT::PiDeviceType>(DeviceType), 0,
256256
pi::cast<RT::PiDevice *>(nullptr), &NumDevices);
257257

258-
if (NumDevices == 0)
258+
if (NumDevices == 0) {
259+
// If platform doesn't have devices (even without filter)
260+
// LastDeviceIds[PlatformId] stay 0 that affects next platform devices num
261+
// analysis. Doing adjustment by simple copy of last device num from
262+
// previous platform.
263+
// Needs non const plugin reference.
264+
std::vector<plugin> &Plugins = RT::initialize();
265+
auto It = std::find_if(Plugins.begin(), Plugins.end(),
266+
[&Platform = MPlatform](plugin &Plugin) {
267+
return Plugin.containsPiPlatform(Platform);
268+
});
269+
if (It != Plugins.end()) {
270+
std::lock_guard<std::mutex> Guard(*(It->getPluginMutex()));
271+
(*It).adjustLastDeviceId(MPlatform);
272+
}
259273
return Res;
274+
}
260275

261276
std::vector<RT::PiDevice> PiDevices(NumDevices);
262277
// TODO catch an exception and put it to list of asynchronous exceptions

sycl/source/detail/plugin.hpp

+10
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,16 @@ class plugin {
238238
LastDeviceIds[PlatformId] = Id;
239239
}
240240

241+
// Adjust the id of the last device for the given platform.
242+
// Involved when there is no device on that platform at all.
243+
// The function is expected to be called in a thread safe manner.
244+
void adjustLastDeviceId(RT::PiPlatform Platform) {
245+
int PlatformId = getPlatformId(Platform);
246+
if (PlatformId > 0 &&
247+
LastDeviceIds[PlatformId] < LastDeviceIds[PlatformId - 1])
248+
LastDeviceIds[PlatformId] = LastDeviceIds[PlatformId - 1];
249+
}
250+
241251
bool containsPiPlatform(RT::PiPlatform Platform) {
242252
auto It = std::find(PiPlatforms.begin(), PiPlatforms.end(), Platform);
243253
return It != PiPlatforms.end();

0 commit comments

Comments
 (0)