-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Param block support for Metal #208
base: main
Are you sure you want to change the base?
Changes from 8 commits
a79b48d
8f92108
4f7a9e0
a8f2045
1004bec
dbb7f23
09e4fc2
8dc896e
a28c73e
fdcb49d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,7 +37,7 @@ | |
CMAKE_PRESET = { | ||
"windows": "windows-msvc", | ||
"linux": "linux-gcc", | ||
"macos": "macos-clang", | ||
"macos": "macos-arm64-clang", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why did you change this? does the macos-clang preset build for x64 only? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there is no macos-clang preset, only macos-arm64-clang and macos-x64-clang, and I think we don't care macos-x64-clang. |
||
}[PLATFORM] | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
kaizhangNV marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
import pytest | ||
import numpy as np | ||
import sgl | ||
import sys | ||
from pathlib import Path | ||
|
||
sys.path.append(str(Path(__file__).parent)) | ||
import sglhelpers as helpers | ||
|
||
# The shader code is in test_parameterBlock.slang, fill in the parameter block | ||
# 'inputStruct' for the field 'a = 1.0', b = 2, c = 3, then read back the result | ||
# 'd' from the buffer and assert it equals 6.0f. The test will only launch 1 thread | ||
# to test the ParameterBlock binding. | ||
@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) | ||
def test_parameter_block(device_type: sgl.DeviceType): | ||
|
||
# Skip this test on Metal devices | ||
if device_type == sgl.DeviceType.metal: | ||
pytest.skip("Skipping parameter block test until https://github.com/shader-slang/slang/pull/6577 is merged") | ||
|
||
# Create device | ||
print(f"Testing {device_type}") | ||
|
||
device = helpers.get_device(type=device_type) | ||
|
||
# Load the shader program | ||
program = device.load_program("test_parameterBlock.slang", ["computeMain"]) | ||
kernel = device.create_compute_kernel(program) | ||
|
||
# Create a buffer for the output | ||
output_buffer = device.create_buffer( | ||
element_count=1, # Only need one element as we're only launching one thread | ||
struct_size=1024, # float is 4 bytes | ||
usage=sgl.ResourceUsage.unordered_access | sgl.ResourceUsage.shader_resource, | ||
) | ||
|
||
input_buffer = device.create_buffer( | ||
element_count=1, # Only need one element as we're only launching one thread | ||
struct_size=4, # float is 4 bytes | ||
usage=sgl.ResourceUsage.unordered_access | sgl.ResourceUsage.shader_resource, | ||
data=np.array([6.0], dtype=np.float32), | ||
) | ||
|
||
# Create a command buffer | ||
command_buffer = device.create_command_buffer() | ||
|
||
# Encode compute commands | ||
with command_buffer.encode_compute_commands() as encoder: | ||
# Bind the pipeline | ||
shader_object = encoder.bind_pipeline(kernel.pipeline) | ||
|
||
# Create a shader cursor for the parameter block | ||
cursor = sgl.ShaderCursor(shader_object) | ||
|
||
# Fill in the parameter block values | ||
cursor["inputStruct"]["a"]["aa"] = 1.0 | ||
cursor["inputStruct"]["a"]["bb"] = 2 | ||
cursor["inputStruct"]["b"] = 3 | ||
cursor["inputStruct"]["c"] = output_buffer | ||
|
||
cursor["inputStruct"]["nestParamBlock"]["aa"] = 4.0 | ||
cursor["inputStruct"]["nestParamBlock"]["bb"] = 5 | ||
cursor["inputStruct"]["nestParamBlock"]["nc"] = input_buffer | ||
|
||
# Dispatch a single thread | ||
encoder.dispatch(thread_count=[1, 1, 1]) | ||
|
||
# Submit the command buffer | ||
command_buffer.submit() | ||
|
||
# Read back the result | ||
result = output_buffer.to_numpy().view(np.float32)[0] | ||
|
||
# Verify the result | ||
assert result == 21.0, f"Expected 21.0, got {result}" | ||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__, "-v", "-s"]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
|
||
kaizhangNV marked this conversation as resolved.
Show resolved
Hide resolved
|
||
struct NestedStruct | ||
{ | ||
float aa; | ||
int bb; | ||
RWStructuredBuffer<float> nc; | ||
} | ||
|
||
struct MyStruct | ||
{ | ||
NestedStruct a; | ||
uint b; | ||
ParameterBlock<NestedStruct> nestParamBlock; | ||
kaizhangNV marked this conversation as resolved.
Show resolved
Hide resolved
|
||
RWStructuredBuffer<float> c; | ||
}; | ||
|
||
ParameterBlock<MyStruct> inputStruct; | ||
|
||
[shader("compute")] | ||
[numthreads(1, 1, 1)] | ||
void computeMain(uint3 tid: SV_DispatchThreadID) | ||
{ | ||
inputStruct.c[0] = inputStruct.a.aa + inputStruct.a.bb + inputStruct.b + | ||
inputStruct.nestParamBlock.aa + inputStruct.nestParamBlock.bb + inputStruct.nestParamBlock.nc[0]; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just keep this for developers using clangd as their language server?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm inclined to agree with @skallweitNV , but I don't know enough about clangd. What are the consequences of enabling this by default in practice, and if it's not enabled, what do users with clangd lose out on? Is it just completely broken for them?
If this is critical for a clangd user then we need it in some form - it's not good to expect people to have to modify cmake files to make their workflow work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm personally a clangd user, there won't be any side effects on enable this. Slang already supports this.
This flag does nothing but just export a compilation command database in json format, which is required by clangd.
One thing to notice that user who sync this repo can always add this single line themselves, but I just want to make this option out-of-box.
As far I know, ms language server is broken on cursor, so clangd is almost the only option for now. So that makes this option nicer to have.