Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement specialization constant support in numthreads / local_size #5963

Merged
merged 15 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions source/slang/slang-ast-modifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -973,9 +973,14 @@ class GLSLLayoutLocalSizeAttribute : public Attribute
//
// TODO: These should be accessors that use the
// ordinary `args` list, rather than side data.
IntVal* x;
IntVal* y;
IntVal* z;
IntVal* extents[3];

bool axisIsSpecConstId[3];

// References to specialization constants, for defining the number of
// threads with them. If set, the corresponding axis is set to nullptr
// above.
DeclRef<VarDeclBase> specConstExtents[3];
};

class GLSLLayoutDerivativeGroupQuadAttribute : public Attribute
Expand Down Expand Up @@ -1038,9 +1043,12 @@ class NumThreadsAttribute : public Attribute
//
// TODO: These should be accessors that use the
// ordinary `args` list, rather than side data.
IntVal* x;
IntVal* y;
IntVal* z;
IntVal* extents[3];

// References to specialization constants, for defining the number of
// threads with them. If set, the corresponding axis is set to nullptr
// above.
DeclRef<VarDeclBase> specConstExtents[3];
juliusikkala marked this conversation as resolved.
Show resolved Hide resolved
};

class WaveSizeAttribute : public Attribute
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-check-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1656,6 +1656,8 @@ struct SemanticsVisitor : public SemanticsContext

void visitModifier(Modifier*);

DeclRef<VarDeclBase> tryGetIntSpecializationConstant(Expr* expr);

AttributeDecl* lookUpAttributeDecl(Name* attributeName, Scope* scope);

bool hasIntArgs(Attribute* attr, int numArgs);
Expand Down
108 changes: 91 additions & 17 deletions source/slang/slang-check-modifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,36 @@ void SemanticsVisitor::visitModifier(Modifier*)
// Do nothing with modifiers for now
}

DeclRef<VarDeclBase> SemanticsVisitor::tryGetIntSpecializationConstant(Expr* expr)
{
// First type-check the expression as normal
expr = CheckExpr(expr);

if (IsErrorExpr(expr))
return DeclRef<VarDeclBase>();

if (!isScalarIntegerType(expr->type))
return DeclRef<VarDeclBase>();

auto specConstVar = as<VarExpr>(expr);
if (!specConstVar || !specConstVar->declRef)
return DeclRef<VarDeclBase>();

auto decl = specConstVar->declRef.getDecl();
if (!decl)
return DeclRef<VarDeclBase>();

for (auto modifier : decl->modifiers)
{
if (as<SpecializationConstantAttribute>(modifier) || as<VkConstantIdAttribute>(modifier))
{
return specConstVar->declRef.as<VarDeclBase>();
}
}

return DeclRef<VarDeclBase>();
}

