diff --git a/spirv_reflect.c b/spirv_reflect.c index 3df963d1..5fed4f88 100644 --- a/spirv_reflect.c +++ b/spirv_reflect.c @@ -27,6 +27,15 @@ #include #endif +static void* spv_calloc(const SpvAllocationCallbacks* p_allocator, size_t num, size_t size) +{ + if (p_allocator != NULL) { + return p_allocator->pfnAllocation(p_allocator->pUserData, num * size); + } + + return calloc(num, size); +} + // Temporary enums until these make it into SPIR-V/Vulkan // clang-format off enum { @@ -173,10 +182,10 @@ typedef struct AccessChain { // Pointing to the base of a composite object. // Generally the id of descriptor block variable uint32_t base_id; - // + // // From spec: - // The first index in Indexes will select the - // top-level member/element/component/element + // The first index in Indexes will select the + // top-level member/element/component/element // of the base composite uint32_t index_count; uint32_t* indexes; @@ -224,12 +233,16 @@ static uint32_t RoundUp(uint32_t value, uint32_t multiple) #define IsNotNull(ptr) \ (ptr != NULL) -#define SafeFree(ptr) \ - { \ - if (ptr != NULL) { \ - free((void*)ptr); \ - ptr = NULL; \ - } \ +#define SafeFree(ptr_allocator, ptr) \ + { \ + if (ptr != NULL) { \ + if (ptr_allocator != NULL) { \ + ptr_allocator->pfnFree(ptr_allocator->pUserData, (void*)ptr); \ + } else { \ + free((void*)ptr); \ + } \ + ptr = NULL; \ + } \ } static int SortCompareUint32(const void* a, const void* b) @@ -280,13 +293,14 @@ static bool SearchSortedUint32(const uint32_t* arr, size_t size, uint32_t target } static SpvReflectResult IntersectSortedUint32( - const uint32_t* p_arr0, + const uint32_t* p_arr0, size_t arr0_size, - const uint32_t* p_arr1, + const uint32_t* p_arr1, size_t arr1_size, uint32_t** pp_res, size_t* res_size -) +, + const SpvAllocationCallbacks* p_allocator) { *res_size = 0; const uint32_t* arr0_end = p_arr0 + arr0_size; @@ -308,7 +322,7 @@ static SpvReflectResult IntersectSortedUint32( *pp_res = NULL; if (*res_size > 0) { - *pp_res = (uint32_t*)calloc(*res_size, sizeof(**pp_res)); + *pp_res = (uint32_t*)spv_calloc(p_allocator, *res_size, sizeof(**pp_res)); if (IsNull(*pp_res)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; } @@ -389,11 +403,11 @@ static SpvReflectResult ReadU32(Parser* p_parser, uint32_t word_offset, uint32_t } static SpvReflectResult ReadStr( - Parser* p_parser, - uint32_t word_offset, - uint32_t word_index, + Parser* p_parser, + uint32_t word_offset, + uint32_t word_index, uint32_t word_count, - uint32_t* p_buf_size, + uint32_t* p_buf_size, char* p_buf ) { @@ -524,41 +538,41 @@ static SpvReflectResult CreateParser(size_t size, void* p_code, Parser* p_parser return SPV_REFLECT_RESULT_SUCCESS; } -static void DestroyParser(Parser* p_parser) +static void DestroyParser(Parser* p_parser, const SpvAllocationCallbacks* p_allocator) { if (!IsNull(p_parser->nodes)) { // Free nodes for (size_t i = 0; i < p_parser->node_count; ++i) { Node* p_node = &(p_parser->nodes[i]); if (IsNotNull(p_node->member_names)) { - SafeFree(p_node->member_names); + SafeFree(p_allocator, p_node->member_names); } if (IsNotNull(p_node->member_decorations)) { - SafeFree(p_node->member_decorations); + SafeFree(p_allocator, p_node->member_decorations); } } // Free functions for (size_t i = 0; i < p_parser->function_count; ++i) { - SafeFree(p_parser->functions[i].callees); - SafeFree(p_parser->functions[i].callee_ptrs); - SafeFree(p_parser->functions[i].accessed_ptrs); + SafeFree(p_allocator, p_parser->functions[i].callees); + SafeFree(p_allocator, p_parser->functions[i].callee_ptrs); + SafeFree(p_allocator, p_parser->functions[i].accessed_ptrs); } // Free access chains for (uint32_t i = 0; i < p_parser->access_chain_count; ++i) { - SafeFree(p_parser->access_chains[i].indexes); + SafeFree(p_allocator, p_parser->access_chains[i].indexes); } - SafeFree(p_parser->nodes); - SafeFree(p_parser->strings); - SafeFree(p_parser->functions); - SafeFree(p_parser->access_chains); + SafeFree(p_allocator, p_parser->nodes); + SafeFree(p_allocator, p_parser->strings); + SafeFree(p_allocator, p_parser->functions); + SafeFree(p_allocator, p_parser->access_chains); p_parser->node_count = 0; } } -static SpvReflectResult ParseNodes(Parser* p_parser) +static SpvReflectResult ParseNodes(Parser* p_parser, const SpvAllocationCallbacks* p_allocator) { assert(IsNotNull(p_parser)); assert(IsNotNull(p_parser->spirv_code)); @@ -588,7 +602,7 @@ static SpvReflectResult ParseNodes(Parser* p_parser) // Allocate nodes p_parser->node_count = node_count; - p_parser->nodes = (Node*)calloc(p_parser->node_count, sizeof(*(p_parser->nodes))); + p_parser->nodes = (Node*)spv_calloc(p_allocator, p_parser->node_count, sizeof(*(p_parser->nodes))); if (IsNull(p_parser->nodes)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; } @@ -611,7 +625,7 @@ static SpvReflectResult ParseNodes(Parser* p_parser) // Allocate access chain if (p_parser->access_chain_count > 0) { - p_parser->access_chains = (AccessChain*)calloc(p_parser->access_chain_count, sizeof(*(p_parser->access_chains))); + p_parser->access_chains = (AccessChain*)spv_calloc(p_allocator, p_parser->access_chain_count, sizeof(*(p_parser->access_chains))); if (IsNull(p_parser->access_chains)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; } @@ -786,7 +800,7 @@ static SpvReflectResult ParseNodes(Parser* p_parser) // p_access_chain->index_count = (node_word_count - SPIRV_ACCESS_CHAIN_INDEX_OFFSET); if (p_access_chain->index_count > 0) { - p_access_chain->indexes = (uint32_t*)calloc(p_access_chain->index_count, sizeof(*(p_access_chain->indexes))); + p_access_chain->indexes = (uint32_t*)spv_calloc(p_allocator, p_access_chain->index_count, sizeof(*(p_access_chain->indexes))); if (IsNull( p_access_chain->indexes)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; } @@ -849,7 +863,7 @@ static SpvReflectResult ParseNodes(Parser* p_parser) return SPV_REFLECT_RESULT_SUCCESS; } -static SpvReflectResult ParseStrings(Parser* p_parser) +static SpvReflectResult ParseStrings(Parser* p_parser, const SpvAllocationCallbacks* p_allocator) { assert(IsNotNull(p_parser)); assert(IsNotNull(p_parser->spirv_code)); @@ -862,7 +876,7 @@ static SpvReflectResult ParseStrings(Parser* p_parser) if (IsNotNull(p_parser) && IsNotNull(p_parser->spirv_code) && IsNotNull(p_parser->nodes)) { // Allocate string storage - p_parser->strings = (String*)calloc(p_parser->string_count, sizeof(*(p_parser->strings))); + p_parser->strings = (String*)spv_calloc(p_allocator, p_parser->string_count, sizeof(*(p_parser->strings))); uint32_t string_index = 0; for (size_t i = 0; i < p_parser->node_count; ++i) { @@ -914,7 +928,7 @@ static SpvReflectResult ParseSource(Parser* p_parser, SpvReflectShaderModule* p_ return SPV_REFLECT_RESULT_SUCCESS; } -static SpvReflectResult ParseFunction(Parser* p_parser, Node* p_func_node, Function* p_func, size_t first_label_index) +static SpvReflectResult ParseFunction(Parser* p_parser, Node* p_func_node, Function* p_func, size_t first_label_index, const SpvAllocationCallbacks* p_allocator) { p_func->id = p_func_node->result_id; @@ -954,7 +968,7 @@ static SpvReflectResult ParseFunction(Parser* p_parser, Node* p_func_node, Funct } if (p_func->callee_count > 0) { - p_func->callees = (uint32_t*)calloc(p_func->callee_count, + p_func->callees = (uint32_t*)spv_calloc(p_allocator, p_func->callee_count, sizeof(*(p_func->callees))); if (IsNull(p_func->callees)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; @@ -962,7 +976,7 @@ static SpvReflectResult ParseFunction(Parser* p_parser, Node* p_func_node, Funct } if (p_func->accessed_ptr_count > 0) { - p_func->accessed_ptrs = (uint32_t*)calloc(p_func->accessed_ptr_count, + p_func->accessed_ptrs = (uint32_t*)spv_calloc(p_allocator, p_func->accessed_ptr_count, sizeof(*(p_func->accessed_ptrs))); if (IsNull(p_func->accessed_ptrs)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; @@ -1042,7 +1056,7 @@ static int SortCompareFunctions(const void* a, const void* b) return (int)af->id - (int)bf->id; } -static SpvReflectResult ParseFunctions(Parser* p_parser) +static SpvReflectResult ParseFunctions(Parser* p_parser, const SpvAllocationCallbacks* p_allocator) { assert(IsNotNull(p_parser)); assert(IsNotNull(p_parser->spirv_code)); @@ -1053,7 +1067,7 @@ static SpvReflectResult ParseFunctions(Parser* p_parser) return SPV_REFLECT_RESULT_SUCCESS; } - p_parser->functions = (Function*)calloc(p_parser->function_count, + p_parser->functions = (Function*)spv_calloc(p_allocator, p_parser->function_count, sizeof(*(p_parser->functions))); if (IsNull(p_parser->functions)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; @@ -1085,7 +1099,7 @@ static SpvReflectResult ParseFunctions(Parser* p_parser) Function* p_function = &(p_parser->functions[function_index]); - SpvReflectResult result = ParseFunction(p_parser, p_node, p_function, i); + SpvReflectResult result = ParseFunction(p_parser, p_node, p_function, i, p_allocator); if (result != SPV_REFLECT_RESULT_SUCCESS) { return result; } @@ -1103,7 +1117,7 @@ static SpvReflectResult ParseFunctions(Parser* p_parser) if (p_func->callee_count == 0) { continue; } - p_func->callee_ptrs = (Function**)calloc(p_func->callee_count, + p_func->callee_ptrs = (Function**)spv_calloc(p_allocator, p_func->callee_count, sizeof(*(p_func->callee_ptrs))); for (size_t j = 0, k = 0; j < p_func->callee_count; ++j) { while (p_parser->functions[k].id != p_func->callees[j]) { @@ -1121,7 +1135,7 @@ static SpvReflectResult ParseFunctions(Parser* p_parser) return SPV_REFLECT_RESULT_SUCCESS; } -static SpvReflectResult ParseMemberCounts(Parser* p_parser) +static SpvReflectResult ParseMemberCounts(Parser* p_parser, const SpvAllocationCallbacks* p_allocator) { assert(IsNotNull(p_parser)); assert(IsNotNull(p_parser->spirv_code)); @@ -1157,12 +1171,12 @@ static SpvReflectResult ParseMemberCounts(Parser* p_parser) continue; } - p_node->member_names = (const char **)calloc(p_node->member_count, sizeof(*(p_node->member_names))); + p_node->member_names = (const char **)spv_calloc(p_allocator, p_node->member_count, sizeof(*(p_node->member_names))); if (IsNull(p_node->member_names)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; } - p_node->member_decorations = (Decorations*)calloc(p_node->member_count, sizeof(*(p_node->member_decorations))); + p_node->member_decorations = (Decorations*)spv_calloc(p_allocator, p_node->member_count, sizeof(*(p_node->member_decorations))); if (IsNull(p_node->member_decorations)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; } @@ -1236,7 +1250,7 @@ static SpvReflectResult ParseDecorations(Parser* p_parser) default: { skip = true; } - break; + break; case SpvDecorationBlock: case SpvDecorationBufferBlock: case SpvDecorationColMajor: @@ -1256,13 +1270,13 @@ static SpvReflectResult ParseDecorations(Parser* p_parser) case SpvReflectDecorationHlslSemanticGOOGLE: { skip = false; } - break; + break; } if (skip) { continue; - } - - // Find target target node + } + + // Find target target node uint32_t target_id = 0; CHECKED_READU32(p_parser, p_node->word_offset + 1, target_id); Node* p_target_node = FindNode(p_parser, target_id); @@ -1389,16 +1403,16 @@ static SpvReflectResult ParseDecorations(Parser* p_parser) } static SpvReflectResult EnumerateAllUniforms( - SpvReflectShaderModule* p_module, - size_t* p_uniform_count, - uint32_t** pp_uniforms -) + SpvReflectShaderModule* p_module, + size_t* p_uniform_count, + uint32_t** pp_uniforms, + const SpvAllocationCallbacks* p_allocator) { *p_uniform_count = p_module->descriptor_binding_count; if (*p_uniform_count == 0) { return SPV_REFLECT_RESULT_SUCCESS; } - *pp_uniforms = (uint32_t*)calloc(*p_uniform_count, sizeof(**pp_uniforms)); + *pp_uniforms = (uint32_t*)spv_calloc(p_allocator, *p_uniform_count, sizeof(**pp_uniforms)); if (IsNull(*pp_uniforms)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; @@ -1413,18 +1427,18 @@ static SpvReflectResult EnumerateAllUniforms( } static SpvReflectResult ParseType( - Parser* p_parser, - Node* p_node, - Decorations* p_struct_member_decorations, - SpvReflectShaderModule* p_module, + Parser* p_parser, + Node* p_node, + Decorations* p_struct_member_decorations, + SpvReflectShaderModule* p_module, SpvReflectTypeDescription* p_type -) +, const SpvAllocationCallbacks* p_allocator) { SpvReflectResult result = SPV_REFLECT_RESULT_SUCCESS; if (p_node->member_count > 0) { p_type->member_count = p_node->member_count; - p_type->members = (SpvReflectTypeDescription*)calloc(p_type->member_count, sizeof(*(p_type->members))); + p_type->members = (SpvReflectTypeDescription*)spv_calloc(p_allocator, p_type->member_count, sizeof(*(p_type->members))); if (IsNotNull(p_type->members)) { // Mark all members types with an invalid state for (size_t i = 0; i < p_type->members->member_count; ++i) { @@ -1449,7 +1463,7 @@ static SpvReflectResult ParseType( } // Top level types need to pick up decorations from all types below it. // Issue and fix here: https://github.com/chaoticbob/SPIRV-Reflect/issues/64 - p_type->decoration_flags = ApplyDecorations(&p_node->decorations); + p_type->decoration_flags |= ApplyDecorations(&p_node->decorations); switch (p_node->op) { default: break; @@ -1482,7 +1496,7 @@ static SpvReflectResult ParseType( // Parse component type Node* p_next_node = FindNode(p_parser, component_type_id); if (IsNotNull(p_next_node)) { - result = ParseType(p_parser, p_next_node, NULL, p_module, p_type); + result = ParseType(p_parser, p_next_node, NULL, p_module, p_type, p_allocator); } else { result = SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE; @@ -1497,7 +1511,7 @@ static SpvReflectResult ParseType( IF_READU32(result, p_parser, p_node->word_offset + 3, p_type->traits.numeric.matrix.column_count); Node* p_next_node = FindNode(p_parser, column_type_id); if (IsNotNull(p_next_node)) { - result = ParseType(p_parser, p_next_node, NULL, p_module, p_type); + result = ParseType(p_parser, p_next_node, NULL, p_module, p_type, p_allocator); } else { result = SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE; @@ -1533,7 +1547,7 @@ static SpvReflectResult ParseType( IF_READU32(result, p_parser, p_node->word_offset + 2, image_type_id); Node* p_next_node = FindNode(p_parser, image_type_id); if (IsNotNull(p_next_node)) { - result = ParseType(p_parser, p_next_node, NULL, p_module, p_type); + result = ParseType(p_parser, p_next_node, NULL, p_module, p_type, p_allocator); } else { result = SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE; @@ -1572,7 +1586,7 @@ static SpvReflectResult ParseType( // Parse next dimension or element type Node* p_next_node = FindNode(p_parser, element_type_id); if (IsNotNull(p_next_node)) { - result = ParseType(p_parser, p_next_node, NULL, p_module, p_type); + result = ParseType(p_parser, p_next_node, NULL, p_module, p_type, p_allocator); } } else { @@ -1588,7 +1602,7 @@ static SpvReflectResult ParseType( // Parse next dimension or element type Node* p_next_node = FindNode(p_parser, element_type_id); if (IsNotNull(p_next_node)) { - result = ParseType(p_parser, p_next_node, NULL, p_module, p_type); + result = ParseType(p_parser, p_next_node, NULL, p_module, p_type, p_allocator); } else { result = SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE; @@ -1619,11 +1633,11 @@ static SpvReflectResult ParseType( SpvReflectTypeDescription* p_member_type = &(p_type->members[member_index]); p_member_type->id = member_id; p_member_type->op = p_member_node->op; - result = ParseType(p_parser, p_member_node, p_member_decorations, p_module, p_member_type); + result = ParseType(p_parser, p_member_node, p_member_decorations, p_module, p_member_type, p_allocator); if (result != SPV_REFLECT_RESULT_SUCCESS) { break; } - // This looks wrong + // This looks wrong //p_member_type->type_name = p_member_node->name; p_member_type->struct_member_name = p_node->member_names[member_index]; } @@ -1639,7 +1653,7 @@ static SpvReflectResult ParseType( // Parse type Node* p_next_node = FindNode(p_parser, type_id); if (IsNotNull(p_next_node)) { - result = ParseType(p_parser, p_next_node, NULL, p_module, p_type); + result = ParseType(p_parser, p_next_node, NULL, p_module, p_type, p_allocator); } else { result = SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE; @@ -1660,14 +1674,14 @@ static SpvReflectResult ParseType( return result; } -static SpvReflectResult ParseTypes(Parser* p_parser, SpvReflectShaderModule* p_module) +static SpvReflectResult ParseTypes(Parser* p_parser, SpvReflectShaderModule* p_module, const SpvAllocationCallbacks* p_allocator) { if (p_parser->type_count == 0) { return SPV_REFLECT_RESULT_SUCCESS; } p_module->_internal->type_description_count = p_parser->type_count; - p_module->_internal->type_descriptions = (SpvReflectTypeDescription*)calloc(p_module->_internal->type_description_count, + p_module->_internal->type_descriptions = (SpvReflectTypeDescription*)spv_calloc(p_allocator, p_module->_internal->type_description_count, sizeof(*(p_module->_internal->type_descriptions))); if (IsNull(p_module->_internal->type_descriptions)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; @@ -1689,7 +1703,7 @@ static SpvReflectResult ParseTypes(Parser* p_parser, SpvReflectShaderModule* p_m } SpvReflectTypeDescription* p_type = &(p_module->_internal->type_descriptions[type_index]); - SpvReflectResult result = ParseType(p_parser, p_node, NULL, p_module, p_type); + SpvReflectResult result = ParseType(p_parser, p_node, NULL, p_module, p_type, p_allocator); if (result != SPV_REFLECT_RESULT_SUCCESS) { return result; } @@ -1712,7 +1726,7 @@ static int SortCompareDescriptorBinding(const void* a, const void* b) return value; } -static SpvReflectResult ParseDescriptorBindings(Parser* p_parser, SpvReflectShaderModule* p_module) +static SpvReflectResult ParseDescriptorBindings(Parser* p_parser, SpvReflectShaderModule* p_module, const SpvAllocationCallbacks* p_allocator) { p_module->descriptor_binding_count = 0; for (size_t i = 0; i < p_parser->node_count; ++i) { @@ -1733,7 +1747,7 @@ static SpvReflectResult ParseDescriptorBindings(Parser* p_parser, SpvReflectShad return SPV_REFLECT_RESULT_SUCCESS; } - p_module->descriptor_bindings = (SpvReflectDescriptorBinding*)calloc(p_module->descriptor_binding_count, sizeof(*(p_module->descriptor_bindings))); + p_module->descriptor_bindings = (SpvReflectDescriptorBinding*)spv_calloc(p_allocator, p_module->descriptor_binding_count, sizeof(*(p_module->descriptor_bindings))); if (IsNull(p_module->descriptor_bindings)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; } @@ -1948,7 +1962,7 @@ static SpvReflectResult ParseUAVCounterBindings(SpvReflectShaderModule* p_module else { const size_t descriptor_name_length = p_descriptor->name? strlen(p_descriptor->name): 0; - memset(name, 0, MAX_NODE_NAME_LENGTH); + memset(name, 0, MAX_NODE_NAME_LENGTH); memcpy(name, p_descriptor->name, descriptor_name_length); #if defined(WIN32) strcat_s(name, MAX_NODE_NAME_LENGTH, k_count_tag); @@ -1977,17 +1991,17 @@ static SpvReflectResult ParseUAVCounterBindings(SpvReflectShaderModule* p_module } static SpvReflectResult ParseDescriptorBlockVariable( - Parser* p_parser, - SpvReflectShaderModule* p_module, - SpvReflectTypeDescription* p_type, + Parser* p_parser, + SpvReflectShaderModule* p_module, + SpvReflectTypeDescription* p_type, SpvReflectBlockVariable* p_var -) +, const SpvAllocationCallbacks* p_allocator) { bool has_non_writable = false; if (IsNotNull(p_type->members) && (p_type->member_count > 0)) { p_var->member_count = p_type->member_count; - p_var->members = (SpvReflectBlockVariable*)calloc(p_var->member_count, sizeof(*p_var->members)); + p_var->members = (SpvReflectBlockVariable*)spv_calloc(p_allocator, p_var->member_count, sizeof(*p_var->members)); if (IsNull(p_var->members)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; } @@ -2017,14 +2031,14 @@ static SpvReflectResult ParseDescriptorBlockVariable( return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE; } } - + // Parse members for (uint32_t member_index = 0; member_index < p_type->member_count; ++member_index) { SpvReflectTypeDescription* p_member_type = &p_type->members[member_index]; SpvReflectBlockVariable* p_member_var = &p_var->members[member_index]; bool is_struct = (p_member_type->type_flags & SPV_REFLECT_TYPE_FLAG_STRUCT) == SPV_REFLECT_TYPE_FLAG_STRUCT; if (is_struct) { - SpvReflectResult result = ParseDescriptorBlockVariable(p_parser, p_module, p_member_type, p_member_var); + SpvReflectResult result = ParseDescriptorBlockVariable(p_parser, p_module, p_member_type, p_member_var, p_allocator); if (result != SPV_REFLECT_RESULT_SUCCESS) { return result; } @@ -2056,11 +2070,11 @@ static SpvReflectResult ParseDescriptorBlockVariable( } static SpvReflectResult ParseDescriptorBlockVariableSizes( - Parser* p_parser, - SpvReflectShaderModule* p_module, - bool is_parent_root, - bool is_parent_aos, - bool is_parent_rta, + Parser* p_parser, + SpvReflectShaderModule* p_module, + bool is_parent_root, + bool is_parent_aos, + bool is_parent_rta, SpvReflectBlockVariable* p_var ) { @@ -2203,7 +2217,7 @@ static SpvReflectResult ParseDescriptorBlockVariableUsage( // Clear the current variable's USED flag p_var->flags &= ~SPV_REFLECT_VARIABLE_FLAGS_UNUSED; - + // Parsing arrays requires overriding the op type for // for the lowest dim's element type. SpvOp op_type = p_var->type_description->op; @@ -2253,7 +2267,7 @@ static SpvReflectResult ParseDescriptorBlockVariableUsage( } uint32_t index = p_access_chain->indexes[index_index]; - + if (index >= p_var->member_count) { return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_BLOCK_MEMBER_REFERENCE; } @@ -2278,7 +2292,7 @@ static SpvReflectResult ParseDescriptorBlockVariableUsage( return SPV_REFLECT_RESULT_SUCCESS; } -static SpvReflectResult ParseDescriptorBlocks(Parser* p_parser, SpvReflectShaderModule* p_module) +static SpvReflectResult ParseDescriptorBlocks(Parser* p_parser, SpvReflectShaderModule* p_module, const SpvAllocationCallbacks* p_allocator) { if (p_module->descriptor_binding_count == 0) { return SPV_REFLECT_RESULT_SUCCESS; @@ -2295,8 +2309,8 @@ static SpvReflectResult ParseDescriptorBlocks(Parser* p_parser, SpvReflectShader // Mark UNUSED p_descriptor->block.flags |= SPV_REFLECT_VARIABLE_FLAGS_UNUSED; - // Parse descriptor block - SpvReflectResult result = ParseDescriptorBlockVariable(p_parser, p_module, p_type, &p_descriptor->block); + // Parse descriptor block + SpvReflectResult result = ParseDescriptorBlockVariable(p_parser, p_module, p_type, &p_descriptor->block, p_allocator); if (result != SPV_REFLECT_RESULT_SUCCESS) { return result; } @@ -2318,7 +2332,7 @@ static SpvReflectResult ParseDescriptorBlocks(Parser* p_parser, SpvReflectShader return result; } } - + p_descriptor->block.name = p_descriptor->name; bool is_parent_rta = (p_descriptor->descriptor_type == SPV_REFLECT_DESCRIPTOR_TYPE_STORAGE_BUFFER); @@ -2383,14 +2397,13 @@ static SpvReflectResult ParseFormat( return result; } -static SpvReflectResult ParseInterfaceVariable( - Parser* p_parser, - const Decorations* p_type_node_decorations, - SpvReflectShaderModule* p_module, - SpvReflectTypeDescription* p_type, - SpvReflectInterfaceVariable* p_var, - bool* p_has_built_in -) +static SpvReflectResult ParseInterfaceVariable(Parser* p_parser, + const Decorations* p_type_node_decorations, + SpvReflectShaderModule* p_module, + SpvReflectTypeDescription* p_type, + SpvReflectInterfaceVariable* p_var, + bool* p_has_built_in, + const SpvAllocationCallbacks* p_allocator) { Node* p_type_node = FindNode(p_parser, p_type->id); if (IsNull(p_type_node)) { @@ -2399,7 +2412,7 @@ static SpvReflectResult ParseInterfaceVariable( if (p_type->member_count > 0) { p_var->member_count = p_type->member_count; - p_var->members = (SpvReflectInterfaceVariable*)calloc(p_var->member_count, sizeof(*p_var->members)); + p_var->members = (SpvReflectInterfaceVariable*)spv_calloc(p_allocator, p_var->member_count, sizeof(*p_var->members)); if (IsNull(p_var->members)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; } @@ -2408,7 +2421,7 @@ static SpvReflectResult ParseInterfaceVariable( Decorations* p_member_decorations = &p_type_node->member_decorations[member_index]; SpvReflectTypeDescription* p_member_type = &p_type->members[member_index]; SpvReflectInterfaceVariable* p_member_var = &p_var->members[member_index]; - SpvReflectResult result = ParseInterfaceVariable(p_parser, p_member_decorations, p_module, p_member_type, p_member_var, p_has_built_in); + SpvReflectResult result = ParseInterfaceVariable(p_parser, p_member_decorations, p_module, p_member_type, p_member_var, p_has_built_in, p_allocator); if (result != SPV_REFLECT_RESULT_SUCCESS) { return result; } @@ -2436,12 +2449,12 @@ static SpvReflectResult ParseInterfaceVariable( } static SpvReflectResult ParseInterfaceVariables( - Parser* p_parser, - SpvReflectShaderModule* p_module, - SpvReflectEntryPoint* p_entry, - size_t io_var_count, - uint32_t* io_vars -) + Parser* p_parser, + SpvReflectShaderModule* p_module, + SpvReflectEntryPoint* p_entry, + size_t io_var_count, + uint32_t* io_vars, + const SpvAllocationCallbacks* p_allocator) { if (io_var_count == 0) { return SPV_REFLECT_RESULT_SUCCESS; @@ -2465,7 +2478,7 @@ static SpvReflectResult ParseInterfaceVariables( } if (p_entry->input_variable_count > 0) { - p_entry->input_variables = (SpvReflectInterfaceVariable*)calloc(p_entry->input_variable_count, sizeof(*(p_entry->input_variables))); + p_entry->input_variables = (SpvReflectInterfaceVariable*)spv_calloc(p_allocator, p_entry->input_variable_count, sizeof(*(p_entry->input_variables))); if (IsNull(p_entry->input_variables)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; } @@ -2473,7 +2486,7 @@ static SpvReflectResult ParseInterfaceVariables( if (p_entry->output_variable_count > 0) { - p_entry->output_variables = (SpvReflectInterfaceVariable*)calloc(p_entry->output_variable_count, sizeof(*(p_entry->output_variables))); + p_entry->output_variables = (SpvReflectInterfaceVariable*)spv_calloc(p_allocator, p_entry->output_variable_count, sizeof(*(p_entry->output_variables))); if (IsNull(p_entry->output_variables)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; } @@ -2535,7 +2548,8 @@ static SpvReflectResult ParseInterfaceVariables( p_module, p_type, p_var, - &has_built_in); + &has_built_in, + p_allocator); if (result != SPV_REFLECT_RESULT_SUCCESS) { return result; } @@ -2569,13 +2583,14 @@ static SpvReflectResult EnumerateAllPushConstants( SpvReflectShaderModule* p_module, size_t* p_push_constant_count, uint32_t** p_push_constants -) +, + const SpvAllocationCallbacks* p_allocator) { *p_push_constant_count = p_module->push_constant_block_count; if (*p_push_constant_count == 0) { return SPV_REFLECT_RESULT_SUCCESS; } - *p_push_constants = (uint32_t*)calloc(*p_push_constant_count, sizeof(**p_push_constants)); + *p_push_constants = (uint32_t*)spv_calloc(p_allocator, *p_push_constant_count, sizeof(**p_push_constants)); if (IsNull(*p_push_constants)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; @@ -2626,7 +2641,8 @@ static SpvReflectResult ParseStaticallyUsedResources( uint32_t* uniforms, size_t push_constant_count, uint32_t* push_constants -) +, + const SpvAllocationCallbacks* p_allocator) { // Find function with the right id Function* p_func = NULL; @@ -2653,7 +2669,7 @@ static SpvReflectResult ParseStaticallyUsedResources( uint32_t* p_called_functions = NULL; if (called_function_count > 0) { - p_called_functions = (uint32_t*)calloc(called_function_count, sizeof(*p_called_functions)); + p_called_functions = (uint32_t*)spv_calloc(p_allocator, called_function_count, sizeof(*p_called_functions)); if (IsNull(p_called_functions)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; } @@ -2672,8 +2688,8 @@ static SpvReflectResult ParseStaticallyUsedResources( if (called_function_count > 0) { qsort( - p_called_functions, - called_function_count, + p_called_functions, + called_function_count, sizeof(*p_called_functions), SortCompareUint32); } @@ -2690,10 +2706,10 @@ static SpvReflectResult ParseStaticallyUsedResources( } uint32_t* used_variables = NULL; if (used_variable_count > 0) { - used_variables = (uint32_t*)calloc(used_variable_count, + used_variables = (uint32_t*)spv_calloc(p_allocator, used_variable_count, sizeof(*used_variables)); if (IsNull(used_variables)) { - SafeFree(p_called_functions); + SafeFree(p_allocator, p_called_functions); return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; } } @@ -2708,13 +2724,13 @@ static SpvReflectResult ParseStaticallyUsedResources( p_parser->functions[j].accessed_ptr_count * sizeof(*used_variables)); used_variable_count += p_parser->functions[j].accessed_ptr_count; } - SafeFree(p_called_functions); + SafeFree(p_allocator, p_called_functions); if (used_variable_count > 0) { qsort(used_variables, used_variable_count, sizeof(*used_variables), SortCompareUint32); } - used_variable_count = (uint32_t)DedupSortedUint32(used_variables, + used_variable_count = (uint32_t)DedupSortedUint32(used_variables, used_variable_count); // Do set intersection to find the used uniform and push constants @@ -2723,20 +2739,22 @@ static SpvReflectResult ParseStaticallyUsedResources( SpvReflectResult result0 = IntersectSortedUint32( used_variables, used_variable_count, - uniforms, + uniforms, uniform_count, &p_entry->used_uniforms, - &used_uniform_count); + &used_uniform_count, + p_allocator); size_t used_push_constant_count = 0; // SpvReflectResult result1 = IntersectSortedUint32( - used_variables, + used_variables, used_variable_count, - push_constants, + push_constants, push_constant_count, &p_entry->used_push_constants, - &used_push_constant_count); + &used_push_constant_count, + p_allocator); for (uint32_t j = 0; j < p_module->descriptor_binding_count; ++j) { SpvReflectDescriptorBinding* p_binding = &p_module->descriptor_bindings[j]; @@ -2749,7 +2767,7 @@ static SpvReflectResult ParseStaticallyUsedResources( } } - SafeFree(used_variables); + SafeFree(p_allocator, used_variables); if (result0 != SPV_REFLECT_RESULT_SUCCESS) { return result0; } @@ -2763,14 +2781,14 @@ static SpvReflectResult ParseStaticallyUsedResources( return SPV_REFLECT_RESULT_SUCCESS; } -static SpvReflectResult ParseEntryPoints(Parser* p_parser, SpvReflectShaderModule* p_module) +static SpvReflectResult ParseEntryPoints(Parser* p_parser, SpvReflectShaderModule* p_module, const SpvAllocationCallbacks* p_allocator) { if (p_parser->entry_point_count == 0) { return SPV_REFLECT_RESULT_SUCCESS; } p_module->entry_point_count = p_parser->entry_point_count; - p_module->entry_points = (SpvReflectEntryPoint*)calloc(p_module->entry_point_count, + p_module->entry_points = (SpvReflectEntryPoint*)spv_calloc(p_allocator, p_module->entry_point_count, sizeof(*(p_module->entry_points))); if (IsNull(p_module->entry_points)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; @@ -2779,13 +2797,13 @@ static SpvReflectResult ParseEntryPoints(Parser* p_parser, SpvReflectShaderModul SpvReflectResult result; size_t uniform_count = 0; uint32_t* uniforms = NULL; - if ((result = EnumerateAllUniforms(p_module, &uniform_count, &uniforms)) != + if ((result = EnumerateAllUniforms(p_module, &uniform_count, &uniforms, p_allocator)) != SPV_REFLECT_RESULT_SUCCESS) { return result; } size_t push_constant_count = 0; uint32_t* push_constants = NULL; - if ((result = EnumerateAllPushConstants(p_module, &push_constant_count, &push_constants)) != + if ((result = EnumerateAllPushConstants(p_module, &push_constant_count, &push_constants, p_allocator)) != SPV_REFLECT_RESULT_SUCCESS) { return result; } @@ -2826,7 +2844,7 @@ static SpvReflectResult ParseEntryPoints(Parser* p_parser, SpvReflectShaderModul size_t interface_variable_count = (p_node->word_count - (name_start_word_offset + name_word_count)); uint32_t* interface_variables = NULL; if (interface_variable_count > 0) { - interface_variables = (uint32_t*)calloc(interface_variable_count, sizeof(*(interface_variables))); + interface_variables = (uint32_t*)spv_calloc(p_allocator, interface_variable_count, sizeof(*(interface_variables))); if (IsNull(interface_variables)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; } @@ -2844,11 +2862,12 @@ static SpvReflectResult ParseEntryPoints(Parser* p_parser, SpvReflectShaderModul p_module, p_entry_point, interface_variable_count, - interface_variables); + interface_variables, + p_allocator); if (result != SPV_REFLECT_RESULT_SUCCESS) { return result; } - SafeFree(interface_variables); + SafeFree(p_allocator, interface_variables); result = ParseStaticallyUsedResources( p_parser, @@ -2857,19 +2876,20 @@ static SpvReflectResult ParseEntryPoints(Parser* p_parser, SpvReflectShaderModul uniform_count, uniforms, push_constant_count, - push_constants); + push_constants, + p_allocator); if (result != SPV_REFLECT_RESULT_SUCCESS) { return result; } } - SafeFree(uniforms); - SafeFree(push_constants); + SafeFree(p_allocator, uniforms); + SafeFree(p_allocator, push_constants); return SPV_REFLECT_RESULT_SUCCESS; } -static SpvReflectResult ParsePushConstantBlocks(Parser* p_parser, SpvReflectShaderModule* p_module) +static SpvReflectResult ParsePushConstantBlocks(Parser* p_parser, SpvReflectShaderModule* p_module, const SpvAllocationCallbacks* p_allocator) { for (size_t i = 0; i < p_parser->node_count; ++i) { Node* p_node = &(p_parser->nodes[i]); @@ -2884,7 +2904,7 @@ static SpvReflectResult ParsePushConstantBlocks(Parser* p_parser, SpvReflectShad return SPV_REFLECT_RESULT_SUCCESS; } - p_module->push_constant_blocks = (SpvReflectBlockVariable*)calloc(p_module->push_constant_block_count, sizeof(*p_module->push_constant_blocks)); + p_module->push_constant_blocks = (SpvReflectBlockVariable*)spv_calloc(p_allocator, p_module->push_constant_block_count, sizeof(*p_module->push_constant_blocks)); if (IsNull(p_module->push_constant_blocks)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; } @@ -2921,7 +2941,7 @@ static SpvReflectResult ParsePushConstantBlocks(Parser* p_parser, SpvReflectShad SpvReflectBlockVariable* p_push_constant = &p_module->push_constant_blocks[push_constant_index]; p_push_constant->spirv_id = p_node->result_id; - SpvReflectResult result = ParseDescriptorBlockVariable(p_parser, p_module, p_type, p_push_constant); + SpvReflectResult result = ParseDescriptorBlockVariable(p_parser, p_module, p_type, p_push_constant, p_allocator); if (result != SPV_REFLECT_RESULT_SUCCESS) { return result; } @@ -2947,14 +2967,14 @@ static int SortCompareDescriptorSet(const void* a, const void* b) return value; } -static SpvReflectResult ParseEntrypointDescriptorSets(SpvReflectShaderModule* p_module) { +static SpvReflectResult ParseEntrypointDescriptorSets(SpvReflectShaderModule* p_module, const SpvAllocationCallbacks* p_allocator) { // Update the entry point's sets for (uint32_t i = 0; i < p_module->entry_point_count; ++i) { SpvReflectEntryPoint* p_entry = &p_module->entry_points[i]; for (uint32_t j = 0; j < p_entry->descriptor_set_count; ++j) { - SafeFree(p_entry->descriptor_sets[j].bindings); + SafeFree(p_allocator, p_entry->descriptor_sets[j].bindings); } - SafeFree(p_entry->descriptor_sets); + SafeFree(p_allocator, p_entry->descriptor_sets); p_entry->descriptor_set_count = 0; for (uint32_t j = 0; j < p_module->descriptor_set_count; ++j) { const SpvReflectDescriptorSet* p_set = &p_module->descriptor_sets[j]; @@ -2972,7 +2992,7 @@ static SpvReflectResult ParseEntrypointDescriptorSets(SpvReflectShaderModule* p_ p_entry->descriptor_sets = NULL; if (p_entry->descriptor_set_count > 0) { - p_entry->descriptor_sets = (SpvReflectDescriptorSet*)calloc(p_entry->descriptor_set_count, + p_entry->descriptor_sets = (SpvReflectDescriptorSet*)spv_calloc(p_allocator, p_entry->descriptor_set_count, sizeof(*p_entry->descriptor_sets)); if (IsNull(p_entry->descriptor_sets)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; @@ -2997,7 +3017,7 @@ static SpvReflectResult ParseEntrypointDescriptorSets(SpvReflectShaderModule* p_ SpvReflectDescriptorSet* p_entry_set = &p_entry->descriptor_sets[ p_entry->descriptor_set_count++]; p_entry_set->set = p_set->set; - p_entry_set->bindings = (SpvReflectDescriptorBinding**)calloc(count, + p_entry_set->bindings = (SpvReflectDescriptorBinding**)spv_calloc(p_allocator, count, sizeof(*p_entry_set->bindings)); if (IsNull(p_entry_set->bindings)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; @@ -3017,7 +3037,7 @@ static SpvReflectResult ParseEntrypointDescriptorSets(SpvReflectShaderModule* p_ return SPV_REFLECT_RESULT_SUCCESS; } -static SpvReflectResult ParseDescriptorSets(SpvReflectShaderModule* p_module) +static SpvReflectResult ParseDescriptorSets(SpvReflectShaderModule* p_module, const SpvAllocationCallbacks* p_allocator) { // Count the descriptors in each set for (uint32_t i = 0; i < p_module->descriptor_binding_count; ++i) { @@ -3071,7 +3091,7 @@ static SpvReflectResult ParseDescriptorSets(SpvReflectShaderModule* p_module) // Build descriptor pointer array for (uint32_t i = 0; i descriptor_set_count; ++i) { SpvReflectDescriptorSet* p_set = &(p_module->descriptor_sets[i]); - p_set->bindings = (SpvReflectDescriptorBinding **)calloc(p_set->binding_count, sizeof(*(p_set->bindings))); + p_set->bindings = (SpvReflectDescriptorBinding **)spv_calloc(p_allocator, p_set->binding_count, sizeof(*(p_set->bindings))); uint32_t descriptor_index = 0; for (uint32_t j = 0; j < p_module->descriptor_binding_count; ++j) { @@ -3084,7 +3104,7 @@ static SpvReflectResult ParseDescriptorSets(SpvReflectShaderModule* p_module) } } - return ParseEntrypointDescriptorSets(p_module); + return ParseEntrypointDescriptorSets(p_module, p_allocator); } static SpvReflectResult DisambiguateStorageBufferSrvUav(SpvReflectShaderModule* p_module) @@ -3114,19 +3134,19 @@ static SpvReflectResult DisambiguateStorageBufferSrvUav(SpvReflectShaderModule* return SPV_REFLECT_RESULT_SUCCESS; } -static SpvReflectResult SynchronizeDescriptorSets(SpvReflectShaderModule* p_module) +static SpvReflectResult SynchronizeDescriptorSets(SpvReflectShaderModule* p_module, const SpvAllocationCallbacks* p_allocator) { // Free and reset all descriptor set numbers for (uint32_t i = 0; i < SPV_REFLECT_MAX_DESCRIPTOR_SETS; ++i) { SpvReflectDescriptorSet* p_set = &p_module->descriptor_sets[i]; - SafeFree(p_set->bindings); + SafeFree(p_allocator, p_set->bindings); p_set->binding_count = 0; p_set->set = (uint32_t)INVALID_VALUE; } // Set descriptor set count to zero p_module->descriptor_set_count = 0; - SpvReflectResult result = ParseDescriptorSets(p_module); + SpvReflectResult result = ParseDescriptorSets(p_module, p_allocator); return result; } @@ -3136,33 +3156,42 @@ SpvReflectResult spvReflectGetShaderModule( SpvReflectShaderModule* p_module ) { - return spvReflectCreateShaderModule(size, p_code, p_module); + return spvReflectCreateShaderModuleEx(size, p_code, p_module, NULL); } SpvReflectResult spvReflectCreateShaderModule( + size_t size, + const void* p_code, + SpvReflectShaderModule* p_module +) +{ + return spvReflectCreateShaderModuleEx(size, p_code, p_module, NULL); +} + +SpvReflectResult spvReflectCreateShaderModuleEx( size_t size, const void* p_code, - SpvReflectShaderModule* p_module -) + SpvReflectShaderModule* p_module, + const SpvAllocationCallbacks* p_allocator) { // Initialize all module fields to zero memset(p_module, 0, sizeof(*p_module)); // Allocate module internals #ifdef __cplusplus - p_module->_internal = (SpvReflectShaderModule::Internal*)calloc(1, sizeof(*(p_module->_internal))); + p_module->_internal = (SpvReflectShaderModule::Internal*)spv_calloc(p_allocator, 1, sizeof(*(p_module->_internal))); #else - p_module->_internal = calloc(1, sizeof(*(p_module->_internal))); + p_module->_internal = spv_calloc(p_allocator, 1, sizeof(*(p_module->_internal))); #endif if (IsNull(p_module->_internal)) { return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; } // Allocate SPIR-V code storage p_module->_internal->spirv_size = size; - p_module->_internal->spirv_code = (uint32_t*)calloc(1, p_module->_internal->spirv_size); + p_module->_internal->spirv_code = (uint32_t*)spv_calloc(p_allocator, 1, p_module->_internal->spirv_size); p_module->_internal->spirv_word_count = (uint32_t)(size / SPIRV_WORD_SIZE); if (IsNull(p_module->_internal->spirv_code)) { - SafeFree(p_module->_internal); + SafeFree(p_allocator, p_module->_internal); return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED; } memcpy(p_module->_internal->spirv_code, p_code, size); @@ -3179,19 +3208,19 @@ SpvReflectResult spvReflectCreateShaderModule( } if (result == SPV_REFLECT_RESULT_SUCCESS) { - result = ParseNodes(&parser); + result = ParseNodes(&parser, p_allocator); } if (result == SPV_REFLECT_RESULT_SUCCESS) { - result = ParseStrings(&parser); + result = ParseStrings(&parser, p_allocator); } if (result == SPV_REFLECT_RESULT_SUCCESS) { result = ParseSource(&parser, p_module); } if (result == SPV_REFLECT_RESULT_SUCCESS) { - result = ParseFunctions(&parser); + result = ParseFunctions(&parser, p_allocator); } if (result == SPV_REFLECT_RESULT_SUCCESS) { - result = ParseMemberCounts(&parser); + result = ParseMemberCounts(&parser, p_allocator); } if (result == SPV_REFLECT_RESULT_SUCCESS) { result = ParseNames(&parser); @@ -3214,10 +3243,10 @@ SpvReflectResult spvReflectCreateShaderModule( } } if (result == SPV_REFLECT_RESULT_SUCCESS) { - result = ParseTypes(&parser, p_module); + result = ParseTypes(&parser, p_module, p_allocator); } if (result == SPV_REFLECT_RESULT_SUCCESS) { - result = ParseDescriptorBindings(&parser, p_module); + result = ParseDescriptorBindings(&parser, p_module, p_allocator); } if (result == SPV_REFLECT_RESULT_SUCCESS) { result = ParseDescriptorType(p_module); @@ -3226,13 +3255,13 @@ SpvReflectResult spvReflectCreateShaderModule( result = ParseUAVCounterBindings(p_module); } if (result == SPV_REFLECT_RESULT_SUCCESS) { - result = ParseDescriptorBlocks(&parser, p_module); + result = ParseDescriptorBlocks(&parser, p_module, p_allocator); } if (result == SPV_REFLECT_RESULT_SUCCESS) { - result = ParsePushConstantBlocks(&parser, p_module); + result = ParsePushConstantBlocks(&parser, p_module, p_allocator); } if (result == SPV_REFLECT_RESULT_SUCCESS) { - result = ParseEntryPoints(&parser, p_module); + result = ParseEntryPoints(&parser, p_module, p_allocator); } if (result == SPV_REFLECT_RESULT_SUCCESS && p_module->entry_point_count > 0) { SpvReflectEntryPoint* p_entry = &(p_module->entry_points[0]); @@ -3249,20 +3278,20 @@ SpvReflectResult spvReflectCreateShaderModule( result = DisambiguateStorageBufferSrvUav(p_module); } if (result == SPV_REFLECT_RESULT_SUCCESS) { - result = SynchronizeDescriptorSets(p_module); + result = SynchronizeDescriptorSets(p_module, p_allocator); } // Destroy module if parse was not successful if (result != SPV_REFLECT_RESULT_SUCCESS) { - spvReflectDestroyShaderModule(p_module); + spvReflectDestroyShaderModuleEx(p_module, p_allocator); } - DestroyParser(&parser); + DestroyParser(&parser, p_allocator); return result; } -static void SafeFreeTypes(SpvReflectTypeDescription* p_type) +static void SafeFreeTypes(SpvReflectTypeDescription* p_type, const SpvAllocationCallbacks* p_allocator) { if (IsNull(p_type)) { return; @@ -3271,15 +3300,15 @@ static void SafeFreeTypes(SpvReflectTypeDescription* p_type) if (IsNotNull(p_type->members)) { for (size_t i = 0; i < p_type->member_count; ++i) { SpvReflectTypeDescription* p_member = &p_type->members[i]; - SafeFreeTypes(p_member); + SafeFreeTypes(p_member, p_allocator); } - SafeFree(p_type->members); + SafeFree(p_allocator, p_type->members); p_type->members = NULL; } } -static void SafeFreeBlockVariables(SpvReflectBlockVariable* p_block) +static void SafeFreeBlockVariables(SpvReflectBlockVariable* p_block, const SpvAllocationCallbacks* p_allocator) { if (IsNull(p_block)) { return; @@ -3288,15 +3317,15 @@ static void SafeFreeBlockVariables(SpvReflectBlockVariable* p_block) if (IsNotNull(p_block->members)) { for (size_t i = 0; i < p_block->member_count; ++i) { SpvReflectBlockVariable* p_member = &p_block->members[i]; - SafeFreeBlockVariables(p_member); + SafeFreeBlockVariables(p_member, p_allocator); } - SafeFree(p_block->members); + SafeFree(p_allocator, p_block->members); p_block->members = NULL; } } -static void SafeFreeInterfaceVariable(SpvReflectInterfaceVariable* p_interface) +static void SafeFreeInterfaceVariable(SpvReflectInterfaceVariable* p_interface, const SpvAllocationCallbacks* p_allocator) { if (IsNull(p_interface)) { return; @@ -3305,15 +3334,20 @@ static void SafeFreeInterfaceVariable(SpvReflectInterfaceVariable* p_interface) if (IsNotNull(p_interface->members)) { for (size_t i = 0; i < p_interface->member_count; ++i) { SpvReflectInterfaceVariable* p_member = &p_interface->members[i]; - SafeFreeInterfaceVariable(p_member); + SafeFreeInterfaceVariable(p_member, p_allocator); } - SafeFree(p_interface->members); + SafeFree(p_allocator, p_interface->members); p_interface->members = NULL; } } void spvReflectDestroyShaderModule(SpvReflectShaderModule* p_module) +{ + spvReflectDestroyShaderModuleEx(p_module, NULL); +} + +void spvReflectDestroyShaderModuleEx(SpvReflectShaderModule* p_module, const SpvAllocationCallbacks* p_allocator) { if (IsNull(p_module->_internal)) { return; @@ -3322,56 +3356,56 @@ void spvReflectDestroyShaderModule(SpvReflectShaderModule* p_module) // Descriptor set bindings for (size_t i = 0; i < p_module->descriptor_set_count; ++i) { SpvReflectDescriptorSet* p_set = &p_module->descriptor_sets[i]; - free(p_set->bindings); + SafeFree(p_allocator, p_set->bindings); } // Descriptor binding blocks for (size_t i = 0; i < p_module->descriptor_binding_count; ++i) { SpvReflectDescriptorBinding* p_descriptor = &p_module->descriptor_bindings[i]; - SafeFreeBlockVariables(&p_descriptor->block); + SafeFreeBlockVariables(&p_descriptor->block, p_allocator); } - SafeFree(p_module->descriptor_bindings); + SafeFree(p_allocator, p_module->descriptor_bindings); // Entry points for (size_t i = 0; i < p_module->entry_point_count; ++i) { SpvReflectEntryPoint* p_entry = &p_module->entry_points[i]; for (size_t j = 0; j < p_entry->input_variable_count; j++) { - SafeFreeInterfaceVariable(&p_entry->input_variables[j]); + SafeFreeInterfaceVariable(&p_entry->input_variables[j], p_allocator); } for (size_t j = 0; j < p_entry->output_variable_count; j++) { - SafeFreeInterfaceVariable(&p_entry->output_variables[j]); + SafeFreeInterfaceVariable(&p_entry->output_variables[j], p_allocator); } for (uint32_t j = 0; j < p_entry->descriptor_set_count; ++j) { - SafeFree(p_entry->descriptor_sets[j].bindings); + SafeFree(p_allocator, p_entry->descriptor_sets[j].bindings); } - SafeFree(p_entry->descriptor_sets); - SafeFree(p_entry->input_variables); - SafeFree(p_entry->output_variables); - SafeFree(p_entry->used_uniforms); - SafeFree(p_entry->used_push_constants); + SafeFree(p_allocator, p_entry->descriptor_sets); + SafeFree(p_allocator, p_entry->input_variables); + SafeFree(p_allocator, p_entry->output_variables); + SafeFree(p_allocator, p_entry->used_uniforms); + SafeFree(p_allocator, p_entry->used_push_constants); } - SafeFree(p_module->entry_points); + SafeFree(p_allocator, p_module->entry_points); // Push constants for (size_t i = 0; i < p_module->push_constant_block_count; ++i) { - SafeFreeBlockVariables(&p_module->push_constant_blocks[i]); + SafeFreeBlockVariables(&p_module->push_constant_blocks[i], p_allocator); } - SafeFree(p_module->push_constant_blocks); + SafeFree(p_allocator, p_module->push_constant_blocks); // Type infos for (size_t i = 0; i < p_module->_internal->type_description_count; ++i) { SpvReflectTypeDescription* p_type = &p_module->_internal->type_descriptions[i]; if (IsNotNull(p_type->members)) { - SafeFreeTypes(p_type); + SafeFreeTypes(p_type, p_allocator); } - SafeFree(p_type->members); + SafeFree(p_allocator, p_type->members); } - SafeFree(p_module->_internal->type_descriptions); + SafeFree(p_allocator, p_module->_internal->type_descriptions); // Free SPIR-V code - SafeFree(p_module->_internal->spirv_code); + SafeFree(p_allocator, p_module->_internal->spirv_code); // Free internal - SafeFree(p_module->_internal); + SafeFree(p_allocator, p_module->_internal); } uint32_t spvReflectGetCodeSize(const SpvReflectShaderModule* p_module) @@ -3818,7 +3852,7 @@ const SpvReflectDescriptorBinding* spvReflectGetEntryPointDescriptorBinding( for (uint32_t index = 0; index < p_module->descriptor_binding_count; ++index) { const SpvReflectDescriptorBinding* p_potential = &p_module->descriptor_bindings[index]; bool found = SearchSortedUint32( - p_entry->used_uniforms, + p_entry->used_uniforms, p_entry->used_uniform_count, p_potential->spirv_id); if ((p_potential->binding == binding_number) && (p_potential->set == set_number) && found) { @@ -4276,6 +4310,15 @@ SpvReflectResult spvReflectChangeDescriptorBindingNumbers( uint32_t new_binding_number, uint32_t new_set_binding ) +{ + return spvReflectChangeDescriptorBindingNumbersEx(p_module, p_binding, new_binding_number, new_set_binding, NULL); +} + +SpvReflectResult spvReflectChangeDescriptorBindingNumbersEx(SpvReflectShaderModule* p_module, + const SpvReflectDescriptorBinding* p_binding, + uint32_t new_binding_number, + uint32_t new_set_binding, + const SpvAllocationCallbacks* p_allocator) { if (IsNull(p_module)) { return SPV_REFLECT_RESULT_ERROR_NULL_POINTER; @@ -4312,7 +4355,7 @@ SpvReflectResult spvReflectChangeDescriptorBindingNumbers( SpvReflectResult result = SPV_REFLECT_RESULT_SUCCESS; if (new_set_binding != SPV_REFLECT_SET_NUMBER_DONT_CHANGE) { - result = SynchronizeDescriptorSets(p_module); + result = SynchronizeDescriptorSets(p_module, p_allocator); } return result; } @@ -4323,10 +4366,10 @@ SpvReflectResult spvReflectChangeDescriptorBindingNumber( uint32_t optional_new_set_number ) { - return spvReflectChangeDescriptorBindingNumbers( - p_module,p_descriptor_binding, - new_binding_number, - optional_new_set_number); + return spvReflectChangeDescriptorBindingNumbersEx( + p_module,p_descriptor_binding, + new_binding_number, + optional_new_set_number, NULL); } SpvReflectResult spvReflectChangeDescriptorSetNumber( @@ -4334,6 +4377,14 @@ SpvReflectResult spvReflectChangeDescriptorSetNumber( const SpvReflectDescriptorSet* p_set, uint32_t new_set_number ) +{ + return spvReflectChangeDescriptorSetNumberEx(p_module, p_set, new_set_number, NULL); +} + +SpvReflectResult spvReflectChangeDescriptorSetNumberEx(SpvReflectShaderModule* p_module, + const SpvReflectDescriptorSet* p_set, + uint32_t new_set_number, + const SpvAllocationCallbacks* p_allocator) { if (IsNull(p_module)) { return SPV_REFLECT_RESULT_ERROR_NULL_POINTER; @@ -4364,7 +4415,7 @@ SpvReflectResult spvReflectChangeDescriptorSetNumber( p_descriptor->set = new_set_number; } - result = SynchronizeDescriptorSets(p_module); + result = SynchronizeDescriptorSets(p_module, p_allocator); } return result; diff --git a/spirv_reflect.h b/spirv_reflect.h index 6554aaa2..c6f82aaf 100644 --- a/spirv_reflect.h +++ b/spirv_reflect.h @@ -49,6 +49,16 @@ VERSION HISTORY #define SPV_REFLECT_DEPRECATED(msg_str) #endif +typedef void* (* PFN_spvAllocationFunction)(void* pUserData, size_t size); + +typedef void (* PFN_spvFreeFunction)(void* pUserData, void* pMemory); + +typedef struct SpvAllocationCallbacks { + void* pUserData; + PFN_spvAllocationFunction pfnAllocation; + PFN_spvFreeFunction pfnFree; +} SpvAllocationCallbacks; + /*! @enum SpvReflectResult */ @@ -434,6 +444,22 @@ SpvReflectResult spvReflectCreateShaderModule( SpvReflectShaderModule* p_module ); +/*! @fn spvReflectCreateShaderModule + + @param size Size in bytes of SPIR-V code. + @param p_code Pointer to SPIR-V code. + @param p_module Pointer to an instance of SpvReflectShaderModule. + @param p_allocator Pointer to an instance of SpvAllocationCallbacks. + @return SPV_REFLECT_RESULT_SUCCESS on success. + +*/ +SpvReflectResult spvReflectCreateShaderModuleEx( + size_t size, + const void* p_code, + SpvReflectShaderModule* p_module, + const SpvAllocationCallbacks* p_allocator +); + SPV_REFLECT_DEPRECATED("renamed to spvReflectCreateShaderModule") SpvReflectResult spvReflectGetShaderModule( size_t size, @@ -449,6 +475,13 @@ SpvReflectResult spvReflectGetShaderModule( */ void spvReflectDestroyShaderModule(SpvReflectShaderModule* p_module); +/*! @fn spvReflectDestroyShaderModuleEx + + @param p_module Pointer to an instance of SpvReflectShaderModule. + @param p_allocator Pointer to an instance of SpvAllocationCallbacks. + +*/ +void spvReflectDestroyShaderModuleEx(SpvReflectShaderModule* p_module, const SpvAllocationCallbacks* p_allocator); /*! @fn spvReflectGetCodeSize @@ -1172,6 +1205,40 @@ SpvReflectResult spvReflectChangeDescriptorBindingNumbers( uint32_t new_binding_number, uint32_t new_set_number ); + +/*! @fn spvReflectChangeDescriptorBindingNumbersEx + @brief Assign new set and/or binding numbers to a descriptor binding. + In addition to updating the reflection data, this function modifies + the underlying SPIR-V bytecode. The updated code can be retrieved + with spvReflectGetCode(). If the binding is used in multiple + entry points within the module, it will be changed in all of them. + @param p_module Pointer to an instance of SpvReflectShaderModule. + @param p_binding Pointer to the descriptor binding to modify. + @param new_binding_number The new binding number to assign to the + provided descriptor binding. + To leave the binding number unchanged, pass + SPV_REFLECT_BINDING_NUMBER_DONT_CHANGE. + @param new_set_number The new set number to assign to the + provided descriptor binding. Successfully changing + a descriptor binding's set number invalidates all + existing SpvReflectDescriptorBinding and + SpvReflectDescriptorSet pointers from this module. + To leave the set number unchanged, pass + SPV_REFLECT_SET_NUMBER_DONT_CHANGE. +@param p_allocator Pointer to an instance of SpvAllocationCallbacks. + @return If successful, returns SPV_REFLECT_RESULT_SUCCESS. + Otherwise, the error code indicates the cause of + the failure. +*/ +SpvReflectResult spvReflectChangeDescriptorBindingNumbersEx( + SpvReflectShaderModule* p_module, + const SpvReflectDescriptorBinding* p_binding, + uint32_t new_binding_number, + uint32_t new_set_number, + const SpvAllocationCallbacks* p_allocator +); + + SPV_REFLECT_DEPRECATED("Renamed to spvReflectChangeDescriptorBindingNumbers") SpvReflectResult spvReflectChangeDescriptorBindingNumber( SpvReflectShaderModule* p_module, @@ -1208,6 +1275,36 @@ SpvReflectResult spvReflectChangeDescriptorSetNumber( uint32_t new_set_number ); +/*! @fn spvReflectChangeDescriptorSetNumberEx + @brief Assign a new set number to an entire descriptor set (including + all descriptor bindings in that set). + In addition to updating the reflection data, this function modifies + the underlying SPIR-V bytecode. The updated code can be retrieved + with spvReflectGetCode(). If the descriptor set is used in + multiple entry points within the module, it will be modified in all + of them. + @param p_module Pointer to an instance of SpvReflectShaderModule. + @param p_set Pointer to the descriptor binding to modify. + @param new_set_number The new set number to assign to the + provided descriptor set, and all its descriptor + bindings. Successfully changing a descriptor + binding's set number invalidates all existing + SpvReflectDescriptorBinding and + SpvReflectDescriptorSet pointers from this module. + To leave the set number unchanged, pass + SPV_REFLECT_SET_NUMBER_DONT_CHANGE. + @param p_allocator Pointer to an instance of SpvAllocationCallbacks. + @return If successful, returns SPV_REFLECT_RESULT_SUCCESS. + Otherwise, the error code indicates the cause of + the failure. +*/ +SpvReflectResult spvReflectChangeDescriptorSetNumberEx( + SpvReflectShaderModule* p_module, + const SpvReflectDescriptorSet* p_set, + uint32_t new_set_number, + const SpvAllocationCallbacks* p_allocator +); + /*! @fn spvReflectChangeInputVariableLocation @brief Assign a new location to an input interface variable. In addition to updating the reflection data, this function modifies