Skip to content

Commit a6d4d6c

Browse files
committed
Adding preliminary SpecConstant support
Used code form godot/godot as starting point. See #121 Next steps: 64 bit types need 2 words as operand, `SpecConstantOp` need implementation for evaluating spec constant (spirv-cross has an implementation for uint KhronosGroup/SPIRV-Cross#1463). Also maybe composite type evaluation support?
1 parent 1ef99b0 commit a6d4d6c

File tree

5 files changed

+729
-27
lines changed

5 files changed

+729
-27
lines changed

common/output_stream.cpp

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1040,6 +1040,31 @@ void StreamWriteInterfaceVariable(std::ostream& os, const SpvReflectInterfaceVar
10401040
}
10411041
}
10421042

1043+
void StreamWriteSpecializationConstant(std::ostream& os, const SpvReflectSpecializationConstant& obj, const char* indent)
1044+
{
1045+
const char* t = indent;
1046+
os << t << "spirv id : " << obj.spirv_id << "\n";
1047+
os << t << "constant id: " << obj.constant_id << "\n";
1048+
os << t << "name : " << (obj.name != NULL ? obj.name : "") << '\n';
1049+
os << t << "type : ";
1050+
switch (obj.constant_type) {
1051+
case SPV_REFLECT_SPECIALIZATION_CONSTANT_BOOL:
1052+
os << "boolean\n";
1053+
os << t << "default : " << obj.default_value.int_bool_value;
1054+
break;
1055+
case SPV_REFLECT_SPECIALIZATION_CONSTANT_INT:
1056+
os << "integer\n";
1057+
os << t << "default : "<<obj.default_value.int_bool_value;
1058+
break;
1059+
case SPV_REFLECT_SPECIALIZATION_CONSTANT_FLOAT:
1060+
os << "float\n";
1061+
os << t << "default : " << obj.default_value.float_value;
1062+
break;
1063+
default:
1064+
os << "unknown type";
1065+
}
1066+
}
1067+
10431068
void StreamWriteEntryPoint(std::ostream& os, const SpvReflectEntryPoint& obj, const char* indent)
10441069
{
10451070
os << indent << "entry point : " << obj.name;
@@ -1088,9 +1113,32 @@ void WriteReflection(const spv_reflect::ShaderModule& obj, bool flatten_cbuffers
10881113
std::vector<SpvReflectDescriptorBinding*> bindings;
10891114
std::vector<SpvReflectDescriptorSet*> sets;
10901115
std::vector<SpvReflectBlockVariable*> push_constant_bocks;
1116+
std::vector<SpvReflectSpecializationConstant*> specialization_constants;
1117+
1118+
count = 0;
1119+
SpvReflectResult result = obj.EnumerateSpecializationConstants(&count, nullptr);
1120+
USE_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
1121+
specialization_constants.resize(count);
1122+
result = obj.EnumerateSpecializationConstants(&count, specialization_constants.data());
1123+
USE_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
1124+
if (count > 0) {
1125+
os << "\n";
1126+
os << "\n";
1127+
os << "\n";
1128+
os << t << "Sepecialization constants: " << count << "\n\n";
1129+
for (size_t i = 0; i < specialization_constants.size(); ++i) {
1130+
auto p_var = specialization_constants[i];
1131+
USE_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
1132+
os << tt << i << ":" << "\n";
1133+
StreamWriteSpecializationConstant(os, *p_var, ttt);
1134+
if (i < (count - 1)) {
1135+
os << "\n";
1136+
}
1137+
}
1138+
}
10911139

10921140
count = 0;
1093-
SpvReflectResult result = obj.EnumerateInputVariables(&count, nullptr);
1141+
result = obj.EnumerateInputVariables(&count, nullptr);
10941142
USE_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
10951143
variables.resize(count);
10961144
result = obj.EnumerateInputVariables(&count, variables.data());

main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ int main(int argn, char** argv)
107107
std::vector<char> spv_data(size);
108108
spv_ifstream.read(spv_data.data(), size);
109109

110-
spv_reflect::ShaderModule reflection(spv_data.size(), spv_data.data());
110+
spv_reflect::ShaderModule reflection(spv_data.size(), spv_data.data(), SPV_REFLECT_MODULE_FLAG_EVALUATE_SPEC_CONSTANT);
111111
if (reflection.GetResult() != SPV_REFLECT_RESULT_SUCCESS) {
112112
std::cerr << "ERROR: could not process '" << input_spv_path
113113
<< "' (is it a valid SPIR-V bytecode?)" << std::endl;

spec_constant.patch

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
diff --git a/thirdparty/spirv-reflect/spirv_reflect.c b/thirdparty/spirv-reflect/spirv_reflect.c
2+
index e9b11bf495..f181df5fa2 100644
3+
--- a/thirdparty/spirv-reflect/spirv_reflect.c
4+
+++ b/thirdparty/spirv-reflect/spirv_reflect.c
5+
@@ -125,6 +125,9 @@ typedef struct SpvReflectPrvDecorations {
6+
SpvReflectPrvNumberDecoration location;
7+
SpvReflectPrvNumberDecoration offset;
8+
SpvReflectPrvNumberDecoration uav_counter_buffer;
9+
+// -- GODOT begin --
10+
+ SpvReflectPrvNumberDecoration specialization_constant;
11+
+// -- GODOT end --
12+
SpvReflectPrvStringDecoration semantic;
13+
uint32_t array_stride;
14+
uint32_t matrix_stride;
15+
@@ -631,6 +634,9 @@ static SpvReflectResult ParseNodes(SpvReflectPrvParser* p_parser)
16+
p_parser->nodes[i].decorations.offset.value = (uint32_t)INVALID_VALUE;
17+
p_parser->nodes[i].decorations.uav_counter_buffer.value = (uint32_t)INVALID_VALUE;
18+
p_parser->nodes[i].decorations.built_in = (SpvBuiltIn)INVALID_VALUE;
19+
+// -- GODOT begin --
20+
+ p_parser->nodes[i].decorations.specialization_constant.value = (SpvBuiltIn)INVALID_VALUE;
21+
+// -- GODOT end --
22+
}
23+
// Mark source file id node
24+
p_parser->source_file_id = (uint32_t)INVALID_VALUE;
25+
@@ -821,10 +827,16 @@ static SpvReflectResult ParseNodes(SpvReflectPrvParser* p_parser)
26+
CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->result_id);
27+
}
28+
break;
29+
-
30+
+// -- GODOT begin --
31+
case SpvOpSpecConstantTrue:
32+
case SpvOpSpecConstantFalse:
33+
- case SpvOpSpecConstant:
34+
+ case SpvOpSpecConstant: {
35+
+ CHECKED_READU32(p_parser, p_node->word_offset + 1, p_node->result_type_id);
36+
+ CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->result_id);
37+
+ p_node->is_type = true;
38+
+ }
39+
+ break;
40+
+// -- GODOT end --
41+
case SpvOpSpecConstantComposite:
42+
case SpvOpSpecConstantOp: {
43+
CHECKED_READU32(p_parser, p_node->word_offset + 1, p_node->result_type_id);
44+
@@ -856,7 +868,7 @@ static SpvReflectResult ParseNodes(SpvReflectPrvParser* p_parser)
45+
CHECKED_READU32(p_parser, p_node->word_offset + 3, p_access_chain->base_id);
46+
//
47+
// SPIRV_ACCESS_CHAIN_INDEX_OFFSET (4) is the number of words up until the first index:
48+
- // [Node, Result Type Id, Result Id, Base Id, <Indexes>]
49+
+ // [SpvReflectPrvNode, Result Type Id, Result Id, Base Id, <Indexes>]
50+
//
51+
p_access_chain->index_count = (node_word_count - SPIRV_ACCESS_CHAIN_INDEX_OFFSET);
52+
if (p_access_chain->index_count > 0) {
53+
@@ -1338,6 +1350,9 @@ static SpvReflectResult ParseDecorations(SpvReflectPrvParser* p_parser)
54+
skip = true;
55+
}
56+
break;
57+
+// -- GODOT begin --
58+
+ case SpvDecorationSpecId:
59+
+// -- GODOT end --
60+
case SpvDecorationRelaxedPrecision:
61+
case SpvDecorationBlock:
62+
case SpvDecorationBufferBlock:
63+
@@ -1481,7 +1496,14 @@ static SpvReflectResult ParseDecorations(SpvReflectPrvParser* p_parser)
64+
p_target_decorations->input_attachment_index.word_offset = word_offset;
65+
}
66+
break;
67+
-
68+
+// -- GODOT begin --
69+
+ case SpvDecorationSpecId: {
70+
+ uint32_t word_offset = p_node->word_offset + member_offset+ 3;
71+
+ CHECKED_READU32(p_parser, word_offset, p_target_decorations->specialization_constant.value);
72+
+ p_target_decorations->specialization_constant.word_offset = word_offset;
73+
+ }
74+
+ break;
75+
+// -- GODOT end --
76+
case SpvReflectDecorationHlslCounterBufferGOOGLE: {
77+
uint32_t word_offset = p_node->word_offset + member_offset+ 3;
78+
CHECKED_READU32(p_parser, word_offset, p_target_decorations->uav_counter_buffer.value);
79+
@@ -1789,6 +1811,13 @@ static SpvReflectResult ParseType(
80+
p_type->type_flags |= SPV_REFLECT_TYPE_FLAG_EXTERNAL_ACCELERATION_STRUCTURE;
81+
}
82+
break;
83+
+// -- GODOT begin --
84+
+ case SpvOpSpecConstantTrue:
85+
+ case SpvOpSpecConstantFalse:
86+
+ case SpvOpSpecConstant: {
87+
+ }
88+
+ break;
89+
+// -- GODOT end --
90+
}
91+
92+
if (result == SPV_REFLECT_RESULT_SUCCESS) {
93+
@@ -3269,6 +3298,69 @@ static SpvReflectResult ParseExecutionModes(
94+
return SPV_REFLECT_RESULT_SUCCESS;
95+
}
96+
97+
+// -- GODOT begin --
98+
+static SpvReflectResult ParseSpecializationConstants(SpvReflectPrvParser* p_parser, SpvReflectShaderModule* p_module)
99+
+{
100+
+ p_module->specialization_constant_count = 0;
101+
+ p_module->specialization_constants = NULL;
102+
+ for (size_t i = 0; i < p_parser->node_count; ++i) {
103+
+ SpvReflectPrvNode* p_node = &(p_parser->nodes[i]);
104+
+ if (p_node->op == SpvOpSpecConstantTrue || p_node->op == SpvOpSpecConstantFalse || p_node->op == SpvOpSpecConstant) {
105+
+ p_module->specialization_constant_count++;
106+
+ }
107+
+ }
108+
+
109+
+ if (p_module->specialization_constant_count == 0) {
110+
+ return SPV_REFLECT_RESULT_SUCCESS;
111+
+ }
112+
+
113+
+ p_module->specialization_constants = (SpvReflectSpecializationConstant*)calloc(p_module->specialization_constant_count, sizeof(SpvReflectSpecializationConstant));
114+
+
115+
+ uint32_t index = 0;
116+
+
117+
+ for (size_t i = 0; i < p_parser->node_count; ++i) {
118+
+ SpvReflectPrvNode* p_node = &(p_parser->nodes[i]);
119+
+ switch(p_node->op) {
120+
+ default: continue;
121+
+ case SpvOpSpecConstantTrue: {
122+
+ p_module->specialization_constants[index].constant_type = SPV_REFLECT_SPECIALIZATION_CONSTANT_BOOL;
123+
+ p_module->specialization_constants[index].default_value.int_bool_value = 1;
124+
+ } break;
125+
+ case SpvOpSpecConstantFalse: {
126+
+ p_module->specialization_constants[index].constant_type = SPV_REFLECT_SPECIALIZATION_CONSTANT_BOOL;
127+
+ p_module->specialization_constants[index].default_value.int_bool_value = 0;
128+
+ } break;
129+
+ case SpvOpSpecConstant: {
130+
+ SpvReflectResult result = SPV_REFLECT_RESULT_SUCCESS;
131+
+ uint32_t element_type_id = (uint32_t)INVALID_VALUE;
132+
+ uint32_t default_value = 0;
133+
+ IF_READU32(result, p_parser, p_node->word_offset + 1, element_type_id);
134+
+ IF_READU32(result, p_parser, p_node->word_offset + 3, default_value);
135+
+
136+
+ SpvReflectPrvNode* p_next_node = FindNode(p_parser, element_type_id);
137+
+
138+
+ if (p_next_node->op == SpvOpTypeInt) {
139+
+ p_module->specialization_constants[index].constant_type = SPV_REFLECT_SPECIALIZATION_CONSTANT_INT;
140+
+ } else if (p_next_node->op == SpvOpTypeFloat) {
141+
+ p_module->specialization_constants[index].constant_type = SPV_REFLECT_SPECIALIZATION_CONSTANT_FLOAT;
142+
+ } else {
143+
+ return SPV_REFLECT_RESULT_ERROR_PARSE_FAILED;
144+
+ }
145+
+
146+
+ p_module->specialization_constants[index].default_value.int_bool_value = default_value; //bits are the same for int and float
147+
+ } break;
148+
+ }
149+
+
150+
+ p_module->specialization_constants[index].name = p_node->name;
151+
+ p_module->specialization_constants[index].constant_id = p_node->decorations.specialization_constant.value;
152+
+ p_module->specialization_constants[index].spirv_id = p_node->result_id;
153+
+ index++;
154+
+ }
155+
+
156+
+ return SPV_REFLECT_RESULT_SUCCESS;
157+
+}
158+
+// -- GODOT end --
159+
+
160+
static SpvReflectResult ParsePushConstantBlocks(
161+
SpvReflectPrvParser* p_parser,
162+
SpvReflectShaderModule* p_module)
163+
@@ -3650,6 +3742,12 @@ static SpvReflectResult CreateShaderModule(
164+
result = ParsePushConstantBlocks(&parser, p_module);
165+
SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
166+
}
167+
+// -- GODOT begin --
168+
+ if (result == SPV_REFLECT_RESULT_SUCCESS) {
169+
+ result = ParseSpecializationConstants(&parser, p_module);
170+
+ SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
171+
+ }
172+
+// -- GODOT end --
173+
if (result == SPV_REFLECT_RESULT_SUCCESS) {
174+
result = ParseEntryPoints(&parser, p_module);
175+
SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
176+
@@ -3807,6 +3905,9 @@ void spvReflectDestroyShaderModule(SpvReflectShaderModule* p_module)
177+
SafeFree(p_entry->used_push_constants);
178+
}
179+
SafeFree(p_module->entry_points);
180+
+// -- GODOT begin --
181+
+ SafeFree(p_module->specialization_constants);
182+
+// -- GODOT end --
183+
184+
// Push constants
185+
for (size_t i = 0; i < p_module->push_constant_block_count; ++i) {
186+
@@ -4077,6 +4178,38 @@ SpvReflectResult spvReflectEnumerateEntryPointInterfaceVariables(
187+
return SPV_REFLECT_RESULT_SUCCESS;
188+
}
189+
190+
+// -- GODOT begin --
191+
+SpvReflectResult spvReflectEnumerateSpecializationConstants(
192+
+ const SpvReflectShaderModule* p_module,
193+
+ uint32_t* p_count,
194+
+ SpvReflectSpecializationConstant** pp_constants
195+
+)
196+
+{
197+
+ if (IsNull(p_module)) {
198+
+ return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
199+
+ }
200+
+ if (IsNull(p_count)) {
201+
+ return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
202+
+ }
203+
+
204+
+ if (IsNotNull(pp_constants)) {
205+
+ if (*p_count != p_module->specialization_constant_count) {
206+
+ return SPV_REFLECT_RESULT_ERROR_COUNT_MISMATCH;
207+
+ }
208+
+
209+
+ for (uint32_t index = 0; index < *p_count; ++index) {
210+
+ SpvReflectSpecializationConstant *p_const = &p_module->specialization_constants[index];
211+
+ pp_constants[index] = p_const;
212+
+ }
213+
+ }
214+
+ else {
215+
+ *p_count = p_module->specialization_constant_count;
216+
+ }
217+
+
218+
+ return SPV_REFLECT_RESULT_SUCCESS;
219+
+}
220+
+// -- GODOT end --
221+
+
222+
SpvReflectResult spvReflectEnumerateInputVariables(
223+
const SpvReflectShaderModule* p_module,
224+
uint32_t* p_count,
225+
diff --git a/thirdparty/spirv-reflect/spirv_reflect.h b/thirdparty/spirv-reflect/spirv_reflect.h
226+
index e9e4c40755..948533d3c0 100644
227+
--- a/thirdparty/spirv-reflect/spirv_reflect.h
228+
+++ b/thirdparty/spirv-reflect/spirv_reflect.h
229+
@@ -323,6 +323,28 @@ typedef struct SpvReflectTypeDescription {
230+
struct SpvReflectTypeDescription* members;
231+
} SpvReflectTypeDescription;
232+
233+
+// -- GODOT begin --
234+
+/*! @struct SpvReflectSpecializationConstant
235+
+
236+
+*/
237+
+
238+
+typedef enum SpvReflectSpecializationConstantType {
239+
+ SPV_REFLECT_SPECIALIZATION_CONSTANT_BOOL = 0,
240+
+ SPV_REFLECT_SPECIALIZATION_CONSTANT_INT = 1,
241+
+ SPV_REFLECT_SPECIALIZATION_CONSTANT_FLOAT = 2,
242+
+} SpvReflectSpecializationConstantType;
243+
+
244+
+typedef struct SpvReflectSpecializationConstant {
245+
+ const char* name;
246+
+ uint32_t spirv_id;
247+
+ uint32_t constant_id;
248+
+ SpvReflectSpecializationConstantType constant_type;
249+
+ union {
250+
+ float float_value;
251+
+ uint32_t int_bool_value;
252+
+ } default_value;
253+
+} SpvReflectSpecializationConstant;
254+
+// -- GODOT end --
255+
256+
/*! @struct SpvReflectInterfaceVariable
257+
258+
@@ -472,6 +494,10 @@ typedef struct SpvReflectShaderModule {
259+
SpvReflectInterfaceVariable* interface_variables; // Uses value(s) from first entry point
260+
uint32_t push_constant_block_count; // Uses value(s) from first entry point
261+
SpvReflectBlockVariable* push_constant_blocks; // Uses value(s) from first entry point
262+
+ // -- GODOT begin --
263+
+ uint32_t specialization_constant_count;
264+
+ SpvReflectSpecializationConstant* specialization_constants;
265+
+ // -- GODOT end --
266+
267+
struct Internal {
268+
SpvReflectModuleFlags module_flags;
269+
@@ -744,6 +770,33 @@ SpvReflectResult spvReflectEnumerateInputVariables(
270+
SpvReflectInterfaceVariable** pp_variables
271+
);
272+
273+
+// -- GOODT begin --
274+
+/*! @fn spvReflectEnumerateSpecializationConstants
275+
+ @brief If the module contains multiple entry points, this will only get
276+
+ the specialization constants for the first one.
277+
+ @param p_module Pointer to an instance of SpvReflectShaderModule.
278+
+ @param p_count If pp_constants is NULL, the module's specialization constant
279+
+ count will be stored here.
280+
+ If pp_variables is not NULL, *p_count must contain
281+
+ the module's specialization constant count.
282+
+ @param pp_variables If NULL, the module's specialization constant count will be
283+
+ written to *p_count.
284+
+ If non-NULL, pp_constants must point to an array with
285+
+ *p_count entries, where pointers to the module's
286+
+ specialization constants will be written. The caller must not
287+
+ free the specialization constants written to this array.
288+
+ @return If successful, returns SPV_REFLECT_RESULT_SUCCESS.
289+
+ Otherwise, the error code indicates the cause of the
290+
+ failure.
291+
+
292+
+*/
293+
+SpvReflectResult spvReflectEnumerateSpecializationConstants(
294+
+ const SpvReflectShaderModule* p_module,
295+
+ uint32_t* p_count,
296+
+ SpvReflectSpecializationConstant** pp_constants
297+
+);
298+
+// -- GODOT end --
299+
+
300+
/*! @fn spvReflectEnumerateEntryPointInputVariables
301+
@brief Enumerate the input variables for a given entry point.
302+
@param entry_point The name of the entry point to get the input variables for.

0 commit comments

Comments
 (0)