Skip to content

Commit 0ea01bb

Browse files
authored
Make WGSL a whole program target (#114)
* Make WGSL a whole program target * Fix bugs with whole program targets not generating all entrypoints * Compile WGSL shader module once for all entrypoints * Clean up code to address review comments * Update type
1 parent 6e9c4b3 commit 0ea01bb

File tree

4 files changed

+37
-38
lines changed

4 files changed

+37
-38
lines changed

src/App.vue

+6-17
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import Help from './components/Help.vue'
88
import RenderCanvas from './components/RenderCanvas.vue'
99
import { compiler, checkShaderType, slangd, moduleLoadingMessage } from './try-slang'
1010
import { computed, defineAsyncComponent, onBeforeMount, onMounted, ref, useTemplateRef, watch, type Ref } from 'vue'
11-
import { isWholeProgramTarget, type Bindings, type ReflectionJSON, type ShaderType } from './compiler'
11+
import { isWholeProgramTarget, type Bindings, type ReflectionJSON, type RunnableShaderType, type ShaderType } from './compiler'
1212
import { demoList } from './demo-list'
1313
import { compressToBase64URL, decompressFromBase64URL, getResourceCommandsFromAttributes, getUniformSize, getUniformSliders, isWebGPUSupported, parseCallCommands, type CallCommand, type ResourceCommand, type UniformController } from './util'
1414
import type { ThreadGroupSize } from './slang-wasm'
@@ -237,8 +237,8 @@ function compileOrRun() {
237237
238238
export type CompiledPlayground = {
239239
slangSource: string,
240-
mainShader: Shader,
241-
callCommandShaders: Shader[],
240+
shader: Shader,
241+
mainEntryPoint: RunnableShaderType,
242242
resourceCommands: ResourceCommand[],
243243
callCommands: CallCommand[],
244244
uniformSize: number,
@@ -286,26 +286,15 @@ function doRun() {
286286
throw new Error("Error while parsing '//! CALL' commands: " + error.message);
287287
}
288288
289-
let callCommandShaders: Shader[] = [];
290-
if (callCommands && (callCommands.length > 0)) {
291-
for (const command of callCommands) {
292-
const compiledResult = compileShader(userSource, command.fnName, "WGSL");
293-
if (!compiledResult.succ) {
294-
throw new Error("Failed to compile shader for requested entry-point: " + command.fnName);
295-
}
296-
callCommandShaders.push(compiledResult);
297-
}
298-
}
299-
300289
if (compiler == null) {
301290
throw new Error("Could not get compiler");
302291
}
303292
toggleDisplayMode(compiler.shaderType);
304293
305294
renderCanvas.value.onRun({
306295
slangSource: userSource,
307-
mainShader: ret,
308-
callCommandShaders,
296+
shader: ret,
297+
mainEntryPoint: entryPointName,
309298
resourceCommands,
310299
callCommands,
311300
uniformSize,
@@ -354,7 +343,7 @@ export type Shader = {
354343
layout: Bindings,
355344
hashedStrings: any,
356345
reflection: ReflectionJSON,
357-
threadGroupSize: ThreadGroupSize | { x: number, y: number, z: number },
346+
threadGroupSize: { [key: string]: ThreadGroupSize },
358347
};
359348
360349
export type MaybeShader = Shader | {

src/compiler.ts

+18-9
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import type { ComponentType, EmbindString, GlobalSession, MainModule, Module, Pr
33
import { playgroundSource } from "./playgroundShader.js";
44

55
export function isWholeProgramTarget(compileTarget: string) {
6-
return compileTarget == "METAL" || compileTarget == "SPIRV";
6+
return compileTarget == "METAL" || compileTarget == "SPIRV" || compileTarget == "WGSL";
77
}
88

99
export const RUNNABLE_ENTRY_POINT_NAMES = ['imageMain', 'printMain'] as const;
@@ -243,9 +243,10 @@ export class SlangCompiler {
243243
// we will also add them to the dropdown list.
244244
findDefinedEntryPoints(shaderSource: string): string[] {
245245
let result: string[] = [];
246+
let runnable: string[] = [];
246247
for (let entryPointName of RUNNABLE_ENTRY_POINT_NAMES) {
247248
if (shaderSource.match(entryPointName)) {
248-
result.push(entryPointName);
249+
runnable.push(entryPointName);
249250
}
250251
}
251252
let slangSession: Session | null | undefined;
@@ -256,7 +257,7 @@ export class SlangCompiler {
256257
return [];
257258
}
258259
let module: Module | null = null;
259-
if (result.length > 0) {
260+
if (runnable.length > 0) {
260261
slangSession.loadModuleFromSource(playgroundSource, "playground", "/playground.slang");
261262
}
262263
module = slangSession.loadModuleFromSource(shaderSource, "user", "/user.slang");
@@ -278,6 +279,7 @@ export class SlangCompiler {
278279
if (slangSession)
279280
slangSession.delete();
280281
}
282+
result.push(...runnable);
281283
return result;
282284
}
283285

@@ -342,7 +344,7 @@ export class SlangCompiler {
342344

343345
// If entry point is provided, we know for sure this is not a whole program compilation,
344346
// so we will just go to find the correct module to include in the compilation.
345-
if (entryPointName != "") {
347+
if (entryPointName != "" && !isWholeProgram) {
346348
if (this.isRunnableEntryPoint(entryPointName)) {
347349
// we use the same entry point name as module name
348350
const mainProgram = this.getPrecompiledProgram(slangSession, entryPointName);
@@ -465,7 +467,7 @@ export class SlangCompiler {
465467
return true;
466468
}
467469

468-
compile(shaderSource: string, entryPointName: string, compileTargetStr: string, noWebGPU: boolean): null | [string, Bindings, any, ReflectionJSON, ThreadGroupSize | { x: number, y: number, z: number }] {
470+
compile(shaderSource: string, entryPointName: string, compileTargetStr: string, noWebGPU: boolean): null | [string, Bindings, any, ReflectionJSON, { [key: string]: ThreadGroupSize }] {
469471
this.diagnosticsMsg = "";
470472

471473
let shouldLinkPlaygroundModule = RUNNABLE_ENTRY_POINT_NAMES.some((entry_point) => shaderSource.match(entry_point) != null);
@@ -537,10 +539,17 @@ export class SlangCompiler {
537539
}
538540
}
539541

540-
// Also read the shader work-group size.
541-
const entryPointReflection = linkedProgram.getLayout(0)?.findEntryPointByName(entryPointName);
542-
let threadGroupSize = entryPointReflection ? entryPointReflection.getComputeThreadGroupSize() :
543-
{ x: 1, y: 1, z: 1 };
542+
// Also read the shader work-group sizes.
543+
let threadGroupSize: { [key: string]: ThreadGroupSize } = {};
544+
const layout = linkedProgram.getLayout(0);
545+
if (layout) {
546+
const entryPoints = this.findDefinedEntryPoints(shaderSource);
547+
for (const name of entryPoints) {
548+
const entryPointReflection = layout.findEntryPointByName(name);
549+
threadGroupSize[name] = entryPointReflection ? entryPointReflection.getComputeThreadGroupSize() :
550+
{ x: 1, y: 1, z: 1 } as ThreadGroupSize;
551+
}
552+
}
544553

545554
if (outCode == "") {
546555
let error = this.slangWasmModule.getLastError();

src/components/RenderCanvas.vue

+11-10
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ async function processResourceCommands(pipeline: ComputePipeline | GraphicsPipel
584584
randomPipeline.createPipelineLayout(layout);
585585
586586
// Create the pipeline (without resource bindings for now)
587-
randomPipeline.createPipeline(module, null);
587+
randomPipeline.createPipeline(module, "computeMain", null);
588588
589589
randFloatPipeline = randomPipeline;
590590
}
@@ -705,21 +705,23 @@ function onRun(compiledCode: CompiledPlayground) {
705705
withRenderLock(
706706
// setupFn
707707
async () => {
708-
hashedStrings = compiledCode.mainShader.hashedStrings;
708+
hashedStrings = compiledCode.shader.hashedStrings;
709709
710-
resourceBindings = compiledCode.mainShader.layout;
710+
resourceBindings = compiledCode.shader.layout;
711711
// create a pipeline resource 'signature' based on the bindings found in the program.
712712
computePipeline.createPipelineLayout(resourceBindings);
713713
714714
if (extraComputePipelines.length > 0)
715715
extraComputePipelines = []; // This should release the resources of the extra pipelines.
716716
717-
for (const callShader of compiledCode.callCommandShaders) {
718-
const module = device.createShaderModule({ code: callShader.code });
717+
const module = device.createShaderModule({ code: compiledCode.shader.code });
718+
719+
for (const callCommand of compiledCode.callCommands) {
720+
const entryPoint = callCommand.fnName;
719721
const pipeline = new ComputePipeline(device);
720-
pipeline.createPipelineLayout(callShader.layout);
721-
pipeline.createPipeline(module, null);
722-
pipeline.setThreadGroupSize(callShader.threadGroupSize);
722+
pipeline.createPipelineLayout(compiledCode.shader.layout);
723+
pipeline.createPipeline(module, entryPoint, null);
724+
pipeline.setThreadGroupSize(compiledCode.shader.threadGroupSize[entryPoint]);
723725
extraComputePipelines.push(pipeline);
724726
}
725727
@@ -742,8 +744,7 @@ function onRun(compiledCode: CompiledPlayground) {
742744
passThroughPipeline.inputTexture = outputTexture;
743745
passThroughPipeline.createBindGroup();
744746
745-
const module = device.createShaderModule({ code: compiledCode.mainShader.code });
746-
computePipeline.createPipeline(module, allocatedResources);
747+
computePipeline.createPipeline(module, compiledCode.mainEntryPoint, allocatedResources);
747748
748749
// Create bind groups for the extra pipelines
749750
for (const pipeline of extraComputePipelines)

src/compute.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@ export class ComputePipeline {
4646
this.pipelineLayout = layout;
4747
}
4848

49-
createPipeline(shaderModule: GPUShaderModule, resources: Map<string, GPUTexture | GPUBuffer> | null) {
49+
createPipeline(shaderModule: GPUShaderModule, entryPoint: string, resources: Map<string, GPUTexture | GPUBuffer> | null) {
5050
if (this.pipelineLayout == undefined)
5151
throw new Error("Cannot create pipeline without layout");
5252
const pipeline = this.device.createComputePipeline({
5353
label: 'compute pipeline',
5454
layout: this.pipelineLayout,
55-
compute: { module: shaderModule },
55+
compute: { module: shaderModule, entryPoint },
5656
});
5757

5858
this.pipeline = pipeline;

0 commit comments

Comments
 (0)