diff --git a/SPIRV/GlslangToSpv.cpp b/SPIRV/GlslangToSpv.cpp index 1a271e27..8af35997 100644 --- a/SPIRV/GlslangToSpv.cpp +++ b/SPIRV/GlslangToSpv.cpp @@ -3789,10 +3789,11 @@ bool TGlslangToSpvTraverser::visitSelection(glslang::TVisit /* visit */, glslang // Find a way of executing both sides and selecting the right result. const auto executeBothSides = [&]() -> void { // execute both sides + spv::Id resultType = convertGlslangToSpvType(node->getType()); node->getTrueBlock()->traverse(this); spv::Id trueValue = accessChainLoad(node->getTrueBlock()->getAsTyped()->getType()); node->getFalseBlock()->traverse(this); - spv::Id falseValue = accessChainLoad(node->getTrueBlock()->getAsTyped()->getType()); + spv::Id falseValue = accessChainLoad(node->getFalseBlock()->getAsTyped()->getType()); builder.setLine(node->getLoc().line, node->getLoc().getFilename()); @@ -3801,8 +3802,8 @@ bool TGlslangToSpvTraverser::visitSelection(glslang::TVisit /* visit */, glslang return; // emit code to select between trueValue and falseValue - - // see if OpSelect can handle it + // see if OpSelect can handle the result type, and that the SPIR-V types + // of the inputs match the result type. if (isOpSelectable()) { // Emit OpSelect for this selection. @@ -3814,10 +3815,18 @@ bool TGlslangToSpvTraverser::visitSelection(glslang::TVisit /* visit */, glslang builder.getNumComponents(trueValue))); } + // If the types do not match, it is because of mismatched decorations on aggregates. + // Since isOpSelectable only lets us get here for SPIR-V >= 1.4, we can use OpCopyObject + // to get matching types. + if (builder.getTypeId(trueValue) != resultType) { + trueValue = builder.createUnaryOp(spv::OpCopyLogical, resultType, trueValue); + } + if (builder.getTypeId(falseValue) != resultType) { + falseValue = builder.createUnaryOp(spv::OpCopyLogical, resultType, falseValue); + } + // OpSelect - result = builder.createTriOp(spv::OpSelect, - convertGlslangToSpvType(node->getType()), condition, - trueValue, falseValue); + result = builder.createTriOp(spv::OpSelect, resultType, condition, trueValue, falseValue); builder.clearAccessChain(); builder.setAccessChainRValue(result); @@ -3825,7 +3834,7 @@ bool TGlslangToSpvTraverser::visitSelection(glslang::TVisit /* visit */, glslang // We need control flow to select the result. // TODO: Once SPIR-V OpSelect allows arbitrary types, eliminate this path. result = builder.createVariable(TranslatePrecisionDecoration(node->getType()), - spv::StorageClassFunction, convertGlslangToSpvType(node->getType())); + spv::StorageClassFunction, resultType); // Selection control: const spv::SelectionControlMask control = TranslateSelectionControl(*node); @@ -3834,10 +3843,15 @@ bool TGlslangToSpvTraverser::visitSelection(glslang::TVisit /* visit */, glslang spv::Builder::If ifBuilder(condition, control, builder); // emit the "then" statement - builder.createStore(trueValue, result); + builder.clearAccessChain(); + builder.setAccessChainLValue(result); + multiTypeStore(node->getType(), trueValue); + ifBuilder.makeBeginElse(); // emit the "else" statement - builder.createStore(falseValue, result); + builder.clearAccessChain(); + builder.setAccessChainLValue(result); + multiTypeStore(node->getType(), falseValue); // finish off the control flow ifBuilder.makeEndIf(); @@ -3864,16 +3878,26 @@ bool TGlslangToSpvTraverser::visitSelection(glslang::TVisit /* visit */, glslang // emit the "then" statement if (node->getTrueBlock() != nullptr) { node->getTrueBlock()->traverse(this); - if (result != spv::NoResult) - builder.createStore(accessChainLoad(node->getTrueBlock()->getAsTyped()->getType()), result); + if (result != spv::NoResult) { + spv::Id load = accessChainLoad(node->getTrueBlock()->getAsTyped()->getType()); + + builder.clearAccessChain(); + builder.setAccessChainLValue(result); + multiTypeStore(node->getType(), load); + } } if (node->getFalseBlock() != nullptr) { ifBuilder.makeBeginElse(); // emit the "else" statement node->getFalseBlock()->traverse(this); - if (result != spv::NoResult) - builder.createStore(accessChainLoad(node->getFalseBlock()->getAsTyped()->getType()), result); + if (result != spv::NoResult) { + spv::Id load = accessChainLoad(node->getFalseBlock()->getAsTyped()->getType()); + + builder.clearAccessChain(); + builder.setAccessChainLValue(result); + multiTypeStore(node->getType(), load); + } } // finish off the control flow