diff --git a/scripts/templates/libapi.cpp.mako b/scripts/templates/libapi.cpp.mako index d0f8487c..36ee5ae9 100644 --- a/scripts/templates/libapi.cpp.mako +++ b/scripts/templates/libapi.cpp.mako @@ -86,6 +86,10 @@ ${th.make_func_name(n, tags, obj)}( return result; }); + if (result != ${X}_RESULT_SUCCESS) { + return result; + } + if(ze_lib::context->inTeardown) { return ${X}_RESULT_ERROR_UNINITIALIZED; } diff --git a/scripts/templates/nullddi.cpp.mako b/scripts/templates/nullddi.cpp.mako index 5a7dc9af..7ff709a0 100644 --- a/scripts/templates/nullddi.cpp.mako +++ b/scripts/templates/nullddi.cpp.mako @@ -130,6 +130,16 @@ ${tbl['export']['name']}( ${x}_result_t result = ${X}_RESULT_SUCCESS; +% if tbl['name'] == 'Global' and n == 'ze': + pDdiTable->pfnInit = driver::zeInit; + + auto missing_api = getenv_string( "ZEL_TEST_MISSING_API" ); + if (std::strcmp(missing_api.c_str(), "zeInitDrivers") == 0) { + pDdiTable->pfnInitDrivers = nullptr; + } else { + pDdiTable->pfnInitDrivers = driver::zeInitDrivers; + } +%else: %for obj in tbl['functions']: %if 'condition' in obj: #if ${th.subt(n, tags, obj['condition'])} @@ -142,6 +152,7 @@ ${tbl['export']['name']}( %endif %endfor +%endif return result; } diff --git a/source/drivers/null/ze_nullddi.cpp b/source/drivers/null/ze_nullddi.cpp index a0e7e30c..e4fd220e 100644 --- a/source/drivers/null/ze_nullddi.cpp +++ b/source/drivers/null/ze_nullddi.cpp @@ -5142,8 +5142,12 @@ zeGetGlobalProcAddrTable( pDdiTable->pfnInit = driver::zeInit; - pDdiTable->pfnInitDrivers = driver::zeInitDrivers; - + auto missing_api = getenv_string( "ZEL_TEST_MISSING_API" ); + if (std::strcmp(missing_api.c_str(), "zeInitDrivers") == 0) { + pDdiTable->pfnInitDrivers = nullptr; + } else { + pDdiTable->pfnInitDrivers = driver::zeInitDrivers; + } return result; } diff --git a/source/lib/ze_libapi.cpp b/source/lib/ze_libapi.cpp index cdeea8db..c93e19c8 100644 --- a/source/lib/ze_libapi.cpp +++ b/source/lib/ze_libapi.cpp @@ -197,6 +197,10 @@ zeInitDrivers( return result; }); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + if(ze_lib::context->inTeardown) { return ZE_RESULT_ERROR_UNINITIALIZED; } diff --git a/source/loader/ze_loader.cpp b/source/loader/ze_loader.cpp index 62d28a30..32efb194 100644 --- a/source/loader/ze_loader.cpp +++ b/source/loader/ze_loader.cpp @@ -145,6 +145,14 @@ namespace loader debug_trace_message(message, ""); } } + // If zeInitDrivers is not supported by this driver, but zeInitDrivers is called first, then return uninitialized. + if (desc && !loader::context->initDriversSupport) { + std::string message = "zeInitDrivers called first, but not supported by driver, returning uninitialized."; + debug_trace_message(message, ""); + return ZE_RESULT_ERROR_UNINITIALIZED; + } + + bool return_first_driver_result=false; std::string initName = "zeInit"; driver_vector_t *drivers = &zeDrivers; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 6dca3648..41eb2dda 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -37,4 +37,6 @@ set_property(TEST tests_both_succeed PROPERTY ENVIRONMENT "ZE_ENABLE_NULL_DRIVER add_test(NAME tests_both_gpu COMMAND tests --gtest_filter=*GivenLevelZeroLoaderPresentWhenCallingzeInitThenZeInitDriversThenBothCallsSucceedWithGPUTypes*) set_property(TEST tests_both_gpu PROPERTY ENVIRONMENT "ZE_ENABLE_NULL_DRIVER=1") add_test(NAME tests_both_npu COMMAND tests --gtest_filter=*GivenLevelZeroLoaderPresentWhenCallingzeInitThenZeInitDriversThenBothCallsSucceedWithNPUTypes*) -set_property(TEST tests_both_npu PROPERTY ENVIRONMENT "ZE_ENABLE_NULL_DRIVER=1") \ No newline at end of file +set_property(TEST tests_both_npu PROPERTY ENVIRONMENT "ZE_ENABLE_NULL_DRIVER=1") +add_test(NAME tests_missing_api COMMAND tests --gtest_filter=*GivenLevelZeroLoaderPresentWhenCallingZeInitDriversThenzeInitWithzeInitDriversUnsupportedOnTheDriverThenzeInitSucceeds*) +set_property(TEST tests_missing_api PROPERTY ENVIRONMENT "ZE_ENABLE_NULL_DRIVER=1") \ No newline at end of file diff --git a/test/loader_api.cpp b/test/loader_api.cpp index 83b15add..6e62ef44 100644 --- a/test/loader_api.cpp +++ b/test/loader_api.cpp @@ -143,6 +143,19 @@ TEST( EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(ZE_INIT_FLAG_GPU_ONLY)); } +TEST( + LoaderInit, + GivenLevelZeroLoaderPresentWhenCallingZeInitDriversThenzeInitWithzeInitDriversUnsupportedOnTheDriverThenzeInitSucceeds) { + + uint32_t pCount = 0; + ze_init_driver_type_desc_t desc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC}; + desc.flags = UINT32_MAX; + desc.pNext = nullptr; + putenv_safe( const_cast( "ZEL_TEST_MISSING_API=zeInitDrivers" ) ); + EXPECT_EQ(ZE_RESULT_ERROR_UNINITIALIZED, zeInitDrivers(&pCount, nullptr, &desc)); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(0)); +} + TEST( LoaderInit, GivenLevelZeroLoaderPresentWhenCallingZeInitDriversThenzeInitThenBothCallsSucceedWithNPUTypes) {