Skip to content

Commit 4c8a546

Browse files
committed
Should now work with vector types and workgroupsize builtin
1 parent cb7b2b9 commit 4c8a546

File tree

3 files changed

+150
-63
lines changed

3 files changed

+150
-63
lines changed

common/output_stream.cpp

Lines changed: 78 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,13 @@ void ParseBlockMembersToTextLines(
663663
// dim = 0 means it's an unbounded array
664664
//
665665
if (dim > 0) {
666+
if (dim == 0xFFFFFFFF) {
667+
SpvReflectValue val{};
668+
SpvReflectResult res = obj.EvaluateResult(member.array.spec_constant_op_ids[array_dim_index], val);
669+
if ((res == SPV_REFLECT_RESULT_SUCCESS) && val.type && (val.type->type_flags == SPV_REFLECT_TYPE_FLAG_INT) && (val.type->traits.numeric.scalar.width == 32)) {
670+
dim = val.values[0].value.uint32_bool_value;
671+
}
672+
}
666673
ss_array << "[" << dim << "]";
667674
}
668675
else {
@@ -702,8 +709,9 @@ void ParseBlockMembersToTextLines(
702709
if (dim == 0xFFFFFFFF) {
703710
SpvReflectValue val{};
704711
SpvReflectResult res = obj.EvaluateResult(member.array.spec_constant_op_ids[array_dim_index], val);
705-
if (res != SPV_REFLECT_RESULT_SUCCESS) throw;
706-
dim = val.values[0].value.uint32_bool_value;
712+
if ((res == SPV_REFLECT_RESULT_SUCCESS) && val.type && (val.type->type_flags == SPV_REFLECT_TYPE_FLAG_INT) && (val.type->traits.numeric.scalar.width == 32)) {
713+
dim = val.values[0].value.uint32_bool_value;
714+
}
707715
}
708716
ss_array << "[" << dim << "]";
709717
}
@@ -1009,7 +1017,7 @@ void StreamWriteDescriptorBinding(std::ostream& os, const spv_reflect::ShaderMod
10091017
}
10101018
}
10111019

1012-
void StreamWriteInterfaceVariable(std::ostream& os, const SpvReflectInterfaceVariable& obj, const char* indent)
1020+
void StreamWriteInterfaceVariable(std::ostream& os, const spv_reflect::ShaderModule& shader, const SpvReflectInterfaceVariable& obj, const char* indent)
10131021
{
10141022
const char* t = indent;
10151023
os << t << "spirv id : " << obj.spirv_id << "\n";
@@ -1028,6 +1036,14 @@ void StreamWriteInterfaceVariable(std::ostream& os, const SpvReflectInterfaceVar
10281036
if (obj.array.dims_count > 0) {
10291037
os << t << "array : ";
10301038
for (uint32_t dim_index = 0; dim_index < obj.array.dims_count; ++dim_index) {
1039+
uint32_t dim = obj.array.dims[dim_index];
1040+
if (dim == 0xFFFFFFFF) {
1041+
SpvReflectValue val{};
1042+
SpvReflectResult res = shader.EvaluateResult(obj.array.spec_constant_op_ids[dim_index], val);
1043+
if ((res == SPV_REFLECT_RESULT_SUCCESS) && val.type && (val.type->type_flags == SPV_REFLECT_TYPE_FLAG_INT) && (val.type->traits.numeric.scalar.width == 32)) {
1044+
dim = val.values[0].value.uint32_bool_value;
1045+
}
1046+
}
10311047
os << "[" << obj.array.dims[dim_index] << "]";
10321048
}
10331049
os << "\n";
@@ -1108,28 +1124,74 @@ void StreamWriteSpecializationConstant(std::ostream& os, const SpvReflectSpecial
11081124
}
11091125
}
11101126

1111-
void StreamWriteEntryPoint(std::ostream& os, const SpvReflectEntryPoint& obj, const char* indent)
1127+
void StreamWriteEntryPoint(std::ostream& os, const spv_reflect::ShaderModule& shader, const SpvReflectEntryPoint& obj, const char* indent)
11121128
{
11131129
os << indent << "entry point : " << obj.name;
11141130
os << " (stage=" << ToStringShaderStage(obj.shader_stage) << ")";
11151131
if (obj.shader_stage == SPV_REFLECT_SHADER_STAGE_COMPUTE_BIT) {
11161132
os << "\n";
1117-
os << "local size : " << "(" << obj.local_size.x << ", " << obj.local_size.y << ", " << obj.local_size.z << ")";
1133+
if (obj.local_size.flags & 2) {
1134+
os << "local size : ";
1135+
}
1136+
else {
1137+
os << "local size hint : ";
1138+
}
1139+
if(obj.local_size.flags & 4) {
1140+
SpvReflectValue val{};
1141+
SpvReflectResult res = shader.EvaluateResult(obj.local_size.x, val);
1142+
if ((res == SPV_REFLECT_RESULT_SUCCESS) && val.type && (val.type->type_flags == (SPV_REFLECT_TYPE_FLAG_INT | SPV_REFLECT_TYPE_FLAG_VECTOR))
1143+
&& (val.type->traits.numeric.scalar.width == 32)) {
1144+
os << "(" << val.values[0].value.uint32_bool_value << ", " << val.values[1].value.uint32_bool_value << ", " << val.values[2].value.uint32_bool_value << ")";
1145+
}
1146+
else {
1147+
os << "(failed evaluation of WorkGroupSize Builtin)";
1148+
}
1149+
}
1150+
else if(obj.local_size.flags & 1) {
1151+
os << "(";
1152+
SpvReflectValue val = {0};
1153+
SpvReflectResult res = shader.EvaluateResult(obj.local_size.x, val);
1154+
if ((res == SPV_REFLECT_RESULT_SUCCESS) && val.type && (val.type->type_flags == SPV_REFLECT_TYPE_FLAG_INT) && (val.type->traits.numeric.scalar.width == 32)) {
1155+
os << val.values[0].value.uint32_bool_value;
1156+
}
1157+
else {
1158+
os << "unknown";
1159+
}
1160+
os << ", ";
1161+
res = shader.EvaluateResult(obj.local_size.y, val);
1162+
if ((res == SPV_REFLECT_RESULT_SUCCESS) && val.type && (val.type->type_flags == SPV_REFLECT_TYPE_FLAG_INT) && (val.type->traits.numeric.scalar.width == 32)) {
1163+
os << val.values[0].value.uint32_bool_value;
1164+
}
1165+
else {
1166+
os << "unknown";
1167+
} os << ", ";
1168+
res = shader.EvaluateResult(obj.local_size.z, val);
1169+
if ((res == SPV_REFLECT_RESULT_SUCCESS) && val.type && (val.type->type_flags == SPV_REFLECT_TYPE_FLAG_INT) && (val.type->traits.numeric.scalar.width == 32)) {
1170+
os << val.values[0].value.uint32_bool_value;
1171+
}
1172+
else {
1173+
os << "unknown";
1174+
}
1175+
os << ")";
1176+
}
1177+
else{
1178+
os << "(" << obj.local_size.x << ", " << obj.local_size.y << ", " << obj.local_size.z << ")";
1179+
}
11181180
}
11191181
}
11201182

1121-
void StreamWriteShaderModule(std::ostream& os, const SpvReflectShaderModule& obj, const char* indent)
1183+
void StreamWriteShaderModule(std::ostream& os, const spv_reflect::ShaderModule& obj, const char* indent)
11221184
{
11231185
(void)indent;
1124-
os << "generator : " << ToStringGenerator(obj.generator) << "\n";
1125-
os << "source lang : " << spvReflectSourceLanguage(obj.source_language) << "\n";
1126-
os << "source lang ver : " << obj.source_language_version << "\n";
1127-
os << "source file : " << (obj.source_file != NULL ? obj.source_file : "") << "\n";
1186+
os << "generator : " << ToStringGenerator(obj.GetShaderModule().generator) << "\n";
1187+
os << "source lang : " << spvReflectSourceLanguage(obj.GetShaderModule().source_language) << "\n";
1188+
os << "source lang ver : " << obj.GetShaderModule().source_language_version << "\n";
1189+
os << "source file : " << (obj.GetShaderModule().source_file != NULL ? obj.GetShaderModule().source_file : "") << "\n";
11281190
//os << "shader stage : " << ToStringShaderStage(obj.shader_stage) << "\n";
11291191

1130-
for (uint32_t i = 0; i < obj.entry_point_count; ++i) {
1131-
StreamWriteEntryPoint(os, obj.entry_points[i], "");
1132-
if (i < (obj.entry_point_count - 1)) {
1192+
for (uint32_t i = 0; i < obj.GetShaderModule().entry_point_count; ++i) {
1193+
StreamWriteEntryPoint(os, obj, obj.GetShaderModule().entry_points[i], "");
1194+
if (i < (obj.GetShaderModule().entry_point_count - 1)) {
11331195
os << "\n";
11341196
}
11351197
}
@@ -1149,7 +1211,7 @@ void WriteReflection(const spv_reflect::ShaderModule& obj, bool flatten_cbuffers
11491211
const char* tt = " ";
11501212
const char* ttt = " ";
11511213

1152-
StreamWriteShaderModule(os, obj.GetShaderModule(), "");
1214+
StreamWriteShaderModule(os, obj, "");
11531215

11541216
uint32_t count = 0;
11551217
std::vector<SpvReflectInterfaceVariable*> variables;
@@ -1195,7 +1257,7 @@ void WriteReflection(const spv_reflect::ShaderModule& obj, bool flatten_cbuffers
11951257
auto p_var = variables[i];
11961258
USE_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
11971259
os << tt << i << ":" << "\n";
1198-
StreamWriteInterfaceVariable(os, *p_var, ttt);
1260+
StreamWriteInterfaceVariable(os, obj,*p_var, ttt);
11991261
if (i < (count - 1)) {
12001262
os << "\n";
12011263
}
@@ -1217,7 +1279,7 @@ void WriteReflection(const spv_reflect::ShaderModule& obj, bool flatten_cbuffers
12171279
auto p_var = variables[i];
12181280
USE_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
12191281
os << tt << i << ":" << "\n";
1220-
StreamWriteInterfaceVariable(os, *p_var, ttt);
1282+
StreamWriteInterfaceVariable(os, obj, *p_var, ttt);
12211283
if (i < (count - 1)) {
12221284
os << "\n";
12231285
}

spirv_reflect.c

Lines changed: 62 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,6 +1445,7 @@ static SpvReflectResult ParseDecorations(SpvReflectPrvParser* p_parser)
14451445
case SpvDecorationBuiltIn: {
14461446
p_target_decorations->is_built_in = true;
14471447
uint32_t word_offset = p_node->word_offset + member_offset + 3;
1448+
// no rule specifies a result cannot be decorated twice. But let's assume this for now...
14481449
CHECKED_READU32_CAST(p_parser, word_offset, SpvBuiltIn, p_target_decorations->built_in);
14491450
}
14501451
break;
@@ -3397,34 +3398,16 @@ SpvReflectResult GetTypeByTypeId(const SpvReflectShaderModule* p_module, uint32_
33973398
#define COMPOSITE_TYPE_FLAGS (VECTOR_TYPE_FLAGS|SPV_REFLECT_TYPE_FLAG_MATRIX|SPV_REFLECT_TYPE_FLAG_STRUCT|SPV_REFLECT_TYPE_FLAG_ARRAY)
33983399
#define COMPOSITE_DISALLOWED_FLAGS (~0 ^ COMPOSITE_TYPE_FLAGS)
33993400

3400-
static SpvReflectScalarType ScalarGeneralTypeFromType(SpvReflectTypeDescription* type)
3401-
{
3402-
if ((type->type_flags & SCALAR_TYPE_FLAGS) == SPV_REFLECT_TYPE_FLAG_BOOL) {
3403-
return SPV_REFLECT_SCALAR_TYPE_BOOL;
3404-
}
3405-
else if ((type->type_flags & SCALAR_TYPE_FLAGS) == SPV_REFLECT_TYPE_FLAG_INT) {
3406-
return SPV_REFLECT_SCALAR_TYPE_INT;
3407-
}
3408-
else if ((type->type_flags & SCALAR_TYPE_FLAGS) == SPV_REFLECT_TYPE_FLAG_FLOAT) {
3409-
return SPV_REFLECT_SCALAR_TYPE_FLOAT;
3410-
}
3411-
else {
3412-
return SPV_REFLECT_SCALAR_TYPE_UNKNOWN;
3413-
}
3414-
}
3415-
34163401
static SpvReflectResult GetScalarConstant(const SpvReflectShaderModule* p_module, SpvReflectPrvNode* p_node,
3417-
SpvReflectScalarValue* result, SpvReflectScalarType* general_type, SpvReflectTypeDescription** type)
3402+
SpvReflectScalarValue* result, SpvReflectTypeDescription** type)
34183403
{
34193404
SpvReflectPrvParser* p_parser = p_module->_internal->parser;
34203405

3421-
SpvReflectScalarType g_type;
34223406
SpvReflectTypeDescription* d_type;
34233407
SpvReflectResult res = GetTypeByTypeId(p_module, p_node->result_type_id, &d_type);
34243408
if (res != SPV_REFLECT_RESULT_SUCCESS) return res;
34253409

34263410
if(d_type->type_flags & SCALAR_DISALLOWED_FLAGS) return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_TYPE;
3427-
g_type = ScalarGeneralTypeFromType(d_type);
34283411
uint32_t low_word;
34293412
CHECKED_READU32(p_parser, p_node->word_offset + 3, low_word);
34303413
// There is no alignment requirements in c/cpp for unions
@@ -3440,7 +3423,6 @@ static SpvReflectResult GetScalarConstant(const SpvReflectShaderModule* p_module
34403423
else {
34413424
return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_TYPE;
34423425
}
3443-
*general_type = g_type;
34443426
*type = d_type;
34453427
return SPV_REFLECT_RESULT_SUCCESS;
34463428
}
@@ -3466,6 +3448,18 @@ static SpvReflectResult ParseSpecializationConstants(SpvReflectPrvParser* p_pars
34663448

34673449
for (size_t i = 0; i < p_parser->node_count; ++i) {
34683450
SpvReflectPrvNode* p_node = &(p_parser->nodes[i]);
3451+
// check first if it's WorkGroupSize builtin
3452+
// maybe handling builtin as global map may be better.
3453+
if (p_node->decorations.built_in == SpvBuiltInWorkgroupSize) {
3454+
// WorkGroupSize builtin's target is all ExecutionMode instructions.
3455+
for(uint32_t j = 0; j<p_module->entry_point_count; ++j) {
3456+
if(p_module->entry_points[j].spirv_execution_model == SpvExecutionModelKernel||
3457+
p_module->entry_points[j].spirv_execution_model == SpvExecutionModelGLCompute){
3458+
p_module->entry_points[j].local_size.flags = 4;
3459+
p_module->entry_points[j].local_size.x = p_node->result_id;
3460+
}
3461+
}
3462+
}
34693463
// Specconstants with no id means constant
34703464
switch(p_node->op) {
34713465
default: continue;
@@ -3482,19 +3476,19 @@ static SpvReflectResult ParseSpecializationConstants(SpvReflectPrvParser* p_pars
34823476
case SpvOpSpecConstant: {
34833477
SpvReflectResult result = SPV_REFLECT_RESULT_SUCCESS;
34843478
SpvReflectScalarValue default_value = { 0 };
3485-
result = GetScalarConstant(p_module, p_node, &default_value, &p_module->specialization_constants[index].general_type, &p_module->specialization_constants[index].type);
3479+
result = GetScalarConstant(p_module, p_node, &default_value, &p_module->specialization_constants[index].type);
34863480
if (result != SPV_REFLECT_RESULT_SUCCESS) return result;
34873481
p_module->specialization_constants[index].default_value = default_value;
34883482
p_module->specialization_constants[index].current_value = p_module->specialization_constants[index].default_value;
34893483
} break;
34903484
}
34913485
// spec constant id cannot be the same, at least for valid values. (invalid value is just constant?)
34923486
if (p_node->decorations.specialization_constant.value != (uint32_t)INVALID_VALUE) {
3493-
for (uint32_t j = 0; j < index; ++j) {
3494-
if (p_module->specialization_constants[j].constant_id == p_node->decorations.specialization_constant.value) {
3495-
return SPV_REFLECT_RESULT_ERROR_SPIRV_DUPLICATE_SPEC_CONSTANT_NAME;
3496-
}
3487+
for (uint32_t j = 0; j < index; ++j) {
3488+
if (p_module->specialization_constants[j].constant_id == p_node->decorations.specialization_constant.value) {
3489+
return SPV_REFLECT_RESULT_ERROR_SPIRV_DUPLICATE_SPEC_CONSTANT_NAME;
34973490
}
3491+
}
34983492
}
34993493

35003494
p_module->specialization_constants[index].name = p_node->name;
@@ -3895,10 +3889,6 @@ static SpvReflectResult CreateShaderModule(
38953889
result = ParsePushConstantBlocks(parser, p_module);
38963890
SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
38973891
}
3898-
if (result == SPV_REFLECT_RESULT_SUCCESS) {
3899-
result = ParseSpecializationConstants(parser, p_module);
3900-
SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
3901-
}
39023892
if (result == SPV_REFLECT_RESULT_SUCCESS) {
39033893
result = ParseEntryPoints(parser, p_module);
39043894
SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
@@ -3928,6 +3918,11 @@ static SpvReflectResult CreateShaderModule(
39283918
result = ParseExecutionModes(parser, p_module);
39293919
SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
39303920
}
3921+
// WorkGroupSize builtin needs to update entry point localsize member
3922+
if (result == SPV_REFLECT_RESULT_SUCCESS) {
3923+
result = ParseSpecializationConstants(parser, p_module);
3924+
SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
3925+
}
39313926

39323927
// Destroy module if parse was not successful
39333928
if (result != SPV_REFLECT_RESULT_SUCCESS) {
@@ -5723,22 +5718,20 @@ SpvReflectResult EvaluateResultImpl(const SpvReflectShaderModule* p_module, uint
57235718
if (!p_node) return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
57245719
switch (p_node->op) {
57255720
default:
5726-
return SPV_REFLECT_RESULT_ERROR_SPIRV_UNRESOLVED_EVALUATION;
5721+
return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_INSTRUCTION;
57275722
case SpvOpConstantTrue:
57285723
{
5729-
result->general_type = SPV_REFLECT_SCALAR_TYPE_BOOL;
57305724
result->values[0].value.uint32_bool_value = 1;
57315725
}
57325726
return SPV_REFLECT_RESULT_SUCCESS;
57335727
case SpvOpConstantFalse:
57345728
{
5735-
result->general_type = SPV_REFLECT_SCALAR_TYPE_BOOL;
57365729
result->values[0].value.uint32_bool_value = 0;
57375730
}
57385731
return SPV_REFLECT_RESULT_SUCCESS;
57395732
case SpvOpConstant:
57405733
CONSTANT_RESULT:
5741-
return GetScalarConstant(p_module, p_node, &result->values[0], &result->general_type, &result->type);
5734+
return GetScalarConstant(p_module, p_node, &result->values[0], &result->type);
57425735
case SpvOpSpecConstantTrue: case SpvOpSpecConstantFalse:
57435736
case SpvOpSpecConstant:
57445737
{
@@ -5748,32 +5741,58 @@ SpvReflectResult EvaluateResultImpl(const SpvReflectShaderModule* p_module, uint
57485741
SpvReflectSpecializationConstant* p_constant;
57495742
res = GetSpecContantById(p_module, p_node->decorations.specialization_constant.value, &p_constant);
57505743
if (res != SPV_REFLECT_RESULT_SUCCESS) return res;
5751-
result->general_type = p_constant->general_type;
57525744
result->type = p_constant->type;
57535745
result->values[0] = p_constant->current_value;
57545746
}
57555747
return SPV_REFLECT_RESULT_SUCCESS;
57565748
case SpvOpSpecConstantComposite:
57575749
{
5758-
// only support scalar types for now...
5750+
// only support compositing vector types for now...
5751+
// vectors are needed for spv compiled to WorkgroupSize builtin
5752+
// in expressing actual localsize
5753+
res = GetTypeByTypeId(p_module, p_node->result_type_id, &result->type);
5754+
if (res != SPV_REFLECT_RESULT_SUCCESS) return res;
5755+
// compositing types
5756+
if (result->type->type_flags & VECTOR_DISALLOWED_FLAGS) {
5757+
return SPV_REFLECT_RESULT_ERROR_SPIRV_UNRESOLVED_EVALUATION;
5758+
}
5759+
uint32_t vec_size = 1;
5760+
// should always have, since scalars do not need composite
5761+
if (result->type->type_flags & SPV_REFLECT_TYPE_FLAG_VECTOR) {
5762+
vec_size = result->type->traits.numeric.vector.component_count;
5763+
}
5764+
// check instruction size
5765+
if (p_node->word_count != 3 + vec_size) {
5766+
return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_INSTRUCTION;
5767+
}
5768+
for (uint32_t i = 0; i < vec_size; ++i) {
5769+
SpvReflectValue operandi = {0};
5770+
GET_OPERAND(p_module, p_node, 3 + i, &operandi, maxRecursion);
5771+
// check type compatibility
5772+
if (operandi.type && (operandi.type->type_flags & SCALAR_DISALLOWED_FLAGS)) {
5773+
return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_TYPE;
5774+
}
5775+
if ((!operandi.type && !(result->type->type_flags& SPV_REFLECT_TYPE_FLAG_BOOL))
5776+
||(operandi.type && ((operandi.type->type_flags & SCALAR_TYPE_FLAGS) != (result->type->type_flags & SCALAR_TYPE_FLAGS)))) {
5777+
return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_TYPE;
5778+
}
5779+
result->values[i] = operandi.values[0];
5780+
}
57595781
}
5760-
return SPV_REFLECT_RESULT_ERROR_SPIRV_UNRESOLVED_EVALUATION;
5782+
return SPV_REFLECT_RESULT_SUCCESS;
57615783
case SpvOpSpecConstantOp:
57625784
{
57635785
// operation has result type id, thus must be typed
57645786
res = GetTypeByTypeId(p_module, p_node->result_type_id, &result->type);
57655787
if (res != SPV_REFLECT_RESULT_SUCCESS) return res;
57665788

5767-
// no support for vectors yet...
5768-
if (result->type->type_flags & SPV_REFLECT_TYPE_FLAG_VECTOR) {
5789+
// only vector and scalar types of int/bool/float types implemented
5790+
// only OpSelect, OpUndef and access chain instructions can work with non-vector or scalar types
5791+
// they are not currently supported... (likely never will)
5792+
if (result->type->type_flags & VECTOR_DISALLOWED_FLAGS) {
57695793
return SPV_REFLECT_RESULT_ERROR_SPIRV_UNRESOLVED_EVALUATION;
57705794
}
57715795

5772-
// only vector and scalar types of int/bool/float types allowed
5773-
CHECK_VECTOR_OR_SCALAR_TYPE(result)
5774-
5775-
result->general_type = ScalarGeneralTypeFromType(result->type);
5776-
57775796
// evaluate op
57785797
uint32_t spec_op;
57795798
CHECKED_READU32(p_parser, p_node->word_offset + 3, spec_op);

0 commit comments

Comments
 (0)