Merge pull request #835 from steve-lunarg/sb-counters

HLSL: structuredbuffer counter functionality
This commit is contained in:
John Kessenich
2017-04-19 17:42:22 -06:00
committed by GitHub
15 changed files with 847 additions and 27 deletions

View File

@@ -475,9 +475,10 @@ bool HlslGrammar::acceptDeclaration(TIntermNode*& nodeList)
if (variableType.getBasicType() != EbtString && parseContext.getAnnotationNestingLevel() == 0) {
if (typedefDecl)
parseContext.declareTypedef(idToken.loc, *fullName, variableType);
else if (variableType.getBasicType() == EbtBlock)
else if (variableType.getBasicType() == EbtBlock) {
parseContext.declareBlock(idToken.loc, variableType, fullName);
else {
parseContext.declareStructBufferCounter(idToken.loc, variableType, *fullName);
} else {
if (variableType.getQualifier().storage == EvqUniform && ! variableType.containsOpaque()) {
// this isn't really an individual variable, but a member of the $Global buffer
parseContext.growGlobalUniformBlock(idToken.loc, variableType, *fullName);
@@ -1955,24 +1956,29 @@ bool HlslGrammar::acceptStructBufferType(TType& type)
bool readonly = false;
TStorageQualifier storage = EvqBuffer;
TBuiltInVariable builtinType = EbvNone;
switch (structBuffType) {
case EHTokAppendStructuredBuffer:
unimplemented("AppendStructuredBuffer");
return false;
builtinType = EbvAppendConsume;
break;
case EHTokByteAddressBuffer:
hasTemplateType = false;
readonly = true;
builtinType = EbvByteAddressBuffer;
break;
case EHTokConsumeStructuredBuffer:
unimplemented("ConsumeStructuredBuffer");
return false;
builtinType = EbvAppendConsume;
break;
case EHTokRWByteAddressBuffer:
hasTemplateType = false;
builtinType = EbvRWByteAddressBuffer;
break;
case EHTokRWStructuredBuffer:
builtinType = EbvRWStructuredBuffer;
break;
case EHTokStructuredBuffer:
builtinType = EbvStructuredBuffer;
readonly = true;
break;
default:
@@ -2014,8 +2020,6 @@ bool HlslGrammar::acceptStructBufferType(TType& type)
// field name is canonical for all structbuffers
templateType->setFieldName("@data");
// Create block type. TODO: hidden internal uint member when needed
TTypeList* blockStruct = new TTypeList;
TTypeLoc member = { templateType, token.loc };
blockStruct->push_back(member);
@@ -2025,6 +2029,7 @@ bool HlslGrammar::acceptStructBufferType(TType& type)
blockType.getQualifier().storage = storage;
blockType.getQualifier().readonly = readonly;
blockType.getQualifier().builtIn = builtinType;
// We may have created an equivalent type before, in which case we should use its
// deep structure.

View File

@@ -842,7 +842,11 @@ bool HlslParseContext::isStructBufferMethod(const TString& name) const
name == "InterlockedMax" ||
name == "InterlockedMin" ||
name == "InterlockedOr" ||
name == "InterlockedXor";
name == "InterlockedXor" ||
name == "IncrementCounter" ||
name == "DecrementCounter" ||
name == "Append" ||
name == "Consume";
}
//
@@ -1514,7 +1518,7 @@ void HlslParseContext::handleFunctionDeclarator(const TSourceLoc& loc, TFunction
error(loc, "function name is redeclaration of existing name", function.getName().c_str(), "");
}
// Add interstage IO variables to the linkage in canonical order.
// Finalization step: Add interstage IO variables to the linkage in canonical order.
void HlslParseContext::addInterstageIoToLinkage()
{
TSourceLoc loc;
@@ -2438,24 +2442,125 @@ TIntermAggregate* HlslParseContext::handleSamplerTextureCombine(const TSourceLoc
return txcombine;
}
// Return true if this a buffer type that has an associated counter buffer.
bool HlslParseContext::hasStructBuffCounter(const TString& name) const
{
const auto bivIt = structBufferBuiltIn.find(name);
if (bivIt == structBufferBuiltIn.end())
return false;
switch (bivIt->second) {
case EbvAppendConsume: // fall through...
case EbvRWStructuredBuffer: // ...
return true;
default:
return false; // the other structuredbfufer types do not have a counter.
}
}
// declare counter for a structured buffer type
void HlslParseContext::declareStructBufferCounter(const TSourceLoc& loc, const TType& bufferType, const TString& name)
{
// Bail out if not a struct buffer
if (! isStructBufferType(bufferType))
return;
if (! hasStructBuffCounter(name))
return;
// Counter type
TType* counterType = new TType(EbtInt, EvqBuffer);
counterType->setFieldName("@count");
TTypeList* blockStruct = new TTypeList;
TTypeLoc member = { counterType, loc };
blockStruct->push_back(member);
TString* blockName = new TString(name);
*blockName += "@count";
structBufferCounter[*blockName] = false;
TType blockType(blockStruct, "", counterType->getQualifier());
blockType.getQualifier().storage = EvqBuffer;
shareStructBufferType(blockType);
declareBlock(loc, blockType, blockName);
}
// return the counter that goes with a given structuredbuffer
TIntermTyped* HlslParseContext::getStructBufferCounter(const TSourceLoc& loc, TIntermTyped* buffer)
{
// Bail out if not a struct buffer
if (buffer == nullptr || ! isStructBufferType(buffer->getType()))
return nullptr;
TString blockName(buffer->getAsSymbolNode()->getName());
blockName += "@count";
// Mark the counter as being used
structBufferCounter[blockName] = true;
TIntermTyped* counterVar = handleVariable(loc, &blockName); // find the block structure
TIntermTyped* index = intermediate.addConstantUnion(0, loc); // index to counter inside block struct
TIntermTyped* counterMember = intermediate.addIndex(EOpIndexDirectStruct, counterVar, index, loc);
counterMember->setType(TType(EbtInt));
return counterMember;
}
//
// Decompose structure buffer methods into AST
//
void HlslParseContext::decomposeStructBufferMethods(const TSourceLoc& loc, TIntermTyped*& node, TIntermNode* arguments)
{
if (!node || !node->getAsOperator())
if (node == nullptr || node->getAsOperator() == nullptr || arguments == nullptr)
return;
const TOperator op = node->getAsOperator()->getOp();
TIntermAggregate* argAggregate = arguments ? arguments->getAsAggregate() : nullptr;
if (argAggregate == nullptr)
return;
if (argAggregate->getSequence().empty())
return;
TIntermAggregate* argAggregate = arguments->getAsAggregate();
// Buffer is the object upon which method is called, so always arg 0
TIntermTyped* bufferObj = argAggregate->getSequence()[0]->getAsTyped();
TIntermTyped* bufferObj = nullptr;
// The parameters can be an aggregate, or just a the object as a symbol if there are no fn params.
if (argAggregate) {
if (argAggregate->getSequence().empty())
return;
bufferObj = argAggregate->getSequence()[0]->getAsTyped();
} else {
bufferObj = arguments->getAsSymbolNode();
}
if (bufferObj == nullptr || bufferObj->getAsSymbolNode() == nullptr)
return;
TString bufferName(bufferObj->getAsSymbolNode()->getName());
const auto bivIt = structBufferBuiltIn.find(bufferName);
if (bivIt == structBufferBuiltIn.end())
return;
const TBuiltInVariable builtInType = bivIt->second;
// Some methods require a hidden internal counter, obtained via getStructBufferCounter().
// This lambda adds something to it and returns the old value.
const auto incDecCounter = [&](int incval) -> TIntermTyped* {
TIntermTyped* incrementValue = intermediate.addConstantUnion(incval, loc, true);
TIntermTyped* counter = getStructBufferCounter(loc, bufferObj); // obtain the counter member
if (counter == nullptr)
return nullptr;
TIntermAggregate* counterIncrement = new TIntermAggregate(EOpAtomicAdd);
counterIncrement->setType(TType(EbtUint, EvqTemporary));
counterIncrement->setLoc(loc);
counterIncrement->getSequence().push_back(counter);
counterIncrement->getSequence().push_back(incrementValue);
return counterIncrement;
};
// Index to obtain the runtime sized array out of the buffer.
TIntermTyped* argArray = indexStructBufferContent(loc, bufferObj);
@@ -2469,7 +2574,9 @@ void HlslParseContext::decomposeStructBufferMethods(const TSourceLoc& loc, TInte
// Byte address buffers index in bytes (only multiples of 4 permitted... not so much a byte address
// buffer then, but that's what it calls itself.
const bool isByteAddressBuffer = (argArray->getBasicType() == EbtUint);
const bool isByteAddressBuffer = (builtInType == EbvByteAddressBuffer ||
builtInType == EbvRWByteAddressBuffer);
if (isByteAddressBuffer)
argIndex = intermediate.addBinaryNode(EOpRightShift, argIndex, intermediate.addConstantUnion(2, loc, true),
loc, TType(EbtInt));
@@ -2670,6 +2777,50 @@ void HlslParseContext::decomposeStructBufferMethods(const TSourceLoc& loc, TInte
}
break;
case EOpMethodIncrementCounter:
{
node = incDecCounter(1);
break;
}
case EOpMethodDecrementCounter:
{
TIntermTyped* preIncValue = incDecCounter(-1); // result is original value
node = intermediate.addBinaryNode(EOpAdd, preIncValue, intermediate.addConstantUnion(-1, loc, true), loc,
preIncValue->getType());
break;
}
case EOpMethodAppend:
{
TIntermTyped* oldCounter = incDecCounter(1);
TIntermTyped* lValue = intermediate.addIndex(EOpIndexIndirect, argArray, oldCounter, loc);
TIntermTyped* rValue = argAggregate->getSequence()[1]->getAsTyped();
const TType derefType(argArray->getType(), 0);
lValue->setType(derefType);
node = intermediate.addAssign(EOpAssign, lValue, rValue, loc);
break;
}
case EOpMethodConsume:
{
TIntermTyped* oldCounter = incDecCounter(-1);
TIntermTyped* newCounter = intermediate.addBinaryNode(EOpAdd, oldCounter, intermediate.addConstantUnion(-1, loc, true), loc,
oldCounter->getType());
node = intermediate.addIndex(EOpIndexIndirect, argArray, newCounter, loc);
const TType derefType(argArray->getType(), 0);
node->setType(derefType);
break;
}
default:
break; // most pass through unchanged
}
@@ -3978,10 +4129,18 @@ TIntermTyped* HlslParseContext::handleFunctionCall(const TSourceLoc& loc, TFunct
// TODO: this needs improvement: there's no way at present to look up a signature in
// the symbol table for an arbitrary type. This is a temporary hack until that ability exists.
// It will have false positives, since it doesn't check arg counts or types.
if (arguments && arguments->getAsAggregate()) {
const TIntermSequence& sequence = arguments->getAsAggregate()->getSequence();
if (arguments) {
// Check if first argument is struct buffer type. It may be an aggregate or a symbol, so we
// look for either case.
if (!sequence.empty() && isStructBufferType(sequence[0]->getAsTyped()->getType())) {
TIntermTyped* arg0 = nullptr;
if (arguments->getAsAggregate() && arguments->getAsAggregate()->getSequence().size() > 0)
arg0 = arguments->getAsAggregate()->getSequence()[0]->getAsTyped();
else if (arguments->getAsSymbolNode())
arg0 = arguments->getAsSymbolNode();
if (arg0 != nullptr && isStructBufferType(arg0->getType())) {
static const int methodPrefixSize = sizeof(BUILTIN_PREFIX)-1;
if (function->getName().length() > methodPrefixSize &&
@@ -5845,8 +6004,11 @@ const TFunction* HlslParseContext::findFunction(const TSourceLoc& loc, TFunction
// These builtin ops can accept any type, so we bypass the argument selection
if (candidateList.size() == 1 && builtIn &&
(candidateList[0]->getBuiltInOp() == EOpMethodAppend ||
candidateList[0]->getBuiltInOp() == EOpMethodRestartStrip)) {
candidateList[0]->getBuiltInOp() == EOpMethodRestartStrip ||
candidateList[0]->getBuiltInOp() == EOpMethodIncrementCounter ||
candidateList[0]->getBuiltInOp() == EOpMethodDecrementCounter ||
candidateList[0]->getBuiltInOp() == EOpMethodAppend ||
candidateList[0]->getBuiltInOp() == EOpMethodConsume)) {
return candidateList[0];
}
@@ -6856,6 +7018,10 @@ void HlslParseContext::declareBlock(const TSourceLoc& loc, TType& type, const TS
switch (type.getQualifier().storage) {
case EvqUniform:
case EvqBuffer:
// remember pre-sanitized builtin type
if (type.getQualifier().storage == EvqBuffer && instanceName != nullptr)
structBufferBuiltIn[*instanceName] = type.getQualifier().builtIn;
correctUniform(type.getQualifier());
break;
case EvqVaryingIn:
@@ -7670,7 +7836,7 @@ TIntermSymbol* HlslParseContext::findLinkageSymbol(TBuiltInVariable biType) cons
return intermediate.addSymbol(*it->second->getAsVariable());
}
// Add patch constant function invocation
// Finalization step: Add patch constant function invocation
void HlslParseContext::addPatchConstantInvocation()
{
TSourceLoc loc;
@@ -8039,9 +8205,23 @@ void HlslParseContext::addPatchConstantInvocation()
epBodySeq.insert(epBodySeq.end(), invocationIdTest);
}
// Finalization step: remove unused buffer blocks from linkage (we don't know until the
// shader is entirely compiled)
void HlslParseContext::removeUnusedStructBufferCounters()
{
const auto endIt = std::remove_if(linkageSymbols.begin(), linkageSymbols.end(),
[this](const TSymbol* sym) {
const auto sbcIt = structBufferCounter.find(sym->getName());
return sbcIt != structBufferCounter.end() && !sbcIt->second;
});
linkageSymbols.erase(endIt, linkageSymbols.end());
}
// post-processing
void HlslParseContext::finish()
{
removeUnusedStructBufferCounters();
addPatchConstantInvocation();
addInterstageIoToLinkage();

View File

@@ -146,6 +146,7 @@ public:
TIntermTyped* constructAggregate(TIntermNode*, const TType&, int, const TSourceLoc&);
TIntermTyped* constructBuiltIn(const TType&, TOperator, TIntermTyped*, const TSourceLoc&, bool subset);
void declareBlock(const TSourceLoc&, TType&, const TString* instanceName = 0, TArraySizes* arraySizes = 0);
void declareStructBufferCounter(const TSourceLoc& loc, const TType& bufferType, const TString& name);
void fixBlockLocations(const TSourceLoc&, TQualifier&, TTypeList&, bool memberWithLocation, bool memberWithoutLocation);
void fixBlockXfbOffsets(TQualifier&, TTypeList&);
void fixBlockUniformOffsets(const TQualifier&, TTypeList&);
@@ -274,11 +275,19 @@ protected:
TType* getStructBufferContentType(const TType& type) const;
bool isStructBufferType(const TType& type) const { return getStructBufferContentType(type) != nullptr; }
TIntermTyped* indexStructBufferContent(const TSourceLoc& loc, TIntermTyped* buffer) const;
TIntermTyped* getStructBufferCounter(const TSourceLoc& loc, TIntermTyped* buffer);
// Return true if this type is a reference. This is not currently a type method in case that's
// a language specific answer.
bool isReference(const TType& type) const { return isStructBufferType(type); }
// Return true if this a buffer type that has an associated counter buffer.
bool hasStructBuffCounter(const TString& name) const;
// Finalization step: remove unused buffer blocks from linkage (we don't know until the
// shader is entirely compiled)
void removeUnusedStructBufferCounters();
// Pass through to base class after remembering builtin mappings.
using TParseContextBase::trackLinkage;
void trackLinkage(TSymbol& variable) override;
@@ -366,6 +375,9 @@ protected:
// Structuredbuffer shared types. Typically there are only a few.
TVector<TType*> structBufferTypes;
TMap<TString, TBuiltInVariable> structBufferBuiltIn;
TMap<TString, bool> structBufferCounter;
// The builtin interstage IO map considers e.g, EvqPosition on input and output separately, so that we
// can build the linkage correctly if position appears on both sides. Otherwise, multiple positions

View File

@@ -871,6 +871,9 @@ void TBuiltInParseablesHlsl::initialize(int /*version*/, EProfile /*profile*/, c
{ "InterlockedMin", nullptr, nullptr, "-", "-", EShLangAll, true },
{ "InterlockedOr", nullptr, nullptr, "-", "-", EShLangAll, true },
{ "InterlockedXor", nullptr, nullptr, "-", "-", EShLangAll, true },
{ "IncrementCounter", nullptr, nullptr, "-", "-", EShLangAll, true },
{ "DecrementCounter", nullptr, nullptr, "-", "-", EShLangAll, true },
{ "Consume", nullptr, nullptr, "-", "-", EShLangAll, true },
// Mark end of list, since we want to avoid a range-based for, as some compilers don't handle it yet.
{ nullptr, nullptr, nullptr, nullptr, nullptr, 0, false },
@@ -1180,6 +1183,10 @@ void TBuiltInParseablesHlsl::identifyBuiltIns(int /*version*/, EProfile /*profil
symbolTable.relateToOperator(BUILTIN_PREFIX "Store2", EOpMethodStore2);
symbolTable.relateToOperator(BUILTIN_PREFIX "Store3", EOpMethodStore3);
symbolTable.relateToOperator(BUILTIN_PREFIX "Store4", EOpMethodStore4);
symbolTable.relateToOperator(BUILTIN_PREFIX "IncrementCounter", EOpMethodIncrementCounter);
symbolTable.relateToOperator(BUILTIN_PREFIX "DecrementCounter", EOpMethodDecrementCounter);
// Append is also a GS method: we don't add it twice
symbolTable.relateToOperator(BUILTIN_PREFIX "Consume", EOpMethodConsume);
symbolTable.relateToOperator(BUILTIN_PREFIX "InterlockedAdd", EOpInterlockedAdd);
symbolTable.relateToOperator(BUILTIN_PREFIX "InterlockedAnd", EOpInterlockedAnd);