Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 87 additions & 55 deletions spirv_reflect.c
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,26 @@ typedef struct SpvReflectPrvString {
const char* string;
} SpvReflectPrvString;

// There are a limit set of instructions that can touch an OpVariable,
// these are represented here with how it was accessed
// Examples:
// OpImageRead -> OpLoad -> OpVariable
// OpImageWrite -> OpLoad -> OpVariable
// OpStore -> OpAccessChain -> OpAccessChain -> OpVariable
// OpAtomicIAdd -> OpAccessChain -> OpVariable
// OpAtomicLoad -> OpImageTexelPointer -> OpVariable
typedef struct SpvReflectPrvAccessedVariable {
uint32_t result_id;
uint32_t variable_ptr;
} SpvReflectPrvAccessedVariable;

typedef struct SpvReflectPrvFunction {
uint32_t id;
uint32_t callee_count;
uint32_t* callees;
struct SpvReflectPrvFunction** callee_ptrs;
uint32_t accessed_ptr_count;
uint32_t* accessed_ptrs;
uint32_t accessed_variable_count;
SpvReflectPrvAccessedVariable* accessed_variables;
} SpvReflectPrvFunction;

typedef struct SpvReflectPrvAccessChain {
Expand Down Expand Up @@ -233,6 +246,13 @@ static int SortCompareUint32(const void* a, const void* b) {
return (int)*p_a - (int)*p_b;
}

static int SortCompareAccessedVariable(const void* a, const void* b) {
const SpvReflectPrvAccessedVariable* p_a = (const SpvReflectPrvAccessedVariable*)a;
const SpvReflectPrvAccessedVariable* p_b = (const SpvReflectPrvAccessedVariable*)b;

return (int)p_a->variable_ptr - (int)p_b->variable_ptr;
}

//
// De-duplicates a sorted array and returns the new size.
//
Expand Down Expand Up @@ -270,23 +290,24 @@ static bool SearchSortedUint32(const uint32_t* arr, size_t size, uint32_t target
return false;
}

static SpvReflectResult IntersectSortedUint32(const uint32_t* p_arr0, size_t arr0_size, const uint32_t* p_arr1, size_t arr1_size,
uint32_t** pp_res, size_t* res_size) {
static SpvReflectResult IntersectSortedAccessedVariable(const SpvReflectPrvAccessedVariable* p_arr0, size_t arr0_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be one param per line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const uint32_t* p_arr1, size_t arr1_size, uint32_t** pp_res,
size_t* res_size) {
*pp_res = NULL;
*res_size = 0;
if (IsNull(p_arr0) || IsNull(p_arr1)) {
return SPV_REFLECT_RESULT_SUCCESS;
}

const uint32_t* arr0_end = p_arr0 + arr0_size;
const SpvReflectPrvAccessedVariable* arr0_end = p_arr0 + arr0_size;
const uint32_t* arr1_end = p_arr1 + arr1_size;

const uint32_t* idx0 = p_arr0;
const SpvReflectPrvAccessedVariable* idx0 = p_arr0;
const uint32_t* idx1 = p_arr1;
while (idx0 != arr0_end && idx1 != arr1_end) {
if (*idx0 < *idx1) {
if (idx0->variable_ptr < *idx1) {
++idx0;
} else if (*idx0 > *idx1) {
} else if (idx0->variable_ptr > *idx1) {
++idx1;
} else {
++*res_size;
Expand All @@ -304,12 +325,12 @@ static SpvReflectResult IntersectSortedUint32(const uint32_t* p_arr0, size_t arr
idx0 = p_arr0;
idx1 = p_arr1;
while (idx0 != arr0_end && idx1 != arr1_end) {
if (*idx0 < *idx1) {
if (idx0->variable_ptr < *idx1) {
++idx0;
} else if (*idx0 > *idx1) {
} else if (idx0->variable_ptr > *idx1) {
++idx1;
} else {
*(idxr++) = *idx0;
*(idxr++) = idx0->variable_ptr;
++idx0;
++idx1;
}
Expand Down Expand Up @@ -602,7 +623,7 @@ static void DestroyParser(SpvReflectPrvParser* p_parser) {
for (size_t i = 0; i < p_parser->function_count; ++i) {
SafeFree(p_parser->functions[i].callees);
SafeFree(p_parser->functions[i].callee_ptrs);
SafeFree(p_parser->functions[i].accessed_ptrs);
SafeFree(p_parser->functions[i].accessed_variables);
}

// Free access chains
Expand Down Expand Up @@ -1039,8 +1060,9 @@ static SpvReflectResult ParseFunction(SpvReflectPrvParser* p_parser, SpvReflectP
p_func->id = p_func_node->result_id;

p_func->callee_count = 0;
p_func->accessed_ptr_count = 0;
p_func->accessed_variable_count = 0;

// First get count to know how much to allocate
for (size_t i = first_label_index; i < p_parser->node_count; ++i) {
SpvReflectPrvNode* p_node = &(p_parser->nodes[i]);
if (p_node->op == SpvOpFunctionEnd) {
Expand All @@ -1059,11 +1081,11 @@ static SpvReflectResult ParseFunction(SpvReflectPrvParser* p_parser, SpvReflectP
case SpvOpInBoundsPtrAccessChain:
case SpvOpStore:
case SpvOpImageTexelPointer: {
++(p_func->accessed_ptr_count);
++(p_func->accessed_variable_count);
} break;
case SpvOpCopyMemory:
case SpvOpCopyMemorySized: {
p_func->accessed_ptr_count += 2;
p_func->accessed_variable_count += 2;
} break;
default:
break;
Expand All @@ -1077,15 +1099,17 @@ static SpvReflectResult ParseFunction(SpvReflectPrvParser* p_parser, SpvReflectP
}
}

if (p_func->accessed_ptr_count > 0) {
p_func->accessed_ptrs = (uint32_t*)calloc(p_func->accessed_ptr_count, sizeof(*(p_func->accessed_ptrs)));
if (IsNull(p_func->accessed_ptrs)) {
if (p_func->accessed_variable_count > 0) {
p_func->accessed_variables =
(SpvReflectPrvAccessedVariable*)calloc(p_func->accessed_variable_count, sizeof(*(p_func->accessed_variables)));
if (IsNull(p_func->accessed_variables)) {
return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
}
}

p_func->callee_count = 0;
p_func->accessed_ptr_count = 0;
p_func->accessed_variable_count = 0;
// Now have allocation, fill in values
for (size_t i = first_label_index; i < p_parser->node_count; ++i) {
SpvReflectPrvNode* p_node = &(p_parser->nodes[i]);
if (p_node->op == SpvOpFunctionEnd) {
Expand All @@ -1104,19 +1128,29 @@ static SpvReflectResult ParseFunction(SpvReflectPrvParser* p_parser, SpvReflectP
case SpvOpGenericPtrMemSemantics:
case SpvOpInBoundsPtrAccessChain:
case SpvOpImageTexelPointer: {
CHECKED_READU32(p_parser, p_node->word_offset + 3, p_func->accessed_ptrs[p_func->accessed_ptr_count]);
(++p_func->accessed_ptr_count);
const uint32_t result_index = p_node->word_offset + 2;
const uint32_t ptr_index = p_node->word_offset + 3;
SpvReflectPrvAccessedVariable* access_ptr = &p_func->accessed_variables[p_func->accessed_variable_count];

// Need to track Result ID as not sure there has been any memory access through here yet
CHECKED_READU32(p_parser, result_index, access_ptr->result_id);
CHECKED_READU32(p_parser, ptr_index, access_ptr->variable_ptr);
(++p_func->accessed_variable_count);
} break;
case SpvOpStore: {
CHECKED_READU32(p_parser, p_node->word_offset + 2, p_func->accessed_ptrs[p_func->accessed_ptr_count]);
(++p_func->accessed_ptr_count);
const uint32_t result_index = p_node->word_offset + 2;
CHECKED_READU32(p_parser, result_index, p_func->accessed_variables[p_func->accessed_variable_count].variable_ptr);
(++p_func->accessed_variable_count);
} break;
case SpvOpCopyMemory:
case SpvOpCopyMemorySized: {
CHECKED_READU32(p_parser, p_node->word_offset + 2, p_func->accessed_ptrs[p_func->accessed_ptr_count]);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a test for OpCopyMemory.. we were off by 1 as the Target and Source operand are 1 and 2

(++p_func->accessed_ptr_count);
CHECKED_READU32(p_parser, p_node->word_offset + 3, p_func->accessed_ptrs[p_func->accessed_ptr_count]);
(++p_func->accessed_ptr_count);
// There is no result_id is being zero is same as being invalid
CHECKED_READU32(p_parser, p_node->word_offset + 1,
p_func->accessed_variables[p_func->accessed_variable_count].variable_ptr);
(++p_func->accessed_variable_count);
CHECKED_READU32(p_parser, p_node->word_offset + 2,
p_func->accessed_variables[p_func->accessed_variable_count].variable_ptr);
(++p_func->accessed_variable_count);
} break;
default:
break;
Expand All @@ -1128,10 +1162,10 @@ static SpvReflectResult ParseFunction(SpvReflectPrvParser* p_parser, SpvReflectP
}
p_func->callee_count = (uint32_t)DedupSortedUint32(p_func->callees, p_func->callee_count);

if (p_func->accessed_ptr_count > 0) {
qsort(p_func->accessed_ptrs, p_func->accessed_ptr_count, sizeof(*(p_func->accessed_ptrs)), SortCompareUint32);
if (p_func->accessed_variable_count > 0) {
qsort(p_func->accessed_variables, p_func->accessed_variable_count, sizeof(*(p_func->accessed_variables)),
SortCompareAccessedVariable);
}
p_func->accessed_ptr_count = (uint32_t)DedupSortedUint32(p_func->accessed_ptrs, p_func->accessed_ptr_count);

return SPV_REFLECT_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -3046,60 +3080,58 @@ static SpvReflectResult ParseStaticallyUsedResources(SpvReflectPrvParser* p_pars
}
called_function_count = DedupSortedUint32(p_called_functions, called_function_count);

uint32_t used_variable_count = 0;
uint32_t used_acessed_count = 0;
for (size_t i = 0, j = 0; i < called_function_count; ++i) {
// No need to bounds check j because a missing ID issue would have been
// found during TraverseCallGraph
while (p_parser->functions[j].id != p_called_functions[i]) {
++j;
}
used_variable_count += p_parser->functions[j].accessed_ptr_count;
used_acessed_count += p_parser->functions[j].accessed_variable_count;
}
uint32_t* used_variables = NULL;
if (used_variable_count > 0) {
used_variables = (uint32_t*)calloc(used_variable_count, sizeof(*used_variables));
if (IsNull(used_variables)) {
SpvReflectPrvAccessedVariable* used_accesses = NULL;
if (used_acessed_count > 0) {
used_accesses = (SpvReflectPrvAccessedVariable*)calloc(used_acessed_count, sizeof(SpvReflectPrvAccessedVariable));
if (IsNull(used_accesses)) {
SafeFree(p_called_functions);
return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
}
}
used_variable_count = 0;
used_acessed_count = 0;
for (size_t i = 0, j = 0; i < called_function_count; ++i) {
while (p_parser->functions[j].id != p_called_functions[i]) {
++j;
}

memcpy(&used_variables[used_variable_count], p_parser->functions[j].accessed_ptrs,
p_parser->functions[j].accessed_ptr_count * sizeof(*used_variables));
used_variable_count += p_parser->functions[j].accessed_ptr_count;
memcpy(&used_accesses[used_acessed_count], p_parser->functions[j].accessed_variables,
p_parser->functions[j].accessed_variable_count * sizeof(SpvReflectPrvAccessedVariable));
used_acessed_count += p_parser->functions[j].accessed_variable_count;
}
SafeFree(p_called_functions);

if (used_variable_count > 0) {
qsort(used_variables, used_variable_count, sizeof(*used_variables), SortCompareUint32);
if (used_acessed_count > 0) {
qsort(used_accesses, used_acessed_count, sizeof(*used_accesses), SortCompareAccessedVariable);
}
used_variable_count = (uint32_t)DedupSortedUint32(used_variables, used_variable_count);

// Do set intersection to find the used uniform and push constants
size_t used_uniform_count = 0;
//
SpvReflectResult result0 = IntersectSortedUint32(used_variables, used_variable_count, uniforms, uniform_count,
&p_entry->used_uniforms, &used_uniform_count);
SpvReflectResult result0 = IntersectSortedAccessedVariable(used_accesses, used_acessed_count, uniforms, uniform_count,
&p_entry->used_uniforms, &used_uniform_count);

size_t used_push_constant_count = 0;
//
SpvReflectResult result1 = IntersectSortedUint32(used_variables, used_variable_count, push_constants, push_constant_count,
&p_entry->used_push_constants, &used_push_constant_count);
SpvReflectResult result1 = IntersectSortedAccessedVariable(used_accesses, used_acessed_count, push_constants, push_constant_count,
&p_entry->used_push_constants, &used_push_constant_count);

for (uint32_t j = 0; j < p_module->descriptor_binding_count; ++j) {
SpvReflectDescriptorBinding* p_binding = &p_module->descriptor_bindings[j];
bool found = SearchSortedUint32(used_variables, used_variable_count, p_binding->spirv_id);
if (found) {
p_binding->accessed = 1;
for (uint32_t i = 0; i < p_module->descriptor_binding_count; ++i) {
SpvReflectDescriptorBinding* p_binding = &p_module->descriptor_bindings[i];
for (uint32_t j = 0; j < used_acessed_count; j++) {
if (used_accesses[j].variable_ptr == p_binding->spirv_id) {
p_binding->accessed = 1;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the goal here is we can find all the accesses pointing to this descriptor binding variable

}
}
}

SafeFree(used_variables);
SafeFree(used_accesses);
if (result0 != SPV_REFLECT_RESULT_SUCCESS) {
return result0;
}
Expand Down
Binary file added tests/variable_access/copy_memory.spv
Binary file not shown.
Loading