diff --git a/spirv_reflect.c b/spirv_reflect.c index 7cbeda2..bc43485 100644 --- a/spirv_reflect.c +++ b/spirv_reflect.c @@ -225,6 +225,7 @@ typedef struct SpvReflectPrvParser { uint32_t type_count; uint32_t descriptor_count; uint32_t push_constant_count; + uint32_t spec_constant_count; SpvReflectTypeDescription* physical_pointer_check[MAX_RECURSIVE_PHYSICAL_POINTER_CHECK]; uint32_t physical_pointer_count; @@ -1420,7 +1421,7 @@ static bool UserTypeMatches(const char* user_type, const char* pattern) { return false; } -static SpvReflectResult ParseDecorations(SpvReflectPrvParser* p_parser, SpvReflectShaderModule* p_module) { +static SpvReflectResult ParseDecorations(SpvReflectPrvParser* p_parser) { uint32_t spec_constant_count = 0; for (uint32_t i = 0; i < p_parser->node_count; ++i) { SpvReflectPrvNode* p_node = &(p_parser->nodes[i]); @@ -1722,31 +1723,7 @@ static SpvReflectResult ParseDecorations(SpvReflectPrvParser* p_parser, SpvRefle } } - if (spec_constant_count > 0) { - p_module->spec_constants = (SpvReflectSpecializationConstant*)calloc(spec_constant_count, sizeof(*p_module->spec_constants)); - if (IsNull(p_module->spec_constants)) { - return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; - } - } - for (uint32_t i = 0; i < p_parser->node_count; ++i) { - SpvReflectPrvNode* p_node = &(p_parser->nodes[i]); - if (p_node->op == SpvOpDecorate) { - uint32_t decoration = (uint32_t)INVALID_VALUE; - CHECKED_READU32(p_parser, p_node->word_offset + 2, decoration); - if (decoration == SpvDecorationSpecId) { - const uint32_t count = p_module->spec_constant_count; - CHECKED_READU32(p_parser, p_node->word_offset + 1, p_module->spec_constants[count].spirv_id); - CHECKED_READU32(p_parser, p_node->word_offset + 3, p_module->spec_constants[count].constant_id); - // If being used for a OpSpecConstantComposite (ex. LocalSizeId), there won't be a name - SpvReflectPrvNode* target_node = FindNode(p_parser, p_module->spec_constants[count].spirv_id); - if (IsNotNull(target_node)) { - p_module->spec_constants[count].name = target_node->name; - } - p_module->spec_constant_count++; - } - } - } - + p_parser->spec_constant_count = spec_constant_count; return SPV_REFLECT_RESULT_SUCCESS; } @@ -4055,6 +4032,63 @@ static SpvReflectResult ParsePushConstantBlocks(SpvReflectPrvParser* p_parser, S return SPV_REFLECT_RESULT_SUCCESS; } +static SpvReflectResult ParseSpecConstants(SpvReflectPrvParser* p_parser, SpvReflectShaderModule* p_module) { + if (p_parser->spec_constant_count > 0) { + p_module->spec_constants = (SpvReflectSpecializationConstant*)calloc(p_parser->spec_constant_count, sizeof(*p_module->spec_constants)); + if (IsNull(p_module->spec_constants)) { + return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; + } + } else { + return SPV_REFLECT_RESULT_SUCCESS; + } + for (uint32_t i = 0; i < p_parser->node_count; ++i) { + SpvReflectPrvNode* p_node = &(p_parser->nodes[i]); + if (p_node->op == SpvOpDecorate) { + uint32_t decoration = (uint32_t)INVALID_VALUE; + CHECKED_READU32(p_parser, p_node->word_offset + 2, decoration); + if (decoration == SpvDecorationSpecId) { + SpvReflectSpecializationConstant* p_spec_constant = &(p_module->spec_constants[p_module->spec_constant_count]); + CHECKED_READU32(p_parser, p_node->word_offset + 1, p_spec_constant->spirv_id); + CHECKED_READU32(p_parser, p_node->word_offset + 3, p_spec_constant->constant_id); + SpvReflectPrvNode* target_node = FindNode(p_parser, p_spec_constant->spirv_id); + if (IsNotNull(target_node)) { + // If being used for a OpSpecConstantComposite (ex. LocalSizeId), there won't be a name + p_spec_constant->name = target_node->name; + + // During external specialization, Boolean values are true if non-zero and false if zero. + static uint32_t true_value = 1; + static uint32_t false_value = 0; + switch (target_node->op) { + default: + // Unexpected, since Spec states: + // (SpecId) Apply only to a scalar specialization constant + SPV_REFLECT_ASSERT(false); + return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_INSTRUCTION; + case SpvOpSpecConstantTrue: + p_spec_constant->default_value_size = sizeof(uint32_t); + p_spec_constant->default_value = &true_value; + break; + case SpvOpSpecConstantFalse: + p_spec_constant->default_value_size = sizeof(uint32_t); + p_spec_constant->default_value = &false_value; + break; + case SpvOpSpecConstant: + p_spec_constant->default_value_size = (target_node->word_count - 3) * sizeof(uint32_t); + p_spec_constant->default_value = p_parser->spirv_code + target_node->word_offset + 3; + break; + } + p_spec_constant->type_description = FindType(p_module, target_node->result_type_id); + } else { + // Decoration target not found + return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE; + } + p_module->spec_constant_count++; + } + } + } + return SPV_REFLECT_RESULT_SUCCESS; +} + static int SortCompareDescriptorSet(const void* a, const void* b) { const SpvReflectDescriptorSet* p_elem_a = (const SpvReflectDescriptorSet*)a; const SpvReflectDescriptorSet* p_elem_b = (const SpvReflectDescriptorSet*)b; @@ -4307,7 +4341,7 @@ static SpvReflectResult CreateShaderModule(uint32_t flags, size_t size, const vo SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS); } if (result == SPV_REFLECT_RESULT_SUCCESS) { - result = ParseDecorations(&parser, p_module); + result = ParseDecorations(&parser); SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS); } @@ -4348,6 +4382,10 @@ static SpvReflectResult CreateShaderModule(uint32_t flags, size_t size, const vo result = ParsePushConstantBlocks(&parser, p_module); SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS); } + if (result == SPV_REFLECT_RESULT_SUCCESS) { + result = ParseSpecConstants(&parser, p_module); + SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS); + } if (result == SPV_REFLECT_RESULT_SUCCESS) { result = ParseEntryPoints(&parser, p_module); SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS); diff --git a/spirv_reflect.h b/spirv_reflect.h index 2c0c73d..35a4e4a 100644 --- a/spirv_reflect.h +++ b/spirv_reflect.h @@ -580,6 +580,21 @@ typedef struct SpvReflectSpecializationConstant { uint32_t spirv_id; uint32_t constant_id; const char* name; + SpvReflectTypeDescription* type_description; + + // Size of the default value in bytes (always a multiple of 4). + // Will be 4 for 8/16/32-bit constants and 8 for 64-bit constants. + uint32_t default_value_size; + + // Pointer to the raw default value data. + // The interpretation of this data depends on type_description->op: + // - SpvOpSpecConstantTrue: size = 4, data = uint32_t(1) + // - SpvOpSpecConstantFalse: size = 4, data = uint32_t(0) + // - SpvOpSpecConstant: data contains the bit pattern of the default value + // * The type will be a scalar integer or float. + // * Types 32 bits wide or smaller take one word. + // * Larger types take multiple words, with low-order words appearing first. + void* default_value; } SpvReflectSpecializationConstant; /*! @struct SpvReflectShaderModule