diff --git a/unified-runtime/source/adapters/level_zero/program.cpp b/unified-runtime/source/adapters/level_zero/program.cpp index a39274ea9bc1..f3c4b2a512e4 100644 --- a/unified-runtime/source/adapters/level_zero/program.cpp +++ b/unified-runtime/source/adapters/level_zero/program.cpp @@ -55,6 +55,26 @@ checkUnresolvedSymbols(ze_module_handle_t ZeModule, } } // extern "C" +static ur_program_handle_t_::CodeFormat matchILCodeFormat(const void *Input, + size_t Length) { + const auto MatchMagicNumber = [&](uint32_t Number) { + return Length >= sizeof(Number) && + std::memcmp(Input, &Number, sizeof(Number)) == 0; + }; + + // SPIR-V Specification: 3.1 Magic Number + // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#Magic + if (MatchMagicNumber(0x07230203)) { + return ur_program_handle_t_::CodeFormat::SPIRV; + } + + return ur_program_handle_t_::CodeFormat::Unknown; +} + +static bool isCodeFormatIL(ur_program_handle_t_::CodeFormat CodeFormat) { + return CodeFormat == ur_program_handle_t_::CodeFormat::SPIRV; +} + namespace ur::level_zero { ur_result_t urProgramCreateWithIL( @@ -70,9 +90,12 @@ ur_result_t urProgramCreateWithIL( ur_program_handle_t *Program) { UR_ASSERT(Context, UR_RESULT_ERROR_INVALID_NULL_HANDLE); UR_ASSERT(IL && Program, UR_RESULT_ERROR_INVALID_NULL_POINTER); + const ur_program_handle_t_::CodeFormat CodeFormat = + matchILCodeFormat(IL, Length); + UR_ASSERT(isCodeFormatIL(CodeFormat), UR_RESULT_ERROR_INVALID_BINARY); try { - ur_program_handle_t_ *UrProgram = - new ur_program_handle_t_(ur_program_handle_t_::IL, Context, IL, Length); + ur_program_handle_t_ *UrProgram = new ur_program_handle_t_( + ur_program_handle_t_::IL, Context, IL, Length, CodeFormat); *Program = reinterpret_cast(UrProgram); } catch (const std::bad_alloc &) { return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; @@ -195,9 +218,17 @@ ur_result_t urProgramBuildExp( auto Code = hProgram->getCode(ZeDevice); UR_ASSERT(Code, UR_RESULT_ERROR_INVALID_PROGRAM); - ZeModuleDesc.format = (State == ur_program_handle_t_::IL) - ? ZE_MODULE_FORMAT_IL_SPIRV - : ZE_MODULE_FORMAT_NATIVE; + switch (hProgram->getCodeFormat(ZeDevice)) { + case ur_program_handle_t_::CodeFormat::SPIRV: + ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV; + break; + case ur_program_handle_t_::CodeFormat::Native: + ZeModuleDesc.format = ZE_MODULE_FORMAT_NATIVE; + break; + default: + assert(false && "Unknown code format"); + return UR_RESULT_ERROR_INVALID_PROGRAM; + } ZeModuleDesc.inputSize = hProgram->getCodeSize(ZeDevice); ZeModuleDesc.pInputModule = Code; ze_context_handle_t ZeContext = hProgram->Context->getZeHandle(); @@ -364,6 +395,8 @@ ur_result_t urProgramLinkExp( // locks simultaneously with "exclusive" access. However, there is no such // code like that, so this is also not a danger. std::vector> Guards(count); + const ur_program_handle_t_::CodeFormat CommonCodeFormat = + phPrograms[0]->getCodeFormat(); for (uint32_t I = 0; I < count; I++) { std::shared_lock Guard(phPrograms[I]->Mutex); Guards[I].swap(Guard); @@ -374,6 +407,13 @@ ur_result_t urProgramLinkExp( return UR_RESULT_ERROR_INVALID_OPERATION; } } + + // The L0 API has no way to represent mixed format modules, + // even though it could be possible to implement linking + // of mixed format modules. + if (phPrograms[I]->getCodeFormat() != CommonCodeFormat) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } } // Previous calls to urProgramCompile did not actually compile the SPIR-V. @@ -406,7 +446,14 @@ ur_result_t urProgramLinkExp( ZeStruct ZeModuleDesc; ZeModuleDesc.pNext = &ZeExtModuleDesc; - ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV; + switch (CommonCodeFormat) { + case ur_program_handle_t_::CodeFormat::SPIRV: + ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV; + break; + default: + assert(false && "Unexpected code format"); + return UR_RESULT_ERROR_INVALID_PROGRAM; + } // This works around a bug in the Level Zero driver. When "ZE_DEBUG=-1", // the driver does validation of the API calls, and it expects @@ -996,11 +1043,13 @@ ur_result_t urProgramSetSpecializationConstants( ur_program_handle_t_::ur_program_handle_t_(state St, ur_context_handle_t Context, - const void *Input, size_t Length) + const void *Input, size_t Length, + CodeFormat CodeFormat) : Context{Context}, NativeProperties{nullptr}, OwnZeModule{true}, - AssociatedDevices(Context->getDevices()), SpirvCode{new uint8_t[Length]}, - SpirvCodeLength{Length} { - std::memcpy(SpirvCode.get(), Input, Length); + AssociatedDevices(Context->getDevices()), ILCode{new uint8_t[Length]}, + ILCodeLength{Length}, ILCodeFormat(CodeFormat) { + assert(isCodeFormatIL(CodeFormat)); + std::memcpy(ILCode.get(), Input, Length); // All devices have the program in IL state. for (auto &Device : Context->getDevices()) { DeviceData &PerDevData = DeviceDataMap[Device->ZeDevice]; diff --git a/unified-runtime/source/adapters/level_zero/program.hpp b/unified-runtime/source/adapters/level_zero/program.hpp index d3af4ce9ddd0..789daf052ba0 100644 --- a/unified-runtime/source/adapters/level_zero/program.hpp +++ b/unified-runtime/source/adapters/level_zero/program.hpp @@ -41,6 +41,12 @@ struct ur_program_handle_t_ : ur_object { Invalid } state; + enum class CodeFormat : uint8_t { + Native, + SPIRV, + Unknown, + }; + // A utility class that converts specialization constants into the form // required by the Level Zero driver. class SpecConstantShim { @@ -68,7 +74,7 @@ struct ur_program_handle_t_ : ur_object { // Construct a program in IL. ur_program_handle_t_(state St, ur_context_handle_t Context, const void *Input, - size_t Length); + size_t Length, CodeFormat Format); // Construct a program in NATIVE for multiple devices. ur_program_handle_t_(state St, ur_context_handle_t Context, @@ -113,28 +119,42 @@ struct ur_program_handle_t_ : ur_object { return DeviceDataMap[ZeDevice].ZeModule; } + CodeFormat getCodeFormat(ze_device_handle_t ZeDevice = nullptr) const { + if (!ZeDevice) + return ILCodeFormat; + + auto It = DeviceDataMap.find(ZeDevice); + if (It == DeviceDataMap.end()) + return ILCodeFormat; + + if (It->second.State == state::IL) + return ILCodeFormat; + else + return CodeFormat::Native; + } + uint8_t *getCode(ze_device_handle_t ZeDevice = nullptr) { if (!ZeDevice) - return SpirvCode.get(); + return ILCode.get(); if (DeviceDataMap.find(ZeDevice) == DeviceDataMap.end()) return nullptr; if (DeviceDataMap[ZeDevice].State == state::IL) - return SpirvCode.get(); + return ILCode.get(); else return DeviceDataMap[ZeDevice].Binary.first.get(); } size_t getCodeSize(ze_device_handle_t ZeDevice = nullptr) { if (ZeDevice == nullptr) - return SpirvCodeLength; + return ILCodeLength; if (DeviceDataMap.find(ZeDevice) == DeviceDataMap.end()) return 0; if (DeviceDataMap[ZeDevice].State == state::IL) - return SpirvCodeLength; + return ILCodeLength; else return DeviceDataMap[ZeDevice].Binary.second; } @@ -233,8 +253,9 @@ struct ur_program_handle_t_ : ur_object { // In IL and Object states, this contains the SPIR-V representation of the // module. - std::unique_ptr SpirvCode; // Array containing raw IL code. - size_t SpirvCodeLength = 0; // Size (bytes) of the array. + std::unique_ptr ILCode; // Array containing raw IL code. + size_t ILCodeLength = 0; // Size (bytes) of the array. + CodeFormat ILCodeFormat = CodeFormat::Unknown; // Format of the IL code. // The Level Zero module handle for interoperability. // This module handle is either initialized with the handle provided to