Skip to content

[UR][L0] Refactor IL code handling allowing future extension #18441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 59 additions & 10 deletions unified-runtime/source/adapters/level_zero/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,26 @@ checkUnresolvedSymbols(ze_module_handle_t ZeModule,
}
} // extern "C"

static ur_program_handle_t_::CodeFormat matchILCodeFormat(const void *Input,
size_t Length) {
const auto MatchMagicNumber = [&](uint32_t Number) {
return Length >= sizeof(Number) &&
std::memcmp(Input, &Number, sizeof(Number)) == 0;
};

// SPIR-V Specification: 3.1 Magic Number
// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#Magic
if (MatchMagicNumber(0x07230203)) {
return ur_program_handle_t_::CodeFormat::SPIRV;
}

return ur_program_handle_t_::CodeFormat::Unknown;
}

static bool isCodeFormatIL(ur_program_handle_t_::CodeFormat CodeFormat) {
return CodeFormat == ur_program_handle_t_::CodeFormat::SPIRV;
}

namespace ur::level_zero {

ur_result_t urProgramCreateWithIL(
Expand All @@ -70,9 +90,12 @@ ur_result_t urProgramCreateWithIL(
ur_program_handle_t *Program) {
UR_ASSERT(Context, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
UR_ASSERT(IL && Program, UR_RESULT_ERROR_INVALID_NULL_POINTER);
const ur_program_handle_t_::CodeFormat CodeFormat =
matchILCodeFormat(IL, Length);
UR_ASSERT(isCodeFormatIL(CodeFormat), UR_RESULT_ERROR_INVALID_BINARY);
try {
ur_program_handle_t_ *UrProgram =
new ur_program_handle_t_(ur_program_handle_t_::IL, Context, IL, Length);
ur_program_handle_t_ *UrProgram = new ur_program_handle_t_(
ur_program_handle_t_::IL, Context, IL, Length, CodeFormat);
*Program = reinterpret_cast<ur_program_handle_t>(UrProgram);
} catch (const std::bad_alloc &) {
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
Expand Down Expand Up @@ -195,9 +218,17 @@ ur_result_t urProgramBuildExp(
auto Code = hProgram->getCode(ZeDevice);
UR_ASSERT(Code, UR_RESULT_ERROR_INVALID_PROGRAM);

ZeModuleDesc.format = (State == ur_program_handle_t_::IL)
? ZE_MODULE_FORMAT_IL_SPIRV
: ZE_MODULE_FORMAT_NATIVE;
switch (hProgram->getCodeFormat(ZeDevice)) {
case ur_program_handle_t_::CodeFormat::SPIRV:
ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV;
break;
case ur_program_handle_t_::CodeFormat::Native:
ZeModuleDesc.format = ZE_MODULE_FORMAT_NATIVE;
break;
default:
assert(false && "Unknown code format");
return UR_RESULT_ERROR_INVALID_PROGRAM;
}
ZeModuleDesc.inputSize = hProgram->getCodeSize(ZeDevice);
ZeModuleDesc.pInputModule = Code;
ze_context_handle_t ZeContext = hProgram->Context->getZeHandle();
Expand Down Expand Up @@ -364,6 +395,8 @@ ur_result_t urProgramLinkExp(
// locks simultaneously with "exclusive" access. However, there is no such
// code like that, so this is also not a danger.
std::vector<std::shared_lock<ur_shared_mutex>> Guards(count);
const ur_program_handle_t_::CodeFormat CommonCodeFormat =
phPrograms[0]->getCodeFormat();
for (uint32_t I = 0; I < count; I++) {
std::shared_lock<ur_shared_mutex> Guard(phPrograms[I]->Mutex);
Guards[I].swap(Guard);
Expand All @@ -374,6 +407,13 @@ ur_result_t urProgramLinkExp(
return UR_RESULT_ERROR_INVALID_OPERATION;
}
}

// The L0 API has no way to represent mixed format modules,
// even though it could be possible to implement linking
// of mixed format modules.
if (phPrograms[I]->getCodeFormat() != CommonCodeFormat) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}
}

// Previous calls to urProgramCompile did not actually compile the SPIR-V.
Expand Down Expand Up @@ -406,7 +446,14 @@ ur_result_t urProgramLinkExp(

ZeStruct<ze_module_desc_t> ZeModuleDesc;
ZeModuleDesc.pNext = &ZeExtModuleDesc;
ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV;
switch (CommonCodeFormat) {
case ur_program_handle_t_::CodeFormat::SPIRV:
ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV;
break;
default:
assert(false && "Unexpected code format");
return UR_RESULT_ERROR_INVALID_PROGRAM;
}

// This works around a bug in the Level Zero driver. When "ZE_DEBUG=-1",
// the driver does validation of the API calls, and it expects
Expand Down Expand Up @@ -996,11 +1043,13 @@ ur_result_t urProgramSetSpecializationConstants(

ur_program_handle_t_::ur_program_handle_t_(state St,
ur_context_handle_t Context,
const void *Input, size_t Length)
const void *Input, size_t Length,
CodeFormat CodeFormat)
: Context{Context}, NativeProperties{nullptr}, OwnZeModule{true},
AssociatedDevices(Context->getDevices()), SpirvCode{new uint8_t[Length]},
SpirvCodeLength{Length} {
std::memcpy(SpirvCode.get(), Input, Length);
AssociatedDevices(Context->getDevices()), ILCode{new uint8_t[Length]},
ILCodeLength{Length}, ILCodeFormat(CodeFormat) {
assert(isCodeFormatIL(CodeFormat));
std::memcpy(ILCode.get(), Input, Length);
// All devices have the program in IL state.
for (auto &Device : Context->getDevices()) {
DeviceData &PerDevData = DeviceDataMap[Device->ZeDevice];
Expand Down
35 changes: 28 additions & 7 deletions unified-runtime/source/adapters/level_zero/program.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ struct ur_program_handle_t_ : ur_object {
Invalid
} state;

enum class CodeFormat : uint8_t {
Native,
SPIRV,
Unknown,
};

// A utility class that converts specialization constants into the form
// required by the Level Zero driver.
class SpecConstantShim {
Expand Down Expand Up @@ -68,7 +74,7 @@ struct ur_program_handle_t_ : ur_object {

// Construct a program in IL.
ur_program_handle_t_(state St, ur_context_handle_t Context, const void *Input,
size_t Length);
size_t Length, CodeFormat Format);

// Construct a program in NATIVE for multiple devices.
ur_program_handle_t_(state St, ur_context_handle_t Context,
Expand Down Expand Up @@ -113,28 +119,42 @@ struct ur_program_handle_t_ : ur_object {
return DeviceDataMap[ZeDevice].ZeModule;
}

CodeFormat getCodeFormat(ze_device_handle_t ZeDevice = nullptr) const {
if (!ZeDevice)
return ILCodeFormat;

auto It = DeviceDataMap.find(ZeDevice);
if (It == DeviceDataMap.end())
return ILCodeFormat;

if (It->second.State == state::IL)
return ILCodeFormat;
else
return CodeFormat::Native;
}

uint8_t *getCode(ze_device_handle_t ZeDevice = nullptr) {
if (!ZeDevice)
return SpirvCode.get();
return ILCode.get();

if (DeviceDataMap.find(ZeDevice) == DeviceDataMap.end())
return nullptr;

if (DeviceDataMap[ZeDevice].State == state::IL)
return SpirvCode.get();
return ILCode.get();
else
return DeviceDataMap[ZeDevice].Binary.first.get();
}

size_t getCodeSize(ze_device_handle_t ZeDevice = nullptr) {
if (ZeDevice == nullptr)
return SpirvCodeLength;
return ILCodeLength;

if (DeviceDataMap.find(ZeDevice) == DeviceDataMap.end())
return 0;

if (DeviceDataMap[ZeDevice].State == state::IL)
return SpirvCodeLength;
return ILCodeLength;
else
return DeviceDataMap[ZeDevice].Binary.second;
}
Expand Down Expand Up @@ -233,8 +253,9 @@ struct ur_program_handle_t_ : ur_object {

// In IL and Object states, this contains the SPIR-V representation of the
// module.
std::unique_ptr<uint8_t[]> SpirvCode; // Array containing raw IL code.
size_t SpirvCodeLength = 0; // Size (bytes) of the array.
std::unique_ptr<uint8_t[]> ILCode; // Array containing raw IL code.
size_t ILCodeLength = 0; // Size (bytes) of the array.
CodeFormat ILCodeFormat = CodeFormat::Unknown; // Format of the IL code.

// The Level Zero module handle for interoperability.
// This module handle is either initialized with the handle provided to
Expand Down
Loading