Skip to content

Commit 4581549

Browse files
Added task and mesh shaders for Slang
1 parent e9a70d7 commit 4581549

File tree

1 file changed

+27
-25
lines changed

1 file changed

+27
-25
lines changed

lvk/vulkan/VulkanUtils.cpp

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -772,9 +772,7 @@ lvk::Result lvk::compileShaderSlang(lvk::ShaderStage stage,
772772
}
773773

774774
#if defined(LVK_WITH_SLANG) && LVK_WITH_SLANG
775-
using namespace Slang;
776-
777-
ComPtr<slang::IGlobalSession> slangGlobalSession;
775+
Slang::ComPtr<slang::IGlobalSession> slangGlobalSession;
778776
if (SLANG_FAILED(slang::createGlobalSession(slangGlobalSession.writeRef()))) {
779777
return Result(Result::Code::RuntimeError, "slang::createGlobalSession() failed");
780778
}
@@ -810,14 +808,14 @@ lvk::Result lvk::compileShaderSlang(lvk::ShaderStage stage,
810808
.targetCount = 1,
811809
};
812810

813-
ComPtr<slang::ISession> session;
811+
Slang::ComPtr<slang::ISession> session;
814812
if (SLANG_FAILED(slangGlobalSession->createSession(sessionDesc, session.writeRef()))) {
815813
return Result(Result::Code::RuntimeError, "slang::createSession() failed");
816814
}
817815

818816
slang::IModule* slangModule = nullptr;
819817
{
820-
ComPtr<slang::IBlob> diagnosticBlob;
818+
Slang::ComPtr<slang::IBlob> diagnosticBlob;
821819
slangModule = session->loadModuleFromSourceString("", "", code, diagnosticBlob.writeRef());
822820
if (diagnosticBlob) {
823821
LLOGW("%s", (const char*)diagnosticBlob->getBufferPointer());
@@ -827,28 +825,32 @@ lvk::Result lvk::compileShaderSlang(lvk::ShaderStage stage,
827825
}
828826
}
829827

830-
ComPtr<slang::IEntryPoint> entryPointVert;
831-
ComPtr<slang::IEntryPoint> entryPointFrag;
832-
if (SLANG_FAILED(slangModule->findEntryPointByName("vertexMain", entryPointVert.writeRef()))) {
833-
LVK_ASSERT_MSG(entryPointVert, "vertexMain() not found");
834-
return Result(Result::Code::RuntimeError, "vertexMain() not found");
835-
}
836-
if (SLANG_FAILED(slangModule->findEntryPointByName("fragmentMain", entryPointFrag.writeRef()))) {
837-
LVK_ASSERT_MSG(entryPointFrag, "fragmentMain() not found");
838-
return Result(Result::Code::RuntimeError, "fragmentMain() not found");
828+
Slang::ComPtr<slang::IEntryPoint> entryPoint;
829+
const char* entryPointName = [stage]() {
830+
switch (stage) {
831+
case lvk::Stage_Vert:
832+
return "vertexMain";
833+
case lvk::Stage_Frag:
834+
return "fragmentMain";
835+
case lvk::Stage_Task:
836+
return "taskMain";
837+
case lvk::Stage_Mesh:
838+
return "meshMain";
839+
}
840+
return "unknown shader type";
841+
}();
842+
if (SLANG_FAILED(slangModule->findEntryPointByName(entryPointName, entryPoint.writeRef()))) {
843+
LVK_ASSERT_MSG(entryPoint, "Entry point %s() not found", entryPointName);
844+
return Result(Result::Code::RuntimeError, "Entry point not found");
839845
}
840846

841847
Slang::List<slang::IComponentType*> componentTypes;
842848
componentTypes.add(slangModule);
843-
int entryPointCount = 0;
844-
int vertexEntryPointIndex = entryPointCount++;
845-
componentTypes.add(entryPointVert);
846-
int fragmentEntryPointIndex = entryPointCount++;
847-
componentTypes.add(entryPointFrag);
849+
componentTypes.add(entryPoint);
848850

849-
ComPtr<slang::IComponentType> composedProgram;
851+
Slang::ComPtr<slang::IComponentType> composedProgram;
850852
{
851-
ComPtr<slang::IBlob> diagnosticBlob;
853+
Slang::ComPtr<slang::IBlob> diagnosticBlob;
852854
SlangResult result = session->createCompositeComponentType(
853855
componentTypes.getBuffer(), componentTypes.getCount(), composedProgram.writeRef(), diagnosticBlob.writeRef());
854856
if (diagnosticBlob) {
@@ -860,11 +862,11 @@ lvk::Result lvk::compileShaderSlang(lvk::ShaderStage stage,
860862
}
861863
}
862864

863-
ComPtr<slang::IBlob> spirvCode;
865+
Slang::ComPtr<slang::IBlob> spirvCode;
864866
{
865-
ComPtr<slang::IBlob> diagnosticBlob;
866-
const int entryPoint = stage == lvk::Stage_Vert ? vertexEntryPointIndex : fragmentEntryPointIndex;
867-
SlangResult result = composedProgram->getEntryPointCode(entryPoint, 0, spirvCode.writeRef(), diagnosticBlob.writeRef());
867+
Slang::ComPtr<slang::IBlob> diagnosticBlob;
868+
const int entryPointIndex = 0;
869+
const SlangResult result = composedProgram->getEntryPointCode(entryPointIndex, 0, spirvCode.writeRef(), diagnosticBlob.writeRef());
868870
if (diagnosticBlob) {
869871
LLOGW("%s\n", (const char*)diagnosticBlob->getBufferPointer());
870872
}

0 commit comments

Comments
 (0)