Skip to content

Commit 4b96a99

Browse files
[SYCL] Add support for eliminated arg masks in SYCLBIN kernel bundles (#19163)
This commit fixes an issue where kernel bundles created from SYCLBIN would not correctly filter based on the DAE-produced kernel arg masks. --------- Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 9805c6c commit 4b96a99

File tree

8 files changed

+178
-9
lines changed

8 files changed

+178
-9
lines changed

sycl/source/detail/device_image_impl.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include <detail/device_image_impl.hpp>
10+
#include <detail/kernel_arg_mask.hpp>
1011
#include <detail/kernel_bundle_impl.hpp>
1112

1213
namespace sycl {
@@ -47,10 +48,14 @@ std::shared_ptr<kernel_impl> device_image_impl::tryGetExtensionKernel(
4748
&UrKernel);
4849
// Kernel created by urKernelCreate is implicitly retained.
4950

51+
const KernelArgMask *ArgMask = nullptr;
52+
if (auto ArgMaskIt = MEliminatedKernelArgMasks.find(AdjustedName);
53+
ArgMaskIt != MEliminatedKernelArgMasks.end())
54+
ArgMask = &ArgMaskIt->second;
55+
5056
return std::make_shared<kernel_impl>(
5157
UrKernel, *detail::getSyclObjImpl(Context), shared_from_this(),
52-
OwnerBundle,
53-
/*ArgMask=*/nullptr, UrProgram, /*CacheMutex=*/nullptr);
58+
OwnerBundle, ArgMask, UrProgram, /*CacheMutex=*/nullptr);
5459
}
5560

5661
} // namespace detail

