Skip to content

Commit

Permalink
[Bugfix][Unity] Recover MSVC/NVCC/ROCm/Vulkan
Browse files Browse the repository at this point in the history
This PR upstreams a few commits that recovers the unity branch from
broken wheel packages. It includes the following changes:

- Fix MSVC build in `pipe.h` where `DWORD` is not cast to proper return
  type (mlc-ai/relax#306);
- Fix MSVC build warnings on not recognizing "#pragma GCC"
  (mlc-ai/relax#307);
- Fix NVCC build warnings where it fails to infer if "[[noreturn]]"
  actually does not return (mlc-ai/relax#308);
- Fix ROCM/Vulkan backend which fails compilation for operators like group
  GEMM, paged attention, etc. (apache#16404,
  apache#16405)

Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Lesheng Jin <[email protected]>
  • Loading branch information
3 people committed Jan 17, 2024
1 parent a2a1b53 commit 7471cd1
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 45 deletions.
18 changes: 16 additions & 2 deletions include/tvm/runtime/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,10 @@ class LogFatal {
#pragma disagnostic push
#pragma warning(disable : 4722)
#endif
[[noreturn]] ~LogFatal() TVM_THROW_EXCEPTION { GetEntry().Finalize(); }
[[noreturn]] ~LogFatal() TVM_THROW_EXCEPTION {
GetEntry().Finalize();
throw;
}
#ifdef _MSC_VER
#pragma disagnostic pop
#endif
Expand All @@ -366,7 +369,7 @@ class LogFatal {
this->file_ = file;
this->lineno_ = lineno;
}
[[noreturn]] TVM_NO_INLINE dmlc::Error Finalize() {
[[noreturn]] TVM_NO_INLINE dmlc::Error Finalize() TVM_THROW_EXCEPTION {
InternalError error(file_, lineno_, stream_.str());
#if DMLC_LOG_BEFORE_THROW
std::cerr << error.what() << std::endl;
Expand Down Expand Up @@ -560,15 +563,26 @@ std::unique_ptr<std::string> LogCheckFormat(const X& x, const Y& y) {
return LogCheck##name<int, int>(x, y); \
}

#if defined(__GNUC__) || defined(__clang__) // GCC and Clang
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wsign-compare"
#elif defined(_MSC_VER) // MSVC
#pragma warning(push)
#pragma warning(disable : 4389) // '==' : signed/unsigned mismatch
#endif

TVM_CHECK_FUNC(_LT, <)
TVM_CHECK_FUNC(_GT, >)
TVM_CHECK_FUNC(_LE, <=)
TVM_CHECK_FUNC(_GE, >=)
TVM_CHECK_FUNC(_EQ, ==)
TVM_CHECK_FUNC(_NE, !=)

#if defined(__GNUC__) || defined(__clang__) // GCC and Clang
#pragma GCC diagnostic pop
#elif defined(_MSC_VER) // MSVC
#pragma warning(pop)
#endif

} // namespace detail

Expand Down
14 changes: 8 additions & 6 deletions src/support/pipe.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,11 @@ class Pipe : public dmlc::Stream {
size_t Read(void* ptr, size_t size) final {
if (size == 0) return 0;
#ifdef _WIN32
auto fread = [&]() {
auto fread = [&]() -> ssize_t {
DWORD nread;
if (!ReadFile(handle_, static_cast<TCHAR*>(ptr), size, &nread, nullptr)) return -1;
return nread;
if (!ReadFile(handle_, static_cast<TCHAR*>(ptr), size, &nread, nullptr))
return static_cast<ssize_t>(-1);
return static_cast<ssize_t>(nread);
};
DWORD nread = static_cast<DWORD>(RetryCallOnEINTR(fread, GetLastErrorCode));
ICHECK_EQ(static_cast<size_t>(nread), size) << "Read Error: " << GetLastError();
Expand All @@ -99,10 +100,11 @@ class Pipe : public dmlc::Stream {
void Write(const void* ptr, size_t size) final {
if (size == 0) return;
#ifdef _WIN32
auto fwrite = [&]() {
auto fwrite = [&]() -> ssize_t {
DWORD nwrite;
if (!WriteFile(handle_, static_cast<const TCHAR*>(ptr), size, &nwrite, nullptr)) return -1;
return nwrite;
if (!WriteFile(handle_, static_cast<const TCHAR*>(ptr), size, &nwrite, nullptr))
return static_cast<ssize_t>(-1);
return static_cast<ssize_t>(nwrite);
};
DWORD nwrite = static_cast<DWORD>(RetryCallOnEINTR(fwrite, GetLastErrorCode));
ICHECK_EQ(static_cast<size_t>(nwrite), size) << "Write Error: " << GetLastError();
Expand Down
2 changes: 2 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1476,6 +1476,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
} else if (op->op.same_as(builtin::assume())) {
llvm::Value* cond = MakeValue(op->args[0]);
return builder_->CreateAssumption(cond);
} else if (op->op.same_as(builtin::tvm_thread_invariant())) {
return MakeValue(op->args[0]);
} else {
LOG(FATAL) << "unknown intrinsic " << op->op;
}
Expand Down
87 changes: 52 additions & 35 deletions src/target/llvm/intrin_rule_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,14 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) {
index = self + delta;
index = Select((self & (width - 1)) + delta >= width, self, index);
}
// reinterprete var as int32
bool is_int32 = var.dtype().is_int() && var.dtype().bits() == 32;
PrimExpr source = is_int32 ? var : reinterpret(DataType::Int(32), var);
PrimExpr res = Call(DataType::Int(32), builtin::call_pure_extern(),
{StringImm("llvm.amdgcn.ds.bpermute"), index << 2, var});
{StringImm("llvm.amdgcn.ds.bpermute"), index << 2, source});
if (!is_int32) {
res = reinterpret(var.dtype(), res);
}
return res;
}

Expand All @@ -114,73 +120,84 @@ TVM_REGISTER_OP("tir.tvm_warp_shuffle_down")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchShuffle);

TVM_REGISTER_OP("tir.floor")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>);

TVM_REGISTER_OP("tir.ceil")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);

TVM_REGISTER_OP("tir.round")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);

TVM_REGISTER_OP("tir.nearbyint")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);

TVM_REGISTER_OP("tir.trunc")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);

TVM_REGISTER_OP("tir.fabs")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>);

TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchPureExternOCML);
TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>(
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>);

