From 48f9ed8b08be974f4e463ef38136c8f23513b2cf Mon Sep 17 00:00:00 2001 From: Arcady Goldmints-Orlov Date: Fri, 6 Oct 2023 17:50:27 -0400 Subject: [PATCH] spirv: only set LocalSizeId mode when necessary SPIR-V 1.6 added the LocalSizeId execution mode that allows using spec constants for setting the work-group size, however it does not deprecate the LocalSize mode. This change causes the LocalSizeId mode to only be used when at least one of the workgroup size is actually specified with a spec constant. Fixes #3200 --- SPIRV/GlslangToSpv.cpp | 37 ++++--- .../hlsl.structcopylogical.comp.out | 100 +++++++++--------- 2 files changed, 73 insertions(+), 64 deletions(-) diff --git a/SPIRV/GlslangToSpv.cpp b/SPIRV/GlslangToSpv.cpp index 6eae76d6..576c680f 100755 --- a/SPIRV/GlslangToSpv.cpp +++ b/SPIRV/GlslangToSpv.cpp @@ -1741,23 +1741,31 @@ TGlslangToSpvTraverser::TGlslangToSpvTraverser(unsigned int spvVersion, } break; - case EShLangCompute: + case EShLangCompute: { builder.addCapability(spv::CapabilityShader); - if (glslangIntermediate->getSpv().spv >= glslang::EShTargetSpv_1_6) { - std::vector dimConstId; - for (int dim = 0; dim < 3; ++dim) { - bool specConst = (glslangIntermediate->getLocalSizeSpecId(dim) != glslang::TQualifier::layoutNotSet); - dimConstId.push_back(builder.makeUintConstant(glslangIntermediate->getLocalSize(dim), specConst)); - if (specConst) { - builder.addDecoration(dimConstId.back(), spv::DecorationSpecId, - glslangIntermediate->getLocalSizeSpecId(dim)); + bool needSizeId = false; + for (int dim = 0; dim < 3; ++dim) { + if ((glslangIntermediate->getLocalSizeSpecId(dim) != glslang::TQualifier::layoutNotSet)) { + needSizeId = true; + break; } - } - builder.addExecutionModeId(shaderEntry, spv::ExecutionModeLocalSizeId, dimConstId); + } + if (glslangIntermediate->getSpv().spv >= glslang::EShTargetSpv_1_6 && needSizeId) { + std::vector dimConstId; + for (int dim = 0; dim < 3; ++dim) { + bool specConst = (glslangIntermediate->getLocalSizeSpecId(dim) != glslang::TQualifier::layoutNotSet); + dimConstId.push_back(builder.makeUintConstant(glslangIntermediate->getLocalSize(dim), specConst)); + if (specConst) { + builder.addDecoration(dimConstId.back(), spv::DecorationSpecId, + glslangIntermediate->getLocalSizeSpecId(dim)); + needSizeId = true; + } + } + builder.addExecutionModeId(shaderEntry, spv::ExecutionModeLocalSizeId, dimConstId); } else { - builder.addExecutionMode(shaderEntry, spv::ExecutionModeLocalSize, glslangIntermediate->getLocalSize(0), - glslangIntermediate->getLocalSize(1), - glslangIntermediate->getLocalSize(2)); + builder.addExecutionMode(shaderEntry, spv::ExecutionModeLocalSize, glslangIntermediate->getLocalSize(0), + glslangIntermediate->getLocalSize(1), + glslangIntermediate->getLocalSize(2)); } if (glslangIntermediate->getLayoutDerivativeModeNone() == glslang::LayoutDerivativeGroupQuads) { builder.addCapability(spv::CapabilityComputeDerivativeGroupQuadsNV); @@ -1769,6 +1777,7 @@ TGlslangToSpvTraverser::TGlslangToSpvTraverser(unsigned int spvVersion, builder.addExtension(spv::E_SPV_NV_compute_shader_derivatives); } break; + } case EShLangTessEvaluation: case EShLangTessControl: builder.addCapability(spv::CapabilityTessellation); diff --git a/Test/baseResults/hlsl.structcopylogical.comp.out b/Test/baseResults/hlsl.structcopylogical.comp.out index 31206566..a9b849be 100644 --- a/Test/baseResults/hlsl.structcopylogical.comp.out +++ b/Test/baseResults/hlsl.structcopylogical.comp.out @@ -248,17 +248,17 @@ local_size = (128, 1, 1) Capability Shader 1: ExtInstImport "GLSL.std.450" MemoryModel Logical GLSL450 - EntryPoint GLCompute 4 "main" 17 32 57 74 - ExecutionModeId 4 LocalSizeId 7 8 8 + EntryPoint GLCompute 4 "main" 16 32 57 74 + ExecutionMode 4 LocalSize 128 1 1 Source HLSL 500 Name 4 "main" - Name 12 "@main(u1;" - Name 11 "id" - Name 14 "MyStruct" - MemberName 14(MyStruct) 0 "a" - MemberName 14(MyStruct) 1 "b" - MemberName 14(MyStruct) 2 "c" - Name 17 "s" + Name 10 "@main(u1;" + Name 9 "id" + Name 12 "MyStruct" + MemberName 12(MyStruct) 0 "a" + MemberName 12(MyStruct) 1 "b" + MemberName 12(MyStruct) 2 "c" + Name 16 "s" Name 25 "count" Name 26 "MyStruct" MemberName 26(MyStruct) 0 "a" @@ -300,20 +300,20 @@ local_size = (128, 1, 1) 2: TypeVoid 3: TypeFunction 2 6: TypeInt 32 0 - 7: 6(int) Constant 128 - 8: 6(int) Constant 1 - 9: TypePointer Function 6(int) - 10: TypeFunction 2 9(ptr) - 14(MyStruct): TypeStruct 6(int) 6(int) 6(int) - 15: TypeArray 14(MyStruct) 7 - 16: TypePointer Workgroup 15 - 17(s): 16(ptr) Variable Workgroup - 18: TypeInt 32 1 - 19: 18(int) Constant 0 + 7: TypePointer Function 6(int) + 8: TypeFunction 2 7(ptr) + 12(MyStruct): TypeStruct 6(int) 6(int) 6(int) + 13: 6(int) Constant 128 + 14: TypeArray 12(MyStruct) 13 + 15: TypePointer Workgroup 14 + 16(s): 15(ptr) Variable Workgroup + 17: TypeInt 32 1 + 18: 17(int) Constant 0 + 19: 6(int) Constant 1 20: 6(int) Constant 2 21: 6(int) Constant 3 - 22:14(MyStruct) ConstantComposite 8 20 21 - 23: TypePointer Workgroup 14(MyStruct) + 22:12(MyStruct) ConstantComposite 19 20 21 + 23: TypePointer Workgroup 12(MyStruct) 26(MyStruct): TypeStruct 6(int) 6(int) 6(int) 27: TypeRuntimeArray 26(MyStruct) 28(MyStructs): TypeStruct 6(int) 27 @@ -322,64 +322,64 @@ local_size = (128, 1, 1) 31: TypePointer StorageBuffer 30(sb) 32(sb): 31(ptr) Variable StorageBuffer 33: TypePointer StorageBuffer 6(int) - 36: TypePointer Function 14(MyStruct) + 36: TypePointer Function 12(MyStruct) 40: TypeBool - 47: 18(int) Constant 1 + 47: 17(int) Constant 1 49: TypePointer StorageBuffer 26(MyStruct) 54: TypeRuntimeArray 26(MyStruct) 55(o): TypeStruct 54 56: TypePointer StorageBuffer 55(o) 57(o): 56(ptr) Variable StorageBuffer 61: 6(int) Constant 0 - 67: 18(int) Constant 2 + 67: 17(int) Constant 2 73: TypePointer Input 6(int) 74(id): 73(ptr) Variable Input 4(main): 2 Function None 3 5: Label - 72(id): 9(ptr) Variable Function - 76(param): 9(ptr) Variable Function + 72(id): 7(ptr) Variable Function + 76(param): 7(ptr) Variable Function 75: 6(int) Load 74(id) Store 72(id) 75 77: 6(int) Load 72(id) Store 76(param) 77 - 78: 2 FunctionCall 12(@main(u1;) 76(param) + 78: 2 FunctionCall 10(@main(u1;) 76(param) Return FunctionEnd - 12(@main(u1;): 2 Function None 10 - 11(id): 9(ptr) FunctionParameter - 13: Label - 25(count): 9(ptr) Variable Function + 10(@main(u1;): 2 Function None 8 + 9(id): 7(ptr) FunctionParameter + 11: Label + 25(count): 7(ptr) Variable Function 37(ms): 36(ptr) Variable Function - 24: 23(ptr) AccessChain 17(s) 19 + 24: 23(ptr) AccessChain 16(s) 18 Store 24 22 - 34: 33(ptr) AccessChain 32(sb) 19 19 19 + 34: 33(ptr) AccessChain 32(sb) 18 18 18 35: 6(int) Load 34 Store 25(count) 35 - 38: 6(int) Load 11(id) + 38: 6(int) Load 9(id) 39: 6(int) Load 25(count) 41: 40(bool) UGreaterThan 38 39 - 42: 6(int) Load 11(id) + 42: 6(int) Load 9(id) 43: 6(int) Load 25(count) 44: 6(int) ISub 42 43 - 45: 23(ptr) AccessChain 17(s) 44 - 46:14(MyStruct) Load 45 - 48: 6(int) Load 11(id) - 50: 49(ptr) AccessChain 32(sb) 19 19 47 48 + 45: 23(ptr) AccessChain 16(s) 44 + 46:12(MyStruct) Load 45 + 48: 6(int) Load 9(id) + 50: 49(ptr) AccessChain 32(sb) 18 18 47 48 51:26(MyStruct) Load 50 - 52:14(MyStruct) CopyLogical 51 - 53:14(MyStruct) Select 41 46 52 + 52:12(MyStruct) CopyLogical 51 + 53:12(MyStruct) Select 41 46 52 Store 37(ms) 53 - 58: 33(ptr) AccessChain 57(o) 19 19 19 - 59: 9(ptr) AccessChain 37(ms) 19 + 58: 33(ptr) AccessChain 57(o) 18 18 18 + 59: 7(ptr) AccessChain 37(ms) 18 60: 6(int) Load 59 - 62: 6(int) AtomicIAdd 58 8 61 60 - 63: 33(ptr) AccessChain 57(o) 19 19 47 - 64: 9(ptr) AccessChain 37(ms) 47 + 62: 6(int) AtomicIAdd 58 19 61 60 + 63: 33(ptr) AccessChain 57(o) 18 18 47 + 64: 7(ptr) AccessChain 37(ms) 47 65: 6(int) Load 64 - 66: 6(int) AtomicIAdd 63 8 61 65 - 68: 33(ptr) AccessChain 57(o) 19 19 67 - 69: 9(ptr) AccessChain 37(ms) 67 + 66: 6(int) AtomicIAdd 63 19 61 65 + 68: 33(ptr) AccessChain 57(o) 18 18 67 + 69: 7(ptr) AccessChain 37(ms) 67 70: 6(int) Load 69 - 71: 6(int) AtomicIAdd 68 8 61 70 + 71: 6(int) AtomicIAdd 68 19 61 70 Return FunctionEnd