@@ -772,9 +772,7 @@ lvk::Result lvk::compileShaderSlang(lvk::ShaderStage stage,
772772 }
773773
774774#if defined(LVK_WITH_SLANG) && LVK_WITH_SLANG
775- using namespace Slang ;
776-
777- ComPtr<slang::IGlobalSession> slangGlobalSession;
775+ Slang::ComPtr<slang::IGlobalSession> slangGlobalSession;
778776 if (SLANG_FAILED (slang::createGlobalSession (slangGlobalSession.writeRef ()))) {
779777 return Result (Result::Code::RuntimeError, " slang::createGlobalSession() failed" );
780778 }
@@ -810,14 +808,14 @@ lvk::Result lvk::compileShaderSlang(lvk::ShaderStage stage,
810808 .targetCount = 1 ,
811809 };
812810
813- ComPtr<slang::ISession> session;
811+ Slang:: ComPtr<slang::ISession> session;
814812 if (SLANG_FAILED (slangGlobalSession->createSession (sessionDesc, session.writeRef ()))) {
815813 return Result (Result::Code::RuntimeError, " slang::createSession() failed" );
816814 }
817815
818816 slang::IModule* slangModule = nullptr ;
819817 {
820- ComPtr<slang::IBlob> diagnosticBlob;
818+ Slang:: ComPtr<slang::IBlob> diagnosticBlob;
821819 slangModule = session->loadModuleFromSourceString (" " , " " , code, diagnosticBlob.writeRef ());
822820 if (diagnosticBlob) {
823821 LLOGW (" %s" , (const char *)diagnosticBlob->getBufferPointer ());
@@ -827,28 +825,32 @@ lvk::Result lvk::compileShaderSlang(lvk::ShaderStage stage,
827825 }
828826 }
829827
830- ComPtr<slang::IEntryPoint> entryPointVert;
831- ComPtr<slang::IEntryPoint> entryPointFrag;
832- if (SLANG_FAILED (slangModule->findEntryPointByName (" vertexMain" , entryPointVert.writeRef ()))) {
833- LVK_ASSERT_MSG (entryPointVert, " vertexMain() not found" );
834- return Result (Result::Code::RuntimeError, " vertexMain() not found" );
835- }
836- if (SLANG_FAILED (slangModule->findEntryPointByName (" fragmentMain" , entryPointFrag.writeRef ()))) {
837- LVK_ASSERT_MSG (entryPointFrag, " fragmentMain() not found" );
838- return Result (Result::Code::RuntimeError, " fragmentMain() not found" );
828+ Slang::ComPtr<slang::IEntryPoint> entryPoint;
829+ const char * entryPointName = [stage]() {
830+ switch (stage) {
831+ case lvk::Stage_Vert:
832+ return " vertexMain" ;
833+ case lvk::Stage_Frag:
834+ return " fragmentMain" ;
835+ case lvk::Stage_Task:
836+ return " taskMain" ;
837+ case lvk::Stage_Mesh:
838+ return " meshMain" ;
839+ }
840+ return " unknown shader type" ;
841+ }();
842+ if (SLANG_FAILED (slangModule->findEntryPointByName (entryPointName, entryPoint.writeRef ()))) {
843+ LVK_ASSERT_MSG (entryPoint, " Entry point %s() not found" , entryPointName);
844+ return Result (Result::Code::RuntimeError, " Entry point not found" );
839845 }
840846
841847 Slang::List<slang::IComponentType*> componentTypes;
842848 componentTypes.add (slangModule);
843- int entryPointCount = 0 ;
844- int vertexEntryPointIndex = entryPointCount++;
845- componentTypes.add (entryPointVert);
846- int fragmentEntryPointIndex = entryPointCount++;
847- componentTypes.add (entryPointFrag);
849+ componentTypes.add (entryPoint);
848850
849- ComPtr<slang::IComponentType> composedProgram;
851+ Slang:: ComPtr<slang::IComponentType> composedProgram;
850852 {
851- ComPtr<slang::IBlob> diagnosticBlob;
853+ Slang:: ComPtr<slang::IBlob> diagnosticBlob;
852854 SlangResult result = session->createCompositeComponentType (
853855 componentTypes.getBuffer (), componentTypes.getCount (), composedProgram.writeRef (), diagnosticBlob.writeRef ());
854856 if (diagnosticBlob) {
@@ -860,11 +862,11 @@ lvk::Result lvk::compileShaderSlang(lvk::ShaderStage stage,
860862 }
861863 }
862864
863- ComPtr<slang::IBlob> spirvCode;
865+ Slang:: ComPtr<slang::IBlob> spirvCode;
864866 {
865- ComPtr<slang::IBlob> diagnosticBlob;
866- const int entryPoint = stage == lvk::Stage_Vert ? vertexEntryPointIndex : fragmentEntryPointIndex ;
867- SlangResult result = composedProgram->getEntryPointCode (entryPoint , 0 , spirvCode.writeRef (), diagnosticBlob.writeRef ());
867+ Slang:: ComPtr<slang::IBlob> diagnosticBlob;
868+ const int entryPointIndex = 0 ;
869+ const SlangResult result = composedProgram->getEntryPointCode (entryPointIndex , 0 , spirvCode.writeRef (), diagnosticBlob.writeRef ());
868870 if (diagnosticBlob) {
869871 LLOGW (" %s\n " , (const char *)diagnosticBlob->getBufferPointer ());
870872 }
0 commit comments