From 7123400c3a974cd8a6eab2b2f38b763811214b03 Mon Sep 17 00:00:00 2001 From: Julius Ikkala Date: Mon, 30 Dec 2024 20:31:45 +0200 Subject: [PATCH 01/12] Allow using specialization constants in numthreads attribute --- source/slang/slang-ast-modifier.h | 7 ++++ source/slang/slang-check-impl.h | 4 +++ source/slang/slang-check-modifier.cpp | 52 ++++++++++++++++++++++++++- source/slang/slang-emit-c-like.cpp | 28 ++++++++++++++- source/slang/slang-emit-c-like.h | 11 +++++- source/slang/slang-emit-glsl.cpp | 16 +++++++-- source/slang/slang-emit-spirv.cpp | 43 +++++++++++++--------- source/slang/slang-ir-insts.h | 14 +++++--- source/slang/slang-ir-util.cpp | 13 +++++++ source/slang/slang-ir-util.h | 2 ++ source/slang/slang-lower-to-ir.cpp | 32 +++++++++++++++-- 11 files changed, 193 insertions(+), 29 deletions(-) diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index f5dd86df15..d68501b54c 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1041,6 +1041,13 @@ class NumThreadsAttribute : public Attribute IntVal* x; IntVal* y; IntVal* z; + + // References to specialization constants, for defining the number of + // threads with them. If set, the corresponding axis is set to nullptr + // above. + DeclRef xSpecConst; + DeclRef ySpecConst; + DeclRef zSpecConst; }; class WaveSizeAttribute : public Attribute diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 300596caa9..94a459b3db 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1651,6 +1651,10 @@ struct SemanticsVisitor : public SemanticsContext // ensure that it has a literal (not just compile-time constant) value. bool checkLiteralStringVal(Expr* expr, String* outVal); + // Check that an expression is a specialization constant integer and return + // the declaration. + DeclRef checkSpecializationConstantInt(Expr* expr); + bool checkCapabilityName(Expr* expr, CapabilityName& outCapabilityName); void visitModifier(Modifier*); diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 3723c98f86..221ac69749 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -85,6 +85,38 @@ bool SemanticsVisitor::checkLiteralStringVal(Expr* expr, String* outVal) return false; } +DeclRef SemanticsVisitor::checkSpecializationConstantInt(Expr* expr) +{ + // First type-check the expression as normal + expr = CheckExpr(expr); + + if (IsErrorExpr(expr)) + return DeclRef(); + + if (!isScalarIntegerType(expr->type)) + return DeclRef(); + + auto specConstVar = as(expr); + if (!specConstVar || !specConstVar->declRef) + return DeclRef(); + + auto decl = specConstVar->declRef.getDecl(); + if (!decl) + return DeclRef(); + + for (auto modifier : decl->modifiers) + { + if (as(modifier) || as(modifier)) + { + return specConstVar->declRef.as(); + } + } + + // TODO: Diagnostics should report that an integer specialization constant + // was expected. + return DeclRef(); +} + bool SemanticsVisitor::checkCapabilityName(Expr* expr, CapabilityName& outCapabilityName) { if (auto varExpr = as(expr)) @@ -350,7 +382,8 @@ Modifier* SemanticsVisitor::validateAttribute( { SLANG_ASSERT(attr->args.getCount() == 3); - IntVal* values[3]; + IntVal* values[3] = {}; + DeclRef specIds[3] = {}; for (int i = 0; i < 3; ++i) { @@ -359,6 +392,19 @@ Modifier* SemanticsVisitor::validateAttribute( auto arg = attr->args[i]; if (arg) { + auto specConstVar = as(arg); + if (specConstVar) + { + auto specConstDecl = checkSpecializationConstantInt(arg); + if (specConstDecl) + { + specIds[i] = specConstDecl; + continue; + } + else + return nullptr; + } + auto intValue = checkLinkTimeConstantIntVal(arg); if (!intValue) { @@ -396,6 +442,10 @@ Modifier* SemanticsVisitor::validateAttribute( numThreadsAttr->x = values[0]; numThreadsAttr->y = values[1]; numThreadsAttr->z = values[2]; + + numThreadsAttr->xSpecConst = specIds[0]; + numThreadsAttr->ySpecConst = specIds[1]; + numThreadsAttr->zSpecConst = specIds[2]; } else if (auto waveSizeAttr = as(attr)) { diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 3175f1b073..fad63144b1 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -298,11 +298,37 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type) /* static */ IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize( IRFunc* func, Int outNumThreads[kThreadGroupAxisCount]) +{ + // TODO: Warn user that the selected emitter doesn't support setting work + // group sizes with specialization constants (yet). They're currently just + // ignored and '1' is returned in their place. + Int specializationConstantIds[kThreadGroupAxisCount]; + return getComputeThreadGroupSize(func, outNumThreads, specializationConstantIds); +} + +/* static */ IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize( + IRFunc* func, + Int outNumThreads[kThreadGroupAxisCount], + Int outSpecializationConstantIds[kThreadGroupAxisCount]) { IRNumThreadsDecoration* decor = func->findDecoration(); for (int i = 0; i < 3; ++i) { - outNumThreads[i] = decor ? Int(getIntVal(decor->getOperand(i))) : 1; + if (!decor) + { + outNumThreads[i] = 1; + outSpecializationConstantIds[i] = -1; + } + else if (auto specConst = as(decor->getOperand(i))) + { + outNumThreads[i] = 1; + outSpecializationConstantIds[i] = getSpecializationConstantId(specConst); + } + else + { + outNumThreads[i] = Int(getIntVal(decor->getOperand(i))); + outSpecializationConstantIds[i] = -1; + } } return decor; } diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index dd8e276740..3923e21f0f 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -500,11 +500,20 @@ class CLikeSourceEmitter : public SourceEmitterBase /// different. Returns an empty slice if not a built in type static UnownedStringSlice getDefaultBuiltinTypeName(IROp op); - /// Finds the IRNumThreadsDecoration and gets the size from that or sets all dimensions to 1 + /// Finds the IRNumThreadsDecoration and gets the size from that or sets all + /// dimensions to 1 static IRNumThreadsDecoration* getComputeThreadGroupSize( IRFunc* func, Int outNumThreads[kThreadGroupAxisCount]); + /// Finds the IRNumThreadsDecoration and gets the size from that or sets all + /// dimensions to 1. If specialization constants are used for an axis, their + /// IDs is reported in non-negative entries of outSpecializationConstantIds. + static IRNumThreadsDecoration* getComputeThreadGroupSize( + IRFunc* func, + Int outNumThreads[kThreadGroupAxisCount], + Int outSpecializationConstantIds[kThreadGroupAxisCount]); + /// Finds the IRWaveSizeDecoration and gets the size from that. static IRWaveSizeDecoration* getComputeWaveSize(IRFunc* func, Int* outWaveSize); diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index a863e7eb1a..7ba92c43f7 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -1335,7 +1335,8 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl( auto emitLocalSizeLayout = [&]() { Int sizeAlongAxis[kThreadGroupAxisCount]; - getComputeThreadGroupSize(irFunc, sizeAlongAxis); + Int specializationConstantIds[kThreadGroupAxisCount]; + getComputeThreadGroupSize(irFunc, sizeAlongAxis, specializationConstantIds); m_writer->emit("layout("); char const* axes[] = {"x", "y", "z"}; @@ -1343,10 +1344,19 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl( { if (ii != 0) m_writer->emit(", "); + m_writer->emit("local_size_"); m_writer->emit(axes[ii]); - m_writer->emit(" = "); - m_writer->emit(sizeAlongAxis[ii]); + if (specializationConstantIds[ii] >= 0) + { + m_writer->emit("_id = "); + m_writer->emit(specializationConstantIds[ii]); + } + else + { + m_writer->emit(" = "); + m_writer->emit(sizeAlongAxis[ii]); + } } m_writer->emit(") in;\n"); }; diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 676a16228c..c7930de835 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -4332,23 +4332,34 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex // [3.6. Execution Mode]: LocalSize case kIROp_NumThreadsDecoration: { - // TODO: The `LocalSize` execution mode option requires - // literal values for the X,Y,Z thread-group sizes. - // There is a `LocalSizeId` variant that takes ``s - // for those sizes, and we should consider using that - // and requiring the appropriate capabilities - // if any of the operands to the decoration are not - // literals (in a future where we support non-literals - // in those positions in the Slang IR). - // auto numThreads = cast(decoration); - requireSPIRVExecutionMode( - decoration, - dstID, - SpvExecutionModeLocalSize, - SpvLiteralInteger::from32(int32_t(numThreads->getX()->getValue())), - SpvLiteralInteger::from32(int32_t(numThreads->getY()->getValue())), - SpvLiteralInteger::from32(int32_t(numThreads->getZ()->getValue()))); + if (numThreads->getXSpecConst() || numThreads->getYSpecConst() || + numThreads->getZSpecConst()) + { + // If any of the dimensions needs an ID, we need to emit + // all dimensions as an ID due to how LocalSizeId works. + int32_t ids[3]; + for (int i = 0; i < 3; ++i) + ids[i] = ensureInst(numThreads->getOperand(i))->id; + + requireSPIRVExecutionMode( + decoration, + dstID, + SpvExecutionModeLocalSizeId, + SpvLiteralInteger::from32(int32_t(ids[0])), + SpvLiteralInteger::from32(int32_t(ids[1])), + SpvLiteralInteger::from32(int32_t(ids[2]))); + } + else + { + requireSPIRVExecutionMode( + decoration, + dstID, + SpvExecutionModeLocalSize, + SpvLiteralInteger::from32(int32_t(numThreads->getX()->getValue())), + SpvLiteralInteger::from32(int32_t(numThreads->getY()->getValue())), + SpvLiteralInteger::from32(int32_t(numThreads->getZ()->getValue()))); + } } break; case kIROp_MaxVertexCountDecoration: diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 53adce87a8..f89b9b3f52 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -570,6 +570,7 @@ struct IRInstanceDecoration : IRDecoration IRIntLit* getCount() { return cast(getOperand(0)); } }; +struct IRGlobalParam; struct IRNumThreadsDecoration : IRDecoration { enum @@ -578,11 +579,16 @@ struct IRNumThreadsDecoration : IRDecoration }; IR_LEAF_ISA(NumThreadsDecoration) - IRIntLit* getX() { return cast(getOperand(0)); } - IRIntLit* getY() { return cast(getOperand(1)); } - IRIntLit* getZ() { return cast(getOperand(2)); } + IRIntLit* getX() { return as(getOperand(0)); } + IRIntLit* getY() { return as(getOperand(1)); } + IRIntLit* getZ() { return as(getOperand(2)); } - IRIntLit* getExtentAlongAxis(int axis) { return cast(getOperand(axis)); } + IRGlobalParam* getXSpecConst() { return as(getOperand(0)); } + IRGlobalParam* getYSpecConst() { return as(getOperand(1)); } + IRGlobalParam* getZSpecConst() { return as(getOperand(2)); } + + IRIntLit* getExtentAlongAxis(int axis) { return as(getOperand(axis)); } + IRGlobalParam* getSpecConstAlongAxis(int axis) { return as(getOperand(axis)); } }; struct IRWaveSizeDecoration : IRDecoration diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 7788a50d5d..f5129f4541 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -1948,4 +1948,17 @@ IRType* getIRVectorBaseType(IRType* type) return as(type)->getElementType(); } +Int getSpecializationConstantId(IRGlobalParam* param) +{ + auto layout = findVarLayout(param); + if (!layout) + return 0; + + auto offset = layout->findOffsetAttr(LayoutResourceKind::SpecializationConstant); + if (!offset) + return 0; + + return offset->getOffset(); +} + } // namespace Slang diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 9a712ba961..c97300b6fe 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -371,6 +371,8 @@ inline bool isSPIRV(CodeGenTarget codeGenTarget) int getIRVectorElementSize(IRType* type); IRType* getIRVectorBaseType(IRType* type); +Int getSpecializationConstantId(IRGlobalParam* param); + } // namespace Slang #endif diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 5bbe44e9ba..819bac175b 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -10227,11 +10227,37 @@ struct DeclLoweringVisitor : DeclVisitor } else if (auto numThreadsAttr = as(modifier)) { + LoweredValInfo x, y, z; + x = numThreadsAttr->xSpecConst + ? emitDeclRef( + context, + numThreadsAttr->xSpecConst, + lowerType( + context, + getType(context->astBuilder, numThreadsAttr->xSpecConst))) + : lowerVal(context, numThreadsAttr->x); + y = numThreadsAttr->ySpecConst + ? emitDeclRef( + context, + numThreadsAttr->ySpecConst, + lowerType( + context, + getType(context->astBuilder, numThreadsAttr->ySpecConst))) + : lowerVal(context, numThreadsAttr->y); + z = numThreadsAttr->zSpecConst + ? emitDeclRef( + context, + numThreadsAttr->zSpecConst, + lowerType( + context, + getType(context->astBuilder, numThreadsAttr->zSpecConst))) + : lowerVal(context, numThreadsAttr->z); + numThreadsDecor = as(getBuilder()->addNumThreadsDecoration( irFunc, - getSimpleVal(context, lowerVal(context, numThreadsAttr->x)), - getSimpleVal(context, lowerVal(context, numThreadsAttr->y)), - getSimpleVal(context, lowerVal(context, numThreadsAttr->z)))); + getSimpleVal(context, x), + getSimpleVal(context, y), + getSimpleVal(context, z))); } else if (auto waveSizeAttr = as(modifier)) { From 56a0b1f4584654382f9e2dd369ea68a423d0b60c Mon Sep 17 00:00:00 2001 From: Julius Ikkala Date: Tue, 31 Dec 2024 15:45:54 +0200 Subject: [PATCH 02/12] Add support for GLSL local_size_x_id syntax --- source/slang/slang-ast-modifier.h | 9 ++++ source/slang/slang-check-modifier.cpp | 61 ++++++++++++++++++++++++++- source/slang/slang-lower-to-ir.cpp | 32 ++++++++++++-- source/slang/slang-parser.cpp | 10 ++++- 4 files changed, 106 insertions(+), 6 deletions(-) diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index d68501b54c..5ea0579bd2 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -976,6 +976,15 @@ class GLSLLayoutLocalSizeAttribute : public Attribute IntVal* x; IntVal* y; IntVal* z; + + bool axisIsSpecConstId[3]; + + // References to specialization constants, for defining the number of + // threads with them. If set, the corresponding axis is set to nullptr + // above. + DeclRef xSpecConst; + DeclRef ySpecConst; + DeclRef zSpecConst; }; class GLSLLayoutDerivativeGroupQuadAttribute : public Attribute diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 221ac69749..077aeb5c63 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -1881,7 +1881,12 @@ Modifier* SemanticsVisitor::checkModifier( { SLANG_ASSERT(attr->args.getCount() == 3); - IntVal* values[3]; + IntVal* values[3] = {}; + DeclRef specIds[3] = {}; + + // GLSLLayoutLocalSizeAttribute is always attached to an EmptyDecl. + auto decl = as(syntaxNode); + SLANG_ASSERT(decl); for (int i = 0; i < 3; ++i) { @@ -1890,6 +1895,19 @@ Modifier* SemanticsVisitor::checkModifier( auto arg = attr->args[i]; if (arg) { + auto specConstVar = as(arg); + if (specConstVar) + { + auto specConstDecl = checkSpecializationConstantInt(arg); + if (specConstDecl) + { + specIds[i] = specConstDecl; + continue; + } + else + return nullptr; + } + auto intValue = checkConstantIntVal(arg); if (!intValue) { @@ -1897,7 +1915,42 @@ Modifier* SemanticsVisitor::checkModifier( } if (auto cintVal = as(intValue)) { - if (cintVal->getValue() < 1) + if (attr->axisIsSpecConstId[i]) + { + // This integer should actually be a reference to a + // specialization constant with this ID. + Int specConstId = cintVal->getValue(); + + for(auto member: decl->parentDecl->members) + { + auto constantId = member->findModifier(); + if (constantId) + { + SLANG_ASSERT(attr->args.getCount() == 1); + auto id = checkConstantIntVal(attr->args[0]); + if (id->getValue() == specConstId) + { + specIds[i] = DeclRef(member->getDefaultDeclRef()); + break; + } + } + } + + // If not found, we need to create a new specialization + // constant with this ID. + if (!specIds[i]) + { + auto specConstVarDecl = getASTBuilder()->create(); + auto constantIdModifier = getASTBuilder()->create(); + constantIdModifier->location = specConstId; + specConstVarDecl->type.type = getASTBuilder()->getIntType(); + addModifier(specConstVarDecl, constantIdModifier); + decl->parentDecl->addMember(specConstVarDecl); + specIds[i] = DeclRef(specConstVarDecl->getDefaultDeclRef()); + } + continue; + } + else if (cintVal->getValue() < 1) { getSink()->diagnose( attr, @@ -1918,6 +1971,10 @@ Modifier* SemanticsVisitor::checkModifier( attr->x = values[0]; attr->y = values[1]; attr->z = values[2]; + + attr->xSpecConst = specIds[0]; + attr->ySpecConst = specIds[1]; + attr->zSpecConst = specIds[2]; } // Default behavior is to leave things as they are, diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 819bac175b..a9bbd81a51 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7576,12 +7576,38 @@ struct DeclLoweringVisitor : DeclVisitor { verifyComputeDerivativeGroupModifier = true; getAllEntryPointsNoOverride(entryPoints); + LoweredValInfo x, y, z; + x = layoutLocalSizeAttr->xSpecConst + ? emitDeclRef( + context, + layoutLocalSizeAttr->xSpecConst, + lowerType( + context, + getType(context->astBuilder, layoutLocalSizeAttr->xSpecConst))) + : lowerVal(context, layoutLocalSizeAttr->x); + y = layoutLocalSizeAttr->ySpecConst + ? emitDeclRef( + context, + layoutLocalSizeAttr->ySpecConst, + lowerType( + context, + getType(context->astBuilder, layoutLocalSizeAttr->ySpecConst))) + : lowerVal(context, layoutLocalSizeAttr->y); + z = layoutLocalSizeAttr->zSpecConst + ? emitDeclRef( + context, + layoutLocalSizeAttr->zSpecConst, + lowerType( + context, + getType(context->astBuilder, layoutLocalSizeAttr->zSpecConst))) + : lowerVal(context, layoutLocalSizeAttr->z); + for (auto d : entryPoints) as(getBuilder()->addNumThreadsDecoration( d, - getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->x)), - getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->y)), - getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->z)))); + getSimpleVal(context, x), + getSimpleVal(context, y), + getSimpleVal(context, z))); } else if (as(modifier)) { diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 22491c848b..c0db4aca1b 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -8442,7 +8442,8 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) int localSizeIndex = -1; if (nameText.startsWith(localSizePrefix) && - nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 1) + (nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 1 || + (nameText.endsWith("_id") && (nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 4)))) { char lastChar = nameText[SLANG_COUNT_OF(localSizePrefix) - 1]; localSizeIndex = (lastChar >= 'x' && lastChar <= 'z') ? (lastChar - 'x') : -1; @@ -8456,6 +8457,8 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) numThreadsAttrib->args.setCount(3); for (auto& i : numThreadsAttrib->args) i = nullptr; + for (auto& b : numThreadsAttrib->axisIsSpecConstId) + b = false; // Just mark the loc and name from the first in the list numThreadsAttrib->keywordName = getName(parser, "numthreads"); @@ -8472,6 +8475,11 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) } numThreadsAttrib->args[localSizeIndex] = expr; + + // We can't resolve the specialization constant declaration + // here, because it may not even exist. IDs pointing to unnamed + // specialization constants are allowed in GLSL. + numThreadsAttrib->axisIsSpecConstId[localSizeIndex] = nameText.endsWith("_id"); } } else if (nameText == "derivative_group_quadsNV") From dcb67e477bfb041123b8a1528e7bd458f679743c Mon Sep 17 00:00:00 2001 From: Julius Ikkala Date: Tue, 31 Dec 2024 16:06:12 +0200 Subject: [PATCH 03/12] Fix overeager specialization constant parsing --- source/slang/slang-check-impl.h | 6 +- source/slang/slang-check-modifier.cpp | 90 +++++++++++---------------- source/slang/slang-emit-spirv.cpp | 2 + 3 files changed, 42 insertions(+), 56 deletions(-) diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 94a459b3db..9205119458 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1651,14 +1651,12 @@ struct SemanticsVisitor : public SemanticsContext // ensure that it has a literal (not just compile-time constant) value. bool checkLiteralStringVal(Expr* expr, String* outVal); - // Check that an expression is a specialization constant integer and return - // the declaration. - DeclRef checkSpecializationConstantInt(Expr* expr); - bool checkCapabilityName(Expr* expr, CapabilityName& outCapabilityName); void visitModifier(Modifier*); + DeclRef tryGetIntSpecializationConstant(Expr* expr); + AttributeDecl* lookUpAttributeDecl(Name* attributeName, Scope* scope); bool hasIntArgs(Attribute* attr, int numArgs); diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 077aeb5c63..f5ec6a9d8f 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -85,38 +85,6 @@ bool SemanticsVisitor::checkLiteralStringVal(Expr* expr, String* outVal) return false; } -DeclRef SemanticsVisitor::checkSpecializationConstantInt(Expr* expr) -{ - // First type-check the expression as normal - expr = CheckExpr(expr); - - if (IsErrorExpr(expr)) - return DeclRef(); - - if (!isScalarIntegerType(expr->type)) - return DeclRef(); - - auto specConstVar = as(expr); - if (!specConstVar || !specConstVar->declRef) - return DeclRef(); - - auto decl = specConstVar->declRef.getDecl(); - if (!decl) - return DeclRef(); - - for (auto modifier : decl->modifiers) - { - if (as(modifier) || as(modifier)) - { - return specConstVar->declRef.as(); - } - } - - // TODO: Diagnostics should report that an integer specialization constant - // was expected. - return DeclRef(); -} - bool SemanticsVisitor::checkCapabilityName(Expr* expr, CapabilityName& outCapabilityName) { if (auto varExpr = as(expr)) @@ -146,6 +114,36 @@ void SemanticsVisitor::visitModifier(Modifier*) // Do nothing with modifiers for now } +DeclRef SemanticsVisitor::tryGetIntSpecializationConstant(Expr* expr) +{ + // First type-check the expression as normal + expr = CheckExpr(expr); + + if (IsErrorExpr(expr)) + return DeclRef(); + + if (!isScalarIntegerType(expr->type)) + return DeclRef(); + + auto specConstVar = as(expr); + if (!specConstVar || !specConstVar->declRef) + return DeclRef(); + + auto decl = specConstVar->declRef.getDecl(); + if (!decl) + return DeclRef(); + + for (auto modifier : decl->modifiers) + { + if (as(modifier) || as(modifier)) + { + return specConstVar->declRef.as(); + } + } + + return DeclRef(); +} + static bool _isDeclAllowedAsAttribute(DeclRef declRef) { if (as(declRef.getDecl())) @@ -392,17 +390,11 @@ Modifier* SemanticsVisitor::validateAttribute( auto arg = attr->args[i]; if (arg) { - auto specConstVar = as(arg); - if (specConstVar) + auto specConstDecl = tryGetIntSpecializationConstant(arg); + if (specConstDecl) { - auto specConstDecl = checkSpecializationConstantInt(arg); - if (specConstDecl) - { - specIds[i] = specConstDecl; - continue; - } - else - return nullptr; + specIds[i] = specConstDecl; + continue; } auto intValue = checkLinkTimeConstantIntVal(arg); @@ -1895,17 +1887,11 @@ Modifier* SemanticsVisitor::checkModifier( auto arg = attr->args[i]; if (arg) { - auto specConstVar = as(arg); - if (specConstVar) + auto specConstDecl = tryGetIntSpecializationConstant(arg); + if (specConstDecl) { - auto specConstDecl = checkSpecializationConstantInt(arg); - if (specConstDecl) - { - specIds[i] = specConstDecl; - continue; - } - else - return nullptr; + specIds[i] = specConstDecl; + continue; } auto intValue = checkConstantIntVal(arg); diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index c7930de835..90f1c15027 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -4342,6 +4342,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex for (int i = 0; i < 3; ++i) ids[i] = ensureInst(numThreads->getOperand(i))->id; + // LocalSizeId is supported from SPIR-V 1.2 onwards without + // any extra capabilities. requireSPIRVExecutionMode( decoration, dstID, From fe4d668f4754ef46fbce242dcb1a1e658916eb73 Mon Sep 17 00:00:00 2001 From: Julius Ikkala Date: Tue, 31 Dec 2024 19:25:40 +0200 Subject: [PATCH 04/12] Add diagnostics for specialization constant numthreads --- source/slang/slang-check-modifier.cpp | 11 +++++--- source/slang/slang-diagnostic-defs.h | 6 +++++ source/slang/slang-emit-c-like.cpp | 17 +++++++++--- source/slang/slang-emit-c-like.h | 2 +- .../slang-ir-collect-global-uniforms.cpp | 10 +++++++ .../slang-ir-legalize-varying-params.cpp | 14 ++++++++++ source/slang/slang-ir-metal-legalize.cpp | 26 ++++++++++++++----- source/slang/slang-lower-to-ir.cpp | 2 ++ source/slang/slang-parser.cpp | 3 ++- 9 files changed, 76 insertions(+), 15 deletions(-) diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index f5ec6a9d8f..9bd856da85 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -1907,9 +1907,10 @@ Modifier* SemanticsVisitor::checkModifier( // specialization constant with this ID. Int specConstId = cintVal->getValue(); - for(auto member: decl->parentDecl->members) + for (auto member : decl->parentDecl->members) { - auto constantId = member->findModifier(); + auto constantId = + member->findModifier(); if (constantId) { SLANG_ASSERT(attr->args.getCount() == 1); @@ -1927,12 +1928,14 @@ Modifier* SemanticsVisitor::checkModifier( if (!specIds[i]) { auto specConstVarDecl = getASTBuilder()->create(); - auto constantIdModifier = getASTBuilder()->create(); + auto constantIdModifier = + getASTBuilder()->create(); constantIdModifier->location = specConstId; specConstVarDecl->type.type = getASTBuilder()->getIntType(); addModifier(specConstVarDecl, constantIdModifier); decl->parentDecl->addMember(specConstVarDecl); - specIds[i] = DeclRef(specConstVarDecl->getDefaultDeclRef()); + specIds[i] = + DeclRef(specConstVarDecl->getDefaultDeclRef()); } continue; } diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 9d08d5b73c..44c142ac4d 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -2459,6 +2459,12 @@ DIAGNOSTIC( Error, unsupportedTargetIntrinsic, "intrinsic operation '$0' is not supported for the current target.") +DIAGNOSTIC( + 55205, + Error, + unsupportedSpecializationConstantForNumThreads, + "Specialization constants are not supported in the 'numthreads' attribute for the current " + "target.") DIAGNOSTIC( 56001, Error, diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index fad63144b1..2f6b0fe107 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -295,7 +295,7 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type) } -/* static */ IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize( +IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize( IRFunc* func, Int outNumThreads[kThreadGroupAxisCount]) { @@ -303,7 +303,18 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type) // group sizes with specialization constants (yet). They're currently just // ignored and '1' is returned in their place. Int specializationConstantIds[kThreadGroupAxisCount]; - return getComputeThreadGroupSize(func, outNumThreads, specializationConstantIds); + IRNumThreadsDecoration* decor = + getComputeThreadGroupSize(func, outNumThreads, specializationConstantIds); + + for (auto id : specializationConstantIds) + { + if (id >= 0) + { + getSink()->diagnose(decor, Diagnostics::unsupportedSpecializationConstantForNumThreads); + break; + } + } + return decor; } /* static */ IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize( @@ -312,7 +323,7 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type) Int outSpecializationConstantIds[kThreadGroupAxisCount]) { IRNumThreadsDecoration* decor = func->findDecoration(); - for (int i = 0; i < 3; ++i) + for (int i = 0; i < kThreadGroupAxisCount; ++i) { if (!decor) { diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index 3923e21f0f..9779ef21cc 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -502,7 +502,7 @@ class CLikeSourceEmitter : public SourceEmitterBase /// Finds the IRNumThreadsDecoration and gets the size from that or sets all /// dimensions to 1 - static IRNumThreadsDecoration* getComputeThreadGroupSize( + IRNumThreadsDecoration* getComputeThreadGroupSize( IRFunc* func, Int outNumThreads[kThreadGroupAxisCount]); diff --git a/source/slang/slang-ir-collect-global-uniforms.cpp b/source/slang/slang-ir-collect-global-uniforms.cpp index 1c833a2948..d2d4d9f07a 100644 --- a/source/slang/slang-ir-collect-global-uniforms.cpp +++ b/source/slang/slang-ir-collect-global-uniforms.cpp @@ -279,6 +279,16 @@ struct CollectGlobalUniformParametersContext continue; } + // NumThreadsDecoration may sometimes be the user for a global + // parameter. This occurs when the parameter was supposed to be + // a specialization constant, but isn't due to that not being + // supported for the target. These can be skipped here and + // diagnosed later. + if (auto layoutAttr = as(user)) + { + continue; + } + // For each use site for the global parameter, we will // insert new code right before the instruction that uses // the parameter. diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 025bcf1b85..f6180a7a60 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -1430,6 +1430,20 @@ struct CPUEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalize // groupExtents = emitCalcGroupExtents(builder, m_entryPointFunc, uint3Type); + if (!groupExtents) + { + m_sink->diagnose( + m_entryPointFunc, + Diagnostics::unsupportedSpecializationConstantForNumThreads); + + // Fill in placeholder values. + static const int kAxisCount = 3; + IRInst* groupExtentAlongAxis[kAxisCount] = {}; + for (int axis = 0; axis < kAxisCount; axis++) + groupExtentAlongAxis[axis] = builder.getIntValue(uint3Type->getElementType(), 1); + groupExtents = builder.emitMakeVector(uint3Type, kAxisCount, groupExtentAlongAxis); + } + dispatchThreadID = emitCalcDispatchThreadID(builder, uint3Type, groupID, groupThreadID, groupExtents); diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp index 835041a592..c0701b5c47 100644 --- a/source/slang/slang-ir-metal-legalize.cpp +++ b/source/slang/slang-ir-metal-legalize.cpp @@ -1827,12 +1827,26 @@ struct LegalizeMetalEntryPointContext IRBuilder svBuilder(builder.getModule()); svBuilder.setInsertBefore(entryPoint.entryPointFunc->getFirstOrdinaryInst()); - auto computeExtent = emitCalcGroupExtents( - svBuilder, - entryPoint.entryPointFunc, - builder.getVectorType( - builder.getUIntType(), - builder.getIntValue(builder.getIntType(), 3))); + auto uint3Type = builder.getVectorType( + builder.getUIntType(), + builder.getIntValue(builder.getIntType(), 3)); + auto computeExtent = + emitCalcGroupExtents(svBuilder, entryPoint.entryPointFunc, uint3Type); + if (!computeExtent) + { + m_sink->diagnose( + entryPoint.entryPointFunc, + Diagnostics::unsupportedSpecializationConstantForNumThreads); + + // Fill in placeholder values. + static const int kAxisCount = 3; + IRInst* groupExtentAlongAxis[kAxisCount] = {}; + for (int axis = 0; axis < kAxisCount; axis++) + groupExtentAlongAxis[axis] = + builder.getIntValue(uint3Type->getElementType(), 1); + computeExtent = + builder.emitMakeVector(uint3Type, kAxisCount, groupExtentAlongAxis); + } auto groupIndexCalc = emitCalcGroupIndex( svBuilder, entryPointToGroupThreadId[entryPoint.entryPointFunc], diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index a9bbd81a51..81f40f917d 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -10254,6 +10254,7 @@ struct DeclLoweringVisitor : DeclVisitor else if (auto numThreadsAttr = as(modifier)) { LoweredValInfo x, y, z; + x = numThreadsAttr->xSpecConst ? emitDeclRef( context, @@ -10284,6 +10285,7 @@ struct DeclLoweringVisitor : DeclVisitor getSimpleVal(context, x), getSimpleVal(context, y), getSimpleVal(context, z))); + numThreadsDecor->sourceLoc = numThreadsAttr->loc; } else if (auto waveSizeAttr = as(modifier)) { diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index c0db4aca1b..a134019cf0 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -8443,7 +8443,8 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) int localSizeIndex = -1; if (nameText.startsWith(localSizePrefix) && (nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 1 || - (nameText.endsWith("_id") && (nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 4)))) + (nameText.endsWith("_id") && + (nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 4)))) { char lastChar = nameText[SLANG_COUNT_OF(localSizePrefix) - 1]; localSizeIndex = (lastChar >= 'x' && lastChar <= 'z') ? (lastChar - 'x') : -1; From 8b4106c29c46c8b51aaee870ba9fdb077d8c43ac Mon Sep 17 00:00:00 2001 From: Julius Ikkala Date: Tue, 31 Dec 2024 19:35:43 +0200 Subject: [PATCH 05/12] Remove unused variable --- source/slang/slang-ir-collect-global-uniforms.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/slang/slang-ir-collect-global-uniforms.cpp b/source/slang/slang-ir-collect-global-uniforms.cpp index d2d4d9f07a..372ef298e7 100644 --- a/source/slang/slang-ir-collect-global-uniforms.cpp +++ b/source/slang/slang-ir-collect-global-uniforms.cpp @@ -284,7 +284,7 @@ struct CollectGlobalUniformParametersContext // a specialization constant, but isn't due to that not being // supported for the target. These can be skipped here and // diagnosed later. - if (auto layoutAttr = as(user)) + if (as(user)) { continue; } From ebf1d0b8c78237e8af642c48dbdc6a7bdf430c95 Mon Sep 17 00:00:00 2001 From: Julius Ikkala Date: Tue, 31 Dec 2024 20:30:27 +0200 Subject: [PATCH 06/12] Fix local_size_x_id not finding existing specialization constant --- source/slang/slang-check-modifier.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 9bd856da85..ddacd1b4a0 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -1910,11 +1910,11 @@ Modifier* SemanticsVisitor::checkModifier( for (auto member : decl->parentDecl->members) { auto constantId = - member->findModifier(); + member->findModifier(); if (constantId) { - SLANG_ASSERT(attr->args.getCount() == 1); - auto id = checkConstantIntVal(attr->args[0]); + SLANG_ASSERT(constantId->args.getCount() == 1); + auto id = checkConstantIntVal(constantId->args[0]); if (id->getValue() == specConstId) { specIds[i] = DeclRef(member->getDefaultDeclRef()); From f729bce46aff8614c348f0b44e2add5fab115ab0 Mon Sep 17 00:00:00 2001 From: Julius Ikkala Date: Tue, 31 Dec 2024 21:37:22 +0200 Subject: [PATCH 07/12] Allow materializeGetWorkGroupSize to reference specialization constants --- .../slang-ir-translate-glsl-global-var.cpp | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/source/slang/slang-ir-translate-glsl-global-var.cpp b/source/slang/slang-ir-translate-glsl-global-var.cpp index 65cb8f64fd..c5c63e58d8 100644 --- a/source/slang/slang-ir-translate-glsl-global-var.cpp +++ b/source/slang/slang-ir-translate-glsl-global-var.cpp @@ -280,10 +280,14 @@ struct GlobalVarTranslationContext if (!numthreadsDecor) return; builder.setInsertBefore(use->getUser()); - IRInst* values[] = { - numthreadsDecor->getExtentAlongAxis(0), - numthreadsDecor->getExtentAlongAxis(1), - numthreadsDecor->getExtentAlongAxis(2)}; + IRInst* values[3] = {}; + for (int i = 0; i < 3; ++i) + { + values[i] = numthreadsDecor->getExtentAlongAxis(i); + if (!values[i]) + values[i] = numthreadsDecor->getSpecConstAlongAxis(i); + } + auto workgroupSize = builder.emitMakeVector( builder.getVectorType(builder.getIntType(), 3), 3, @@ -326,10 +330,13 @@ struct GlobalVarTranslationContext if (!firstBlock) continue; builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); - IRInst* args[] = { - numthreadsDecor->getExtentAlongAxis(0), - numthreadsDecor->getExtentAlongAxis(1), - numthreadsDecor->getExtentAlongAxis(2)}; + IRInst* args[3] = {}; + for (int i = 0; i < 3; ++i) + { + args[i] = numthreadsDecor->getExtentAlongAxis(i); + if (!args[i]) + args[i] = numthreadsDecor->getSpecConstAlongAxis(i); + } auto workgroupSize = builder.emitMakeVector(workgroupSizeInst->getFullType(), 3, args); builder.emitStore(globalVar, workgroupSize); From 464982471b2ee5a5023c357cbcff3491dd4ceea5 Mon Sep 17 00:00:00 2001 From: Julius Ikkala Date: Sun, 5 Jan 2025 13:56:34 +0200 Subject: [PATCH 08/12] Use SpvOpExecutionModeId for modes that require it --- source/slang/slang-check-modifier.cpp | 3 +-- source/slang/slang-emit-spirv.cpp | 10 +++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index ddacd1b4a0..dc3fad5f70 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -1909,8 +1909,7 @@ Modifier* SemanticsVisitor::checkModifier( for (auto member : decl->parentDecl->members) { - auto constantId = - member->findModifier(); + auto constantId = member->findModifier(); if (constantId) { SLANG_ASSERT(constantId->args.getCount() == 1); diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 90f1c15027..67b112ddd3 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -7967,10 +7967,18 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { if (m_executionModes[entryPoint].add(executionMode)) { + SpvOp execModeOp = SpvOpExecutionMode; + if (executionMode == SpvExecutionModeLocalSizeId || + executionMode == SpvExecutionModeLocalSizeHintId || + executionMode == SpvExecutionModeSubgroupsPerWorkgroupId) + { + execModeOp = SpvOpExecutionModeId; + } + emitInst( getSection(SpvLogicalSectionID::ExecutionModes), parentInst, - SpvOpExecutionMode, + execModeOp, entryPoint, executionMode, ops...); From 1c7b6cc3a34063834759aadbb4ec79a8b4cafea1 Mon Sep 17 00:00:00 2001 From: Julius Ikkala Date: Mon, 6 Jan 2025 14:33:13 +0200 Subject: [PATCH 09/12] Cleanup specialization constant numthreads code --- source/slang/slang-ast-modifier.h | 16 +--- source/slang/slang-check-modifier.cpp | 45 +++------ source/slang/slang-emit-c-like.cpp | 3 - source/slang/slang-emit-glsl.cpp | 2 +- source/slang/slang-ir-insts.h | 3 - .../slang-ir-legalize-varying-params.cpp | 2 +- .../slang-ir-translate-glsl-global-var.cpp | 22 ++--- source/slang/slang-lower-to-ir.cpp | 94 ++++++++----------- source/slang/slang-reflection-api.cpp | 20 ++-- 9 files changed, 71 insertions(+), 136 deletions(-) diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 5ea0579bd2..ee29750a6a 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -973,18 +973,14 @@ class GLSLLayoutLocalSizeAttribute : public Attribute // // TODO: These should be accessors that use the // ordinary `args` list, rather than side data. - IntVal* x; - IntVal* y; - IntVal* z; + IntVal* extents[3]; bool axisIsSpecConstId[3]; // References to specialization constants, for defining the number of // threads with them. If set, the corresponding axis is set to nullptr // above. - DeclRef xSpecConst; - DeclRef ySpecConst; - DeclRef zSpecConst; + DeclRef specConstExtents[3]; }; class GLSLLayoutDerivativeGroupQuadAttribute : public Attribute @@ -1047,16 +1043,12 @@ class NumThreadsAttribute : public Attribute // // TODO: These should be accessors that use the // ordinary `args` list, rather than side data. - IntVal* x; - IntVal* y; - IntVal* z; + IntVal* extents[3]; // References to specialization constants, for defining the number of // threads with them. If set, the corresponding axis is set to nullptr // above. - DeclRef xSpecConst; - DeclRef ySpecConst; - DeclRef zSpecConst; + DeclRef specConstExtents[3]; }; class WaveSizeAttribute : public Attribute diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index dc3fad5f70..f1b795baea 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -380,9 +380,6 @@ Modifier* SemanticsVisitor::validateAttribute( { SLANG_ASSERT(attr->args.getCount() == 3); - IntVal* values[3] = {}; - DeclRef specIds[3] = {}; - for (int i = 0; i < 3; ++i) { IntVal* value = nullptr; @@ -393,7 +390,8 @@ Modifier* SemanticsVisitor::validateAttribute( auto specConstDecl = tryGetIntSpecializationConstant(arg); if (specConstDecl) { - specIds[i] = specConstDecl; + numThreadsAttr->extents[i] = nullptr; + numThreadsAttr->specConstExtents[i] = specConstDecl; continue; } @@ -428,16 +426,8 @@ Modifier* SemanticsVisitor::validateAttribute( { value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); } - values[i] = value; + numThreadsAttr->extents[i] = value; } - - numThreadsAttr->x = values[0]; - numThreadsAttr->y = values[1]; - numThreadsAttr->z = values[2]; - - numThreadsAttr->xSpecConst = specIds[0]; - numThreadsAttr->ySpecConst = specIds[1]; - numThreadsAttr->zSpecConst = specIds[2]; } else if (auto waveSizeAttr = as(attr)) { @@ -1873,16 +1863,13 @@ Modifier* SemanticsVisitor::checkModifier( { SLANG_ASSERT(attr->args.getCount() == 3); - IntVal* values[3] = {}; - DeclRef specIds[3] = {}; - // GLSLLayoutLocalSizeAttribute is always attached to an EmptyDecl. - auto decl = as(syntaxNode); + auto decl = as(syntaxNode); SLANG_ASSERT(decl); for (int i = 0; i < 3; ++i) { - IntVal* value = nullptr; + attr->extents[i] = nullptr; auto arg = attr->args[i]; if (arg) @@ -1890,7 +1877,7 @@ Modifier* SemanticsVisitor::checkModifier( auto specConstDecl = tryGetIntSpecializationConstant(arg); if (specConstDecl) { - specIds[i] = specConstDecl; + attr->specConstExtents[i] = specConstDecl; continue; } @@ -1916,7 +1903,8 @@ Modifier* SemanticsVisitor::checkModifier( auto id = checkConstantIntVal(constantId->args[0]); if (id->getValue() == specConstId) { - specIds[i] = DeclRef(member->getDefaultDeclRef()); + attr->specConstExtents[i] = + DeclRef(member->getDefaultDeclRef()); break; } } @@ -1924,7 +1912,7 @@ Modifier* SemanticsVisitor::checkModifier( // If not found, we need to create a new specialization // constant with this ID. - if (!specIds[i]) + if (!attr->specConstExtents[i]) { auto specConstVarDecl = getASTBuilder()->create(); auto constantIdModifier = @@ -1933,7 +1921,7 @@ Modifier* SemanticsVisitor::checkModifier( specConstVarDecl->type.type = getASTBuilder()->getIntType(); addModifier(specConstVarDecl, constantIdModifier); decl->parentDecl->addMember(specConstVarDecl); - specIds[i] = + attr->specConstExtents[i] = DeclRef(specConstVarDecl->getDefaultDeclRef()); } continue; @@ -1947,22 +1935,13 @@ Modifier* SemanticsVisitor::checkModifier( return nullptr; } } - value = intValue; + attr->extents[i] = intValue; } else { - value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); + attr->extents[i] = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); } - values[i] = value; } - - attr->x = values[0]; - attr->y = values[1]; - attr->z = values[2]; - - attr->xSpecConst = specIds[0]; - attr->ySpecConst = specIds[1]; - attr->zSpecConst = specIds[2]; } // Default behavior is to leave things as they are, diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 2f6b0fe107..3a1f22ffcd 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -299,9 +299,6 @@ IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize( IRFunc* func, Int outNumThreads[kThreadGroupAxisCount]) { - // TODO: Warn user that the selected emitter doesn't support setting work - // group sizes with specialization constants (yet). They're currently just - // ignored and '1' is returned in their place. Int specializationConstantIds[kThreadGroupAxisCount]; IRNumThreadsDecoration* decor = getComputeThreadGroupSize(func, outNumThreads, specializationConstantIds); diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 7ba92c43f7..9e29c87f09 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -1344,9 +1344,9 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl( { if (ii != 0) m_writer->emit(", "); - m_writer->emit("local_size_"); m_writer->emit(axes[ii]); + if (specializationConstantIds[ii] >= 0) { m_writer->emit("_id = "); diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index f89b9b3f52..23cc61935a 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -586,9 +586,6 @@ struct IRNumThreadsDecoration : IRDecoration IRGlobalParam* getXSpecConst() { return as(getOperand(0)); } IRGlobalParam* getYSpecConst() { return as(getOperand(1)); } IRGlobalParam* getZSpecConst() { return as(getOperand(2)); } - - IRIntLit* getExtentAlongAxis(int axis) { return as(getOperand(axis)); } - IRGlobalParam* getSpecConstAlongAxis(int axis) { return as(getOperand(axis)); } }; struct IRWaveSizeDecoration : IRDecoration diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index f6180a7a60..6290f5a1fd 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -188,7 +188,7 @@ IRInst* emitCalcGroupExtents(IRBuilder& builder, IRFunc* entryPoint, IRVectorTyp for (int axis = 0; axis < kAxisCount; axis++) { - auto litValue = as(numThreadsDecor->getExtentAlongAxis(axis)); + auto litValue = as(numThreadsDecor->getOperand(axis)); if (!litValue) return nullptr; diff --git a/source/slang/slang-ir-translate-glsl-global-var.cpp b/source/slang/slang-ir-translate-glsl-global-var.cpp index c5c63e58d8..2e1e143f56 100644 --- a/source/slang/slang-ir-translate-glsl-global-var.cpp +++ b/source/slang/slang-ir-translate-glsl-global-var.cpp @@ -280,13 +280,10 @@ struct GlobalVarTranslationContext if (!numthreadsDecor) return; builder.setInsertBefore(use->getUser()); - IRInst* values[3] = {}; - for (int i = 0; i < 3; ++i) - { - values[i] = numthreadsDecor->getExtentAlongAxis(i); - if (!values[i]) - values[i] = numthreadsDecor->getSpecConstAlongAxis(i); - } + IRInst* values[3] = { + numthreadsDecor->getOperand(0), + numthreadsDecor->getOperand(1), + numthreadsDecor->getOperand(2)}; auto workgroupSize = builder.emitMakeVector( builder.getVectorType(builder.getIntType(), 3), @@ -330,13 +327,10 @@ struct GlobalVarTranslationContext if (!firstBlock) continue; builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); - IRInst* args[3] = {}; - for (int i = 0; i < 3; ++i) - { - args[i] = numthreadsDecor->getExtentAlongAxis(i); - if (!args[i]) - args[i] = numthreadsDecor->getSpecConstAlongAxis(i); - } + IRInst* args[3] = { + numthreadsDecor->getOperand(0), + numthreadsDecor->getOperand(1), + numthreadsDecor->getOperand(2)}; auto workgroupSize = builder.emitMakeVector(workgroupSizeInst->getFullType(), 3, args); builder.emitStore(globalVar, workgroupSize); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 81f40f917d..050f1c395e 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7576,38 +7576,29 @@ struct DeclLoweringVisitor : DeclVisitor { verifyComputeDerivativeGroupModifier = true; getAllEntryPointsNoOverride(entryPoints); - LoweredValInfo x, y, z; - x = layoutLocalSizeAttr->xSpecConst - ? emitDeclRef( - context, - layoutLocalSizeAttr->xSpecConst, - lowerType( - context, - getType(context->astBuilder, layoutLocalSizeAttr->xSpecConst))) - : lowerVal(context, layoutLocalSizeAttr->x); - y = layoutLocalSizeAttr->ySpecConst - ? emitDeclRef( - context, - layoutLocalSizeAttr->ySpecConst, - lowerType( - context, - getType(context->astBuilder, layoutLocalSizeAttr->ySpecConst))) - : lowerVal(context, layoutLocalSizeAttr->y); - z = layoutLocalSizeAttr->zSpecConst - ? emitDeclRef( - context, - layoutLocalSizeAttr->zSpecConst, - lowerType( - context, - getType(context->astBuilder, layoutLocalSizeAttr->zSpecConst))) - : lowerVal(context, layoutLocalSizeAttr->z); + + LoweredValInfo extents[3]; + + for (int i = 0; i < 3; ++i) + { + extents[i] = layoutLocalSizeAttr->specConstExtents[i] + ? emitDeclRef( + context, + layoutLocalSizeAttr->specConstExtents[i], + lowerType( + context, + getType( + context->astBuilder, + layoutLocalSizeAttr->specConstExtents[i]))) + : lowerVal(context, layoutLocalSizeAttr->extents[i]); + } for (auto d : entryPoints) as(getBuilder()->addNumThreadsDecoration( d, - getSimpleVal(context, x), - getSimpleVal(context, y), - getSimpleVal(context, z))); + getSimpleVal(context, extents[0]), + getSimpleVal(context, extents[1]), + getSimpleVal(context, extents[2]))); } else if (as(modifier)) { @@ -10253,38 +10244,27 @@ struct DeclLoweringVisitor : DeclVisitor } else if (auto numThreadsAttr = as(modifier)) { - LoweredValInfo x, y, z; - - x = numThreadsAttr->xSpecConst - ? emitDeclRef( - context, - numThreadsAttr->xSpecConst, - lowerType( - context, - getType(context->astBuilder, numThreadsAttr->xSpecConst))) - : lowerVal(context, numThreadsAttr->x); - y = numThreadsAttr->ySpecConst - ? emitDeclRef( - context, - numThreadsAttr->ySpecConst, - lowerType( - context, - getType(context->astBuilder, numThreadsAttr->ySpecConst))) - : lowerVal(context, numThreadsAttr->y); - z = numThreadsAttr->zSpecConst - ? emitDeclRef( - context, - numThreadsAttr->zSpecConst, - lowerType( - context, - getType(context->astBuilder, numThreadsAttr->zSpecConst))) - : lowerVal(context, numThreadsAttr->z); + LoweredValInfo extents[3]; + + for (int i = 0; i < 3; ++i) + { + extents[i] = numThreadsAttr->specConstExtents[i] + ? emitDeclRef( + context, + numThreadsAttr->specConstExtents[i], + lowerType( + context, + getType( + context->astBuilder, + numThreadsAttr->specConstExtents[i]))) + : lowerVal(context, numThreadsAttr->extents[i]); + } numThreadsDecor = as(getBuilder()->addNumThreadsDecoration( irFunc, - getSimpleVal(context, x), - getSimpleVal(context, y), - getSimpleVal(context, z))); + getSimpleVal(context, extents[0]), + getSimpleVal(context, extents[1]), + getSimpleVal(context, extents[2]))); numThreadsDecor->sourceLoc = numThreadsAttr->loc; } else if (auto waveSizeAttr = as(modifier)) diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index d235c82703..71f9825fea 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -4033,18 +4033,14 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize( auto numThreadsAttribute = entryPointFunc.getDecl()->findModifier(); if (numThreadsAttribute) { - if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->x)) - sizeAlongAxis[0] = (SlangUInt)cint->getValue(); - else if (numThreadsAttribute->x) - sizeAlongAxis[0] = 0; - if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->y)) - sizeAlongAxis[1] = (SlangUInt)cint->getValue(); - else if (numThreadsAttribute->y) - sizeAlongAxis[1] = 0; - if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->z)) - sizeAlongAxis[2] = (SlangUInt)cint->getValue(); - else if (numThreadsAttribute->z) - sizeAlongAxis[2] = 0; + for (int i = 0; i < 3; ++i) + { + if (auto cint = + entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->extents[i])) + sizeAlongAxis[0] = (SlangUInt)cint->getValue(); + else if (numThreadsAttribute->extents[i]) + sizeAlongAxis[0] = 0; + } } // From b2bf3ccccc62ba4f055d1b41733076d396986a13 Mon Sep 17 00:00:00 2001 From: Julius Ikkala Date: Wed, 8 Jan 2025 13:10:22 +0200 Subject: [PATCH 10/12] Add tests for specialization constant work group sizes --- tests/glsl/compute-shader-layout-id.slang | 19 ++++++++++++ tests/spirv/spec-constant-numthreads.slang | 35 ++++++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 tests/glsl/compute-shader-layout-id.slang create mode 100644 tests/spirv/spec-constant-numthreads.slang diff --git a/tests/glsl/compute-shader-layout-id.slang b/tests/glsl/compute-shader-layout-id.slang new file mode 100644 index 0000000000..bee8137d82 --- /dev/null +++ b/tests/glsl/compute-shader-layout-id.slang @@ -0,0 +1,19 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv -stage compute -entry main -allow-glsl +#version 450 + +[vk::constant_id(1)] +const int constValue1 = 0; + +[vk::constant_id(2)] +const int constValue3 = 5; + +// CHECK-DAG: OpExecutionModeId %main LocalSizeId %[[C0:[0-9A-Za-z_]+]] %[[C1:[0-9A-Za-z_]+]] %[[C2:[0-9A-Za-z_]+]] +// CHECK-DAG: OpDecorate %[[C0]] SpecId 1 +// CHECK-DAG: OpDecorate %[[C1]] SpecId 0 +// CHECK-DAG: OpDecorate %[[C2]] SpecId 2 + +layout(local_size_x_id = 1, local_size_y_id = 0, local_size_z = constValue3) in; +void main() +{ +} + diff --git a/tests/spirv/spec-constant-numthreads.slang b/tests/spirv/spec-constant-numthreads.slang new file mode 100644 index 0000000000..5c133219cf --- /dev/null +++ b/tests/spirv/spec-constant-numthreads.slang @@ -0,0 +1,35 @@ +//TEST:SIMPLE(filecheck=GLSL): -target glsl -allow-glsl +//TEST:SIMPLE(filecheck=GLSL): -target glsl +//TEST:SIMPLE(filecheck=CHECK): -target spirv -allow-glsl +//TEST:SIMPLE(filecheck=CHECK): -target spirv + +// CHECK-DAG: OpExecutionModeId %computeMain1 LocalSizeId %[[C0:[0-9A-Za-z_]+]] %[[C1:[0-9A-Za-z_]+]] %[[C2:[0-9A-Za-z_]+]] +// CHECK-DAG: OpDecorate %[[C0]] SpecId 1 +// CHECK-DAG: OpDecorate %[[C1]] SpecId 0 +// CHECK-DAG: %[[C2]] = OpConstant %int 4 +// CHECK-DAG: OpStore %{{.*}} %[[C0]] +// CHECK-DAG: OpStore %{{.*}} %[[C1]] +// CHECK-DAG: OpStore %{{.*}} %[[C2]] + +// GLSL-DAG: layout(constant_id = 1) +// GLSL-DAG: int constValue0_0 = 0; +// GLSL-DAG: layout(constant_id = 0) +// GLSL-DAG: int constValue1_0 = 0; +// GLSL-DAG: layout(local_size_x_id = 1, local_size_y_id = 0, local_size_z = 4) in; + +[vk::specialization_constant] +const int constValue0 = 0; + +[vk::constant_id(0)] +const int constValue1 = 0; + +RWStructuredBuffer outputBuffer; + +[numthreads(constValue0, constValue1, 4)] +void computeMain1() +{ + int3 size = WorkgroupSize(); + outputBuffer[0] = size.x; + outputBuffer[1] = size.y; + outputBuffer[2] = size.z; +} From 736af8f4b0022cff6054fc0263fb7484ea076181 Mon Sep 17 00:00:00 2001 From: Julius Ikkala Date: Thu, 9 Jan 2025 19:23:41 +0200 Subject: [PATCH 11/12] Fix implicit Slang::Int -> int32_t cast --- source/slang/slang-check-modifier.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index f1b795baea..6e451b5cf9 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -1917,7 +1917,7 @@ Modifier* SemanticsVisitor::checkModifier( auto specConstVarDecl = getASTBuilder()->create(); auto constantIdModifier = getASTBuilder()->create(); - constantIdModifier->location = specConstId; + constantIdModifier->location = (int32_t)specConstId; specConstVarDecl->type.type = getASTBuilder()->getIntType(); addModifier(specConstVarDecl, constantIdModifier); decl->parentDecl->addMember(specConstVarDecl); From c912deaf86a878bf822aee8925a91c47983c9260 Mon Sep 17 00:00:00 2001 From: Julius Ikkala Date: Sun, 12 Jan 2025 17:56:36 +0200 Subject: [PATCH 12/12] Fix querying thread group size in reflection API --- source/slang/slang-reflection-api.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index 71f9825fea..d1adfedc0b 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -4037,9 +4037,9 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize( { if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->extents[i])) - sizeAlongAxis[0] = (SlangUInt)cint->getValue(); + sizeAlongAxis[i] = (SlangUInt)cint->getValue(); else if (numThreadsAttribute->extents[i]) - sizeAlongAxis[0] = 0; + sizeAlongAxis[i] = 0; } }