-
Notifications
You must be signed in to change notification settings - Fork 13.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[HLSL] Implement dot2add intrinsic #131237
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
a9e4cc4
to
0c5a6af
Compare
@llvm/pr-subscribers-hlsl @llvm/pr-subscribers-llvm-ir Author: Sumit Agarwal (sumitsays) ChangesResolves #99221
Full diff: https://github.com/llvm/llvm-project/pull/131237.diff 11 Files Affected:
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 72a5e495c4059..76ab463ca0ed6 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4891,6 +4891,12 @@ def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
+def HLSLDot2Add : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_dot2add"];
+ let Attributes = [NoThrow, Const, CustomTypeChecking];
+ let Prototype = "void(...)";
+}
+
def HLSLDot4AddI8Packed : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_dot4add_i8packed"];
let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index c126f88b9e3a5..b3d9db5be7d8d 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -19681,6 +19681,21 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
getDotProductIntrinsic(CGM.getHLSLRuntime(), VecTy0->getElementType()),
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");
}
+ case Builtin::BI__builtin_hlsl_dot2add: {
+ llvm::Triple::ArchType Arch = CGM.getTarget().getTriple().getArch();
+ if (Arch != llvm::Triple::dxil) {
+ llvm_unreachable("Intrinsic dot2add can be executed as a builtin only on dxil");
+ }
+ Value *A = EmitScalarExpr(E->getArg(0));
+ Value *B = EmitScalarExpr(E->getArg(1));
+ Value *C = EmitScalarExpr(E->getArg(2));
+
+ //llvm::Intrinsic::dx_##IntrinsicPostfix
+ Intrinsic::ID ID = llvm ::Intrinsic::dx_dot2add;
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
+ "hlsl.dot2add");
+ }
case Builtin::BI__builtin_hlsl_dot4add_i8packed: {
Value *A = EmitScalarExpr(E->getArg(0));
Value *B = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 5f7c047dbf340..46653d7b295b2 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -45,6 +45,14 @@ distance_vec_impl(vector<T, N> X, vector<T, N> Y) {
return length_vec_impl(X - Y);
}
+constexpr float dot2add_impl(half2 a, half2 b, float c) {
+#if defined(__DIRECTX__)
+ return __builtin_hlsl_dot2add(a, b, c);
+#else
+ return dot(a, b) + c;
+#endif
+}
+
template <typename T> constexpr T reflect_impl(T I, T N) {
return I - 2 * N * I * N;
}
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 5459cbeb34fd0..b1c1335ce3328 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -117,6 +117,18 @@ const inline float distance(__detail::HLSL_FIXED_VECTOR<float, N> X,
return __detail::distance_vec_impl(X, Y);
}
+//===----------------------------------------------------------------------===//
+// dot2add builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn float dot2add(half2 a, half2 b, float c)
+/// \brief Dot product of 2 vector of type half and add a float scalar value.
+
+_HLSL_AVAILABILITY(shadermodel, 6.4)
+const inline float dot2add(half2 a, half2 b, float c) {
+ return __detail::dot2add_impl(a, b, c);
+}
+
//===----------------------------------------------------------------------===//
// fmod builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 36de110e75e8a..399371c4ae2f6 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1989,7 +1989,7 @@ void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
}
// Helper function for CheckHLSLBuiltinFunctionCall
-static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
+static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall, unsigned NumArgs) {
assert(TheCall->getNumArgs() > 1);
ExprResult A = TheCall->getArg(0);
@@ -1999,7 +1999,7 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
SourceLocation BuiltinLoc = TheCall->getBeginLoc();
bool AllBArgAreVectors = true;
- for (unsigned i = 1; i < TheCall->getNumArgs(); ++i) {
+ for (unsigned i = 1; i < NumArgs; ++i) {
ExprResult B = TheCall->getArg(i);
QualType ArgTyB = B.get()->getType();
auto *VecTyB = ArgTyB->getAs<VectorType>();
@@ -2049,6 +2049,10 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
return false;
}
+static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
+ return CheckVectorElementCallArgs(S, TheCall, TheCall->getNumArgs());
+}
+
static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
assert(TheCall->getNumArgs() > 1);
QualType ArgTy0 = TheCall->getArg(0)->getType();
@@ -2091,10 +2095,10 @@ static bool CheckArgTypeIsCorrect(
return false;
}
-static bool CheckAllArgTypesAreCorrect(
- Sema *S, CallExpr *TheCall, QualType ExpectedType,
+static bool CheckArgTypesAreCorrect(
+ Sema *S, CallExpr *TheCall, unsigned NumArgs, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
- for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
+ for (unsigned i = 0; i < NumArgs; ++i) {
Expr *Arg = TheCall->getArg(i);
if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
return true;
@@ -2103,6 +2107,13 @@ static bool CheckAllArgTypesAreCorrect(
return false;
}
+static bool CheckAllArgTypesAreCorrect(
+ Sema *S, CallExpr *TheCall, QualType ExpectedType,
+ llvm::function_ref<bool(clang::QualType PassedType)> Check) {
+ return CheckArgTypesAreCorrect(S, TheCall, TheCall->getNumArgs(),
+ ExpectedType, Check);
+}
+
static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
return !PassedType->hasFloatingRepresentation();
@@ -2146,15 +2157,17 @@ static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
return true;
}
-static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
+static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall,
+ unsigned NumArgs, QualType ExpectedType) {
auto checkDoubleVector = [](clang::QualType PassedType) -> bool {
if (const auto *VecTy = PassedType->getAs<VectorType>())
return VecTy->getElementType()->isDoubleType();
return false;
};
- return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
- checkDoubleVector);
+ return CheckArgTypesAreCorrect(S, TheCall, NumArgs,
+ ExpectedType, checkDoubleVector);
}
+
static bool CheckFloatingOrIntRepresentation(Sema *S, CallExpr *TheCall) {
auto checkAllSignedTypes = [](clang::QualType PassedType) -> bool {
return !PassedType->hasIntegerRepresentation() &&
@@ -2468,8 +2481,36 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
if (SemaRef.BuiltinVectorToScalarMath(TheCall))
return true;
- if (CheckNoDoubleVectors(&SemaRef, TheCall))
+ if (CheckNoDoubleVectors(&SemaRef, TheCall,
+ TheCall->getNumArgs(), SemaRef.Context.FloatTy))
+ return true;
+ break;
+ }
+ case Builtin::BI__builtin_hlsl_dot2add: {
+ // Check number of arguments should be 3
+ if (SemaRef.checkArgCount(TheCall, 3))
+ return true;
+
+ // Check first two arguments are vector of length 2 with half data type
+ auto checkHalfVectorOfSize2 = [](clang::QualType PassedType) -> bool {
+ if (const auto *VecTy = PassedType->getAs<VectorType>())
+ return !(VecTy->getNumElements() == 2 &&
+ VecTy->getElementType()->isHalfType());
+ return true;
+ };
+ if(CheckArgTypeIsCorrect(&SemaRef, TheCall->getArg(0),
+ SemaRef.getASTContext().HalfTy,
+ checkHalfVectorOfSize2))
+ return true;
+ if(CheckArgTypeIsCorrect(&SemaRef, TheCall->getArg(1),
+ SemaRef.getASTContext().HalfTy,
+ checkHalfVectorOfSize2))
+ return true;
+
+ // Check third argument is a float
+ if (CheckArgTypeMatches(&SemaRef, TheCall->getArg(2), SemaRef.getASTContext().FloatTy))
return true;
+ TheCall->setType(TheCall->getArg(2)->getType());
break;
}
case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
diff --git a/clang/test/CodeGenHLSL/builtins/dot2add.hlsl b/clang/test/CodeGenHLSL/builtins/dot2add.hlsl
new file mode 100644
index 0000000000000..ce325327a01b5
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/dot2add.hlsl
@@ -0,0 +1,17 @@
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+float test(half2 p1, half2 p2, float p3) {
+ // CHECK-SPIRV: %[[MUL:.*]] = call {{.*}} float @llvm.spv.fdot.v2f32(<2 x float> %1, <2 x float> %2)
+ // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr, align 4
+ // CHECK-SPIRV: %[[RES:.*]] = fadd {{.*}} float %[[MUL]], %[[C]]
+ // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f32(<2 x float> %0, <2 x float> %1, float %2)
+ // CHECK: ret float %[[RES]]
+ return dot2add(p1, p2, p3);
+}
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/Dot2Add-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/Dot2Add-errors.hlsl
new file mode 100644
index 0000000000000..61282a319dafd
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/Dot2Add-errors.hlsl
@@ -0,0 +1,11 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+bool test_too_few_arg() {
+ return __builtin_hlsl_dot2add();
+ // expected-error@-1 {{too few arguments to function call, expected 3, have 0}}
+}
+
+bool test_too_many_arg(half2 p1, half2 p2, float p3) {
+ return __builtin_hlsl_dot2add(p1, p2, p3, p1);
+ // expected-error@-1 {{too many arguments to function call, expected 3, have 4}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index ead7286f4311c..775d325feeb14 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -100,6 +100,10 @@ def int_dx_udot :
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, Commutative] >;
+def int_dx_dot2add :
+ DefaultAttrsIntrinsic<[llvm_float_ty],
+ [llvm_anyfloat_ty, LLVMMatchType<0>, llvm_float_ty],
+ [IntrNoMem, Commutative]>;
def int_dx_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
def int_dx_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index ebe1d876d58b1..193b592a525a0 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -1098,6 +1098,17 @@ def RawBufferStore : DXILOp<140, rawBufferStore> {
let stages = [Stages<DXIL1_2, [all_stages]>];
}
+def Dot2AddHalf : DXILOp<162, dot2AddHalf> {
+ let Doc = "dot product of 2 vectors of half having size = 2, returns "
+ "float";
+ let intrinsics = [IntrinSelect<int_dx_dot2add>];
+ let arguments = [FloatTy, HalfTy, HalfTy, HalfTy, HalfTy];
+ let result = FloatTy;
+ let overloads = [Overloads<DXIL1_0, []>];
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+ let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
+
def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> {
let Doc = "signed dot product of 4 x i8 vectors packed into i32, with "
"accumulate to i32";
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index dff9f3e03079e..f7ed0c5071d75 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -54,10 +54,36 @@ static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
return ExtractedElements;
}
+static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
+ IRBuilder<> &Builder,
+ unsigned NumOperands) {
+ assert(NumOperands > 0);
+ Value *Arg0 = Orig->getOperand(0);
+ [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
+ assert(VecArg0);
+ SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
+ for (unsigned I = 1; I < NumOperands; ++I) {
+ Value *Arg = Orig->getOperand(I);
+ [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
+ assert(VecArg);
+ assert(VecArg0->getElementType() == VecArg->getElementType());
+ assert(VecArg0->getNumElements() == VecArg->getNumElements());
+ auto NextOperandList = populateOperands(Arg, Builder);
+ NewOperands.append(NextOperandList.begin(), NextOperandList.end());
+ }
+ return NewOperands;
+}
+
static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
IRBuilder<> &Builder) {
// Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
- unsigned NumOperands = Orig->getNumOperands() - 1;
+ return argVectorFlatten(Orig, Builder, Orig->getNumOperands() - 1);
+}
+/*
+static SmallVector<Value *> argVectorFlattenExcludeLastElement(CallInst *Orig,
+ IRBuilder<> &Builder) {
+ // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
+ unsigned NumOperands = Orig->getNumOperands() - 2;
assert(NumOperands > 0);
Value *Arg0 = Orig->getOperand(0);
[[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
@@ -74,7 +100,7 @@ static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
}
return NewOperands;
}
-
+*/
namespace {
class OpLowerer {
Module &M;
@@ -168,6 +194,25 @@ class OpLowerer {
}
} else if (IsVectorArgExpansion) {
Args = argVectorFlatten(CI, OpBuilder.getIRB());
+ } else if (F.getIntrinsicID() == Intrinsic::dx_dot2add) {
+ // arg[NumOperands-1] is a pointer and is not needed by our flattening.
+ // arg[NumOperands-2] also does not need to be flattened because it is a scalar.
+ unsigned NumOperands = CI->getNumOperands() - 2;
+ Args.push_back(CI->getArgOperand(NumOperands));
+ Args.append(argVectorFlatten(CI, OpBuilder.getIRB(), NumOperands));
+
+ /*unsigned NumOperands = CI->getNumOperands() - 1;
+ assert(NumOperands > 0);
+ Value *LastArg = CI->getOperand(NumOperands - 1);
+
+ Args.push_back(LastArg);
+
+ //dbgs() << "Value of LastArg" << LastArg->getName() << "\n";
+
+
+ //Args = populateOperands(LastArg, OpBuilder.getIRB());
+ Args.append(argVectorFlattenExcludeLastElement(CI, OpBuilder.getIRB()));
+ */
} else {
Args.append(CI->arg_begin(), CI->arg_end());
}
diff --git a/llvm/test/CodeGen/DirectX/dot2add.ll b/llvm/test/CodeGen/DirectX/dot2add.ll
new file mode 100644
index 0000000000000..b1019c36b56e8
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/dot2add.ll
@@ -0,0 +1,8 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
+
+define noundef float @dot2add_simple(<2 x half> noundef %a, <2 x half> noundef %b, float %c) {
+entry:
+; CHECK: call float @dx.op.dot2AddHalf(i32 162, float %c, half %0, half %1, half %2, half %3)
+ %ret = call float @llvm.dx.dot2add(<2 x half> %a, <2 x half> %b, float %c)
+ ret float %ret
+}
\ No newline at end of file
|
@llvm/pr-subscribers-backend-x86 Author: Sumit Agarwal (sumitsays) ChangesResolves #99221
Full diff: https://github.com/llvm/llvm-project/pull/131237.diff 11 Files Affected:
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 72a5e495c4059..76ab463ca0ed6 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4891,6 +4891,12 @@ def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
+def HLSLDot2Add : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_dot2add"];
+ let Attributes = [NoThrow, Const, CustomTypeChecking];
+ let Prototype = "void(...)";
+}
+
def HLSLDot4AddI8Packed : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_dot4add_i8packed"];
let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index c126f88b9e3a5..b3d9db5be7d8d 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -19681,6 +19681,21 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
getDotProductIntrinsic(CGM.getHLSLRuntime(), VecTy0->getElementType()),
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");
}
+ case Builtin::BI__builtin_hlsl_dot2add: {
+ llvm::Triple::ArchType Arch = CGM.getTarget().getTriple().getArch();
+ if (Arch != llvm::Triple::dxil) {
+ llvm_unreachable("Intrinsic dot2add can be executed as a builtin only on dxil");
+ }
+ Value *A = EmitScalarExpr(E->getArg(0));
+ Value *B = EmitScalarExpr(E->getArg(1));
+ Value *C = EmitScalarExpr(E->getArg(2));
+
+ //llvm::Intrinsic::dx_##IntrinsicPostfix
+ Intrinsic::ID ID = llvm ::Intrinsic::dx_dot2add;
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
+ "hlsl.dot2add");
+ }
case Builtin::BI__builtin_hlsl_dot4add_i8packed: {
Value *A = EmitScalarExpr(E->getArg(0));
Value *B = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 5f7c047dbf340..46653d7b295b2 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -45,6 +45,14 @@ distance_vec_impl(vector<T, N> X, vector<T, N> Y) {
return length_vec_impl(X - Y);
}
+constexpr float dot2add_impl(half2 a, half2 b, float c) {
+#if defined(__DIRECTX__)
+ return __builtin_hlsl_dot2add(a, b, c);
+#else
+ return dot(a, b) + c;
+#endif
+}
+
template <typename T> constexpr T reflect_impl(T I, T N) {
return I - 2 * N * I * N;
}
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 5459cbeb34fd0..b1c1335ce3328 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -117,6 +117,18 @@ const inline float distance(__detail::HLSL_FIXED_VECTOR<float, N> X,
return __detail::distance_vec_impl(X, Y);
}
+//===----------------------------------------------------------------------===//
+// dot2add builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn float dot2add(half2 a, half2 b, float c)
+/// \brief Dot product of 2 vector of type half and add a float scalar value.
+
+_HLSL_AVAILABILITY(shadermodel, 6.4)
+const inline float dot2add(half2 a, half2 b, float c) {
+ return __detail::dot2add_impl(a, b, c);
+}
+
//===----------------------------------------------------------------------===//
// fmod builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 36de110e75e8a..399371c4ae2f6 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1989,7 +1989,7 @@ void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
}
// Helper function for CheckHLSLBuiltinFunctionCall
-static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
+static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall, unsigned NumArgs) {
assert(TheCall->getNumArgs() > 1);
ExprResult A = TheCall->getArg(0);
@@ -1999,7 +1999,7 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
SourceLocation BuiltinLoc = TheCall->getBeginLoc();
bool AllBArgAreVectors = true;
- for (unsigned i = 1; i < TheCall->getNumArgs(); ++i) {
+ for (unsigned i = 1; i < NumArgs; ++i) {
ExprResult B = TheCall->getArg(i);
QualType ArgTyB = B.get()->getType();
auto *VecTyB = ArgTyB->getAs<VectorType>();
@@ -2049,6 +2049,10 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
return false;
}
+static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
+ return CheckVectorElementCallArgs(S, TheCall, TheCall->getNumArgs());
+}
+
static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
assert(TheCall->getNumArgs() > 1);
QualType ArgTy0 = TheCall->getArg(0)->getType();
@@ -2091,10 +2095,10 @@ static bool CheckArgTypeIsCorrect(
return false;
}
-static bool CheckAllArgTypesAreCorrect(
- Sema *S, CallExpr *TheCall, QualType ExpectedType,
+static bool CheckArgTypesAreCorrect(
+ Sema *S, CallExpr *TheCall, unsigned NumArgs, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
- for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
+ for (unsigned i = 0; i < NumArgs; ++i) {
Expr *Arg = TheCall->getArg(i);
if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
return true;
@@ -2103,6 +2107,13 @@ static bool CheckAllArgTypesAreCorrect(
return false;
}
+static bool CheckAllArgTypesAreCorrect(
+ Sema *S, CallExpr *TheCall, QualType ExpectedType,
+ llvm::function_ref<bool(clang::QualType PassedType)> Check) {
+ return CheckArgTypesAreCorrect(S, TheCall, TheCall->getNumArgs(),
+ ExpectedType, Check);
+}
+
static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
return !PassedType->hasFloatingRepresentation();
@@ -2146,15 +2157,17 @@ static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
return true;
}
-static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
+static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall,
+ unsigned NumArgs, QualType ExpectedType) {
auto checkDoubleVector = [](clang::QualType PassedType) -> bool {
if (const auto *VecTy = PassedType->getAs<VectorType>())
return VecTy->getElementType()->isDoubleType();
return false;
};
- return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
- checkDoubleVector);
+ return CheckArgTypesAreCorrect(S, TheCall, NumArgs,
+ ExpectedType, checkDoubleVector);
}
+
static bool CheckFloatingOrIntRepresentation(Sema *S, CallExpr *TheCall) {
auto checkAllSignedTypes = [](clang::QualType PassedType) -> bool {
return !PassedType->hasIntegerRepresentation() &&
@@ -2468,8 +2481,36 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
if (SemaRef.BuiltinVectorToScalarMath(TheCall))
return true;
- if (CheckNoDoubleVectors(&SemaRef, TheCall))
+ if (CheckNoDoubleVectors(&SemaRef, TheCall,
+ TheCall->getNumArgs(), SemaRef.Context.FloatTy))
+ return true;
+ break;
+ }
+ case Builtin::BI__builtin_hlsl_dot2add: {
+ // Check number of arguments should be 3
+ if (SemaRef.checkArgCount(TheCall, 3))
+ return true;
+
+ // Check first two arguments are vector of length 2 with half data type
+ auto checkHalfVectorOfSize2 = [](clang::QualType PassedType) -> bool {
+ if (const auto *VecTy = PassedType->getAs<VectorType>())
+ return !(VecTy->getNumElements() == 2 &&
+ VecTy->getElementType()->isHalfType());
+ return true;
+ };
+ if(CheckArgTypeIsCorrect(&SemaRef, TheCall->getArg(0),
+ SemaRef.getASTContext().HalfTy,
+ checkHalfVectorOfSize2))
+ return true;
+ if(CheckArgTypeIsCorrect(&SemaRef, TheCall->getArg(1),
+ SemaRef.getASTContext().HalfTy,
+ checkHalfVectorOfSize2))
+ return true;
+
+ // Check third argument is a float
+ if (CheckArgTypeMatches(&SemaRef, TheCall->getArg(2), SemaRef.getASTContext().FloatTy))
return true;
+ TheCall->setType(TheCall->getArg(2)->getType());
break;
}
case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
diff --git a/clang/test/CodeGenHLSL/builtins/dot2add.hlsl b/clang/test/CodeGenHLSL/builtins/dot2add.hlsl
new file mode 100644
index 0000000000000..ce325327a01b5
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/dot2add.hlsl
@@ -0,0 +1,17 @@
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+float test(half2 p1, half2 p2, float p3) {
+ // CHECK-SPIRV: %[[MUL:.*]] = call {{.*}} float @llvm.spv.fdot.v2f32(<2 x float> %1, <2 x float> %2)
+ // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr, align 4
+ // CHECK-SPIRV: %[[RES:.*]] = fadd {{.*}} float %[[MUL]], %[[C]]
+ // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f32(<2 x float> %0, <2 x float> %1, float %2)
+ // CHECK: ret float %[[RES]]
+ return dot2add(p1, p2, p3);
+}
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/Dot2Add-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/Dot2Add-errors.hlsl
new file mode 100644
index 0000000000000..61282a319dafd
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/Dot2Add-errors.hlsl
@@ -0,0 +1,11 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+bool test_too_few_arg() {
+ return __builtin_hlsl_dot2add();
+ // expected-error@-1 {{too few arguments to function call, expected 3, have 0}}
+}
+
+bool test_too_many_arg(half2 p1, half2 p2, float p3) {
+ return __builtin_hlsl_dot2add(p1, p2, p3, p1);
+ // expected-error@-1 {{too many arguments to function call, expected 3, have 4}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index ead7286f4311c..775d325feeb14 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -100,6 +100,10 @@ def int_dx_udot :
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, Commutative] >;
+def int_dx_dot2add :
+ DefaultAttrsIntrinsic<[llvm_float_ty],
+ [llvm_anyfloat_ty, LLVMMatchType<0>, llvm_float_ty],
+ [IntrNoMem, Commutative]>;
def int_dx_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
def int_dx_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index ebe1d876d58b1..193b592a525a0 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -1098,6 +1098,17 @@ def RawBufferStore : DXILOp<140, rawBufferStore> {
let stages = [Stages<DXIL1_2, [all_stages]>];
}
+def Dot2AddHalf : DXILOp<162, dot2AddHalf> {
+ let Doc = "dot product of 2 vectors of half having size = 2, returns "
+ "float";
+ let intrinsics = [IntrinSelect<int_dx_dot2add>];
+ let arguments = [FloatTy, HalfTy, HalfTy, HalfTy, HalfTy];
+ let result = FloatTy;
+ let overloads = [Overloads<DXIL1_0, []>];
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+ let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
+
def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> {
let Doc = "signed dot product of 4 x i8 vectors packed into i32, with "
"accumulate to i32";
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index dff9f3e03079e..f7ed0c5071d75 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -54,10 +54,36 @@ static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
return ExtractedElements;
}
+static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
+ IRBuilder<> &Builder,
+ unsigned NumOperands) {
+ assert(NumOperands > 0);
+ Value *Arg0 = Orig->getOperand(0);
+ [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
+ assert(VecArg0);
+ SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
+ for (unsigned I = 1; I < NumOperands; ++I) {
+ Value *Arg = Orig->getOperand(I);
+ [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
+ assert(VecArg);
+ assert(VecArg0->getElementType() == VecArg->getElementType());
+ assert(VecArg0->getNumElements() == VecArg->getNumElements());
+ auto NextOperandList = populateOperands(Arg, Builder);
+ NewOperands.append(NextOperandList.begin(), NextOperandList.end());
+ }
+ return NewOperands;
+}
+
static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
IRBuilder<> &Builder) {
// Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
- unsigned NumOperands = Orig->getNumOperands() - 1;
+ return argVectorFlatten(Orig, Builder, Orig->getNumOperands() - 1);
+}
+/*
+static SmallVector<Value *> argVectorFlattenExcludeLastElement(CallInst *Orig,
+ IRBuilder<> &Builder) {
+ // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
+ unsigned NumOperands = Orig->getNumOperands() - 2;
assert(NumOperands > 0);
Value *Arg0 = Orig->getOperand(0);
[[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
@@ -74,7 +100,7 @@ static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
}
return NewOperands;
}
-
+*/
namespace {
class OpLowerer {
Module &M;
@@ -168,6 +194,25 @@ class OpLowerer {
}
} else if (IsVectorArgExpansion) {
Args = argVectorFlatten(CI, OpBuilder.getIRB());
+ } else if (F.getIntrinsicID() == Intrinsic::dx_dot2add) {
+ // arg[NumOperands-1] is a pointer and is not needed by our flattening.
+ // arg[NumOperands-2] also does not need to be flattened because it is a scalar.
+ unsigned NumOperands = CI->getNumOperands() - 2;
+ Args.push_back(CI->getArgOperand(NumOperands));
+ Args.append(argVectorFlatten(CI, OpBuilder.getIRB(), NumOperands));
+
+ /*unsigned NumOperands = CI->getNumOperands() - 1;
+ assert(NumOperands > 0);
+ Value *LastArg = CI->getOperand(NumOperands - 1);
+
+ Args.push_back(LastArg);
+
+ //dbgs() << "Value of LastArg" << LastArg->getName() << "\n";
+
+
+ //Args = populateOperands(LastArg, OpBuilder.getIRB());
+ Args.append(argVectorFlattenExcludeLastElement(CI, OpBuilder.getIRB()));
+ */
} else {
Args.append(CI->arg_begin(), CI->arg_end());
}
diff --git a/llvm/test/CodeGen/DirectX/dot2add.ll b/llvm/test/CodeGen/DirectX/dot2add.ll
new file mode 100644
index 0000000000000..b1019c36b56e8
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/dot2add.ll
@@ -0,0 +1,8 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
+
+define noundef float @dot2add_simple(<2 x half> noundef %a, <2 x half> noundef %b, float %c) {
+entry:
+; CHECK: call float @dx.op.dot2AddHalf(i32 162, float %c, half %0, half %1, half %2, half %3)
+ %ret = call float @llvm.dx.dot2add(<2 x half> %a, <2 x half> %b, float %c)
+ ret float %ret
+}
\ No newline at end of file
|
@llvm/pr-subscribers-clang Author: Sumit Agarwal (sumitsays) ChangesResolves #99221
Full diff: https://github.com/llvm/llvm-project/pull/131237.diff 11 Files Affected:
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 72a5e495c4059..76ab463ca0ed6 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4891,6 +4891,12 @@ def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
+def HLSLDot2Add : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_dot2add"];
+ let Attributes = [NoThrow, Const, CustomTypeChecking];
+ let Prototype = "void(...)";
+}
+
def HLSLDot4AddI8Packed : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_dot4add_i8packed"];
let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index c126f88b9e3a5..b3d9db5be7d8d 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -19681,6 +19681,21 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
getDotProductIntrinsic(CGM.getHLSLRuntime(), VecTy0->getElementType()),
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");
}
+ case Builtin::BI__builtin_hlsl_dot2add: {
+ llvm::Triple::ArchType Arch = CGM.getTarget().getTriple().getArch();
+ if (Arch != llvm::Triple::dxil) {
+ llvm_unreachable("Intrinsic dot2add can be executed as a builtin only on dxil");
+ }
+ Value *A = EmitScalarExpr(E->getArg(0));
+ Value *B = EmitScalarExpr(E->getArg(1));
+ Value *C = EmitScalarExpr(E->getArg(2));
+
+ //llvm::Intrinsic::dx_##IntrinsicPostfix
+ Intrinsic::ID ID = llvm ::Intrinsic::dx_dot2add;
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
+ "hlsl.dot2add");
+ }
case Builtin::BI__builtin_hlsl_dot4add_i8packed: {
Value *A = EmitScalarExpr(E->getArg(0));
Value *B = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 5f7c047dbf340..46653d7b295b2 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -45,6 +45,14 @@ distance_vec_impl(vector<T, N> X, vector<T, N> Y) {
return length_vec_impl(X - Y);
}
+constexpr float dot2add_impl(half2 a, half2 b, float c) {
+#if defined(__DIRECTX__)
+ return __builtin_hlsl_dot2add(a, b, c);
+#else
+ return dot(a, b) + c;
+#endif
+}
+
template <typename T> constexpr T reflect_impl(T I, T N) {
return I - 2 * N * I * N;
}
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 5459cbeb34fd0..b1c1335ce3328 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -117,6 +117,18 @@ const inline float distance(__detail::HLSL_FIXED_VECTOR<float, N> X,
return __detail::distance_vec_impl(X, Y);
}
+//===----------------------------------------------------------------------===//
+// dot2add builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn float dot2add(half2 a, half2 b, float c)
+/// \brief Dot product of 2 vector of type half and add a float scalar value.
+
+_HLSL_AVAILABILITY(shadermodel, 6.4)
+const inline float dot2add(half2 a, half2 b, float c) {
+ return __detail::dot2add_impl(a, b, c);
+}
+
//===----------------------------------------------------------------------===//
// fmod builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 36de110e75e8a..399371c4ae2f6 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1989,7 +1989,7 @@ void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
}
// Helper function for CheckHLSLBuiltinFunctionCall
-static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
+static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall, unsigned NumArgs) {
assert(TheCall->getNumArgs() > 1);
ExprResult A = TheCall->getArg(0);
@@ -1999,7 +1999,7 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
SourceLocation BuiltinLoc = TheCall->getBeginLoc();
bool AllBArgAreVectors = true;
- for (unsigned i = 1; i < TheCall->getNumArgs(); ++i) {
+ for (unsigned i = 1; i < NumArgs; ++i) {
ExprResult B = TheCall->getArg(i);
QualType ArgTyB = B.get()->getType();
auto *VecTyB = ArgTyB->getAs<VectorType>();
@@ -2049,6 +2049,10 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
return false;
}
+static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
+ return CheckVectorElementCallArgs(S, TheCall, TheCall->getNumArgs());
+}
+
static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
assert(TheCall->getNumArgs() > 1);
QualType ArgTy0 = TheCall->getArg(0)->getType();
@@ -2091,10 +2095,10 @@ static bool CheckArgTypeIsCorrect(
return false;
}
-static bool CheckAllArgTypesAreCorrect(
- Sema *S, CallExpr *TheCall, QualType ExpectedType,
+static bool CheckArgTypesAreCorrect(
+ Sema *S, CallExpr *TheCall, unsigned NumArgs, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
- for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
+ for (unsigned i = 0; i < NumArgs; ++i) {
Expr *Arg = TheCall->getArg(i);
if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
return true;
@@ -2103,6 +2107,13 @@ static bool CheckAllArgTypesAreCorrect(
return false;
}
+static bool CheckAllArgTypesAreCorrect(
+ Sema *S, CallExpr *TheCall, QualType ExpectedType,
+ llvm::function_ref<bool(clang::QualType PassedType)> Check) {
+ return CheckArgTypesAreCorrect(S, TheCall, TheCall->getNumArgs(),
+ ExpectedType, Check);
+}
+
static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
return !PassedType->hasFloatingRepresentation();
@@ -2146,15 +2157,17 @@ static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
return true;
}
-static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
+static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall,
+ unsigned NumArgs, QualType ExpectedType) {
auto checkDoubleVector = [](clang::QualType PassedType) -> bool {
if (const auto *VecTy = PassedType->getAs<VectorType>())
return VecTy->getElementType()->isDoubleType();
return false;
};
- return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
- checkDoubleVector);
+ return CheckArgTypesAreCorrect(S, TheCall, NumArgs,
+ ExpectedType, checkDoubleVector);
}
+
static bool CheckFloatingOrIntRepresentation(Sema *S, CallExpr *TheCall) {
auto checkAllSignedTypes = [](clang::QualType PassedType) -> bool {
return !PassedType->hasIntegerRepresentation() &&
@@ -2468,8 +2481,36 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
if (SemaRef.BuiltinVectorToScalarMath(TheCall))
return true;
- if (CheckNoDoubleVectors(&SemaRef, TheCall))
+ if (CheckNoDoubleVectors(&SemaRef, TheCall,
+ TheCall->getNumArgs(), SemaRef.Context.FloatTy))
+ return true;
+ break;
+ }
+ case Builtin::BI__builtin_hlsl_dot2add: {
+ // Check number of arguments should be 3
+ if (SemaRef.checkArgCount(TheCall, 3))
+ return true;
+
+ // Check first two arguments are vector of length 2 with half data type
+ auto checkHalfVectorOfSize2 = [](clang::QualType PassedType) -> bool {
+ if (const auto *VecTy = PassedType->getAs<VectorType>())
+ return !(VecTy->getNumElements() == 2 &&
+ VecTy->getElementType()->isHalfType());
+ return true;
+ };
+ if(CheckArgTypeIsCorrect(&SemaRef, TheCall->getArg(0),
+ SemaRef.getASTContext().HalfTy,
+ checkHalfVectorOfSize2))
+ return true;
+ if(CheckArgTypeIsCorrect(&SemaRef, TheCall->getArg(1),
+ SemaRef.getASTContext().HalfTy,
+ checkHalfVectorOfSize2))
+ return true;
+
+ // Check third argument is a float
+ if (CheckArgTypeMatches(&SemaRef, TheCall->getArg(2), SemaRef.getASTContext().FloatTy))
return true;
+ TheCall->setType(TheCall->getArg(2)->getType());
break;
}
case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
diff --git a/clang/test/CodeGenHLSL/builtins/dot2add.hlsl b/clang/test/CodeGenHLSL/builtins/dot2add.hlsl
new file mode 100644
index 0000000000000..ce325327a01b5
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/dot2add.hlsl
@@ -0,0 +1,17 @@
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+float test(half2 p1, half2 p2, float p3) {
+ // CHECK-SPIRV: %[[MUL:.*]] = call {{.*}} float @llvm.spv.fdot.v2f32(<2 x float> %1, <2 x float> %2)
+ // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr, align 4
+ // CHECK-SPIRV: %[[RES:.*]] = fadd {{.*}} float %[[MUL]], %[[C]]
+ // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f32(<2 x float> %0, <2 x float> %1, float %2)
+ // CHECK: ret float %[[RES]]
+ return dot2add(p1, p2, p3);
+}
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/Dot2Add-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/Dot2Add-errors.hlsl
new file mode 100644
index 0000000000000..61282a319dafd
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/Dot2Add-errors.hlsl
@@ -0,0 +1,11 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+bool test_too_few_arg() {
+ return __builtin_hlsl_dot2add();
+ // expected-error@-1 {{too few arguments to function call, expected 3, have 0}}
+}
+
+bool test_too_many_arg(half2 p1, half2 p2, float p3) {
+ return __builtin_hlsl_dot2add(p1, p2, p3, p1);
+ // expected-error@-1 {{too many arguments to function call, expected 3, have 4}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index ead7286f4311c..775d325feeb14 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -100,6 +100,10 @@ def int_dx_udot :
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, Commutative] >;
+def int_dx_dot2add :
+ DefaultAttrsIntrinsic<[llvm_float_ty],
+ [llvm_anyfloat_ty, LLVMMatchType<0>, llvm_float_ty],
+ [IntrNoMem, Commutative]>;
def int_dx_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
def int_dx_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index ebe1d876d58b1..193b592a525a0 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -1098,6 +1098,17 @@ def RawBufferStore : DXILOp<140, rawBufferStore> {
let stages = [Stages<DXIL1_2, [all_stages]>];
}
+def Dot2AddHalf : DXILOp<162, dot2AddHalf> {
+ let Doc = "dot product of 2 vectors of half having size = 2, returns "
+ "float";
+ let intrinsics = [IntrinSelect<int_dx_dot2add>];
+ let arguments = [FloatTy, HalfTy, HalfTy, HalfTy, HalfTy];
+ let result = FloatTy;
+ let overloads = [Overloads<DXIL1_0, []>];
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+ let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
+
def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> {
let Doc = "signed dot product of 4 x i8 vectors packed into i32, with "
"accumulate to i32";
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index dff9f3e03079e..f7ed0c5071d75 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -54,10 +54,36 @@ static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
return ExtractedElements;
}
+static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
+ IRBuilder<> &Builder,
+ unsigned NumOperands) {
+ assert(NumOperands > 0);
+ Value *Arg0 = Orig->getOperand(0);
+ [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
+ assert(VecArg0);
+ SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
+ for (unsigned I = 1; I < NumOperands; ++I) {
+ Value *Arg = Orig->getOperand(I);
+ [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
+ assert(VecArg);
+ assert(VecArg0->getElementType() == VecArg->getElementType());
+ assert(VecArg0->getNumElements() == VecArg->getNumElements());
+ auto NextOperandList = populateOperands(Arg, Builder);
+ NewOperands.append(NextOperandList.begin(), NextOperandList.end());
+ }
+ return NewOperands;
+}
+
static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
IRBuilder<> &Builder) {
// Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
- unsigned NumOperands = Orig->getNumOperands() - 1;
+ return argVectorFlatten(Orig, Builder, Orig->getNumOperands() - 1);
+}
+/*
+static SmallVector<Value *> argVectorFlattenExcludeLastElement(CallInst *Orig,
+ IRBuilder<> &Builder) {
+ // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
+ unsigned NumOperands = Orig->getNumOperands() - 2;
assert(NumOperands > 0);
Value *Arg0 = Orig->getOperand(0);
[[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
@@ -74,7 +100,7 @@ static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
}
return NewOperands;
}
-
+*/
namespace {
class OpLowerer {
Module &M;
@@ -168,6 +194,25 @@ class OpLowerer {
}
} else if (IsVectorArgExpansion) {
Args = argVectorFlatten(CI, OpBuilder.getIRB());
+ } else if (F.getIntrinsicID() == Intrinsic::dx_dot2add) {
+ // arg[NumOperands-1] is a pointer and is not needed by our flattening.
+ // arg[NumOperands-2] also does not need to be flattened because it is a scalar.
+ unsigned NumOperands = CI->getNumOperands() - 2;
+ Args.push_back(CI->getArgOperand(NumOperands));
+ Args.append(argVectorFlatten(CI, OpBuilder.getIRB(), NumOperands));
+
+ /*unsigned NumOperands = CI->getNumOperands() - 1;
+ assert(NumOperands > 0);
+ Value *LastArg = CI->getOperand(NumOperands - 1);
+
+ Args.push_back(LastArg);
+
+ //dbgs() << "Value of LastArg" << LastArg->getName() << "\n";
+
+
+ //Args = populateOperands(LastArg, OpBuilder.getIRB());
+ Args.append(argVectorFlattenExcludeLastElement(CI, OpBuilder.getIRB()));
+ */
} else {
Args.append(CI->arg_begin(), CI->arg_end());
}
diff --git a/llvm/test/CodeGen/DirectX/dot2add.ll b/llvm/test/CodeGen/DirectX/dot2add.ll
new file mode 100644
index 0000000000000..b1019c36b56e8
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/dot2add.ll
@@ -0,0 +1,8 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
+
+define noundef float @dot2add_simple(<2 x half> noundef %a, <2 x half> noundef %b, float %c) {
+entry:
+; CHECK: call float @dx.op.dot2AddHalf(i32 162, float %c, half %0, half %1, half %2, half %3)
+ %ret = call float @llvm.dx.dot2add(<2 x half> %a, <2 x half> %b, float %c)
+ ret float %ret
+}
\ No newline at end of file
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, please address Finn’s requests before merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please run clang format
You have 19 commits head one isn’t going to help. Type git log and pick the commit id before you your 19 commits the run |
If you run that will also work |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM once formatting is resolved
ddafc82
to
6d0eb98
Compare
@sumitsays Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR. Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues. How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
Resolves #99221
Key points: For SPIRV backend, it decompose into a
dot
followed aadd
.