diff --git a/hlsl/hlslAttributes.cpp b/hlsl/hlslAttributes.cpp index 61ef8055..2a8e3702 100644 --- a/hlsl/hlslAttributes.cpp +++ b/hlsl/hlslAttributes.cpp @@ -36,6 +36,7 @@ #include "hlslAttributes.h" #include #include +#include namespace glslang { // Map the given string to an attribute enum from TAttributeType, @@ -131,4 +132,51 @@ namespace glslang { return attributes.find(attr) != attributes.end(); } + // extract integers out of attribute arguments stored in attribute aggregate + bool TAttributeMap::getInt(TAttributeType attr, int& value, int argNum) const + { + const TConstUnion* intConst = getConstUnion(attr, EbtInt, argNum); + + if (intConst == nullptr) + return false; + + value = intConst->getIConst(); + return true; + }; + + // extract strings out of attribute arguments stored in attribute aggregate. + // convert to lower case if converToLower is true (for case-insensitive compare convenience) + bool TAttributeMap::getString(TAttributeType attr, TString& value, int argNum, bool convertToLower) const + { + const TConstUnion* stringConst = getConstUnion(attr, EbtString, argNum); + + if (stringConst == nullptr) + return false; + + value = *stringConst->getSConst(); + + // Convenience. + if (convertToLower) + std::transform(value.begin(), value.end(), value.begin(), ::tolower); + + return true; + }; + + // Helper to get attribute const union. Returns nullptr on failure. + const TConstUnion* TAttributeMap::getConstUnion(TAttributeType attr, TBasicType basicType, int argNum) const + { + const TIntermAggregate* attrAgg = (*this)[attr]; + if (attrAgg == nullptr) + return nullptr; + + if (argNum >= int(attrAgg->getSequence().size())) + return nullptr; + + const TConstUnion* constVal = &attrAgg->getSequence()[argNum]->getAsConstantUnion()->getConstArray()[0]; + if (constVal == nullptr || constVal->getType() != basicType) + return nullptr; + + return constVal; + } + } // end namespace glslang diff --git a/hlsl/hlslAttributes.h b/hlsl/hlslAttributes.h index 16ec31da..2d7b6c7a 100644 --- a/hlsl/hlslAttributes.h +++ b/hlsl/hlslAttributes.h @@ -93,7 +93,16 @@ namespace glslang { // True if entry exists in map (even if value is nullptr) bool contains(TAttributeType) const; + // Obtain attribute as integer + bool getInt(TAttributeType attr, int& value, int argNum = 0) const; + + // Obtain attribute as string, with optional to-lower transform + bool getString(TAttributeType attr, TString& value, int argNum = 0, bool convertToLower = true) const; + protected: + // Helper to get attribute const union + const TConstUnion* getConstUnion(TAttributeType attr, TBasicType, int argNum) const; + // Find an attribute enum given its name. static TAttributeType attributeFromName(const TString& nameSpace, const TString& name); diff --git a/hlsl/hlslParseHelper.cpp b/hlsl/hlslParseHelper.cpp index fe6333b9..b2b80d05 100755 --- a/hlsl/hlslParseHelper.cpp +++ b/hlsl/hlslParseHelper.cpp @@ -1717,36 +1717,33 @@ void HlslParseContext::handleEntryPointAttributes(const TSourceLoc& loc, const T } // MaxVertexCount - const TIntermAggregate* maxVertexCount = attributes[EatMaxVertexCount]; - if (maxVertexCount != nullptr) { - if (! intermediate.setVertices(maxVertexCount->getSequence()[0]->getAsConstantUnion()-> - getConstArray()[0].getIConst())) { - error(loc, "cannot change previously set maxvertexcount attribute", "", ""); + if (attributes.contains(EatMaxVertexCount)) { + int maxVertexCount; + + if (! attributes.getInt(EatMaxVertexCount, maxVertexCount)) { + error(loc, "invalid maxvertexcount", "", ""); + } else { + if (! intermediate.setVertices(maxVertexCount)) + error(loc, "cannot change previously set maxvertexcount attribute", "", ""); } } // Handle [patchconstantfunction("...")] - const TIntermAggregate* pcfAttr = attributes[EatPatchConstantFunc]; - if (pcfAttr != nullptr) { - const TConstUnion& pcfName = pcfAttr->getSequence()[0]->getAsConstantUnion()->getConstArray()[0]; - - if (pcfName.getType() != EbtString) { + if (attributes.contains(EatPatchConstantFunc)) { + TString pcfName; + if (! attributes.getString(EatPatchConstantFunc, pcfName, 0, false)) { error(loc, "invalid patch constant function", "", ""); } else { - patchConstantFunctionName = *pcfName.getSConst(); + patchConstantFunctionName = pcfName; } } // Handle [domain("...")] - const TIntermAggregate* domainAttr = attributes[EatDomain]; - if (domainAttr != nullptr) { - const TConstUnion& domainType = domainAttr->getSequence()[0]->getAsConstantUnion()->getConstArray()[0]; - if (domainType.getType() != EbtString) { + if (attributes.contains(EatDomain)) { + TString domainStr; + if (! attributes.getString(EatDomain, domainStr)) { error(loc, "invalid domain", "", ""); } else { - TString domainStr = *domainType.getSConst(); - std::transform(domainStr.begin(), domainStr.end(), domainStr.begin(), ::tolower); - TLayoutGeometry domain = ElgNone; if (domainStr == "tri") { @@ -1770,15 +1767,11 @@ void HlslParseContext::handleEntryPointAttributes(const TSourceLoc& loc, const T } // Handle [outputtopology("...")] - const TIntermAggregate* topologyAttr = attributes[EatOutputTopology]; - if (topologyAttr != nullptr) { - const TConstUnion& topoType = topologyAttr->getSequence()[0]->getAsConstantUnion()->getConstArray()[0]; - if (topoType.getType() != EbtString) { + if (attributes.contains(EatOutputTopology)) { + TString topologyStr; + if (! attributes.getString(EatOutputTopology, topologyStr)) { error(loc, "invalid outputtopology", "", ""); } else { - TString topologyStr = *topoType.getSConst(); - std::transform(topologyStr.begin(), topologyStr.end(), topologyStr.begin(), ::tolower); - TVertexOrder vertexOrder = EvoNone; TLayoutGeometry primitive = ElgNone; @@ -1808,15 +1801,11 @@ void HlslParseContext::handleEntryPointAttributes(const TSourceLoc& loc, const T } // Handle [partitioning("...")] - const TIntermAggregate* partitionAttr = attributes[EatPartitioning]; - if (partitionAttr != nullptr) { - const TConstUnion& partType = partitionAttr->getSequence()[0]->getAsConstantUnion()->getConstArray()[0]; - if (partType.getType() != EbtString) { + if (attributes.contains(EatPartitioning)) { + TString partitionStr; + if (! attributes.getString(EatPartitioning, partitionStr)) { error(loc, "invalid partitioning", "", ""); } else { - TString partitionStr = *partType.getSConst(); - std::transform(partitionStr.begin(), partitionStr.end(), partitionStr.begin(), ::tolower); - TVertexSpacing partitioning = EvsNone; if (partitionStr == "integer") { @@ -1837,14 +1826,11 @@ void HlslParseContext::handleEntryPointAttributes(const TSourceLoc& loc, const T } // Handle [outputcontrolpoints("...")] - const TIntermAggregate* outputControlPoints = attributes[EatOutputControlPoints]; - if (outputControlPoints != nullptr) { - const TConstUnion& ctrlPointConst = - outputControlPoints->getSequence()[0]->getAsConstantUnion()->getConstArray()[0]; - if (ctrlPointConst.getType() != EbtInt) { + if (attributes.contains(EatOutputControlPoints)) { + int ctrlPoints; + if (! attributes.getInt(EatOutputControlPoints, ctrlPoints)) { error(loc, "invalid outputcontrolpoints", "", ""); } else { - const int ctrlPoints = ctrlPointConst.getIConst(); if (! intermediate.setVertices(ctrlPoints)) { error(loc, "cannot change previously set outputcontrolpoints attribute", "", ""); } @@ -1856,37 +1842,23 @@ void HlslParseContext::handleEntryPointAttributes(const TSourceLoc& loc, const T // attributes. void HlslParseContext::transferTypeAttributes(const TAttributeMap& attributes, TType& type) { - // extract integers out of attribute arguments stored in attribute aggregate - const auto getInt = [&](TAttributeType attr, int argNum, int& value) -> bool { - const TIntermAggregate* attrAgg = attributes[attr]; - if (attrAgg == nullptr) - return false; - if (argNum >= (int)attrAgg->getSequence().size()) - return false; - const TConstUnion& intConst = attrAgg->getSequence()[argNum]->getAsConstantUnion()->getConstArray()[0]; - if (intConst.getType() != EbtInt) - return false; - value = intConst.getIConst(); - return true; - }; - // location int value; - if (getInt(EatLocation, 0, value)) + if (attributes.getInt(EatLocation, value)) type.getQualifier().layoutLocation = value; // binding - if (getInt(EatBinding, 0, value)) { + if (attributes.getInt(EatBinding, value)) { type.getQualifier().layoutBinding = value; type.getQualifier().layoutSet = 0; } // set - if (getInt(EatBinding, 1, value)) + if (attributes.getInt(EatBinding, value, 1)) type.getQualifier().layoutSet = value; // input attachment - if (getInt(EatInputAttachment, 0, value)) + if (attributes.getInt(EatInputAttachment, value)) type.getQualifier().layoutAttachment = value; }