Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support BLACK_SCREEN, add format inference, and add new demo #124

Merged
merged 12 commits into from
Apr 2, 2025
1 change: 1 addition & 0 deletions public/demos/gsplat2d-diff.slang
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ RWStructuredBuffer<float> adamFirstMoment;
RWStructuredBuffer<float> adamSecondMoment;

[playground::URL("static/jeep.jpg")]
[format("rgba8")]
Texture2D<float4> targetTexture;

// ----- Shared memory declarations --------
Expand Down
1 change: 1 addition & 0 deletions public/demos/image-from-url.slang
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import playground;

[playground::URL("static/jeep.jpg")]
[format("rgba8")]
Texture2D<float4> myImage;

float4 imageMain(uint2 dispatchThreadID, int2 screenSize)
Expand Down
49 changes: 49 additions & 0 deletions public/demos/painting.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import playground;

const static int MAX_BRUSH_SIZE = 16;

[playground::BLACK_SCREEN(1.0, 1.0)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering that if we can extend this attribute to like "COLOR_SCREEN" attribute that we can accept a single color to clear the texture?

Copy link
Contributor Author

@Devon7925 Devon7925 Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it but decided against it because:

  • People can use a call once shader to set the color anyways
  • Color doesn't generalize well to int, uint, or < 3 component textures
  • It would require either specifying a color for every texture

It might make sense to have a set of DEFAULT_VALUE_ attributes that can also apply to buffers though...

In which case the attributes should be renamed:

  • ZEROS -> BUFFER
  • BLACK -> TEXTURE
  • BLACK_SCREEN -> SCREEN_TEXTURE

RWTexture2D<float> tex_red;
[playground::BLACK_SCREEN(1.0, 1.0)]
RWTexture2D<float> tex_green;
[playground::BLACK_SCREEN(1.0, 1.0)]
RWTexture2D<float> tex_blue;

[playground::SLIDER(10.0, 4.0, 16.0)]
uniform float brush_size;
[playground::COLOR_PICK(1.0, 0.0, 1.0)]
uniform float3 color;

[playground::MOUSE_POSITION]
uniform float4 mousePosition;

[shader("compute")]
[numthreads(8, 8, 1)]
[playground::CALL(MAX_BRUSH_SIZE, MAX_BRUSH_SIZE, 1)]
void draw(uint2 dispatchThreadId: SV_DispatchThreadID)
{
if (mousePosition.z >= 0)
return;

let offset = float2(dispatchThreadId.xy) - float(MAX_BRUSH_SIZE) / 2;
if (length(offset) > brush_size / 2)
return;

var mouse_pos = uint2(mousePosition.xy + offset);
tex_red[mouse_pos] = color.r;
tex_green[mouse_pos] = color.g;
tex_blue[mouse_pos] = color.b;
}

float4 imageMain(uint2 dispatchThreadID, int2 screenSize)
{
uint imageW;
uint imageH;
tex_red.GetDimensions(imageW, imageH);

uint2 scaled = (uint2)floor(float2(dispatchThreadID.xy));
uint2 flipped = uint2(scaled.x, imageH - scaled.y);

float4 imageColor = float4(tex_red[flipped], tex_green[flipped], tex_blue[flipped], 1.0);
return imageColor;
}
8 changes: 4 additions & 4 deletions src/App.vue
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,9 @@ export type Shader = {
succ: true,
code: string,
layout: Bindings,
hashedStrings: HashedStringData[],
hashedStrings: HashedStringData,
reflection: ReflectionJSON,
threadGroupSize: { [key: string]: ThreadGroupSize },
threadGroupSizes: { [key: string]: [number, number, number] },
};

export type MaybeShader = Shader | {
Expand All @@ -370,7 +370,7 @@ function compileShader(userSource: string, entryPoint: string, compileTarget: ty
return { succ: false };
}

let [compiledCode, layout, hashedStrings, reflectionJsonObj, threadGroupSize] = compiledResult;
let [compiledCode, layout, hashedStrings, reflectionJsonObj, threadGroupSizes] = compiledResult;
reflectionJson = reflectionJsonObj;

codeGenArea.value?.setEditorValue(compiledCode);
Expand All @@ -380,7 +380,7 @@ function compileShader(userSource: string, entryPoint: string, compileTarget: ty
window.$jsontree.setJson("reflectionDiv", reflectionJson);
window.$jsontree.refreshAll();

return { succ: true, code: compiledCode, layout: layout, hashedStrings: hashedStrings, reflection: reflectionJson, threadGroupSize: threadGroupSize };
return { succ: true, code: compiledCode, layout: layout, hashedStrings, reflection: reflectionJson, threadGroupSizes };
}

function restoreFromURL(): boolean {
Expand Down
157 changes: 77 additions & 80 deletions src/compiler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import type { ComponentType, EmbindString, GlobalSession, MainModule, Module, Pr
import playgroundSource from "./slang/playground.slang?raw";
import imageMainSource from "./slang/imageMain.slang?raw";
import printMainSource from "./slang/printMain.slang?raw";
import type { HashedStringData } from "./util.js";
import { ACCESS_MAP, getTextureFormat, sizeFromFormat, webgpuFormatfromSlangFormat, type HashedStringData, type ScalarType, type SlangFormat } from "./util.js";

export function isWholeProgramTarget(compileTarget: string) {
return compileTarget == "METAL" || compileTarget == "SPIRV" || compileTarget == "WGSL";
Expand All @@ -18,19 +18,6 @@ const RUNNABLE_ENTRY_POINT_SOURCE_MAP: { [key in RunnableShaderType]: string } =
'printMain': printMainSource,
};

type BindingDescriptor = {
storageTexture: {
access: "write-only" | "read-write",
format: GPUTextureFormat,
}
} | {
texture: {}
} | {
buffer: {
type: "uniform" | "storage"
}
};

export type Bindings = Map<string, GPUBindGroupLayoutEntry>;

export type ReflectionBinding = {
Expand All @@ -48,24 +35,26 @@ export type ReflectionType = {
"fields": ReflectionParameter[]
} | {
"kind": "vector",
"elementCount": number,
"elementCount": 2 | 3 | 4,
"elementType": ReflectionType,
} | {
"kind": "scalar",
"scalarType": `${"uint" | "int"}${8 | 16 | 32 | 64}` | `${"float"}${16 | 32 | 64}`,
"scalarType": ScalarType,
} | {
"kind": "resource",
"baseShape": "structuredBuffer",
"access"?: "readWrite",
"resultType": ReflectionType
"resultType": ReflectionType,
} | {
"kind": "resource",
"baseShape": "texture2D",
"access"?: "readWrite"
"access"?: "readWrite" | "write",
"resultType": ReflectionType,
};

export type ReflectionParameter = {
"binding": ReflectionBinding,
"format"?: SlangFormat,
"name": string,
"type": ReflectionType,
"userAttribs"?: ReflectionUserAttribute[],
Expand All @@ -74,13 +63,14 @@ export type ReflectionParameter = {
export type ReflectionJSON = {
"entryPoints": ReflectionEntryPoint[],
"parameters": ReflectionParameter[],
"hashedStrings": { [str: string]: number },
};

export type ReflectionEntryPoint = {
"name": string,
"parameters": ReflectionParameter[],
"stage": string,
"threadGroupSize": number[],
"threadGroupSize": [number, number, number],
"userAttribs"?: ReflectionUserAttribute[],
};

Expand Down Expand Up @@ -360,64 +350,77 @@ export class SlangCompiler {
return true;
}

getBindingDescriptor(index: number, programReflection: ProgramLayout, parameter: VariableLayoutReflection): BindingDescriptor | null {
const globalLayout = programReflection.getGlobalParamsTypeLayout();

if (globalLayout == null) {
throw new Error("Could not get layout");
}

const bindingType = globalLayout.getDescriptorSetDescriptorRangeType(0, index);
getBindingDescriptor(name: string, parameterReflection: ReflectionParameter): Partial<GPUBindGroupLayoutEntry> {
if (parameterReflection.type.kind == "resource") {
if (parameterReflection.type.baseShape == "texture2D") {
let slangAccess = parameterReflection.type.access;
if (slangAccess == undefined) {
return { texture: {} };
}
let access = ACCESS_MAP[slangAccess];

let scalarType: ScalarType;
let componentCount: 1 | 2 | 3 | 4;
if (parameterReflection.type.resultType.kind == "scalar") {
componentCount = 1;
scalarType = parameterReflection.type.resultType.scalarType;
} else if (parameterReflection.type.resultType.kind == "vector") {
componentCount = parameterReflection.type.resultType.elementCount;
if (parameterReflection.type.resultType.elementType.kind != "scalar") throw new Error(`Unhandled inner type for ${name}`)
scalarType = parameterReflection.type.resultType.elementType.scalarType;
} else {
throw new Error(`Unhandled inner type for ${name}`)
}

// Special case.. TODO: Remove this as soon as the reflection API properly reports write-only textures.
if (parameter.getName() == "outputTexture") {
return { storageTexture: { access: "write-only", format: "rgba8unorm" } };
}
let format: GPUTextureFormat;
if (parameterReflection.format) {
format = webgpuFormatfromSlangFormat(parameterReflection.format);
} else {
try {
format = getTextureFormat(componentCount, scalarType, access);
} catch (e) {
if (e instanceof Error)
throw new Error(`Could not get texture format for ${name}: ${e.message}`)
else
throw new Error(`Could not get texture format for ${name}`)
}
}

if (bindingType == this.slangWasmModule.BindingType.Texture) {
return { texture: {} };
}
else if (bindingType == this.slangWasmModule.BindingType.MutableTexture) {
return { storageTexture: { access: "read-write", format: "r32float" } };
}
else if (bindingType == this.slangWasmModule.BindingType.ConstantBuffer) {
return { storageTexture: { access, format } };
} else if (parameterReflection.type.baseShape == "structuredBuffer") {
return { buffer: { type: 'storage' } };
} else {
let _: never = parameterReflection.type;
console.error(`Could not generate binding for ${name}`)
return {}
}
} else if (parameterReflection.binding.kind == "uniform") {
return { buffer: { type: 'uniform' } };
} else {
console.error(`Could not generate binding for ${name}`)
return {}
}
else if (bindingType == this.slangWasmModule.BindingType.MutableTypedBuffer) {
return { buffer: { type: 'storage' } };
}
else if (bindingType == this.slangWasmModule.BindingType.MutableRawBuffer) {
return { buffer: { type: 'storage' } };
}
return null;
}

getResourceBindings(linkedProgram: ComponentType): Bindings {
const reflection: ProgramLayout | null = linkedProgram.getLayout(0); // assume target-index = 0

if (reflection == null) {
throw new Error("Could not get reflection!");
}

const count = reflection.getParameterCount();

let resourceDescriptors = new Map();
for (let i = 0; i < count; i++) {
const parameter = reflection.getParameterByIndex(i);
if (parameter == null) {
throw new Error("Invalid state!");
}
const name = parameter.getName();
getResourceBindings(reflectionJson: ReflectionJSON): Bindings {
let resourceDescriptors: Bindings = new Map();
for (let parameter of reflectionJson.parameters) {
const name = parameter.name;
let binding = {
binding: parameter.getBindingIndex(),
binding: parameter.binding.kind == "descriptorTableSlot" ? parameter.binding.index : 0,
visibility: GPUShaderStage.COMPUTE,
};

const resourceInfo = this.getBindingDescriptor(parameter.getBindingIndex(), reflection, parameter);
let parameterReflection = reflectionJson.parameters.find((p) => p.name == name)

if (parameterReflection == undefined) {
throw new Error("Could not find parameter in reflection JSON")
}

const resourceInfo = this.getBindingDescriptor(name, parameterReflection);

// extend binding with resourceInfo
if (resourceInfo)
Object.assign(binding, resourceInfo);
Object.assign(binding, resourceInfo);

resourceDescriptors.set(name, binding);
}
Expand All @@ -437,7 +440,7 @@ export class SlangCompiler {
return true;
}

compile(shaderSource: string, entryPointName: string, compileTargetStr: string, noWebGPU: boolean): null | [string, Bindings, HashedStringData[], ReflectionJSON, { [key: string]: ThreadGroupSize }] {
compile(shaderSource: string, entryPointName: string, compileTargetStr: string, noWebGPU: boolean): null | [string, Bindings, HashedStringData, ReflectionJSON, { [key: string]: [number, number, number] }] {
this.diagnosticsMsg = "";

let shouldLinkPlaygroundModule = RUNNABLE_ENTRY_POINT_NAMES.some((entry_point) => shaderSource.match(entry_point) != null);
Expand Down Expand Up @@ -476,7 +479,6 @@ export class SlangCompiler {
return null;
let program: ComponentType = slangSession.createCompositeComponentType(components);
let linkedProgram: ComponentType = program.link();
let hashedStrings: HashedStringData[] = linkedProgram.loadStrings();

let outCode: string;
if (compileTargetStr == "SPIRV") {
Expand All @@ -493,9 +495,10 @@ export class SlangCompiler {
0 /* entryPointIndex */, 0 /* targetIndex */);
}

let bindings: Bindings = noWebGPU ? new Map() : this.getResourceBindings(linkedProgram);

let reflectionJson: ReflectionJSON = linkedProgram.getLayout(0)?.toJsonObject();
let hashedStrings: HashedStringData = reflectionJson.hashedStrings ? Object.fromEntries(Object.entries(reflectionJson.hashedStrings).map(entry => entry.reverse())) : {};

let bindings: Bindings = noWebGPU ? new Map() : this.getResourceBindings(reflectionJson);

// remove incorrect uniform bindings
let has_uniform_been_binded = false;
Expand All @@ -515,15 +518,9 @@ export class SlangCompiler {
}

// Also read the shader work-group sizes.
let threadGroupSize: { [key: string]: ThreadGroupSize } = {};
const layout = linkedProgram.getLayout(0);
if (layout) {
const entryPoints = this.findDefinedEntryPoints(shaderSource);
for (const name of entryPoints) {
const entryPointReflection = layout.findEntryPointByName(name);
threadGroupSize[name] = entryPointReflection ? entryPointReflection.getComputeThreadGroupSize() :
{ x: 1, y: 1, z: 1 } as ThreadGroupSize;
}
let threadGroupSizes: { [key: string]: [number, number, number] } = {};
for (const entryPoint of reflectionJson.entryPoints) {
threadGroupSizes[entryPoint.name] = entryPoint.threadGroupSize;
}

if (outCode == "") {
Expand All @@ -538,7 +535,7 @@ export class SlangCompiler {
if (!outCode || outCode == "")
return null;

return [outCode, bindings, hashedStrings, reflectionJson, threadGroupSize];
return [outCode, bindings, hashedStrings, reflectionJson, threadGroupSizes];
} catch (e) {
console.error(e);
// typescript is missing the type for WebAssembly.Exception
Expand All @@ -550,4 +547,4 @@ export class SlangCompiler {
return null;
}
}
};
};
2 changes: 2 additions & 0 deletions src/components/Help.vue
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ defineExpose({
Initialize a <code>float</code> buffer with zeros of the provided size.
<h4 class="doc-header"><code>[playground::BLACK(512, 512)]</code></h4>
Initialize a <code>float</code> texture with zeros of the provided size.
<h4 class="doc-header"><code>[playground::BLACK_SCREEN(1.0, 1.0)]</code></h4>
Initialize a <code>float</code> texture with zeros with a size proportional to the screen size.
<h4 class="doc-header"><code>[playground::URL("https://example.com/image.png")]</code></h4>
Initialize a texture with image from URL.
<h4 class="doc-header"><code>[playground::RAND(1000)]</code></h4>
Expand Down
Loading