Skip to content

Commit

Permalink
[SYCL][clang] Emit default template arguments in integration header (#…
Browse files Browse the repository at this point in the history
…16005)

For free function kernels support clang forward declares the kernel
itself as well as its parameter types. In case a free function kernel
has a parameter that is templated and has a default template argument,
all template arguments including arguments that match default arguments
must be printed in kernel's forward declarations, for example

```
template <typename T, typename = int> struct Arg {
  T val;
};

// For the kernel
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
    (ext::oneapi::experimental::nd_range_kernel<1>)) void foo(Arg<int> arg) {
  arg.val = 42;
}

// Integration header must contain
void foo(Arg<int, int> arg);
```

Unfortunately, even though integration header emission already has
extensive support for forward declarations priting, some modifications
to clang's type printing are still required, since neither of existing
PrintingPolicy flags help to reach the correct result.
Using `SuppressDefaultTemplateArgs = true` doesn't help without printing
canonical types, printing canonical types for the case like
```
template <typename T>
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
    (ext::oneapi::experimental::nd_range_kernel<1>)) void foo(Arg<T> arg) {
  arg.val = 42;
}
// Printing canonical types is causing the following integration header
template <typename T>
void foo(Arg<type-parameter-0-0, int> arg);
```

Using `SkipCanonicalizationOfTemplateTypeParms` field of printing policy
doesn't help here since at the one point where it is checked we take
canonical type of `Arg`, not its parameters and it will contain template
argument types in canonical type after that.
  • Loading branch information