sycl/source/detail/device_image_impl.hpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class ManagedDeviceBinaries {
147147

148148
using MangledKernelNameMapT = std::map<std::string, std::string, std::less<>>;
149149
using KernelNameSetT = std::set<std::string, std::less<>>;
150+
using KernelNameToArgMaskMap = std::unordered_map<std::string, KernelArgMask>;
150151

151152
// Information unique to images compiled at runtime through the
152153
// ext_oneapi_kernel_compiler extension.
@@ -255,12 +256,23 @@ class device_image_impl
255256
MKernelIDs(std::move(KernelIDs)),
256257
MSpecConstsDefValBlob(getSpecConstsDefValBlob()), MOrigins(Origins) {
257258
updateSpecConstSymMap();
258-
// SYCLBIN files have the kernel names embedded in the binaries, so we
259-
// collect them.
260-
if (BinImage && (MOrigins & ImageOriginSYCLBIN))
259+
if (BinImage && (MOrigins & ImageOriginSYCLBIN)) {
260+
// SYCLBIN files have the kernel names embedded in the binaries, so we
261+
// collect them.
261262
for (const sycl_device_binary_property &KNProp :
262263
BinImage->getKernelNames())
263264
MKernelNames.insert(KNProp->Name);
265+
266+
KernelArgMask ArgMask;
267+
if (BinImage->getKernelParamOptInfo().isAvailable()) {
268+
// Extract argument mask from the image.
269+
const RTDeviceBinaryImage::PropertyRange &KPOIRange =
270+
BinImage->getKernelParamOptInfo();
271+
for (const auto &Info : KPOIRange)
272+
MEliminatedKernelArgMasks[Info->Name] =
273+
createKernelArgMask(DeviceBinaryProperty(Info).asByteArray());
274+
}
275+
}
264276
}
265277

266278
device_image_impl(
@@ -271,10 +283,12 @@ class device_image_impl
271283
const std::vector<unsigned char> &SpecConstsBlob, uint8_t Origins,
272284
std::optional<KernelCompilerBinaryInfo> &&RTCInfo,
273285
KernelNameSetT &&KernelNames,
286+
KernelNameToArgMaskMap &&EliminatedKernelArgMasks,
274287
std::unique_ptr<DynRTDeviceBinaryImage> &&MergedImageStorage, private_tag)
275288
: MBinImage(BinImage), MContext(std::move(Context)),
276289
MDevices(std::move(Devices)), MState(State), MProgram(Program),
277290
MKernelIDs(std::move(KernelIDs)), MKernelNames{std::move(KernelNames)},
291+
MEliminatedKernelArgMasks{std::move(EliminatedKernelArgMasks)},
278292
MSpecConstsBlob(SpecConstsBlob),
279293
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
280294
MSpecConstSymMap(SpecConstMap), MOrigins(Origins),
@@ -284,11 +298,14 @@ class device_image_impl
284298
device_image_impl(const RTDeviceBinaryImage *BinImage, const context &Context,
285299
const std::vector<device> &Devices, bundle_state State,
286300
ur_program_handle_t Program, syclex::source_language Lang,
287-
KernelNameSetT &&KernelNames, private_tag)
301+
KernelNameSetT &&KernelNames,
302+
KernelNameToArgMaskMap &&EliminatedKernelArgMasks,
303+
private_tag)
288304
: MBinImage(BinImage), MContext(std::move(Context)),
289305
MDevices(std::move(Devices)), MState(State), MProgram(Program),
290306
MKernelIDs(std::make_shared<std::vector<kernel_id>>()),
291307
MKernelNames{std::move(KernelNames)},
308+
MEliminatedKernelArgMasks{std::move(EliminatedKernelArgMasks)},
292309
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
293310
MOrigins(ImageOriginKernelCompiler),
294311
MRTCBinInfo(KernelCompilerBinaryInfo{Lang}) {
@@ -669,6 +686,10 @@ class device_image_impl
669686

670687
const KernelNameSetT &getKernelNames() const noexcept { return MKernelNames; }
671688

689+
const KernelNameToArgMaskMap &getEliminatedKernelArgMasks() const noexcept {
690+
return MEliminatedKernelArgMasks;
691+
}
692+
672693
bool isNonSYCLSourceBased() const noexcept {
673694
return (getOriginMask() & ImageOriginKernelCompiler) &&
674695
!isFromSourceLanguage(syclex::source_language::sycl);
@@ -1261,6 +1282,10 @@ class device_image_impl
12611282
// List of known kernel names.
12621283
KernelNameSetT MKernelNames;
12631284

1285+
// Map for storing kernel argument masks for kernels. This is currently only
1286+
// used for images created from SYCLBIN.
1287+
KernelNameToArgMaskMap MEliminatedKernelArgMasks;
1288+
12641289
// A mutex for sycnhronizing access to spec constants blob. Mutable because
12651290
// needs to be locked in the const method for getting spec constant value.
12661291
mutable std::mutex MSpecConstAccessMtx;

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2871,6 +2871,8 @@ ProgramManager::compile(const DevImgPlainWithDeps &ImgWithDeps,
28712871
setSpecializationConstants(InputImpl, Prog, Adapter);
28722872

28732873
KernelNameSetT KernelNames = InputImpl.getKernelNames();
2874+
std::unordered_map<std::string, KernelArgMask> EliminatedKernelArgMasks =
2875+
InputImpl.getEliminatedKernelArgMasks();
28742876

28752877
std::optional<detail::KernelCompilerBinaryInfo> RTCInfo =
28762878
InputImpl.getRTCInfo();
@@ -2881,7 +2883,7 @@ ProgramManager::compile(const DevImgPlainWithDeps &ImgWithDeps,
28812883
InputImpl.get_spec_const_data_ref(),
28822884
InputImpl.get_spec_const_blob_ref(), InputImpl.getOriginMask(),
28832885
std::move(RTCInfo), std::move(KernelNames),
2884-
/*MergedImageStorage = */ nullptr);
2886+
std::move(EliminatedKernelArgMasks), nullptr);
28852887

28862888
std::string CompileOptions;
28872889
applyCompileOptionsFromEnvironment(CompileOptions);
@@ -3070,20 +3072,25 @@ ProgramManager::link(const std::vector<device_image_plain> &Imgs,
30703072
RTCInfoPtrs;
30713073
RTCInfoPtrs.reserve(Imgs.size());
30723074
KernelNameSetT MergedKernelNames;
3075+
std::unordered_map<std::string, KernelArgMask> MergedEliminatedKernelArgMasks;
30733076
for (const device_image_plain &DevImg : Imgs) {
30743077
const DeviceImageImplPtr &DevImgImpl = getSyclObjImpl(DevImg);
30753078
CombinedOrigins |= DevImgImpl->getOriginMask();
30763079
RTCInfoPtrs.emplace_back(&(DevImgImpl->getRTCInfo()));
30773080
MergedKernelNames.insert(DevImgImpl->getKernelNames().begin(),
30783081
DevImgImpl->getKernelNames().end());
3082+
MergedEliminatedKernelArgMasks.insert(
3083+
DevImgImpl->getEliminatedKernelArgMasks().begin(),
3084+
DevImgImpl->getEliminatedKernelArgMasks().end());
30793085
}
30803086
auto MergedRTCInfo = detail::KernelCompilerBinaryInfo::Merge(RTCInfoPtrs);
30813087

30823088
DeviceImageImplPtr ExecutableImpl = device_image_impl::create(
30833089
NewBinImg, Context, std::vector<device>{Devs}, bundle_state::executable,
30843090
std::move(KernelIDs), LinkedProg, std::move(NewSpecConstMap),
30853091
std::move(NewSpecConstBlob), CombinedOrigins, std::move(MergedRTCInfo),
3086-
std::move(MergedKernelNames), std::move(MergedImageStorage));
3092+
std::move(MergedKernelNames), std::move(MergedEliminatedKernelArgMasks),
3093+
std::move(MergedImageStorage));
30873094

30883095
// TODO: Make multiple sets of device images organized by devices they are
30893096
// compiled for.
@@ -3151,11 +3158,15 @@ ProgramManager::build(const DevImgPlainWithDeps &DevImgWithDeps,
31513158
RTCInfoPtrs;
31523159
RTCInfoPtrs.reserve(DevImgWithDeps.size());
31533160
KernelNameSetT MergedKernelNames;
3161+
std::unordered_map<std::string, KernelArgMask> MergedEliminatedKernelArgMasks;
31543162
for (const device_image_plain &DevImg : DevImgWithDeps) {
31553163
const auto &DevImgImpl = getSyclObjImpl(DevImg);
31563164
RTCInfoPtrs.emplace_back(&(DevImgImpl->getRTCInfo()));
31573165
MergedKernelNames.insert(DevImgImpl->getKernelNames().begin(),
31583166
DevImgImpl->getKernelNames().end());
3167+
MergedEliminatedKernelArgMasks.insert(
3168+
DevImgImpl->getEliminatedKernelArgMasks().begin(),
3169+
DevImgImpl->getEliminatedKernelArgMasks().end());
31593170
}
31603171
auto MergedRTCInfo = detail::KernelCompilerBinaryInfo::Merge(RTCInfoPtrs);
31613172

@@ -3164,7 +3175,7 @@ ProgramManager::build(const DevImgPlainWithDeps &DevImgWithDeps,
31643175
bundle_state::executable, std::move(KernelIDs), ResProgram,
31653176
std::move(SpecConstMap), std::move(SpecConstBlob), CombinedOrigins,
31663177
std::move(MergedRTCInfo), std::move(MergedKernelNames),
3167-
std::move(MergedImageStorage));
3178+
std::move(MergedEliminatedKernelArgMasks), std::move(MergedImageStorage));
31683179
return createSyclObjFromImpl<device_image_plain>(std::move(ExecImpl));
31693180
}
31703181

sycl/test-e2e/SYCLBIN/Inputs/dae.hpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#include "common.hpp"
2+
3+
#include <sycl/usm.hpp>
4+
5+
static constexpr size_t NUM = 1024;
6+
static constexpr size_t WGSIZE = 16;
7+
static constexpr float EPS = 0.001;
8+
9+
int main(int argc, char *argv[]) {
10+
assert(argc == 2);
11+
12+
sycl::queue Q;
13+
14+
int Failed = CommonLoadCheck(Q.get_context(), argv[1]);
15+
16+
#if defined(SYCLBIN_INPUT_STATE)
17+
auto KBInput = syclexp::get_kernel_bundle<sycl::bundle_state::input>(
18+
Q.get_context(), std::string{argv[1]});
19+
auto KBExe = sycl::build(KBInput);
20+
#elif defined(SYCLBIN_OBJECT_STATE)
21+
auto KBObj = syclexp::get_kernel_bundle<sycl::bundle_state::object>(
22+
Q.get_context(), std::string{argv[1]});
23+
auto KBExe = sycl::link(KBObj);
24+
#else // defined(SYCLBIN_EXECUTABLE_STATE)
25+
auto KBExe = syclexp::get_kernel_bundle<sycl::bundle_state::executable>(
26+
Q.get_context(), std::string{argv[1]});
27+
#endif
28+
29+
assert(KBExe.ext_oneapi_has_kernel("iota"));
30+
sycl::kernel IotaKern = KBExe.ext_oneapi_get_kernel("iota");
31+
32+
float *Ptr = sycl::malloc_shared<float>(NUM, Q);
33+
Q.submit([&](sycl::handler &CGH) {
34+
// First arugment is unused, but should still be passed, even if eliminated
35+
// by DAE.
36+
CGH.set_args(3.14f, Ptr);
37+
CGH.parallel_for(sycl::nd_range{{NUM}, {WGSIZE}}, IotaKern);
38+
}).wait_and_throw();
39+
40+
for (int I = 0; I < NUM; I++) {
41+
const float Truth = static_cast<float>(I);
42+
if (std::abs(Ptr[I] - Truth) > EPS) {
43+
std::cout << "Result: " << Ptr[I] << " expected " << I << "\n";
44+
++Failed;
45+
}
46+
}
47+
sycl::free(Ptr, Q);
48+
return Failed;
49+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#include <sycl/sycl.hpp>
2+
3+
namespace syclexp = sycl::ext::oneapi::experimental;
4+
namespace syclext = sycl::ext::oneapi;
5+
6+
extern "C" SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
7+
(syclexp::nd_range_kernel<1>)) void iota(float, float *ptr) {
8+
size_t id = syclext::this_work_item::get_nd_item<1>().get_global_linear_id();
9+
ptr[id] = static_cast<float>(id);
10+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//==----------- dae_executable.cpp --- SYCLBIN extension tests -------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// REQUIRES: aspect-usm_device_allocations
10+
11+
// -- Test for using a kernel from a SYCLBIN with a dead argument.
12+
13+
// SYCLBIN currently only properly detects SPIR-V binaries.
14+
// XFAIL: !target-spir
15+
// XFAIL-TRACKER: CMPLRLLVM-68811
16+
17+
// RUN: %clangxx --offload-new-driver -fsyclbin=executable %{sycl_target_opts} %S/Inputs/dae_kernel.cpp -o %t.syclbin
18+
// RUN: %{build} -o %t.out
19+
// RUN: %{l0_leak_check} %{run} %t.out %t.syclbin
20+
21+
#define SYCLBIN_EXECUTABLE_STATE
22+
23+
#include "Inputs/dae.hpp"

sycl/test-e2e/SYCLBIN/dae_input.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//==----------- dae_input.cpp --- SYCLBIN extension tests ------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// REQUIRES: aspect-usm_device_allocations
10+
11+
// -- Test for using a kernel from a SYCLBIN with a dead argument.
12+
13+
// SYCLBIN currently only properly detects SPIR-V binaries.
14+
// XFAIL: !target-spir
15+
// XFAIL-TRACKER: CMPLRLLVM-68811
16+
17+
// RUN: %clangxx --offload-new-driver -fsyclbin=input %{sycl_target_opts} %S/Inputs/dae_kernel.cpp -o %t.syclbin
18+
// RUN: %{build} -o %t.out
19+
// RUN: %{l0_leak_check} %{run} %t.out %t.syclbin
20+
21+
#define SYCLBIN_INPUT_STATE
22+
23+
#include "Inputs/dae.hpp"

sycl/test-e2e/SYCLBIN/dae_object.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//==----------- dae_object.cpp --- SYCLBIN extension tests -----------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// REQUIRES: aspect-usm_device_allocations
10+
11+
// -- Test for using a kernel from a SYCLBIN with a dead argument.
12+
13+
// SYCLBIN currently only properly detects SPIR-V binaries.
14+
// XFAIL: !target-spir
15+
// XFAIL-TRACKER: CMPLRLLVM-68811
16+
17+
// RUN: %clangxx --offload-new-driver -fsyclbin=object %{sycl_target_opts} %S/Inputs/dae_kernel.cpp -o %t.syclbin
18+
// RUN: %{build} -o %t.out
19+
// RUN: %{l0_leak_check} %{run} %t.out %t.syclbin
20+
21+
#define SYCLBIN_OBJECT_STATE
22+
23+
#include "Inputs/dae.hpp"

0 commit comments

Comments
 (0)