Skip to content

Commit 9b1cfc1

Browse files
[Vulkan] Some fixes of Vulkan codegen (#309)
- add exp2 dispatch - handle thread invariant op - disable smem merge for vulkan Co-authored-by: spectrometerHBH <[email protected]>
1 parent 5556d56 commit 9b1cfc1

File tree

4 files changed

+16
-1
lines changed

4 files changed

+16
-1
lines changed

src/target/spirv/codegen_spirv.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
509509
spirv::SType ptr_type = builder_->GetPointerType(ele_stype, buffer_val.stype.storage_class);
510510
ICHECK(var_map_.count(buffer_node));
511511
return builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], MakeValue(index));
512+
} else if (op->op.same_as(builtin::tvm_thread_invariant())) {
513+
return MakeValue(op->args[0]);
512514
} else {
513515
LOG(FATAL) << "Unresolved call " << op->op;
514516
}

src/target/spirv/intrin_rule_spirv.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ TVM_REGISTER_OP("tir.fabs")
8282
TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
8383
DispatchGLSLPureIntrin<GLSLstd450Exp>);
8484

85+
TVM_REGISTER_OP("tir.exp2")
86+
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Exp2>);
87+
8588
TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
8689
DispatchGLSLPureIntrin<GLSLstd450Sin>);
8790

src/tir/transforms/merge_shared_memory_allocations.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,11 @@ namespace transform {
662662
Pass MergeSharedMemoryAllocations() {
663663
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
664664
bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
665+
// disable this pass for Vulkan
666+
auto target = Target::Current(true);
667+
if (target.defined() && target->kind->name == "vulkan") {
668+
return f;
669+
}
665670
auto* n = f.CopyOnWrite();
666671
n->body = MergeSharedMemoryAllocations(std::move(n->body), merge_static_smem);
667672
return f;

src/tir/transforms/storage_rewrite.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1705,8 +1705,13 @@ namespace transform {
17051705
Pass StorageRewrite() {
17061706
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
17071707
bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
1708+
// disable merge_static_smem for Vulkan
1709+
auto target = Target::Current(true);
1710+
if (target.defined() && target->kind->name == "vulkan") {
1711+
merge_static_smem = false;
1712+
}
17081713
auto* n = f.CopyOnWrite();
1709-
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, !merge_static_smem);
1714+
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, merge_static_smem);
17101715
// Parameters may not be rewritten, but internal allocations may.
17111716
// Vectorization of AllocateConst is currently disabled, as it has
17121717
// indexing issues for types that include padding (e.g. int8x3

0 commit comments

Comments
 (0)