TVM_REGISTER_OP("tir.exp2")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>);

TVM_REGISTER_OP("tir.exp10")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
// TVM_REGISTER_OP("tir.exp10")
// .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
// DispatchLLVMPureIntrin<::llvm::Intrinsic::exp10, 1>);

TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchPureExternOCML);
// TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
// DispatchPureExternOCML);

TVM_REGISTER_OP("tir.fma").set_attr<FLowerIntrinsic>(
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>);

TVM_REGISTER_OP("tir.log").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchPureExternOCML);
TVM_REGISTER_OP("tir.log").set_attr<FLowerIntrinsic>(
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>);

TVM_REGISTER_OP("tir.log2")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>);

TVM_REGISTER_OP("tir.log10")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>);

TVM_REGISTER_OP("tir.sqrt")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>);

TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchPureExternOCML);
TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>(
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>);

TVM_REGISTER_OP("tir.tanh")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
// TVM_REGISTER_OP("tir.tanh")
// .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);

TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchPureExternOCML);
// TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
// DispatchPureExternOCML);

TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchPureExternOCML);
TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>(
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);

TVM_REGISTER_OP("tir.cosh")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
// TVM_REGISTER_OP("tir.cosh")
// .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);

TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchPureExternOCML);
TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>(
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);

TVM_REGISTER_OP("tir.sinh")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
// TVM_REGISTER_OP("tir.sinh")
// .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);

TVM_REGISTER_OP("tir.atan")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
// TVM_REGISTER_OP("tir.atan")
// .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);

} // namespace llvm
} // namespace codegen
Expand Down
2 changes: 2 additions & 0 deletions src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
spirv::SType ptr_type = builder_->GetPointerType(ele_stype, buffer_val.stype.storage_class);
ICHECK(var_map_.count(buffer_node));
return builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], MakeValue(index));
} else if (op->op.same_as(builtin::tvm_thread_invariant())) {
return MakeValue(op->args[0]);
} else {
LOG(FATAL) << "Unresolved call " << op->op;
}
Expand Down
3 changes: 3 additions & 0 deletions src/target/spirv/intrin_rule_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ TVM_REGISTER_OP("tir.fabs")
TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
DispatchGLSLPureIntrin<GLSLstd450Exp>);

