@@ -3,7 +3,7 @@ import type { ComponentType, EmbindString, GlobalSession, MainModule, Module, Pr
3
3
import { playgroundSource } from "./playgroundShader.js" ;
4
4
5
5
export function isWholeProgramTarget ( compileTarget : string ) {
6
- return compileTarget == "METAL" || compileTarget == "SPIRV" ;
6
+ return compileTarget == "METAL" || compileTarget == "SPIRV" || compileTarget == "WGSL" ;
7
7
}
8
8
9
9
export const RUNNABLE_ENTRY_POINT_NAMES = [ 'imageMain' , 'printMain' ] as const ;
@@ -243,9 +243,10 @@ export class SlangCompiler {
243
243
// we will also add them to the dropdown list.
244
244
findDefinedEntryPoints ( shaderSource : string ) : string [ ] {
245
245
let result : string [ ] = [ ] ;
246
+ let runnable : string [ ] = [ ] ;
246
247
for ( let entryPointName of RUNNABLE_ENTRY_POINT_NAMES ) {
247
248
if ( shaderSource . match ( entryPointName ) ) {
248
- result . push ( entryPointName ) ;
249
+ runnable . push ( entryPointName ) ;
249
250
}
250
251
}
251
252
let slangSession : Session | null | undefined ;
@@ -256,7 +257,7 @@ export class SlangCompiler {
256
257
return [ ] ;
257
258
}
258
259
let module : Module | null = null ;
259
- if ( result . length > 0 ) {
260
+ if ( runnable . length > 0 ) {
260
261
slangSession . loadModuleFromSource ( playgroundSource , "playground" , "/playground.slang" ) ;
261
262
}
262
263
module = slangSession . loadModuleFromSource ( shaderSource , "user" , "/user.slang" ) ;
@@ -278,6 +279,7 @@ export class SlangCompiler {
278
279
if ( slangSession )
279
280
slangSession . delete ( ) ;
280
281
}
282
+ result . push ( ...runnable ) ;
281
283
return result ;
282
284
}
283
285
@@ -342,7 +344,7 @@ export class SlangCompiler {
342
344
343
345
// If entry point is provided, we know for sure this is not a whole program compilation,
344
346
// so we will just go to find the correct module to include in the compilation.
345
- if ( entryPointName != "" ) {
347
+ if ( entryPointName != "" && ! isWholeProgram ) {
346
348
if ( this . isRunnableEntryPoint ( entryPointName ) ) {
347
349
// we use the same entry point name as module name
348
350
const mainProgram = this . getPrecompiledProgram ( slangSession , entryPointName ) ;
@@ -465,7 +467,7 @@ export class SlangCompiler {
465
467
return true ;
466
468
}
467
469
468
- compile ( shaderSource : string , entryPointName : string , compileTargetStr : string , noWebGPU : boolean ) : null | [ string , Bindings , any , ReflectionJSON , ThreadGroupSize | { x : number , y : number , z : number } ] {
470
+ compile ( shaderSource : string , entryPointName : string , compileTargetStr : string , noWebGPU : boolean ) : null | [ string , Bindings , any , ReflectionJSON , { [ key : string ] : ThreadGroupSize } ] {
469
471
this . diagnosticsMsg = "" ;
470
472
471
473
let shouldLinkPlaygroundModule = RUNNABLE_ENTRY_POINT_NAMES . some ( ( entry_point ) => shaderSource . match ( entry_point ) != null ) ;
@@ -537,10 +539,17 @@ export class SlangCompiler {
537
539
}
538
540
}
539
541
540
- // Also read the shader work-group size.
541
- const entryPointReflection = linkedProgram . getLayout ( 0 ) ?. findEntryPointByName ( entryPointName ) ;
542
- let threadGroupSize = entryPointReflection ? entryPointReflection . getComputeThreadGroupSize ( ) :
543
- { x : 1 , y : 1 , z : 1 } ;
542
+ // Also read the shader work-group sizes.
543
+ let threadGroupSize : { [ key : string ] : ThreadGroupSize } = { } ;
544
+ const layout = linkedProgram . getLayout ( 0 ) ;
545
+ if ( layout ) {
546
+ const entryPoints = this . findDefinedEntryPoints ( shaderSource ) ;
547
+ for ( const name of entryPoints ) {
548
+ const entryPointReflection = layout . findEntryPointByName ( name ) ;
549
+ threadGroupSize [ name ] = entryPointReflection ? entryPointReflection . getComputeThreadGroupSize ( ) :
550
+ { x : 1 , y : 1 , z : 1 } as ThreadGroupSize ;
551
+ }
552
+ }
544
553
545
554
if ( outCode == "" ) {
546
555
let error = this . slangWasmModule . getLastError ( ) ;
0 commit comments