Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleaning up GPU executable flatbuffers prior to a larger reworking. #18208

Merged
merged 6 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions build_tools/bazel/iree_flatcc.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ def iree_flatbuffer_c_library(
name,
srcs,
flatcc_args = ["--common", "--reader"],
includes = [],
testonly = False,
**kwargs):
flatcc = "@com_github_dvidelabs_flatcc//:flatcc"

flags = [
"-o$(RULEDIR)",
"-I runtime/src",
] + flatcc_args

out_stem = "%s" % (srcs[0].replace(".fbs", ""))
Expand All @@ -34,10 +36,10 @@ def iree_flatbuffer_c_library(

native.genrule(
name = name + "_gen",
srcs = srcs,
srcs = srcs + includes,
outs = outs,
tools = [flatcc],
cmd = "$(location %s) %s $(SRCS)" % (flatcc, " ".join(flags)),
cmd = "$(location %s) %s %s" % (flatcc, " ".join(flags), " ".join(["$(location {})".format(src) for src in srcs])),
benvanik marked this conversation as resolved.
Show resolved Hide resolved
testonly = testonly,
)
native.cc_library(
Expand Down
4 changes: 3 additions & 1 deletion build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,16 +662,18 @@ def iree_bytecode_module(
f" PUBLIC\n)\n\n"
)

def iree_flatbuffer_c_library(self, name, srcs, flatcc_args=None):
def iree_flatbuffer_c_library(self, name, srcs, flatcc_args=None, includes=None):
name_block = self._convert_string_arg_block("NAME", name, quote=False)
srcs_block = self._convert_srcs_block(srcs)
flatcc_args_block = self._convert_string_list_block("FLATCC_ARGS", flatcc_args)
includes_block = self._convert_srcs_block(includes, block_name="INCLUDES")

self._converter.body += (
f"flatbuffer_c_library(\n"
f"{name_block}"
f"{srcs_block}"
f"{flatcc_args_block}"
f"{includes_block}"
f" PUBLIC\n)\n\n"
)

Expand Down
3 changes: 2 additions & 1 deletion build_tools/cmake/flatbuffer_c_library.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function(flatbuffer_c_library)
cmake_parse_arguments(_RULE
"PUBLIC;TESTONLY"
"NAME"
"SRCS;FLATCC_ARGS"
"SRCS;FLATCC_ARGS;INCLUDES"
${ARGN}
)

Expand Down Expand Up @@ -94,6 +94,7 @@ function(flatbuffer_c_library)
iree-flatcc-cli
-o "${CMAKE_CURRENT_BINARY_DIR}"
-I "${IREE_ROOT_DIR}"
-I "${IREE_ROOT_DIR}/runtime/src"
${_RULE_FLATCC_ARGS}
"${_RULE_SRCS}"
WORKING_DIRECTORY
Expand Down
2 changes: 2 additions & 0 deletions compiler/plugins/target/CUDA/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/LLVMGPU",
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/HAL/Target",
"//compiler/src/iree/compiler/Dialect/HAL/Utils:ExecutableDebugInfoUtils",
"//compiler/src/iree/compiler/Dialect/HAL/Utils:LLVMLinkerUtils",
"//compiler/src/iree/compiler/PluginAPI",
"//compiler/src/iree/compiler/Utils",
"//runtime/src/iree/base/internal/flatcc:building",
"//runtime/src/iree/schemas:cuda_executable_def_c_fbs",
"//runtime/src/iree/schemas:executable_debug_info_c_fbs",
"@iree_cuda//:libdevice_embedded",
"@llvm-project//llvm:Analysis",
"@llvm-project//llvm:BitReader",
Expand Down
2 changes: 2 additions & 0 deletions compiler/plugins/target/CUDA/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ iree_cc_library(
iree::compiler::Codegen::LLVMGPU
iree::compiler::Codegen::Utils
iree::compiler::Dialect::HAL::Target
iree::compiler::Dialect::HAL::Utils::ExecutableDebugInfoUtils
iree::compiler::Dialect::HAL::Utils::LLVMLinkerUtils
iree::compiler::PluginAPI
iree::compiler::Utils
iree::schemas::cuda_executable_def_c_fbs
iree::schemas::executable_debug_info_c_fbs
iree_cuda::libdevice_embedded
PUBLIC
)
Expand Down
16 changes: 11 additions & 5 deletions compiler/plugins/target/CUDA/CUDATarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/HAL/Utils/ExecutableDebugInfoUtils.h"
#include "iree/compiler/Dialect/HAL/Utils/LLVMLinkerUtils.h"
#include "iree/compiler/PluginAPI/Client.h"
#include "iree/compiler/Utils/FlatbufferUtils.h"
Expand Down Expand Up @@ -517,9 +518,13 @@ class CUDATargetBackend final : public TargetBackend {
FlatbufferBuilder builder;
iree_hal_cuda_ExecutableDef_start_as_root(builder);

// Attach embedded source file contents.
auto sourceFilesRef = createSourceFilesVec(
serOptions.debugLevel, variantOp.getSourcesAttr(), builder);

SmallVector<std::string> entryPointNames;
std::string ptxImage;
SmallVector<iree_hal_cuda_FileLineLocDef_ref_t> sourceLocationRefs;
SmallVector<iree_hal_debug_FileLineLocDef_ref_t> sourceLocationRefs;
if (variantOp.isExternal()) {
if (!variantOp.getObjects().has_value()) {
return variantOp.emitOpError()
Expand Down Expand Up @@ -590,7 +595,7 @@ class CUDATargetBackend final : public TargetBackend {
if (serOptions.debugLevel >= 1) {
if (auto loc = findFirstFileLoc(exportOp.getLoc())) {
auto filenameRef = builder.createString(loc->getFilename());
sourceLocationRefs.push_back(iree_hal_cuda_FileLineLocDef_create(
sourceLocationRefs.push_back(iree_hal_debug_FileLineLocDef_create(
builder, filenameRef, loc->getLine()));
}
}
Expand Down Expand Up @@ -665,12 +670,12 @@ class CUDATargetBackend final : public TargetBackend {
std::string gpuImage = produceGpuImage(options, targetArch, ptxImage);
auto gpuImageRef =
flatbuffers_string_create(builder, gpuImage.c_str(), gpuImage.size());
iree_hal_cuda_BlockSizeDef_vec_start(builder);
iree_hal_cuda_BlockSize_vec_start(builder);
for (const auto &workgroupSize : workgroupSizes) {
iree_hal_cuda_BlockSizeDef_vec_push_create(
iree_hal_cuda_BlockSize_vec_push_create(
builder, workgroupSize[0], workgroupSize[1], workgroupSize[2]);
}
auto blockSizesRef = iree_hal_cuda_BlockSizeDef_vec_end(builder);
auto blockSizesRef = iree_hal_cuda_BlockSize_vec_end(builder);
auto workgroupLocalMemoriesRef =
builder.createInt32Vec(workgroupLocalMemories);
auto entryPointsRef = builder.createStringVec(entryPointNames);
Expand All @@ -686,6 +691,7 @@ class CUDATargetBackend final : public TargetBackend {
iree_hal_cuda_ExecutableDef_source_locations_add(builder,
sourceLocationsRef);
}
iree_hal_cuda_ExecutableDef_source_files_add(builder, sourceFilesRef);
iree_hal_cuda_ExecutableDef_end_as_root(builder);

// Add the binary data to the target executable.
Expand Down
2 changes: 1 addition & 1 deletion compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ class LLVMCPUTargetBackend final : public TargetBackend {

// Specify the constant and binding information used to validate
// dispatches.
// TODO(#18189): pack per-binding information bitfields.
// TODO(#18154): pack per-binding information bitfields.
dispatchAttrs.constantCount = exportOp.getLayout().getPushConstants();
dispatchAttrs.bindingCount =
exportOp.getLayout().getSetLayout(0).getBindings().size();
Expand Down
2 changes: 2 additions & 0 deletions compiler/plugins/target/MetalSPIRV/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/HAL/Target",
"//compiler/src/iree/compiler/Dialect/HAL/Utils:ExecutableDebugInfoUtils",
"//compiler/src/iree/compiler/PluginAPI",
"//compiler/src/iree/compiler/Utils",
"//runtime/src/iree/schemas:executable_debug_info_c_fbs",
"//runtime/src/iree/schemas:metal_executable_def_c_fbs",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TargetParser",
Expand Down
2 changes: 2 additions & 0 deletions compiler/plugins/target/MetalSPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ iree_cc_library(
iree::compiler::Codegen::Utils
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::Target
iree::compiler::Dialect::HAL::Utils::ExecutableDebugInfoUtils
iree::compiler::PluginAPI
iree::compiler::Utils
iree::schemas::executable_debug_info_c_fbs
iree::schemas::metal_executable_def_c_fbs
PUBLIC
)
Expand Down
7 changes: 7 additions & 0 deletions compiler/plugins/target/MetalSPIRV/MetalSPIRVTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "iree/compiler/Codegen/SPIRV/Passes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/HAL/Utils/ExecutableDebugInfoUtils.h"
#include "iree/compiler/PluginAPI/Client.h"
#include "iree/compiler/Utils/FlatbufferUtils.h"
#include "iree/schemas/metal_executable_def_builder.h"
Expand Down Expand Up @@ -212,6 +213,10 @@ class MetalSPIRVTargetBackend : public TargetBackend {
FlatbufferBuilder builder;
iree_hal_metal_ExecutableDef_start_as_root(builder);

// Attach embedded source file contents.
auto sourceFilesRef = createSourceFilesVec(
serOptions.debugLevel, variantOp.getSourcesAttr(), builder);

auto entryPointNamesRef = builder.createStringVec(mslEntryPointNames);
iree_hal_metal_ExecutableDef_entry_points_add(builder, entryPointNamesRef);

Expand Down Expand Up @@ -243,6 +248,8 @@ class MetalSPIRVTargetBackend : public TargetBackend {
iree_hal_metal_ExecutableDef_shader_libraries_add(builder, libsRef);
}

iree_hal_metal_ExecutableDef_source_files_add(builder, sourceFilesRef);

iree_hal_metal_ExecutableDef_end_as_root(builder);

// 5. Add the binary data to the target executable.
Expand Down
4 changes: 3 additions & 1 deletion compiler/plugins/target/ROCM/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/HAL/Target",
"//compiler/src/iree/compiler/Dialect/HAL/Utils:ExecutableDebugInfoUtils",
"//compiler/src/iree/compiler/Dialect/HAL/Utils:LLVMLinkerUtils",
"//compiler/src/iree/compiler/PluginAPI",
"//compiler/src/iree/compiler/Utils",
"//runtime/src/iree/schemas:rocm_executable_def_c_fbs",
"//runtime/src/iree/schemas:executable_debug_info_c_fbs",
"//runtime/src/iree/schemas:hip_executable_def_c_fbs",
"@llvm-project//llvm:AMDGPUCodeGen",
"@llvm-project//llvm:Analysis",
"@llvm-project//llvm:BitWriter",
Expand Down
4 changes: 3 additions & 1 deletion compiler/plugins/target/ROCM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ iree_cc_library(
iree::compiler::Codegen::Utils
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::Target
iree::compiler::Dialect::HAL::Utils::ExecutableDebugInfoUtils
iree::compiler::Dialect::HAL::Utils::LLVMLinkerUtils
iree::compiler::PluginAPI
iree::compiler::Utils
iree::schemas::rocm_executable_def_c_fbs
iree::schemas::executable_debug_info_c_fbs
iree::schemas::hip_executable_def_c_fbs
PUBLIC
)

Expand Down
72 changes: 28 additions & 44 deletions compiler/plugins/target/ROCM/ROCMTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/HAL/Utils/ExecutableDebugInfoUtils.h"
#include "iree/compiler/Dialect/HAL/Utils/LLVMLinkerUtils.h"
#include "iree/compiler/PluginAPI/Client.h"
#include "iree/compiler/Utils/FlatbufferUtils.h"
#include "iree/compiler/Utils/ModuleUtils.h"
#include "iree/compiler/Utils/ToolUtils.h"
#include "iree/schemas/rocm_executable_def_builder.h"
#include "iree/schemas/hip_executable_def_builder.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/TargetTransformInfo.h"
Expand Down Expand Up @@ -572,29 +573,14 @@ class ROCMTargetBackend final : public TargetBackend {
}

iree_compiler::FlatbufferBuilder builder;
iree_hal_rocm_ExecutableDef_start_as_root(builder);
iree_hal_hip_ExecutableDef_start_as_root(builder);

// Attach embedded source file contents.
SmallVector<iree_hal_rocm_SourceFileDef_ref_t> sourceFileRefs;
if (auto sourcesAttr = variantOp.getSourcesAttr()) {
for (auto sourceAttr : llvm::reverse(sourcesAttr.getValue())) {
if (auto resourceAttr = dyn_cast_if_present<DenseResourceElementsAttr>(
sourceAttr.getValue())) {
auto filenameRef = builder.createString(sourceAttr.getName());
auto contentRef = builder.streamUint8Vec([&](llvm::raw_ostream &os) {
auto blobData = resourceAttr.getRawHandle().getBlob()->getData();
os.write(blobData.data(), blobData.size());
return true;
});
sourceFileRefs.push_back(iree_hal_rocm_SourceFileDef_create(
builder, filenameRef, contentRef));
}
}
std::reverse(sourceFileRefs.begin(), sourceFileRefs.end());
}
auto sourceFilesRef = createSourceFilesVec(
serOptions.debugLevel, variantOp.getSourcesAttr(), builder);

SmallVector<StringRef> entryPointNames;
SmallVector<iree_hal_rocm_FileLineLocDef_ref_t> sourceLocationRefs;
SmallVector<iree_hal_debug_FileLineLocDef_ref_t> sourceLocationRefs;
entryPointNames.resize(exportOps.size());
for (auto exportOp : exportOps) {
auto ordinalAttr = exportOp.getOrdinalAttr();
Expand All @@ -612,27 +598,28 @@ class ROCMTargetBackend final : public TargetBackend {
// be kept as-is.
sourceLocationRefs.resize(exportOps.size());
auto filenameRef = builder.createString(loc->getFilename());
sourceLocationRefs[ordinal] = iree_hal_rocm_FileLineLocDef_create(
sourceLocationRefs[ordinal] = iree_hal_debug_FileLineLocDef_create(
builder, filenameRef, loc->getLine());
}
}
}

// Optional compilation stage source files.
SmallVector<iree_hal_rocm_StageLocationsDef_ref_t> stageLocationsRefs;
SmallVector<iree_hal_debug_StageLocationsDef_ref_t> stageLocationsRefs;
if (serOptions.debugLevel >= 3) {
for (auto exportOp : exportOps) {
SmallVector<iree_hal_rocm_StageLocationDef_ref_t> stageLocationRefs;
SmallVector<iree_hal_debug_StageLocationDef_ref_t> stageLocationRefs;
if (auto locsAttr = exportOp.getSourceLocsAttr()) {
for (auto locAttr : locsAttr.getValue()) {
if (auto loc =
findFirstFileLoc(cast<LocationAttr>(locAttr.getValue()))) {
auto stageNameRef = builder.createString(locAttr.getName());
auto filenameRef = builder.createString(loc->getFilename());
stageLocationRefs.push_back(iree_hal_rocm_StageLocationDef_create(
builder, stageNameRef,
iree_hal_rocm_FileLineLocDef_create(builder, filenameRef,
loc->getLine())));
stageLocationRefs.push_back(
iree_hal_debug_StageLocationDef_create(
builder, stageNameRef,
iree_hal_debug_FileLineLocDef_create(builder, filenameRef,
loc->getLine())));
}
}
}
Expand All @@ -641,7 +628,7 @@ class ROCMTargetBackend final : public TargetBackend {
// be kept as-is.
stageLocationsRefs.resize(exportOps.size());
int64_t ordinal = exportOp.getOrdinalAttr().getInt();
stageLocationsRefs[ordinal] = iree_hal_rocm_StageLocationsDef_create(
stageLocationsRefs[ordinal] = iree_hal_debug_StageLocationsDef_create(
builder, builder.createOffsetVecDestructive(stageLocationRefs));
}
}
Expand All @@ -651,38 +638,35 @@ class ROCMTargetBackend final : public TargetBackend {
targetHSACO.size());

auto entryPointsRef = builder.createStringVec(entryPointNames);
iree_hal_rocm_BlockSizeDef_vec_start(builder);
iree_hal_hip_BlockSize_vec_start(builder);
auto blockSizes = workgroupSizes.begin();
for (int i = 0, e = entryPointNames.size(); i < e; ++i) {
iree_hal_rocm_BlockSizeDef_vec_push_create(
iree_hal_hip_BlockSize_vec_push_create(
builder, (*blockSizes)[0], (*blockSizes)[1], (*blockSizes)[2]);
++blockSizes;
}
auto workgroupLocalMemoriesRef =
builder.createInt32Vec(workgroupLocalMemories);
auto blockSizesRef = iree_hal_rocm_BlockSizeDef_vec_end(builder);
iree_hal_rocm_ExecutableDef_entry_points_add(builder, entryPointsRef);
iree_hal_rocm_ExecutableDef_block_sizes_add(builder, blockSizesRef);
iree_hal_rocm_ExecutableDef_shared_memory_sizes_add(
auto blockSizesRef = iree_hal_hip_BlockSize_vec_end(builder);
iree_hal_hip_ExecutableDef_entry_points_add(builder, entryPointsRef);
iree_hal_hip_ExecutableDef_block_sizes_add(builder, blockSizesRef);
iree_hal_hip_ExecutableDef_shared_memory_sizes_add(
builder, workgroupLocalMemoriesRef);
iree_hal_rocm_ExecutableDef_hsaco_image_add(builder, hsacoRef);
iree_hal_hip_ExecutableDef_hsaco_image_add(builder, hsacoRef);
if (!sourceLocationRefs.empty()) {
auto sourceLocationsRef =
builder.createOffsetVecDestructive(sourceLocationRefs);
iree_hal_rocm_ExecutableDef_source_locations_add(builder,
sourceLocationsRef);
iree_hal_hip_ExecutableDef_source_locations_add(builder,
sourceLocationsRef);
}
if (!stageLocationsRefs.empty()) {
auto stageLocationsRef =
builder.createOffsetVecDestructive(stageLocationsRefs);
iree_hal_rocm_ExecutableDef_stage_locations_add(builder,
stageLocationsRef);
}
if (!sourceFileRefs.empty()) {
auto sourceFilesRef = builder.createOffsetVecDestructive(sourceFileRefs);
iree_hal_rocm_ExecutableDef_source_files_add(builder, sourceFilesRef);
iree_hal_hip_ExecutableDef_stage_locations_add(builder,
stageLocationsRef);
}
iree_hal_rocm_ExecutableDef_end_as_root(builder);
iree_hal_hip_ExecutableDef_source_files_add(builder, sourceFilesRef);
iree_hal_hip_ExecutableDef_end_as_root(builder);

// Add the binary data to the target executable.
executableBuilder.create<iree_compiler::IREE::HAL::ExecutableBinaryOp>(
Expand Down
2 changes: 1 addition & 1 deletion compiler/plugins/target/VMVX/VMVXTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class VMVXTargetBackend final : public TargetBackend {

// Specify the constant and binding information used to validate
// dispatches.
// TODO(#18189): pack per-binding information bitfields.
// TODO(#18154): pack per-binding information bitfields.
if (auto layoutAttr = exportOp.getLayout()) {
int64_t constantCount = layoutAttr.getPushConstants();
if (constantCount > 0) {
Expand Down
Loading
Loading