TVM_REGISTER_OP("tir.exp2")
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Exp2>);

TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
DispatchGLSLPureIntrin<GLSLstd450Sin>);

Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// rocm only supports 32 bit operands for shuffling at the moment
if ((target_->kind->name == "rocm") &&
(std::any_of(types.begin(), types.end(), [](DataType ty) {
if ((ty.is_vector()) || !ty.is_int()) return true;
if (ty.is_vector()) return ty.bits() * ty.lanes() != 32;
return ty.bits() != 32;
}))) {
return false;
Expand Down
5 changes: 5 additions & 0 deletions src/tir/transforms/merge_shared_memory_allocations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,11 @@ namespace transform {
Pass MergeSharedMemoryAllocations() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
// disable this pass for Vulkan
auto target = Target::Current(true);
if (target.defined() && target->kind->name == "vulkan") {
return f;
}
auto* n = f.CopyOnWrite();
n->body = MergeSharedMemoryAllocations(std::move(n->body), merge_static_smem);
return f;
Expand Down
7 changes: 6 additions & 1 deletion src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1705,8 +1705,13 @@ namespace transform {
Pass StorageRewrite() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
// disable merge_static_smem for Vulkan
auto target = Target::Current(true);
if (target.defined() && target->kind->name == "vulkan") {
merge_static_smem = false;
}
auto* n = f.CopyOnWrite();
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, !merge_static_smem);
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, merge_static_smem);
// Parameters may not be rewritten, but internal allocations may.
// Vectorization of AllocateConst is currently disabled, as it has
// indexing issues for types that include padding (e.g. int8x3
Expand Down
53 changes: 53 additions & 0 deletions tests/python/codegen/test_target_codegen_rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tvm import te
import numpy as np
import unittest
from tvm.script import tir as T

tx = te.thread_axis("threadIdx.x")
ty = te.thread_axis("threadIdx.y")
Expand Down Expand Up @@ -130,9 +131,61 @@ def check_rocm(dtype, n, lanes):
check_rocm("float16", 64, 2)


@tvm.testing.requires_rocm
def test_rocm_warp_shuffle():
@T.prim_func
def func(
A_handle: T.handle,
):
A = T.match_buffer(A_handle, (32,), dtype="float32")

for bx in T.thread_binding(1, thread="blockIdx.x"):
for tx in T.thread_binding(32, thread="threadIdx.x"):
with T.block("test"):
A_local = T.alloc_buffer((1,), "float32", scope="local")
mask = T.alloc_buffer((1,), "uint32", scope="local")
t0 = T.alloc_buffer((1,), "float32", scope="local")

A_local[0] = A[tx]
A_local[0] = T.tvm_warp_shuffle(mask[0], A_local[0], 0, 32, 32)
A[tx] = A_local[0]

mod = tvm.build(func, target="rocm")
dev = tvm.rocm(0)
a = tvm.nd.array(np.random.uniform(size=(32,)).astype("float32"), dev)
mod(a)
tvm.testing.assert_allclose(a.numpy(), np.ones((32,)) * a.numpy()[0])


@tvm.testing.requires_rocm
def test_rocm_vectorized_exp():
@T.prim_func
def func(
A_handle: T.handle,
B_handle: T.handle,
):
A = T.match_buffer(A_handle, (4,), dtype="float32")
B = T.match_buffer(B_handle, (4,), dtype="float32")

for bx in T.thread_binding(1, thread="blockIdx.x"):
for tx in T.thread_binding(1, thread="threadIdx.x"):
with T.block("test"):
for i in T.vectorized(0, 4):
B[i] = T.exp2(A[i])

mod = tvm.build(func, target="rocm")
dev = tvm.rocm(0)
a = tvm.nd.array(np.ones((4,)).astype("float32"), dev)
b = tvm.nd.array(np.zeros((4,)).astype("float32"), dev)
mod(a, b)
tvm.testing.assert_allclose(b.numpy(), np.exp2(a.numpy()))


if __name__ == "__main__":
test_rocm_cross_thread_reduction()
test_rocm_inf_nan()
test_rocm_reduction_binding()
test_rocm_copy()
test_rocm_vectorize_add()
test_rocm_warp_shuffle()
test_rocm_vectorized_exp()

0 comments on commit 7471cd1

Please sign in to comment.