Skip to content

[SYCL][Fusion] Kernel Fusion support for CUDA backend #8747

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
May 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
5b2d0a0
[SYCL][Fusion] Embed LLVM IR for SYCL for Nvidia
sommerlukas Jan 9, 2023
ee8c29b
Enable LLVM IR as alternative fusion input format;
sommerlukas Jan 10, 2023
5d1d778
[SYCL][Fusion] Add result translation to PTX
sommerlukas Jan 11, 2023
d32c758
[SYCL][Fusion] Provide correct target spec;
sommerlukas Jan 12, 2023
b994781
[SYCL][Fusion] Avoid removing dependencies
sommerlukas Jan 13, 2023
651847e
[SYCL][Fusion] Set device binary image format
sommerlukas Jan 13, 2023
7fddbd5
[SYCL][Fusion] Refactor target-specific processing
sommerlukas Jan 16, 2023
894ee89
[SYCL][Fusion] Do not require null terminator
sommerlukas Jan 17, 2023
a909a75
[SYCL][Fusion] Refactor more target-specific code
sommerlukas Jan 17, 2023
32172f0
[SYCL][Fusion] Handle attributes for CUDA fusion
sommerlukas Feb 14, 2023
1559b85
[SYCL][Fusion] Cache and groom input binaries
sommerlukas Feb 14, 2023
fd34124
[SYCL][Fusion] Disable heterogeneous ND ranges on CUDA
sommerlukas Feb 14, 2023
fc5efbc
[SYCL][Fusion] Enable JIT caching for CUDA fusion
sommerlukas Feb 14, 2023
6c14311
[SYCL][Fusion] Catch empty standard arguments
sommerlukas Feb 15, 2023
980d36d
[SYCL][Fusion] Rebase and address feedback
sommerlukas Mar 7, 2023
4ba8e44
[SYCL][Fusion] Update linkage graph diagram
sommerlukas Mar 22, 2023
8fcd4c7
Don't compile NVPTX-specifics if not supported
sommerlukas Mar 28, 2023
75a77fd
Migrate test changes from intel/llvm-test-suite
sommerlukas Mar 28, 2023
a8afe1d
Address more PR feedback
sommerlukas Apr 5, 2023
f7df423
Add test for kernel fusion with math function
sommerlukas Apr 5, 2023
bc32fad
Document CUDA kernel fusion in design documentation
sommerlukas Apr 5, 2023
b4d3968
Update kernel fusion design document
sommerlukas Apr 18, 2023
a7e1369
Fix formatting for test
sommerlukas Apr 20, 2023
88b4ada
Rebase on branch 'sycl'
sommerlukas Apr 27, 2023
4877a40
Address PR feedback and formatting
sommerlukas May 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion clang/include/clang/Driver/Action.h
Original file line number Diff line number Diff line change
Expand Up @@ -660,9 +660,14 @@ class OffloadUnbundlingJobAction final : public JobAction {
class OffloadWrapperJobAction : public JobAction {
void anchor() override;

bool EmbedIR;

public:
OffloadWrapperJobAction(ActionList &Inputs, types::ID Type);
OffloadWrapperJobAction(Action *Input, types::ID OutputType);
OffloadWrapperJobAction(Action *Input, types::ID OutputType,
bool EmbedIR = false);

bool isEmbeddedIR() const { return EmbedIR; }

static bool classof(const Action *A) {
return A->getKind() == OffloadWrapperJobClass;
Expand Down
2 changes: 2 additions & 0 deletions clang/include/clang/Driver/Options.td
Original file line number Diff line number Diff line change
Expand Up @@ -2976,6 +2976,8 @@ def fintelfpga : Flag<["-"], "fintelfpga">, Group<f_Group>,
HelpText<"Perform ahead-of-time compilation for FPGA">;
def fsycl_device_only : Flag<["-"], "fsycl-device-only">, Flags<[CoreOption]>,
HelpText<"Compile SYCL kernels for device">;
def fsycl_embed_ir : Flag<["-"], "fsycl-embed-ir">, Flags<[CoreOption]>,
HelpText<"Embed LLVM IR for runtime kernel fusion">;
defm sycl_esimd_force_stateless_mem : BoolFOption<"sycl-esimd-force-stateless-mem",
LangOpts<"SYCLESIMDForceStatelessMem">, DefaultFalse,
PosFlag<SetTrue, [], "Enforce using stateless memory accesses. "
Expand Down
8 changes: 4 additions & 4 deletions clang/lib/Driver/Action.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,11 +478,11 @@ void OffloadWrapperJobAction::anchor() {}

OffloadWrapperJobAction::OffloadWrapperJobAction(ActionList &Inputs,
types::ID Type)
: JobAction(OffloadWrapperJobClass, Inputs, Type) {}
: JobAction(OffloadWrapperJobClass, Inputs, Type), EmbedIR(false) {}

OffloadWrapperJobAction::OffloadWrapperJobAction(Action *Input,
types::ID Type)
: JobAction(OffloadWrapperJobClass, Input, Type) {}
OffloadWrapperJobAction::OffloadWrapperJobAction(Action *Input, types::ID Type,
bool IsEmbeddedIR)
: JobAction(OffloadWrapperJobClass, Input, Type), EmbedIR(IsEmbeddedIR) {}

void OffloadPackagerJobAction::anchor() {}

Expand Down
116 changes: 64 additions & 52 deletions clang/lib/Driver/Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5517,6 +5517,8 @@ class OffloadingActionBuilder final {
// s - device code split requested
// r - relocatable device code is requested
// f - link object output type is TY_Tempfilelist (fat archive)
// e - Embedded IR for fusion (-fsycl-embed-ir) was requested
// and target is NVPTX.
// * - "all other cases"
// - no condition means output/input is "always" present
// First symbol indicates output/input type
Expand All @@ -5536,58 +5538,58 @@ class OffloadingActionBuilder final {
// | |
// | |
// .---------------------------------------.
// | PostLink |
// .---------------------------------------.
// [+*] [+]
// | |
// | |
// |--------- |
// | | |
// | | |
// | [+!rf] |
// | .-------------. |
// | | llvm-foreach| |
// | .-------------. |
// | | |
// [+*] [+!rf] |
// .-----------------. |
// | FileTableTform | |
// | (extract "Code")| |
// .-----------------. |
// [-] |-----------
// --------------------| |
// | | |
// | |----------------- |
// | | | |
// | | [-!rf] |
// | | .--------------. |
// | | |FileTableTform| |
// | | | (merge) | |
// | | .--------------. |
// | | [-] |-------
// | | | | |
// | | | ------| |
// | | --------| | |
// [.] [-*] [-!rf] [+!rf] |
// .---------------. .-------------------. .--------------. |
// | finalizeNVPTX | | SPIRVTranslator | |FileTableTform| |
// | finalizeAMDGCN | | | | (merge) | |
// .---------------. .-------------------. . -------------. |
// [.] [-as] [-!a] | |
// | | | | |
// | [-s] | | |
// | .----------------. | | |
// | | BackendCompile | | | |
// | .----------------. | ------| |
// | [-s] | | |
// | | | | |
// | [-a] [-!a] [-!rf] |
// | .--------------------. |
// -----------[-n]| FileTableTform |[+*]--------------|
// | (replace "Code") |
// .--------------------.
// |
// [+*]
// | PostLink |[+e]----------------
// .---------------------------------------. |
// [+*] [+] |
// | | |
// | | |
// |--------- | |
// | | | |
// | | | |
// | [+!rf] | |
// | .-------------. | |
// | | llvm-foreach| | |
// | .-------------. | |
// | | | |
// [+*] [+!rf] | |
// .-----------------. | |
// | FileTableTform | | |
// | (extract "Code")| | |
// .-----------------. | |
// [-] |----------- |
// --------------------| | |
// | | | |
// | |----------------- | |
// | | | | |
// | | [-!rf] | |
// | | .--------------. | |
// | | |FileTableTform| | |
// | | | (merge) | | |
// | | .--------------. | |
// | | [-] |------- |
// | | | | | |
// | | | ------| | |
// | | --------| | | |
// [.] [-*] [-!rf] [+!rf] | |
// .---------------. .-------------------. .--------------. | |
// | finalizeNVPTX | | SPIRVTranslator | |FileTableTform| | |
// | finalizeAMDGCN | | | | (merge) | | |
// .---------------. .-------------------. . -------------. | |
// [.] [-as] [-!a] | | |
// | | | | | |
// | [-s] | | | |
// | .----------------. | | | |
// | | BackendCompile | | | | |
// | .----------------. | ------| | |
// | [-s] | | | |
// | | | | | |
// | [-a] [-!a] [-!rf] | |
// | .--------------------. | |
// -----------[-n]| FileTableTform |[+*]--------------| |
// | (replace "Code") | |
// .--------------------. |
// | -------------------------
// [+*] | [+e]
// .--------------------------------------.
// | OffloadWrapper |
// .--------------------------------------.
Expand Down Expand Up @@ -5694,6 +5696,16 @@ class OffloadingActionBuilder final {
return TypedPostLinkAction;
};
Action *PostLinkAction = createPostLinkAction();
if (isNVPTX && Args.hasArg(options::OPT_fsycl_embed_ir)) {
// When compiling for Nvidia/CUDA devices and the user requested the
// IR to be embedded in the application (via option), run the output
// of sycl-post-link (filetable referencing LLVM Bitcode + symbols)
// through the offload wrapper and link the resulting object to the
// application.
auto *WrapBitcodeAction = C.MakeAction<OffloadWrapperJobAction>(
PostLinkAction, types::TY_Object, true);
DA.add(*WrapBitcodeAction, *TC, BoundArch, Action::OFK_SYCL);
}
bool NoRDCFatStaticArchive =
!IsRDC &&
FullDeviceLinkAction->getType() == types::TY_Tempfilelist;
Expand Down
10 changes: 9 additions & 1 deletion clang/lib/Driver/ToolChains/Clang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9253,6 +9253,14 @@ void OffloadWrapper::ConstructJob(Compilation &C, const JobAction &JA,
createArgString("-link-opts=");
}

bool IsEmbeddedIR = cast<OffloadWrapperJobAction>(JA).isEmbeddedIR();
if (IsEmbeddedIR) {
// When the offload-wrapper is called to embed LLVM IR, add a prefix to
// the target triple to distinguish the LLVM IR from the actual device
// binary for that target.
TargetTripleOpt = ("llvm_" + TargetTripleOpt).str();
}

WrapperArgs.push_back(
C.getArgs().MakeArgString(Twine("-target=") + TargetTripleOpt));

Expand All @@ -9274,7 +9282,7 @@ void OffloadWrapper::ConstructJob(Compilation &C, const JobAction &JA,
assert(I.isFilename() && "Invalid input.");

if (I.getType() == types::TY_Tempfiletable ||
I.getType() == types::TY_Tempfilelist)
I.getType() == types::TY_Tempfilelist || IsEmbeddedIR)
// wrapper actual input files are passed via the batch job file table:
WrapperArgs.push_back(C.getArgs().MakeArgString("-batch"));
WrapperArgs.push_back(C.getArgs().MakeArgString(I.getFilename()));
Expand Down
2 changes: 1 addition & 1 deletion sycl-fusion/common/include/Kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ enum class ParameterKind : uint32_t {
};

/// Different binary formats supported as input to the JIT compiler.
enum class BinaryFormat : uint32_t { INVALID, LLVM, SPIRV };
enum class BinaryFormat : uint32_t { INVALID, LLVM, SPIRV, PTX };

/// Information about a device intermediate representation module (e.g., SPIR-V,
/// LLVM IR) from DPC++.
Expand Down
1 change: 1 addition & 0 deletions sycl-fusion/common/lib/KernelIO.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ template <> struct ScalarEnumerationTraits<jit_compiler::BinaryFormat> {
static void enumeration(IO &IO, jit_compiler::BinaryFormat &BF) {
IO.enumCase(BF, "LLVM", jit_compiler::BinaryFormat::LLVM);
IO.enumCase(BF, "SPIRV", jit_compiler::BinaryFormat::SPIRV);
IO.enumCase(BF, "PTX", jit_compiler::BinaryFormat::PTX);
IO.enumCase(BF, "INVALID", jit_compiler::BinaryFormat::INVALID);
}
};
Expand Down
12 changes: 11 additions & 1 deletion sycl-fusion/jit-compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
add_llvm_library(sycl-fusion
lib/KernelFusion.cpp
lib/JITContext.cpp
lib/translation/KernelTranslation.cpp
lib/translation/SPIRVLLVMTranslation.cpp
lib/fusion/FusionPipeline.cpp
lib/fusion/FusionHelper.cpp
lib/fusion/ModuleHelper.cpp
lib/helper/ConfigHelper.cpp

LINK_COMPONENTS
LINK_COMPONENTS
BitReader
Core
Support
Analysis
Expand All @@ -18,6 +20,10 @@ add_llvm_library(sycl-fusion
Linker
ScalarOpts
InstCombine
Target
TargetParser
MC
${LLVM_TARGETS_TO_BUILD}
)

target_include_directories(sycl-fusion
Expand All @@ -40,6 +46,10 @@ target_link_libraries(sycl-fusion
${CMAKE_THREAD_LIBS_INIT}
)

if("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD)
target_compile_definitions(sycl-fusion PRIVATE FUSION_JIT_SUPPORT_PTX)
endif()

if (BUILD_SHARED_LIBS)
if(NOT MSVC AND NOT APPLE)
# Manage symbol visibility through the linker to make sure no LLVM symbols
Expand Down
17 changes: 12 additions & 5 deletions sycl-fusion/jit-compiler/include/JITContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,21 @@ using CacheKeyT =
std::optional<std::vector<NDRange>>>;

///
/// Wrapper around a SPIR-V binary.
class SPIRVBinary {
/// Wrapper around a kernel binary.
class KernelBinary {
public:
explicit SPIRVBinary(std::string Binary);
explicit KernelBinary(std::string &&Binary, BinaryFormat Format);

jit_compiler::BinaryAddress address() const;

size_t size() const;

BinaryFormat format() const;

private:
std::string Blob;

BinaryFormat Format;
};

///
Expand All @@ -61,7 +65,10 @@ class JITContext {

llvm::LLVMContext *getLLVMContext();

SPIRVBinary &emplaceSPIRVBinary(std::string Binary);
template <typename... Ts> KernelBinary &emplaceKernelBinary(Ts &&...Args) {
WriteLockT WriteLock{BinariesMutex};
return Binaries.emplace_back(std::forward<Ts>(Args)...);
}

std::optional<SYCLKernelInfo> getCacheEntry(CacheKeyT &Identifier) const;

Expand All @@ -79,7 +86,7 @@ class JITContext {

MutexT BinariesMutex;

std::vector<SPIRVBinary> Binaries;
std::vector<KernelBinary> Binaries;

mutable MutexT CacheMutex;

Expand Down
7 changes: 6 additions & 1 deletion sycl-fusion/jit-compiler/include/Options.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
#ifndef SYCL_FUSION_JIT_COMPILER_OPTIONS_H
#define SYCL_FUSION_JIT_COMPILER_OPTIONS_H

#include "Kernel.h"

#include <memory>
#include <unordered_map>

namespace jit_compiler {

enum OptionID { VerboseOutput, EnableCaching };
enum OptionID { VerboseOutput, EnableCaching, TargetFormat };

class OptionPtrBase {};

Expand Down Expand Up @@ -78,6 +80,9 @@ struct JITEnableVerbose : public OptionBase<OptionID::VerboseOutput, bool> {};

struct JITEnableCaching : public OptionBase<OptionID::EnableCaching, bool> {};

struct JITTargetFormat
: public OptionBase<OptionID::TargetFormat, BinaryFormat> {};

} // namespace option
} // namespace jit_compiler

Expand Down
17 changes: 6 additions & 11 deletions sycl-fusion/jit-compiler/lib/JITContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,24 @@

using namespace jit_compiler;

SPIRVBinary::SPIRVBinary(std::string Binary) : Blob{std::move(Binary)} {}
KernelBinary::KernelBinary(std::string &&Binary, BinaryFormat Fmt)
: Blob{std::move(Binary)}, Format{Fmt} {}

jit_compiler::BinaryAddress SPIRVBinary::address() const {
jit_compiler::BinaryAddress KernelBinary::address() const {
// FIXME: Verify it's a good idea to perform this reinterpret_cast here.
return reinterpret_cast<jit_compiler::BinaryAddress>(Blob.c_str());
}

size_t SPIRVBinary::size() const { return Blob.size(); }
size_t KernelBinary::size() const { return Blob.size(); }

BinaryFormat KernelBinary::format() const { return Format; }

JITContext::JITContext() : LLVMCtx{new llvm::LLVMContext}, Binaries{} {}

JITContext::~JITContext() = default;

llvm::LLVMContext *JITContext::getLLVMContext() { return LLVMCtx.get(); }

SPIRVBinary &JITContext::emplaceSPIRVBinary(std::string Binary) {
WriteLockT WriteLock{BinariesMutex};
// NOTE: With C++17, which returns a reference from emplace_back, the
// following code would be even simpler.
Binaries.emplace_back(std::move(Binary));
return Binaries.back();
}

std::optional<SYCLKernelInfo>
JITContext::getCacheEntry(CacheKeyT &Identifier) const {
ReadLockT ReadLock{CacheMutex};
Expand Down
Loading