From 7a95289fbf5a406a5ef61a5e640d48ec183e53aa Mon Sep 17 00:00:00 2001 From: Gergely Meszaros Date: Fri, 9 May 2025 01:06:08 -0700 Subject: [PATCH 1/2] [UR][L0] Refactor IL code handling allowing future extension Refactor the code to make it easier to add support for different IL formats (besides SPIR-V) in the future. The only functional change is that SPIR-V binaries with invalid magic number are now rejected by returning UR_RESULT_ERROR_INVALID_BINARY from urProgramCreateWithIL. --- .../source/adapters/level_zero/program.cpp | 69 ++++++++++++++++--- .../source/adapters/level_zero/program.hpp | 35 ++++++++-- 2 files changed, 87 insertions(+), 17 deletions(-) diff --git a/unified-runtime/source/adapters/level_zero/program.cpp b/unified-runtime/source/adapters/level_zero/program.cpp index a39274ea9bc10..fce41600c1d9f 100644 --- a/unified-runtime/source/adapters/level_zero/program.cpp +++ b/unified-runtime/source/adapters/level_zero/program.cpp @@ -55,6 +55,26 @@ checkUnresolvedSymbols(ze_module_handle_t ZeModule, } } // extern "C" +static ur_program_handle_t_::CodeFormat matchILCodeFormat(const void *Input, + size_t Length) { + const auto MatchMagicNumber = [&](uint32_t Number) { + return Length >= sizeof(Number) && + std::memcmp(Input, &Number, sizeof(Number)) == 0; + }; + + // SPIR-V Specification: 3.1 Magic Number + // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#Magic + if (MatchMagicNumber(0x07230203)) { + return ur_program_handle_t_::CodeFormat::SPIRV; + } + + return ur_program_handle_t_::CodeFormat::Unknown; +} + +static bool isCodeFormatIL(ur_program_handle_t_::CodeFormat CodeFormat) { + return CodeFormat == ur_program_handle_t_::CodeFormat::SPIRV; +} + namespace ur::level_zero { ur_result_t urProgramCreateWithIL( @@ -70,9 +90,12 @@ ur_result_t urProgramCreateWithIL( ur_program_handle_t *Program) { UR_ASSERT(Context, UR_RESULT_ERROR_INVALID_NULL_HANDLE); UR_ASSERT(IL && Program, UR_RESULT_ERROR_INVALID_NULL_POINTER); + const ur_program_handle_t_::CodeFormat CodeFormat = + matchILCodeFormat(IL, Length); + UR_ASSERT(isCodeFormatIL(CodeFormat), UR_RESULT_ERROR_INVALID_BINARY); try { - ur_program_handle_t_ *UrProgram = - new ur_program_handle_t_(ur_program_handle_t_::IL, Context, IL, Length); + ur_program_handle_t_ *UrProgram = new ur_program_handle_t_( + ur_program_handle_t_::IL, Context, IL, Length, CodeFormat); *Program = reinterpret_cast(UrProgram); } catch (const std::bad_alloc &) { return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; @@ -195,9 +218,17 @@ ur_result_t urProgramBuildExp( auto Code = hProgram->getCode(ZeDevice); UR_ASSERT(Code, UR_RESULT_ERROR_INVALID_PROGRAM); - ZeModuleDesc.format = (State == ur_program_handle_t_::IL) - ? ZE_MODULE_FORMAT_IL_SPIRV - : ZE_MODULE_FORMAT_NATIVE; + switch (hProgram->getCodeFormat(ZeDevice)) { + case ur_program_handle_t_::CodeFormat::SPIRV: + ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV; + break; + case ur_program_handle_t_::CodeFormat::Native: + ZeModuleDesc.format = ZE_MODULE_FORMAT_NATIVE; + break; + default: + assert(false && "Unknown code format"); + return UR_RESULT_ERROR_INVALID_PROGRAM; + } ZeModuleDesc.inputSize = hProgram->getCodeSize(ZeDevice); ZeModuleDesc.pInputModule = Code; ze_context_handle_t ZeContext = hProgram->Context->getZeHandle(); @@ -364,6 +395,8 @@ ur_result_t urProgramLinkExp( // locks simultaneously with "exclusive" access. However, there is no such // code like that, so this is also not a danger. std::vector> Guards(count); + const ur_program_handle_t_::CodeFormat CommonCodeFormat = + phPrograms[0]->getCodeFormat(); for (uint32_t I = 0; I < count; I++) { std::shared_lock Guard(phPrograms[I]->Mutex); Guards[I].swap(Guard); @@ -374,6 +407,13 @@ ur_result_t urProgramLinkExp( return UR_RESULT_ERROR_INVALID_OPERATION; } } + + // The L0 API has no way to represent mixed format modules, + // even though it could be possible to implement linking + // of mixed format modules. + if (phPrograms[I]->getCodeFormat() != CommonCodeFormat) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } } // Previous calls to urProgramCompile did not actually compile the SPIR-V. @@ -406,7 +446,14 @@ ur_result_t urProgramLinkExp( ZeStruct ZeModuleDesc; ZeModuleDesc.pNext = &ZeExtModuleDesc; - ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV; + switch(CommonCodeFormat) { + case ur_program_handle_t_::CodeFormat::SPIRV: + ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV; + break; + default: + assert(false && "Unexpected code format"); + return UR_RESULT_ERROR_INVALID_PROGRAM; + } // This works around a bug in the Level Zero driver. When "ZE_DEBUG=-1", // the driver does validation of the API calls, and it expects @@ -996,11 +1043,13 @@ ur_result_t urProgramSetSpecializationConstants( ur_program_handle_t_::ur_program_handle_t_(state St, ur_context_handle_t Context, - const void *Input, size_t Length) + const void *Input, size_t Length, + CodeFormat CodeFormat) : Context{Context}, NativeProperties{nullptr}, OwnZeModule{true}, - AssociatedDevices(Context->getDevices()), SpirvCode{new uint8_t[Length]}, - SpirvCodeLength{Length} { - std::memcpy(SpirvCode.get(), Input, Length); + AssociatedDevices(Context->getDevices()), ILCode{new uint8_t[Length]}, + ILCodeLength{Length}, ILCodeFormat(CodeFormat) { + assert(isCodeFormatIL(CodeFormat)); + std::memcpy(ILCode.get(), Input, Length); // All devices have the program in IL state. for (auto &Device : Context->getDevices()) { DeviceData &PerDevData = DeviceDataMap[Device->ZeDevice]; diff --git a/unified-runtime/source/adapters/level_zero/program.hpp b/unified-runtime/source/adapters/level_zero/program.hpp index d3af4ce9ddd04..789daf052ba0c 100644 --- a/unified-runtime/source/adapters/level_zero/program.hpp +++ b/unified-runtime/source/adapters/level_zero/program.hpp @@ -41,6 +41,12 @@ struct ur_program_handle_t_ : ur_object { Invalid } state; + enum class CodeFormat : uint8_t { + Native, + SPIRV, + Unknown, + }; + // A utility class that converts specialization constants into the form // required by the Level Zero driver. class SpecConstantShim { @@ -68,7 +74,7 @@ struct ur_program_handle_t_ : ur_object { // Construct a program in IL. ur_program_handle_t_(state St, ur_context_handle_t Context, const void *Input, - size_t Length); + size_t Length, CodeFormat Format); // Construct a program in NATIVE for multiple devices. ur_program_handle_t_(state St, ur_context_handle_t Context, @@ -113,28 +119,42 @@ struct ur_program_handle_t_ : ur_object { return DeviceDataMap[ZeDevice].ZeModule; } + CodeFormat getCodeFormat(ze_device_handle_t ZeDevice = nullptr) const { + if (!ZeDevice) + return ILCodeFormat; + + auto It = DeviceDataMap.find(ZeDevice); + if (It == DeviceDataMap.end()) + return ILCodeFormat; + + if (It->second.State == state::IL) + return ILCodeFormat; + else + return CodeFormat::Native; + } + uint8_t *getCode(ze_device_handle_t ZeDevice = nullptr) { if (!ZeDevice) - return SpirvCode.get(); + return ILCode.get(); if (DeviceDataMap.find(ZeDevice) == DeviceDataMap.end()) return nullptr; if (DeviceDataMap[ZeDevice].State == state::IL) - return SpirvCode.get(); + return ILCode.get(); else return DeviceDataMap[ZeDevice].Binary.first.get(); } size_t getCodeSize(ze_device_handle_t ZeDevice = nullptr) { if (ZeDevice == nullptr) - return SpirvCodeLength; + return ILCodeLength; if (DeviceDataMap.find(ZeDevice) == DeviceDataMap.end()) return 0; if (DeviceDataMap[ZeDevice].State == state::IL) - return SpirvCodeLength; + return ILCodeLength; else return DeviceDataMap[ZeDevice].Binary.second; } @@ -233,8 +253,9 @@ struct ur_program_handle_t_ : ur_object { // In IL and Object states, this contains the SPIR-V representation of the // module. - std::unique_ptr SpirvCode; // Array containing raw IL code. - size_t SpirvCodeLength = 0; // Size (bytes) of the array. + std::unique_ptr ILCode; // Array containing raw IL code. + size_t ILCodeLength = 0; // Size (bytes) of the array. + CodeFormat ILCodeFormat = CodeFormat::Unknown; // Format of the IL code. // The Level Zero module handle for interoperability. // This module handle is either initialized with the handle provided to From 701a320a521b0475b6fe4b78ecab6a8e6ff37869 Mon Sep 17 00:00:00 2001 From: Gergely Meszaros Date: Wed, 14 May 2025 00:19:43 -0700 Subject: [PATCH 2/2] Apply clang-format --- unified-runtime/source/adapters/level_zero/program.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unified-runtime/source/adapters/level_zero/program.cpp b/unified-runtime/source/adapters/level_zero/program.cpp index fce41600c1d9f..f3c4b2a512e44 100644 --- a/unified-runtime/source/adapters/level_zero/program.cpp +++ b/unified-runtime/source/adapters/level_zero/program.cpp @@ -446,7 +446,7 @@ ur_result_t urProgramLinkExp( ZeStruct ZeModuleDesc; ZeModuleDesc.pNext = &ZeExtModuleDesc; - switch(CommonCodeFormat) { + switch (CommonCodeFormat) { case ur_program_handle_t_::CodeFormat::SPIRV: ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV; break;