diff --git a/SPIRV/GlslangToSpv.cpp b/SPIRV/GlslangToSpv.cpp index 21a04b0e..81f7cc1c 100755 --- a/SPIRV/GlslangToSpv.cpp +++ b/SPIRV/GlslangToSpv.cpp @@ -98,6 +98,7 @@ protected: spv::Id convertGlslangToSpvType(const glslang::TType& type, glslang::TLayoutPacking, const glslang::TQualifier&); spv::Id makeArraySizeId(const glslang::TArraySizes&, int dim); spv::Id accessChainLoad(const glslang::TType& type); + void accessChainStore(const glslang::TType& type, spv::Id rvalue); glslang::TLayoutPacking getExplicitLayout(const glslang::TType& type) const; int getArrayStride(const glslang::TType& arrayType, glslang::TLayoutPacking, glslang::TLayoutMatrix); int getMatrixStride(const glslang::TType& matrixType, glslang::TLayoutPacking, glslang::TLayoutMatrix); @@ -805,7 +806,7 @@ bool TGlslangToSpvTraverser::visitBinary(glslang::TVisit /* visit */, glslang::T // store the result builder.setAccessChain(lValue); - builder.accessChainStore(rValue); + accessChainStore(node->getType(), rValue); // assignments are expressions having an rValue after they are evaluated... builder.clearAccessChain(); @@ -1904,12 +1905,55 @@ spv::Id TGlslangToSpvTraverser::accessChainLoad(const glslang::TType& type) spv::Id loadedId = builder.accessChainLoad(TranslatePrecisionDecoration(type), nominalTypeId); // Need to convert to abstract types when necessary - if (builder.isScalarType(nominalTypeId) && type.getBasicType() == glslang::EbtBool && nominalTypeId != builder.makeBoolType()) - loadedId = builder.createBinOp(spv::OpINotEqual, builder.makeBoolType(), loadedId, builder.makeUintConstant(0)); + if (type.getBasicType() == glslang::EbtBool) { + if (builder.isScalarType(nominalTypeId)) { + // Conversion for bool + spv::Id boolType = builder.makeBoolType(); + if (nominalTypeId != boolType) + loadedId = builder.createBinOp(spv::OpINotEqual, boolType, loadedId, builder.makeUintConstant(0)); + } else if (builder.isVectorType(nominalTypeId)) { + // Conversion for bvec + int vecSize = builder.getNumTypeComponents(nominalTypeId); + spv::Id bvecType = builder.makeVectorType(builder.makeBoolType(), vecSize); + if (nominalTypeId != bvecType) + loadedId = builder.createBinOp(spv::OpINotEqual, bvecType, loadedId, makeSmearedConstant(builder.makeUintConstant(0), vecSize)); + } + } return loadedId; } +// Wrap the builder's accessChainStore to: +// - do conversion of concrete to abstract type +void TGlslangToSpvTraverser::accessChainStore(const glslang::TType& type, spv::Id rvalue) +{ + // Need to convert to abstract types when necessary + if (type.getBasicType() == glslang::EbtBool) { + spv::Id nominalTypeId = builder.accessChainGetInferredType(); + + if (builder.isScalarType(nominalTypeId)) { + // Conversion for bool + spv::Id boolType = builder.makeBoolType(); + if (nominalTypeId != boolType) { + spv::Id zero = builder.makeUintConstant(0); + spv::Id one = builder.makeUintConstant(1); + rvalue = builder.createTriOp(spv::OpSelect, nominalTypeId, rvalue, one, zero); + } + } else if (builder.isVectorType(nominalTypeId)) { + // Conversion for bvec + int vecSize = builder.getNumTypeComponents(nominalTypeId); + spv::Id bvecType = builder.makeVectorType(builder.makeBoolType(), vecSize); + if (nominalTypeId != bvecType) { + spv::Id zero = makeSmearedConstant(builder.makeUintConstant(0), vecSize); + spv::Id one = makeSmearedConstant(builder.makeUintConstant(1), vecSize); + rvalue = builder.createTriOp(spv::OpSelect, nominalTypeId, rvalue, one, zero); + } + } + } + + builder.accessChainStore(rvalue); +} + // Decide whether or not this type should be // decorated with offsets and strides, and if so // whether std140 or std430 rules should be applied. @@ -2470,7 +2514,7 @@ spv::Id TGlslangToSpvTraverser::handleUserFunctionCall(const glslang::TIntermAgg if (qualifiers[a] == glslang::EvqOut || qualifiers[a] == glslang::EvqInOut) { spv::Id copy = builder.createLoad(spvArgs[a]); builder.setAccessChain(lValues[lValueCount]); - builder.accessChainStore(copy); + accessChainStore(glslangArgs[a]->getAsTyped()->getType(), copy); } ++lValueCount; } diff --git a/Test/baseResults/spv.boolInBlock.frag.out b/Test/baseResults/spv.boolInBlock.frag.out new file mode 100644 index 00000000..e49d0678 --- /dev/null +++ b/Test/baseResults/spv.boolInBlock.frag.out @@ -0,0 +1,124 @@ +spv.boolInBlock.frag +Warning, version 450 is not yet complete; most version-specific features are present, but some are missing. + + +Linked fragment stage: + + +// Module Version 10000 +// Generated by (magic number): 80001 +// Id's are bound by 72 + + Capability Shader + 1: ExtInstImport "GLSL.std.450" + MemoryModel Logical GLSL450 + EntryPoint Fragment 4 "main" + ExecutionMode 4 OriginUpperLeft + Source GLSL 450 + Name 4 "main" + Name 14 "foo(vb4;vb2;" + Name 12 "paramb4" + Name 13 "paramb2" + Name 17 "b1" + Name 24 "Buffer" + MemberName 24(Buffer) 0 "b2" + Name 26 "" + Name 39 "Uniform" + MemberName 39(Uniform) 0 "b4" + Name 41 "" + Name 62 "param" + Name 67 "param" + MemberDecorate 24(Buffer) 0 Offset 0 + Decorate 24(Buffer) BufferBlock + Decorate 26 DescriptorSet 0 + Decorate 26 Binding 1 + MemberDecorate 39(Uniform) 0 Offset 0 + Decorate 39(Uniform) Block + Decorate 41 DescriptorSet 0 + Decorate 41 Binding 0 + 2: TypeVoid + 3: TypeFunction 2 + 6: TypeBool + 7: TypeVector 6(bool) 4 + 8: TypePointer Function 7(bvec4) + 9: TypeVector 6(bool) 2 + 10: TypePointer Function 9(bvec2) + 11: TypeFunction 2 8(ptr) 10(ptr) + 16: TypePointer Function 6(bool) + 22: TypeInt 32 0 + 23: TypeVector 22(int) 2 + 24(Buffer): TypeStruct 23(ivec2) + 25: TypePointer Uniform 24(Buffer) + 26: 25(ptr) Variable Uniform + 27: TypeInt 32 1 + 28: 27(int) Constant 0 + 29: 6(bool) ConstantFalse + 30: 9(bvec2) ConstantComposite 29 29 + 31: 22(int) Constant 0 + 32: 23(ivec2) ConstantComposite 31 31 + 33: 22(int) Constant 1 + 34: 23(ivec2) ConstantComposite 33 33 + 36: TypePointer Uniform 23(ivec2) + 38: TypeVector 22(int) 4 + 39(Uniform): TypeStruct 38(ivec4) + 40: TypePointer Uniform 39(Uniform) + 41: 40(ptr) Variable Uniform + 42: TypePointer Uniform 38(ivec4) + 65: 38(ivec4) ConstantComposite 31 31 31 31 + 4(main): 2 Function None 3 + 5: Label + 62(param): 8(ptr) Variable Function + 67(param): 10(ptr) Variable Function + 35: 23(ivec2) Select 30 34 32 + 37: 36(ptr) AccessChain 26 28 + Store 37 35 + 43: 42(ptr) AccessChain 41 28 + 44: 38(ivec4) Load 43 + 45: 22(int) CompositeExtract 44 2 + 46: 6(bool) INotEqual 45 31 + SelectionMerge 48 None + BranchConditional 46 47 48 + 47: Label + 49: 42(ptr) AccessChain 41 28 + 50: 38(ivec4) Load 49 + 51: 22(int) CompositeExtract 50 0 + 52: 6(bool) INotEqual 51 31 + 53: 9(bvec2) CompositeConstruct 52 52 + 54: 23(ivec2) Select 53 34 32 + 55: 36(ptr) AccessChain 26 28 + Store 55 54 + Branch 48 + 48: Label + 56: 36(ptr) AccessChain 26 28 + 57: 23(ivec2) Load 56 + 58: 22(int) CompositeExtract 57 0 + 59: 6(bool) INotEqual 58 31 + SelectionMerge 61 None + BranchConditional 59 60 61 + 60: Label + 63: 42(ptr) AccessChain 41 28 + 64: 38(ivec4) Load 63 + 66: 7(bvec4) INotEqual 64 65 + Store 62(param) 66 + 68: 2 FunctionCall 14(foo(vb4;vb2;) 62(param) 67(param) + 69: 9(bvec2) Load 67(param) + 70: 23(ivec2) Select 69 34 32 + 71: 36(ptr) AccessChain 26 28 + Store 71 70 + Branch 61 + 61: Label + Return + FunctionEnd +14(foo(vb4;vb2;): 2 Function None 11 + 12(paramb4): 8(ptr) FunctionParameter + 13(paramb2): 10(ptr) FunctionParameter + 15: Label + 17(b1): 16(ptr) Variable Function + 18: 7(bvec4) Load 12(paramb4) + 19: 6(bool) CompositeExtract 18 2 + Store 17(b1) 19 + 20: 6(bool) Load 17(b1) + 21: 9(bvec2) CompositeConstruct 20 20 + Store 13(paramb2) 21 + Return + FunctionEnd diff --git a/Test/spv.boolInBlock.frag b/Test/spv.boolInBlock.frag new file mode 100644 index 00000000..96b2de0e --- /dev/null +++ b/Test/spv.boolInBlock.frag @@ -0,0 +1,26 @@ +#version 450 + +layout(binding = 0, std140) uniform Uniform +{ + bvec4 b4; +}; + +layout(binding = 1, std430) buffer Buffer +{ + bvec2 b2; +}; + +void foo(bvec4 paramb4, out bvec2 paramb2) +{ + bool b1 = paramb4.z; + paramb2 = bvec2(b1); +} + +void main() +{ + b2 = bvec2(0.0); + if (b4.z) + b2 = bvec2(b4.x); + if (b2.x) + foo(b4, b2); +} \ No newline at end of file diff --git a/Test/test-spirv-list b/Test/test-spirv-list index 51e2e8f3..477067f4 100644 --- a/Test/test-spirv-list +++ b/Test/test-spirv-list @@ -35,6 +35,7 @@ spv.always-discard.frag spv.always-discard2.frag spv.bitCast.frag spv.bool.vert +spv.boolInBlock.frag spv.branch-return.vert spv.conditionalDiscard.frag spv.conversion.frag