Skip to content

Commit

Permalink
Should now work with vector types and workgroupsize builtin
Browse files Browse the repository at this point in the history
  • Loading branch information
shangjiaxuan committed Jul 31, 2022
1 parent cb7b2b9 commit 4c8a546
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 63 deletions.
94 changes: 78 additions & 16 deletions common/output_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,13 @@ void ParseBlockMembersToTextLines(
// dim = 0 means it's an unbounded array
//
if (dim > 0) {
if (dim == 0xFFFFFFFF) {
SpvReflectValue val{};
SpvReflectResult res = obj.EvaluateResult(member.array.spec_constant_op_ids[array_dim_index], val);
if ((res == SPV_REFLECT_RESULT_SUCCESS) && val.type && (val.type->type_flags == SPV_REFLECT_TYPE_FLAG_INT) && (val.type->traits.numeric.scalar.width == 32)) {
dim = val.values[0].value.uint32_bool_value;
}
}
ss_array << "[" << dim << "]";
}
else {
Expand Down Expand Up @@ -702,8 +709,9 @@ void ParseBlockMembersToTextLines(
if (dim == 0xFFFFFFFF) {
SpvReflectValue val{};
SpvReflectResult res = obj.EvaluateResult(member.array.spec_constant_op_ids[array_dim_index], val);
if (res != SPV_REFLECT_RESULT_SUCCESS) throw;
dim = val.values[0].value.uint32_bool_value;
if ((res == SPV_REFLECT_RESULT_SUCCESS) && val.type && (val.type->type_flags == SPV_REFLECT_TYPE_FLAG_INT) && (val.type->traits.numeric.scalar.width == 32)) {
dim = val.values[0].value.uint32_bool_value;
}
}
ss_array << "[" << dim << "]";
}
Expand Down Expand Up @@ -1009,7 +1017,7 @@ void StreamWriteDescriptorBinding(std::ostream& os, const spv_reflect::ShaderMod
}
}

void StreamWriteInterfaceVariable(std::ostream& os, const SpvReflectInterfaceVariable& obj, const char* indent)
void StreamWriteInterfaceVariable(std::ostream& os, const spv_reflect::ShaderModule& shader, const SpvReflectInterfaceVariable& obj, const char* indent)
{
const char* t = indent;
os << t << "spirv id : " << obj.spirv_id << "\n";
Expand All @@ -1028,6 +1036,14 @@ void StreamWriteInterfaceVariable(std::ostream& os, const SpvReflectInterfaceVar
if (obj.array.dims_count > 0) {
os << t << "array : ";
for (uint32_t dim_index = 0; dim_index < obj.array.dims_count; ++dim_index) {
uint32_t dim = obj.array.dims[dim_index];
if (dim == 0xFFFFFFFF) {
SpvReflectValue val{};
SpvReflectResult res = shader.EvaluateResult(obj.array.spec_constant_op_ids[dim_index], val);
if ((res == SPV_REFLECT_RESULT_SUCCESS) && val.type && (val.type->type_flags == SPV_REFLECT_TYPE_FLAG_INT) && (val.type->traits.numeric.scalar.width == 32)) {
dim = val.values[0].value.uint32_bool_value;
}
}
os << "[" << obj.array.dims[dim_index] << "]";
}
os << "\n";
Expand Down Expand Up @@ -1108,28 +1124,74 @@ void StreamWriteSpecializationConstant(std::ostream& os, const SpvReflectSpecial
}
}

void StreamWriteEntryPoint(std::ostream& os, const SpvReflectEntryPoint& obj, const char* indent)
void StreamWriteEntryPoint(std::ostream& os, const spv_reflect::ShaderModule& shader, const SpvReflectEntryPoint& obj, const char* indent)
{
os << indent << "entry point : " << obj.name;
os << " (stage=" << ToStringShaderStage(obj.shader_stage) << ")";
if (obj.shader_stage == SPV_REFLECT_SHADER_STAGE_COMPUTE_BIT) {
os << "\n";
os << "local size : " << "(" << obj.local_size.x << ", " << obj.local_size.y << ", " << obj.local_size.z << ")";
if (obj.local_size.flags & 2) {
os << "local size : ";
}
else {
os << "local size hint : ";
}
if(obj.local_size.flags & 4) {
SpvReflectValue val{};
SpvReflectResult res = shader.EvaluateResult(obj.local_size.x, val);
if ((res == SPV_REFLECT_RESULT_SUCCESS) && val.type && (val.type->type_flags == (SPV_REFLECT_TYPE_FLAG_INT | SPV_REFLECT_TYPE_FLAG_VECTOR))
&& (val.type->traits.numeric.scalar.width == 32)) {
os << "(" << val.values[0].value.uint32_bool_value << ", " << val.values[1].value.uint32_bool_value << ", " << val.values[2].value.uint32_bool_value << ")";
}
else {
os << "(failed evaluation of WorkGroupSize Builtin)";
}
}
else if(obj.local_size.flags & 1) {
os << "(";
SpvReflectValue val = {0};
SpvReflectResult res = shader.EvaluateResult(obj.local_size.x, val);
if ((res == SPV_REFLECT_RESULT_SUCCESS) && val.type && (val.type->type_flags == SPV_REFLECT_TYPE_FLAG_INT) && (val.type->traits.numeric.scalar.width == 32)) {
os << val.values[0].value.uint32_bool_value;
}
else {
os << "unknown";
}
os << ", ";
res = shader.EvaluateResult(obj.local_size.y, val);
if ((res == SPV_REFLECT_RESULT_SUCCESS) && val.type && (val.type->type_flags == SPV_REFLECT_TYPE_FLAG_INT) && (val.type->traits.numeric.scalar.width == 32)) {
os << val.values[0].value.uint32_bool_value;
}
else {
os << "unknown";
} os << ", ";
res = shader.EvaluateResult(obj.local_size.z, val);
if ((res == SPV_REFLECT_RESULT_SUCCESS) && val.type && (val.type->type_flags == SPV_REFLECT_TYPE_FLAG_INT) && (val.type->traits.numeric.scalar.width == 32)) {
os << val.values[0].value.uint32_bool_value;
}
else {
os << "unknown";
}
os << ")";
}
else{
os << "(" << obj.local_size.x << ", " << obj.local_size.y << ", " << obj.local_size.z << ")";
}
}
}

void StreamWriteShaderModule(std::ostream& os, const SpvReflectShaderModule& obj, const char* indent)
void StreamWriteShaderModule(std::ostream& os, const spv_reflect::ShaderModule& obj, const char* indent)
{
(void)indent;
os << "generator : " << ToStringGenerator(obj.generator) << "\n";
os << "source lang : " << spvReflectSourceLanguage(obj.source_language) << "\n";
os << "source lang ver : " << obj.source_language_version << "\n";
os << "source file : " << (obj.source_file != NULL ? obj.source_file : "") << "\n";
os << "generator : " << ToStringGenerator(obj.GetShaderModule().generator) << "\n";
os << "source lang : " << spvReflectSourceLanguage(obj.GetShaderModule().source_language) << "\n";
os << "source lang ver : " << obj.GetShaderModule().source_language_version << "\n";
os << "source file : " << (obj.GetShaderModule().source_file != NULL ? obj.GetShaderModule().source_file : "") << "\n";
//os << "shader stage : " << ToStringShaderStage(obj.shader_stage) << "\n";

for (uint32_t i = 0; i < obj.entry_point_count; ++i) {
StreamWriteEntryPoint(os, obj.entry_points[i], "");
if (i < (obj.entry_point_count - 1)) {
for (uint32_t i = 0; i < obj.GetShaderModule().entry_point_count; ++i) {
StreamWriteEntryPoint(os, obj, obj.GetShaderModule().entry_points[i], "");
if (i < (obj.GetShaderModule().entry_point_count - 1)) {
os << "\n";
}
}
Expand All @@ -1149,7 +1211,7 @@ void WriteReflection(const spv_reflect::ShaderModule& obj, bool flatten_cbuffers
const char* tt = " ";
const char* ttt = " ";

StreamWriteShaderModule(os, obj.GetShaderModule(), "");
StreamWriteShaderModule(os, obj, "");

uint32_t count = 0;
std::vector<SpvReflectInterfaceVariable*> variables;
Expand Down Expand Up @@ -1195,7 +1257,7 @@ void WriteReflection(const spv_reflect::ShaderModule& obj, bool flatten_cbuffers
auto p_var = variables[i];
USE_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
os << tt << i << ":" << "\n";
StreamWriteInterfaceVariable(os, *p_var, ttt);
StreamWriteInterfaceVariable(os, obj,*p_var, ttt);
if (i < (count - 1)) {
os << "\n";
}
Expand All @@ -1217,7 +1279,7 @@ void WriteReflection(const spv_reflect::ShaderModule& obj, bool flatten_cbuffers
auto p_var = variables[i];
USE_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
os << tt << i << ":" << "\n";
StreamWriteInterfaceVariable(os, *p_var, ttt);
StreamWriteInterfaceVariable(os, obj, *p_var, ttt);
if (i < (count - 1)) {
os << "\n";
}
Expand Down
105 changes: 62 additions & 43 deletions spirv_reflect.c
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,7 @@ static SpvReflectResult ParseDecorations(SpvReflectPrvParser* p_parser)
case SpvDecorationBuiltIn: {
p_target_decorations->is_built_in = true;
uint32_t word_offset = p_node->word_offset + member_offset + 3;
// no rule specifies a result cannot be decorated twice. But let's assume this for now...
CHECKED_READU32_CAST(p_parser, word_offset, SpvBuiltIn, p_target_decorations->built_in);
}
break;
Expand Down Expand Up @@ -3397,34 +3398,16 @@ SpvReflectResult GetTypeByTypeId(const SpvReflectShaderModule* p_module, uint32_
#define COMPOSITE_TYPE_FLAGS (VECTOR_TYPE_FLAGS|SPV_REFLECT_TYPE_FLAG_MATRIX|SPV_REFLECT_TYPE_FLAG_STRUCT|SPV_REFLECT_TYPE_FLAG_ARRAY)
#define COMPOSITE_DISALLOWED_FLAGS (~0 ^ COMPOSITE_TYPE_FLAGS)

static SpvReflectScalarType ScalarGeneralTypeFromType(SpvReflectTypeDescription* type)
{
if ((type->type_flags & SCALAR_TYPE_FLAGS) == SPV_REFLECT_TYPE_FLAG_BOOL) {
return SPV_REFLECT_SCALAR_TYPE_BOOL;
}
else if ((type->type_flags & SCALAR_TYPE_FLAGS) == SPV_REFLECT_TYPE_FLAG_INT) {
return SPV_REFLECT_SCALAR_TYPE_INT;
}
else if ((type->type_flags & SCALAR_TYPE_FLAGS) == SPV_REFLECT_TYPE_FLAG_FLOAT) {
return SPV_REFLECT_SCALAR_TYPE_FLOAT;
}
else {
return SPV_REFLECT_SCALAR_TYPE_UNKNOWN;
}
}

static SpvReflectResult GetScalarConstant(const SpvReflectShaderModule* p_module, SpvReflectPrvNode* p_node,
SpvReflectScalarValue* result, SpvReflectScalarType* general_type, SpvReflectTypeDescription** type)
SpvReflectScalarValue* result, SpvReflectTypeDescription** type)
{
SpvReflectPrvParser* p_parser = p_module->_internal->parser;

SpvReflectScalarType g_type;
SpvReflectTypeDescription* d_type;
SpvReflectResult res = GetTypeByTypeId(p_module, p_node->result_type_id, &d_type);
if (res != SPV_REFLECT_RESULT_SUCCESS) return res;

if(d_type->type_flags & SCALAR_DISALLOWED_FLAGS) return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_TYPE;
g_type = ScalarGeneralTypeFromType(d_type);
uint32_t low_word;
CHECKED_READU32(p_parser, p_node->word_offset + 3, low_word);
// There is no alignment requirements in c/cpp for unions
Expand All @@ -3440,7 +3423,6 @@ static SpvReflectResult GetScalarConstant(const SpvReflectShaderModule* p_module
else {
return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_TYPE;
}
*general_type = g_type;
*type = d_type;
return SPV_REFLECT_RESULT_SUCCESS;
}
Expand All @@ -3466,6 +3448,18 @@ static SpvReflectResult ParseSpecializationConstants(SpvReflectPrvParser* p_pars

for (size_t i = 0; i < p_parser->node_count; ++i) {
SpvReflectPrvNode* p_node = &(p_parser->nodes[i]);
// check first if it's WorkGroupSize builtin
// maybe handling builtin as global map may be better.
if (p_node->decorations.built_in == SpvBuiltInWorkgroupSize) {
// WorkGroupSize builtin's target is all ExecutionMode instructions.
for(uint32_t j = 0; j<p_module->entry_point_count; ++j) {
if(p_module->entry_points[j].spirv_execution_model == SpvExecutionModelKernel||
p_module->entry_points[j].spirv_execution_model == SpvExecutionModelGLCompute){
p_module->entry_points[j].local_size.flags = 4;
p_module->entry_points[j].local_size.x = p_node->result_id;
}
}
}
// Specconstants with no id means constant
switch(p_node->op) {
default: continue;
Expand All @@ -3482,19 +3476,19 @@ static SpvReflectResult ParseSpecializationConstants(SpvReflectPrvParser* p_pars
case SpvOpSpecConstant: {
SpvReflectResult result = SPV_REFLECT_RESULT_SUCCESS;
SpvReflectScalarValue default_value = { 0 };
result = GetScalarConstant(p_module, p_node, &default_value, &p_module->specialization_constants[index].general_type, &p_module->specialization_constants[index].type);
result = GetScalarConstant(p_module, p_node, &default_value, &p_module->specialization_constants[index].type);
if (result != SPV_REFLECT_RESULT_SUCCESS) return result;
p_module->specialization_constants[index].default_value = default_value;
p_module->specialization_constants[index].current_value = p_module->specialization_constants[index].default_value;
} break;
}
// spec constant id cannot be the same, at least for valid values. (invalid value is just constant?)
if (p_node->decorations.specialization_constant.value != (uint32_t)INVALID_VALUE) {
for (uint32_t j = 0; j < index; ++j) {
if (p_module->specialization_constants[j].constant_id == p_node->decorations.specialization_constant.value) {
return SPV_REFLECT_RESULT_ERROR_SPIRV_DUPLICATE_SPEC_CONSTANT_NAME;
}
for (uint32_t j = 0; j < index; ++j) {
if (p_module->specialization_constants[j].constant_id == p_node->decorations.specialization_constant.value) {
return SPV_REFLECT_RESULT_ERROR_SPIRV_DUPLICATE_SPEC_CONSTANT_NAME;
}
}
}

p_module->specialization_constants[index].name = p_node->name;
Expand Down Expand Up @@ -3895,10 +3889,6 @@ static SpvReflectResult CreateShaderModule(
result = ParsePushConstantBlocks(parser, p_module);
SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
}
if (result == SPV_REFLECT_RESULT_SUCCESS) {
result = ParseSpecializationConstants(parser, p_module);
SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
}
if (result == SPV_REFLECT_RESULT_SUCCESS) {
result = ParseEntryPoints(parser, p_module);
SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
Expand Down Expand Up @@ -3928,6 +3918,11 @@ static SpvReflectResult CreateShaderModule(
result = ParseExecutionModes(parser, p_module);
SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
}
// WorkGroupSize builtin needs to update entry point localsize member
if (result == SPV_REFLECT_RESULT_SUCCESS) {
result = ParseSpecializationConstants(parser, p_module);
SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
}

// Destroy module if parse was not successful
if (result != SPV_REFLECT_RESULT_SUCCESS) {
Expand Down Expand Up @@ -5723,22 +5718,20 @@ SpvReflectResult EvaluateResultImpl(const SpvReflectShaderModule* p_module, uint
if (!p_node) return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
switch (p_node->op) {
default:
return SPV_REFLECT_RESULT_ERROR_SPIRV_UNRESOLVED_EVALUATION;
return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_INSTRUCTION;
case SpvOpConstantTrue:
{
result->general_type = SPV_REFLECT_SCALAR_TYPE_BOOL;
result->values[0].value.uint32_bool_value = 1;
}
return SPV_REFLECT_RESULT_SUCCESS;
case SpvOpConstantFalse:
{
result->general_type = SPV_REFLECT_SCALAR_TYPE_BOOL;
result->values[0].value.uint32_bool_value = 0;
}
return SPV_REFLECT_RESULT_SUCCESS;
case SpvOpConstant:
CONSTANT_RESULT:
return GetScalarConstant(p_module, p_node, &result->values[0], &result->general_type, &result->type);
return GetScalarConstant(p_module, p_node, &result->values[0], &result->type);
case SpvOpSpecConstantTrue: case SpvOpSpecConstantFalse:
case SpvOpSpecConstant:
{
Expand All @@ -5748,32 +5741,58 @@ SpvReflectResult EvaluateResultImpl(const SpvReflectShaderModule* p_module, uint
SpvReflectSpecializationConstant* p_constant;
res = GetSpecContantById(p_module, p_node->decorations.specialization_constant.value, &p_constant);
if (res != SPV_REFLECT_RESULT_SUCCESS) return res;
result->general_type = p_constant->general_type;
result->type = p_constant->type;
result->values[0] = p_constant->current_value;
}
return SPV_REFLECT_RESULT_SUCCESS;
case SpvOpSpecConstantComposite:
{
// only support scalar types for now...
// only support compositing vector types for now...
// vectors are needed for spv compiled to WorkgroupSize builtin
// in expressing actual localsize
res = GetTypeByTypeId(p_module, p_node->result_type_id, &result->type);
if (res != SPV_REFLECT_RESULT_SUCCESS) return res;
// compositing types
if (result->type->type_flags & VECTOR_DISALLOWED_FLAGS) {
return SPV_REFLECT_RESULT_ERROR_SPIRV_UNRESOLVED_EVALUATION;
}
uint32_t vec_size = 1;
// should always have, since scalars do not need composite
if (result->type->type_flags & SPV_REFLECT_TYPE_FLAG_VECTOR) {
vec_size = result->type->traits.numeric.vector.component_count;
}
// check instruction size
if (p_node->word_count != 3 + vec_size) {
return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_INSTRUCTION;
}
for (uint32_t i = 0; i < vec_size; ++i) {
SpvReflectValue operandi = {0};
GET_OPERAND(p_module, p_node, 3 + i, &operandi, maxRecursion);
// check type compatibility
if (operandi.type && (operandi.type->type_flags & SCALAR_DISALLOWED_FLAGS)) {
return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_TYPE;
}
if ((!operandi.type && !(result->type->type_flags& SPV_REFLECT_TYPE_FLAG_BOOL))
||(operandi.type && ((operandi.type->type_flags & SCALAR_TYPE_FLAGS) != (result->type->type_flags & SCALAR_TYPE_FLAGS)))) {
return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_TYPE;
}
result->values[i] = operandi.values[0];
}
}
return SPV_REFLECT_RESULT_ERROR_SPIRV_UNRESOLVED_EVALUATION;
return SPV_REFLECT_RESULT_SUCCESS;
case SpvOpSpecConstantOp:
{
// operation has result type id, thus must be typed
res = GetTypeByTypeId(p_module, p_node->result_type_id, &result->type);
if (res != SPV_REFLECT_RESULT_SUCCESS) return res;

// no support for vectors yet...
if (result->type->type_flags & SPV_REFLECT_TYPE_FLAG_VECTOR) {
// only vector and scalar types of int/bool/float types implemented
// only OpSelect, OpUndef and access chain instructions can work with non-vector or scalar types
// they are not currently supported... (likely never will)
if (result->type->type_flags & VECTOR_DISALLOWED_FLAGS) {
return SPV_REFLECT_RESULT_ERROR_SPIRV_UNRESOLVED_EVALUATION;
}

// only vector and scalar types of int/bool/float types allowed
CHECK_VECTOR_OR_SCALAR_TYPE(result)

result->general_type = ScalarGeneralTypeFromType(result->type);

// evaluate op
uint32_t spec_op;
CHECKED_READU32(p_parser, p_node->word_offset + 3, spec_op);
Expand Down
Loading

0 comments on commit 4c8a546

Please sign in to comment.