Fznamznon authored Nov 18, 2024
1 parent 01f7e44 commit 08a2edc
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 27 deletions.
20 changes: 15 additions & 5 deletions clang/include/clang/AST/PrettyPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,18 @@ struct PrintingPolicy {
SuppressStrongLifetime(false), SuppressLifetimeQualifiers(false),
SuppressTypedefs(false), SuppressFinalSpecifier(false),
SuppressTemplateArgsInCXXConstructors(false),
SuppressDefaultTemplateArgs(true), Bool(LO.Bool),
Nullptr(LO.CPlusPlus11 || LO.C23), NullptrTypeInNamespace(LO.CPlusPlus),
Restrict(LO.C99), Alignof(LO.CPlusPlus11), UnderscoreAlignof(LO.C11),
SuppressDefaultTemplateArgs(true), EnforceDefaultTemplateArgs(false),
Bool(LO.Bool), Nullptr(LO.CPlusPlus11 || LO.C23),
NullptrTypeInNamespace(LO.CPlusPlus), Restrict(LO.C99),
Alignof(LO.CPlusPlus11), UnderscoreAlignof(LO.C11),
UseVoidForZeroParams(!LO.CPlusPlus),
SplitTemplateClosers(!LO.CPlusPlus11), TerseOutput(false),
PolishForDeclaration(false), Half(LO.Half),
MSWChar(LO.MicrosoftExt && !LO.WChar), IncludeNewlines(true),
MSVCFormatting(false), ConstantsAsWritten(false),
SuppressImplicitBase(false), FullyQualifiedName(false),
SuppressDefinition(false), SuppressDefaultTemplateArguments(false),
PrintCanonicalTypes(false),
EnforceScopeForElaboratedTypes(false), SuppressDefinition(false),
SuppressDefaultTemplateArguments(false), PrintCanonicalTypes(false),
SkipCanonicalizationOfTemplateTypeParms(false),
PrintInjectedClassNameWithArguments(true), UsePreferredNames(true),
AlwaysIncludeTypeForTemplateArgument(false),
Expand Down Expand Up @@ -241,6 +242,11 @@ struct PrintingPolicy {
LLVM_PREFERRED_TYPE(bool)
unsigned SuppressDefaultTemplateArgs : 1;

/// When true, print template arguments that match the default argument for
/// the parameter, even if they're not specified in the source.
LLVM_PREFERRED_TYPE(bool)
unsigned EnforceDefaultTemplateArgs : 1;

/// Whether we can use 'bool' rather than '_Bool' (even if the language
/// doesn't actually have 'bool', because, e.g., it is defined as a macro).
LLVM_PREFERRED_TYPE(bool)
Expand Down Expand Up @@ -339,6 +345,10 @@ struct PrintingPolicy {
LLVM_PREFERRED_TYPE(bool)
unsigned FullyQualifiedName : 1;

/// Enforce fully qualified name printing for elaborated types.
LLVM_PREFERRED_TYPE(bool)
unsigned EnforceScopeForElaboratedTypes : 1;

/// When true does not print definition of a type. E.g.
/// \code
/// template<typename T> class C0 : public C1 {...}
Expand Down
60 changes: 39 additions & 21 deletions clang/lib/AST/TypePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class ElaboratedTypePolicyRAII {
SuppressTagKeyword = Policy.SuppressTagKeyword;
SuppressScope = Policy.SuppressScope;
Policy.SuppressTagKeyword = true;
Policy.SuppressScope = true;
Policy.SuppressScope = !Policy.EnforceScopeForElaboratedTypes;
}

~ElaboratedTypePolicyRAII() {
Expand Down Expand Up @@ -1728,8 +1728,10 @@ void TypePrinter::printElaboratedBefore(const ElaboratedType *T,
Policy.SuppressScope = OldSupressScope;
return;
}
if (Qualifier && !(Policy.SuppressTypedefs &&
T->getNamedType()->getTypeClass() == Type::Typedef))
if (Qualifier &&
!(Policy.SuppressTypedefs &&
T->getNamedType()->getTypeClass() == Type::Typedef) &&
!Policy.EnforceScopeForElaboratedTypes)
Qualifier->print(OS, Policy);
}

Expand Down Expand Up @@ -2220,15 +2222,6 @@ static void printArgument(const TemplateArgument &A, const PrintingPolicy &PP,
A.print(PP, OS, IncludeType);
}

static void printArgument(const TemplateArgumentLoc &A,
const PrintingPolicy &PP, llvm::raw_ostream &OS,
bool IncludeType) {
const TemplateArgument::ArgKind &Kind = A.getArgument().getKind();
if (Kind == TemplateArgument::ArgKind::Type)
return A.getTypeSourceInfo()->getType().print(OS, PP);
return A.getArgument().print(PP, OS, IncludeType);
}

static bool isSubstitutedTemplateArgument(ASTContext &Ctx, TemplateArgument Arg,
TemplateArgument Pattern,
ArrayRef<TemplateArgument> Args,
Expand Down Expand Up @@ -2399,15 +2392,40 @@ template <typename TA>
static void
printTo(raw_ostream &OS, ArrayRef<TA> Args, const PrintingPolicy &Policy,
const TemplateParameterList *TPL, bool IsPack, unsigned ParmIndex) {
// Drop trailing template arguments that match default arguments.
if (TPL && Policy.SuppressDefaultTemplateArgs &&
!Policy.PrintCanonicalTypes && !Args.empty() && !IsPack &&
llvm::SmallVector<TemplateArgument, 8> ArgsToPrint;
for (const TA &A : Args)
ArgsToPrint.push_back(getArgument(A));
if (TPL && !Policy.PrintCanonicalTypes && !IsPack &&
Args.size() <= TPL->size()) {
llvm::SmallVector<TemplateArgument, 8> OrigArgs;
for (const TA &A : Args)
OrigArgs.push_back(getArgument(A));
while (!Args.empty() && getArgument(Args.back()).getIsDefaulted())
Args = Args.drop_back();
// Drop trailing template arguments that match default arguments.
if (Policy.SuppressDefaultTemplateArgs) {
while (!ArgsToPrint.empty() &&
getArgument(ArgsToPrint.back()).getIsDefaulted())
ArgsToPrint.pop_back();
} else if (Policy.EnforceDefaultTemplateArgs) {
for (unsigned I = Args.size(); I < TPL->size(); ++I) {
auto Param = TPL->getParam(I);
if (auto *TTPD = dyn_cast<TemplateTypeParmDecl>(Param)) {
// If we met a non default-argument past provided list of arguments,
// it is either a pack which must be the last arguments, or provided
// argument list was problematic. Bail out either way. Do the same
// for each kind of template argument.
if (!TTPD->hasDefaultArgument())
break;
ArgsToPrint.push_back(getArgument(TTPD->getDefaultArgument()));
} else if (auto *TTPD = dyn_cast<TemplateTemplateParmDecl>(Param)) {
if (!TTPD->hasDefaultArgument())
break;
ArgsToPrint.push_back(getArgument(TTPD->getDefaultArgument()));
} else if (auto *NTTPD = dyn_cast<NonTypeTemplateParmDecl>(Param)) {
if (!NTTPD->hasDefaultArgument())
break;
ArgsToPrint.push_back(getArgument(NTTPD->getDefaultArgument()));
} else {
llvm_unreachable("unexpected template parameter");
}
}
}
}

const char *Comma = Policy.MSVCFormatting ? "," : ", ";
Expand All @@ -2416,7 +2434,7 @@ printTo(raw_ostream &OS, ArrayRef<TA> Args, const PrintingPolicy &Policy,

bool NeedSpace = false;
bool FirstArg = true;
for (const auto &Arg : Args) {
for (const auto &Arg : ArgsToPrint) {
// Print the argument into a string.
SmallString<128> Buf;
llvm::raw_svector_ostream ArgOS(Buf);
Expand Down
32 changes: 31 additions & 1 deletion clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6509,16 +6509,46 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
O << "extern \"C\" ";
std::string ParmList;
bool FirstParam = true;
Policy.SuppressDefaultTemplateArgs = false;
for (ParmVarDecl *Param : K.SyclKernel->parameters()) {
if (FirstParam)
FirstParam = false;
else
ParmList += ", ";
ParmList += Param->getType().getCanonicalType().getAsString();
ParmList += Param->getType().getCanonicalType().getAsString(Policy);
}
FunctionTemplateDecl *FTD = K.SyclKernel->getPrimaryTemplate();
Policy.SuppressDefinition = true;
Policy.PolishForDeclaration = true;
Policy.FullyQualifiedName = true;
Policy.EnforceScopeForElaboratedTypes = true;

// Now we need to print the declaration of the kernel itself.
// Example:
// template <typename T, typename = int> struct Arg {
// T val;
// };
// For the following free function kernel:
// template <typename = T>
// SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
// (ext::oneapi::experimental::nd_range_kernel<1>))
// void foo(Arg<int> arg) {}
// Integration header must contain the following declaration:
// template <typename>
// void foo(Arg<int, int> arg);
// SuppressDefaultTemplateArguments is a downstream addition that suppresses
// default template arguments in the function declaration. It should be set
// to true to emit function declaration that won't cause any compilation
// errors when present in the integration header.
// To print Arg<int, int> in the function declaration and shim functions we
// need to disable default arguments printing suppression via community flag
// SuppressDefaultTemplateArgs, otherwise they will be suppressed even for
// canonical types or if even written in the original source code.
Policy.SuppressDefaultTemplateArguments = true;
// EnforceDefaultTemplateArgs is a downstream addition that forces printing
// template arguments that match default template arguments while printing
// template-ids, even if the source code doesn't reference them.
Policy.EnforceDefaultTemplateArgs = true;
if (FTD) {
FTD->print(O, Policy);
} else {
Expand Down
100 changes: 100 additions & 0 deletions clang/test/CodeGenSYCL/free_function_default_template_arguments.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown -sycl-std=2020 -fsycl-int-header=%t.h %s
// RUN: FileCheck -input-file=%t.h %s

// This test checks integration header contents for free functions kernels with
// parameter types that have default template arguments.

#include "mock_properties.hpp"
#include "sycl.hpp"

namespace ns {

struct notatuple {
int a;
};

namespace ns1 {
template <typename A = notatuple>
class hasDefaultArg {

};
}

template <typename T, typename = int, int a = 12, typename = notatuple, typename ...TS> struct Arg {
T val;
};

[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel",
2)]] void
simple(Arg<char>){
}

}

[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel",
2)]] void
simple1(ns::Arg<ns::ns1::hasDefaultArg<>>){
}


template <typename T>
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] void
templated(ns::Arg<T, float, 3>, T end) {
}

template void templated(ns::Arg<int, float, 3>, int);

using namespace ns;

template <typename T>
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] void
templated2(Arg<T, notatuple>, T end) {
}

template void templated2(Arg<int, notatuple>, int);

template <typename T, int a = 3>
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] void
templated3(Arg<T, notatuple, a, ns1::hasDefaultArg<>, int, int>, T end) {
}

template void templated3(Arg<int, notatuple, 3, ns1::hasDefaultArg<>, int, int>, int);

// CHECK: Forward declarations of kernel and its argument types:
// CHECK-NEXT: namespace ns {
// CHECK-NEXT: struct notatuple;
// CHECK-NEXT: }
// CHECK-NEXT: namespace ns {
// CHECK-NEXT: template <typename T, typename, int a, typename, typename ...TS> struct Arg;
// CHECK-NEXT: }

// CHECK: void ns::simple(ns::Arg<char, int, 12, ns::notatuple>);
// CHECK-NEXT: static constexpr auto __sycl_shim1() {
// CHECK-NEXT: return (void (*)(struct ns::Arg<char, int, 12, struct ns::notatuple>))simple;
// CHECK-NEXT: }

// CHECK: Forward declarations of kernel and its argument types:
// CHECK: namespace ns {
// CHECK: namespace ns1 {
// CHECK-NEXT: template <typename A> class hasDefaultArg;
// CHECK-NEXT: }

// CHECK: void simple1(ns::Arg<ns::ns1::hasDefaultArg<ns::notatuple>, int, 12, ns::notatuple>);
// CHECK-NEXT: static constexpr auto __sycl_shim2() {
// CHECK-NEXT: return (void (*)(struct ns::Arg<class ns::ns1::hasDefaultArg<struct ns::notatuple>, int, 12, struct ns::notatuple>))simple1;
// CHECK-NEXT: }

// CHECK: template <typename T> void templated(ns::Arg<T, float, 3, ns::notatuple>, T end);
// CHECK-NEXT: static constexpr auto __sycl_shim3() {
// CHECK-NEXT: return (void (*)(struct ns::Arg<int, float, 3, struct ns::notatuple>, int))templated<int>;
// CHECK-NEXT: }

// CHECK: template <typename T> void templated2(ns::Arg<T, ns::notatuple, 12, ns::notatuple>, T end);
// CHECK-NEXT: static constexpr auto __sycl_shim4() {
// CHECK-NEXT: return (void (*)(struct ns::Arg<int, struct ns::notatuple, 12, struct ns::notatuple>, int))templated2<int>;
// CHECK-NEXT: }

// CHECK: template <typename T, int a> void templated3(ns::Arg<T, ns::notatuple, a, ns::ns1::hasDefaultArg<ns::notatuple>, int, int>, T end);
// CHECK-NEXT: static constexpr auto __sycl_shim5() {
// CHECK-NEXT: return (void (*)(struct ns::Arg<int, struct ns::notatuple, 3, class ns::ns1::hasDefaultArg<struct ns::notatuple>, int, int>, int))templated3<int, 3>;
// CHECK-NEXT: }

0 comments on commit 08a2edc

Please sign in to comment.