Skip to content

Commit

Permalink
Fix DriverGet to handle failed drivers and avoid layer init during ch…
Browse files Browse the repository at this point in the history
…ecks

- During init, instrumentation may call driver get before
  uinitialized drivers can be removed. DriverGet has been
  updated to return the driver count only for valid drivers
  and update drivers that are not init to be skipped in
  subsequent DriverGet calls.
- Dont run layer init during checks for driver init to avoid creating
  invalid layer ddi table calls.

Signed-off-by: Neil R. Spruit <[email protected]>
  • Loading branch information
nrspruit committed Jul 9, 2024
1 parent e299374 commit 64ca4f9
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 50 deletions.
14 changes: 12 additions & 2 deletions scripts/templates/ldrddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,13 @@ namespace loader
uint32_t library_driver_handle_count = 0;

result = drv.dditable.${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( &library_driver_handle_count, nullptr );
if( ${X}_RESULT_SUCCESS != result ) break;
if( ${X}_RESULT_SUCCESS != result ) {
// If Get Drivers fails with Uninitialized, then update the driver init status to prevent reporting this driver in the next get call.
if (${X}_RESULT_ERROR_UNINITIALIZED == result) {
drv.initStatus = result;
}
break;
}

if( nullptr != ${obj['params'][1]['name']} && *${obj['params'][0]['name']} !=0)
{
Expand All @@ -109,8 +115,12 @@ namespace loader
total_driver_handle_count += library_driver_handle_count;
}

if( ${X}_RESULT_SUCCESS == result )
// If the last driver get failed, but at least one driver succeeded, then return success with total count.
if( ${X}_RESULT_SUCCESS == result || total_driver_handle_count > 0)
*${obj['params'][0]['name']} = total_driver_handle_count;
if (total_driver_handle_count > 0) {
result = ${X}_RESULT_SUCCESS;
}

%else:
%for i, item in enumerate(th.get_loader_prologue(n, tags, obj, meta)):
Expand Down
6 changes: 0 additions & 6 deletions source/lib/ze_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,6 @@ namespace ze_lib
// Check which drivers support the ze_driver_flag_t specified
// No need to check if only initializing sysman
result = zelLoaderDriverCheck(flags);
// reInit the ze ddi tables after verifying the zeInit() with dummy tables.
// This ensures the tracing and validation layers are pointing to the correct function pointers after init.
if( ZE_RESULT_SUCCESS == result )
{
result = zeInit();
}
}

if( ZE_RESULT_SUCCESS == result )
Expand Down
14 changes: 12 additions & 2 deletions source/loader/ze_ldrddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,13 @@ namespace loader
uint32_t library_driver_handle_count = 0;

result = drv.dditable.ze.Driver.pfnGet( &library_driver_handle_count, nullptr );
if( ZE_RESULT_SUCCESS != result ) break;
if( ZE_RESULT_SUCCESS != result ) {
// If Get Drivers fails with Uninitialized, then update the driver init status to prevent reporting this driver in the next get call.
if (ZE_RESULT_ERROR_UNINITIALIZED == result) {
drv.initStatus = result;
}
break;
}

if( nullptr != phDrivers && *pCount !=0)
{
Expand All @@ -116,8 +122,12 @@ namespace loader
total_driver_handle_count += library_driver_handle_count;
}

if( ZE_RESULT_SUCCESS == result )
// If the last driver get failed, but at least one driver succeeded, then return success with total count.
if( ZE_RESULT_SUCCESS == result || total_driver_handle_count > 0)
*pCount = total_driver_handle_count;
if (total_driver_handle_count > 0) {
result = ZE_RESULT_SUCCESS;
}

return result;
}
Expand Down
40 changes: 0 additions & 40 deletions source/loader/ze_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,46 +200,6 @@ namespace loader
return ZE_RESULT_ERROR_UNINITIALIZED;
}

if(nullptr != validationLayer) {
getTable = reinterpret_cast<ze_pfnGetGlobalProcAddrTable_t>(
GET_FUNCTION_PTR(validationLayer, "zeGetGlobalProcAddrTable") );
if(!getTable) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zeGetGlobalProcAddrTable function pointer null with validation layer. Returning ";
debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED));
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}
getTableResult = getTable( version, &global);
if(getTableResult != ZE_RESULT_SUCCESS) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zeGetGlobalProcAddrTable() with validation layer failed with ";
debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED));
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}
}

if(nullptr != tracingLayer) {
getTable = reinterpret_cast<ze_pfnGetGlobalProcAddrTable_t>(
GET_FUNCTION_PTR(tracingLayer, "zeGetGlobalProcAddrTable") );
if(!getTable) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zeGetGlobalProcAddrTable function pointer null with tracing layer. Returning ";
debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED));
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}
getTableResult = getTable( version, &global);
if(getTableResult != ZE_RESULT_SUCCESS) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zeGetGlobalProcAddrTable() with tracing layer failed with ";
debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED));
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}
}

auto pfnInit = global.pfnInit;
if(nullptr == pfnInit) {
if (debugTraceEnabled) {
Expand Down

0 comments on commit 64ca4f9

Please sign in to comment.