Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Param block support for Metal #208

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ elseif(UNIX)
set(SGL_ORIGIN "$ORIGIN")
endif()

set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this

Copy link
Contributor Author

@kaizhangNV kaizhangNV Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just keep this for developers using clangd as their language server?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm inclined to agree with @skallweitNV , but I don't know enough about clangd. What are the consequences of enabling this by default in practice, and if it's not enabled, what do users with clangd lose out on? Is it just completely broken for them?

If this is critical for a clangd user then we need it in some form - it's not good to expect people to have to modify cmake files to make their workflow work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm personally a clangd user, there won't be any side effects on enable this. Slang already supports this.
This flag does nothing but just export a compilation command database in json format, which is required by clangd.

One thing to notice that user who sync this repo can always add this single line themselves, but I just want to make this option out-of-box.

As far I know, ms language server is broken on cursor, so clangd is almost the only option for now. So that makes this option nicer to have.


# TODO we should probably set install RPATH on individual targets instead
set(CMAKE_INSTALL_RPATH "${SGL_ORIGIN};${SGL_ORIGIN}/../lib")

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
CMAKE_PRESET = {
"windows": "windows-msvc",
"linux": "linux-gcc",
"macos": "macos-clang",
"macos": "macos-arm64-clang",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you change this? does the macos-clang preset build for x64 only?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no macos-clang preset, only macos-arm64-clang and macos-x64-clang, and I think we don't care macos-x64-clang.

}[PLATFORM]


Expand Down
4 changes: 3 additions & 1 deletion src/sgl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,10 @@ Device::Device(const DeviceDesc& desc)
if (m_desc.type == DeviceType::automatic) {
#if SGL_WINDOWS
m_desc.type = DeviceType::d3d12;
#elif SGL_LINUX || SGL_MACOS
#elif SGL_LINUX
m_desc.type = DeviceType::vulkan;
#elif SGL_MACOS
m_desc.type = DeviceType::metal;
#endif
}

Expand Down
15 changes: 10 additions & 5 deletions src/sgl/device/shader_cursor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,16 @@ ShaderCursor ShaderCursor::dereference() const
switch ((TypeReflection::Kind)m_type_layout->getKind()) {
case TypeReflection::Kind::constant_buffer:
case TypeReflection::Kind::parameter_block:
return ShaderCursor(m_shader_object->get_object(m_offset));
{
ShaderCursor d = ShaderCursor(m_shader_object->get_object(m_offset));
#if SGL_MACOS
d.m_type_layout = m_shader_object->get_slang_session()->getTypeLayout(
m_type_layout->getElementTypeLayout()->getType(),
0,
slang::LayoutRules::MetalArgumentBufferTier2);
#endif
return d;
}
default:
return {};
}
Expand Down Expand Up @@ -155,10 +164,6 @@ ShaderCursor ShaderCursor::find_field(std::string_view name) const
//
case TypeReflection::Kind::constant_buffer:
case TypeReflection::Kind::parameter_block: {
// We basically need to "dereference" the current cursor
// to go from a pointer to a constant buffer to a pointer
// to the *contents* of the constant buffer.
//
ShaderCursor d = dereference();
return d.find_field(name);
}
Expand Down
5 changes: 5 additions & 0 deletions src/sgl/device/shader_object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ void ShaderObject::get_cuda_interop_buffers(std::vector<ref<cuda::InteropBuffer>
.insert(cuda_interop_buffers.end(), m_cuda_interop_buffers.begin(), m_cuda_interop_buffers.end());
}

slang::ISession* ShaderObject::get_slang_session() const
{
return m_device->slang_session()->get_slang_session();
}

//
// TransientShaderObject
//
Expand Down
1 change: 1 addition & 0 deletions src/sgl/device/shader_object.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class SGL_API ShaderObject : public Object {

gfx::IShaderObject* gfx_shader_object() const { return m_shader_object; }

slang::ISession* get_slang_session() const;
protected:
ref<Device> m_device;
gfx::IShaderObject* m_shader_object;
Expand Down
80 changes: 80 additions & 0 deletions src/sgl/device/tests/test_parameterBlock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import numpy as np
import sgl
import sys
from pathlib import Path

sys.path.append(str(Path(__file__).parent))
import sglhelpers as helpers

# The shader code is in test_parameterBlock.slang, fill in the parameter block
# 'inputStruct' for the field 'a = 1.0', b = 2, c = 3, then read back the result
# 'd' from the buffer and assert it equals 6.0f. The test will only launch 1 thread
# to test the ParameterBlock binding.
@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES)
def test_parameter_block(device_type: sgl.DeviceType):

# Skip this test on Metal devices
if device_type == sgl.DeviceType.metal:
pytest.skip("Skipping parameter block test until https://github.com/shader-slang/slang/pull/6577 is merged")

# Create device
print(f"Testing {device_type}")

device = helpers.get_device(type=device_type)

# Load the shader program
program = device.load_program("test_parameterBlock.slang", ["computeMain"])
kernel = device.create_compute_kernel(program)

# Create a buffer for the output
output_buffer = device.create_buffer(
element_count=1, # Only need one element as we're only launching one thread
struct_size=1024, # float is 4 bytes
usage=sgl.ResourceUsage.unordered_access | sgl.ResourceUsage.shader_resource,
)

input_buffer = device.create_buffer(
element_count=1, # Only need one element as we're only launching one thread
struct_size=4, # float is 4 bytes
usage=sgl.ResourceUsage.unordered_access | sgl.ResourceUsage.shader_resource,
data=np.array([6.0], dtype=np.float32),
)

# Create a command buffer
command_buffer = device.create_command_buffer()

# Encode compute commands
with command_buffer.encode_compute_commands() as encoder:
# Bind the pipeline
shader_object = encoder.bind_pipeline(kernel.pipeline)

# Create a shader cursor for the parameter block
cursor = sgl.ShaderCursor(shader_object)

# Fill in the parameter block values
cursor["inputStruct"]["a"]["aa"] = 1.0
cursor["inputStruct"]["a"]["bb"] = 2
cursor["inputStruct"]["b"] = 3
cursor["inputStruct"]["c"] = output_buffer

cursor["inputStruct"]["nestParamBlock"]["aa"] = 4.0
cursor["inputStruct"]["nestParamBlock"]["bb"] = 5
cursor["inputStruct"]["nestParamBlock"]["nc"] = input_buffer

# Dispatch a single thread
encoder.dispatch(thread_count=[1, 1, 1])

# Submit the command buffer
command_buffer.submit()

# Read back the result
result = output_buffer.to_numpy().view(np.float32)[0]

# Verify the result
assert result == 21.0, f"Expected 21.0, got {result}"

if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])
25 changes: 25 additions & 0 deletions src/sgl/device/tests/test_parameterBlock.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@

struct NestedStruct
{
float aa;
int bb;
RWStructuredBuffer<float> nc;
}

struct MyStruct
{
NestedStruct a;
uint b;
ParameterBlock<NestedStruct> nestParamBlock;
RWStructuredBuffer<float> c;
};

ParameterBlock<MyStruct> inputStruct;

[shader("compute")]
[numthreads(1, 1, 1)]
void computeMain(uint3 tid: SV_DispatchThreadID)
{
inputStruct.c[0] = inputStruct.a.aa + inputStruct.a.bb + inputStruct.b +
inputStruct.nestParamBlock.aa + inputStruct.nestParamBlock.bb + inputStruct.nestParamBlock.nc[0];
}
Loading