static bool _isDeclAllowedAsAttribute(DeclRef<Decl> declRef)
{
if (as<AttributeDecl>(declRef.getDecl()))
Expand Down Expand Up @@ -350,15 +380,21 @@ Modifier* SemanticsVisitor::validateAttribute(
{
SLANG_ASSERT(attr->args.getCount() == 3);

IntVal* values[3];

for (int i = 0; i < 3; ++i)
{
IntVal* value = nullptr;

auto arg = attr->args[i];
if (arg)
{
auto specConstDecl = tryGetIntSpecializationConstant(arg);
if (specConstDecl)
{
numThreadsAttr->extents[i] = nullptr;
numThreadsAttr->specConstExtents[i] = specConstDecl;
continue;
}

auto intValue = checkLinkTimeConstantIntVal(arg);
if (!intValue)
{
Expand Down Expand Up @@ -390,12 +426,8 @@ Modifier* SemanticsVisitor::validateAttribute(
{
value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1);
}
values[i] = value;
numThreadsAttr->extents[i] = value;
}

numThreadsAttr->x = values[0];
numThreadsAttr->y = values[1];
numThreadsAttr->z = values[2];
}
else if (auto waveSizeAttr = as<WaveSizeAttribute>(attr))
{
Expand Down Expand Up @@ -1831,23 +1863,70 @@ Modifier* SemanticsVisitor::checkModifier(
{
SLANG_ASSERT(attr->args.getCount() == 3);

IntVal* values[3];
// GLSLLayoutLocalSizeAttribute is always attached to an EmptyDecl.
auto decl = as<EmptyDecl>(syntaxNode);
SLANG_ASSERT(decl);

for (int i = 0; i < 3; ++i)
{
IntVal* value = nullptr;
attr->extents[i] = nullptr;

auto arg = attr->args[i];
if (arg)
{
auto specConstDecl = tryGetIntSpecializationConstant(arg);
if (specConstDecl)
{
attr->specConstExtents[i] = specConstDecl;
continue;
}

auto intValue = checkConstantIntVal(arg);
if (!intValue)
{
return nullptr;
}
if (auto cintVal = as<ConstantIntVal>(intValue))
{
if (cintVal->getValue() < 1)
if (attr->axisIsSpecConstId[i])
{
// This integer should actually be a reference to a
// specialization constant with this ID.
Int specConstId = cintVal->getValue();

for (auto member : decl->parentDecl->members)
{
auto constantId = member->findModifier<VkConstantIdAttribute>();
if (constantId)
{
SLANG_ASSERT(constantId->args.getCount() == 1);
auto id = checkConstantIntVal(constantId->args[0]);
if (id->getValue() == specConstId)
{
attr->specConstExtents[i] =
DeclRef<VarDeclBase>(member->getDefaultDeclRef());
break;
}
}
}

// If not found, we need to create a new specialization
// constant with this ID.
if (!attr->specConstExtents[i])
{
auto specConstVarDecl = getASTBuilder()->create<VarDecl>();
auto constantIdModifier =
getASTBuilder()->create<VkConstantIdAttribute>();
constantIdModifier->location = (int32_t)specConstId;
specConstVarDecl->type.type = getASTBuilder()->getIntType();
addModifier(specConstVarDecl, constantIdModifier);
decl->parentDecl->addMember(specConstVarDecl);
attr->specConstExtents[i] =
DeclRef<VarDeclBase>(specConstVarDecl->getDefaultDeclRef());
}
continue;
}
else if (cintVal->getValue() < 1)
{
getSink()->diagnose(
attr,
Expand All @@ -1856,18 +1935,13 @@ Modifier* SemanticsVisitor::checkModifier(
return nullptr;
}
}
value = intValue;
attr->extents[i] = intValue;
}
else
{
value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1);
attr->extents[i] = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1);
}
values[i] = value;
}

attr->x = values[0];
attr->y = values[1];
attr->z = values[2];
}

// Default behavior is to leave things as they are,
Expand Down
6 changes: 6 additions & 0 deletions source/slang/slang-diagnostic-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -2459,6 +2459,12 @@ DIAGNOSTIC(
Error,
unsupportedTargetIntrinsic,
"intrinsic operation '$0' is not supported for the current target.")
DIAGNOSTIC(
55205,
Error,
unsupportedSpecializationConstantForNumThreads,
"Specialization constants are not supported in the 'numthreads' attribute for the current "
"target.")
DIAGNOSTIC(
56001,
Error,
Expand Down
40 changes: 37 additions & 3 deletions source/slang/slang-emit-c-like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,14 +295,48 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type)
}


/* static */ IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize(
IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize(
IRFunc* func,
Int outNumThreads[kThreadGroupAxisCount])
{
Int specializationConstantIds[kThreadGroupAxisCount];
IRNumThreadsDecoration* decor =
getComputeThreadGroupSize(func, outNumThreads, specializationConstantIds);

for (auto id : specializationConstantIds)
{
if (id >= 0)
{
getSink()->diagnose(decor, Diagnostics::unsupportedSpecializationConstantForNumThreads);
break;
}
}
return decor;
}

/* static */ IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize(
IRFunc* func,
Int outNumThreads[kThreadGroupAxisCount],
Int outSpecializationConstantIds[kThreadGroupAxisCount])
{
IRNumThreadsDecoration* decor = func->findDecoration<IRNumThreadsDecoration>();
for (int i = 0; i < 3; ++i)
for (int i = 0; i < kThreadGroupAxisCount; ++i)
{
outNumThreads[i] = decor ? Int(getIntVal(decor->getOperand(i))) : 1;
if (!decor)
{
outNumThreads[i] = 1;
outSpecializationConstantIds[i] = -1;
}
else if (auto specConst = as<IRGlobalParam>(decor->getOperand(i)))
{
outNumThreads[i] = 1;
outSpecializationConstantIds[i] = getSpecializationConstantId(specConst);
}
else
{
outNumThreads[i] = Int(getIntVal(decor->getOperand(i)));
outSpecializationConstantIds[i] = -1;
}
}
return decor;
}
Expand Down
13 changes: 11 additions & 2 deletions source/slang/slang-emit-c-like.h
Original file line number Diff line number Diff line change
Expand Up @@ -500,11 +500,20 @@ class CLikeSourceEmitter : public SourceEmitterBase
/// different. Returns an empty slice if not a built in type
static UnownedStringSlice getDefaultBuiltinTypeName(IROp op);

/// Finds the IRNumThreadsDecoration and gets the size from that or sets all dimensions to 1
static IRNumThreadsDecoration* getComputeThreadGroupSize(
/// Finds the IRNumThreadsDecoration and gets the size from that or sets all
/// dimensions to 1
IRNumThreadsDecoration* getComputeThreadGroupSize(
IRFunc* func,
Int outNumThreads[kThreadGroupAxisCount]);

/// Finds the IRNumThreadsDecoration and gets the size from that or sets all
/// dimensions to 1. If specialization constants are used for an axis, their
/// IDs is reported in non-negative entries of outSpecializationConstantIds.
static IRNumThreadsDecoration* getComputeThreadGroupSize(
IRFunc* func,
Int outNumThreads[kThreadGroupAxisCount],
Int outSpecializationConstantIds[kThreadGroupAxisCount]);

/// Finds the IRWaveSizeDecoration and gets the size from that.
static IRWaveSizeDecoration* getComputeWaveSize(IRFunc* func, Int* outWaveSize);

Expand Down
16 changes: 13 additions & 3 deletions source/slang/slang-emit-glsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1335,7 +1335,8 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl(
auto emitLocalSizeLayout = [&]()
{
Int sizeAlongAxis[kThreadGroupAxisCount];
getComputeThreadGroupSize(irFunc, sizeAlongAxis);
Int specializationConstantIds[kThreadGroupAxisCount];
getComputeThreadGroupSize(irFunc, sizeAlongAxis, specializationConstantIds);

m_writer->emit("layout(");
char const* axes[] = {"x", "y", "z"};
Expand All @@ -1345,8 +1346,17 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl(
m_writer->emit(", ");
m_writer->emit("local_size_");
m_writer->emit(axes[ii]);
m_writer->emit(" = ");
m_writer->emit(sizeAlongAxis[ii]);

if (specializationConstantIds[ii] >= 0)
{
m_writer->emit("_id = ");
m_writer->emit(specializationConstantIds[ii]);
}
else
{
m_writer->emit(" = ");
m_writer->emit(sizeAlongAxis[ii]);
}
}
m_writer->emit(") in;\n");
};
Expand Down
Loading