Merge pull request #835 from steve-lunarg/sb-counters
HLSL: structuredbuffer counter functionality
This commit is contained in:
@@ -842,7 +842,11 @@ bool HlslParseContext::isStructBufferMethod(const TString& name) const
|
||||
name == "InterlockedMax" ||
|
||||
name == "InterlockedMin" ||
|
||||
name == "InterlockedOr" ||
|
||||
name == "InterlockedXor";
|
||||
name == "InterlockedXor" ||
|
||||
name == "IncrementCounter" ||
|
||||
name == "DecrementCounter" ||
|
||||
name == "Append" ||
|
||||
name == "Consume";
|
||||
}
|
||||
|
||||
//
|
||||
@@ -1514,7 +1518,7 @@ void HlslParseContext::handleFunctionDeclarator(const TSourceLoc& loc, TFunction
|
||||
error(loc, "function name is redeclaration of existing name", function.getName().c_str(), "");
|
||||
}
|
||||
|
||||
// Add interstage IO variables to the linkage in canonical order.
|
||||
// Finalization step: Add interstage IO variables to the linkage in canonical order.
|
||||
void HlslParseContext::addInterstageIoToLinkage()
|
||||
{
|
||||
TSourceLoc loc;
|
||||
@@ -2438,24 +2442,125 @@ TIntermAggregate* HlslParseContext::handleSamplerTextureCombine(const TSourceLoc
|
||||
return txcombine;
|
||||
}
|
||||
|
||||
// Return true if this a buffer type that has an associated counter buffer.
|
||||
bool HlslParseContext::hasStructBuffCounter(const TString& name) const
|
||||
{
|
||||
const auto bivIt = structBufferBuiltIn.find(name);
|
||||
if (bivIt == structBufferBuiltIn.end())
|
||||
return false;
|
||||
|
||||
switch (bivIt->second) {
|
||||
case EbvAppendConsume: // fall through...
|
||||
case EbvRWStructuredBuffer: // ...
|
||||
return true;
|
||||
default:
|
||||
return false; // the other structuredbfufer types do not have a counter.
|
||||
}
|
||||
}
|
||||
|
||||
// declare counter for a structured buffer type
|
||||
void HlslParseContext::declareStructBufferCounter(const TSourceLoc& loc, const TType& bufferType, const TString& name)
|
||||
{
|
||||
// Bail out if not a struct buffer
|
||||
if (! isStructBufferType(bufferType))
|
||||
return;
|
||||
|
||||
if (! hasStructBuffCounter(name))
|
||||
return;
|
||||
|
||||
// Counter type
|
||||
TType* counterType = new TType(EbtInt, EvqBuffer);
|
||||
counterType->setFieldName("@count");
|
||||
|
||||
TTypeList* blockStruct = new TTypeList;
|
||||
TTypeLoc member = { counterType, loc };
|
||||
blockStruct->push_back(member);
|
||||
|
||||
TString* blockName = new TString(name);
|
||||
*blockName += "@count";
|
||||
|
||||
structBufferCounter[*blockName] = false;
|
||||
|
||||
TType blockType(blockStruct, "", counterType->getQualifier());
|
||||
blockType.getQualifier().storage = EvqBuffer;
|
||||
|
||||
shareStructBufferType(blockType);
|
||||
declareBlock(loc, blockType, blockName);
|
||||
}
|
||||
|
||||
// return the counter that goes with a given structuredbuffer
|
||||
TIntermTyped* HlslParseContext::getStructBufferCounter(const TSourceLoc& loc, TIntermTyped* buffer)
|
||||
{
|
||||
// Bail out if not a struct buffer
|
||||
if (buffer == nullptr || ! isStructBufferType(buffer->getType()))
|
||||
return nullptr;
|
||||
|
||||
TString blockName(buffer->getAsSymbolNode()->getName());
|
||||
blockName += "@count";
|
||||
|
||||
// Mark the counter as being used
|
||||
structBufferCounter[blockName] = true;
|
||||
|
||||
TIntermTyped* counterVar = handleVariable(loc, &blockName); // find the block structure
|
||||
TIntermTyped* index = intermediate.addConstantUnion(0, loc); // index to counter inside block struct
|
||||
|
||||
TIntermTyped* counterMember = intermediate.addIndex(EOpIndexDirectStruct, counterVar, index, loc);
|
||||
counterMember->setType(TType(EbtInt));
|
||||
return counterMember;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Decompose structure buffer methods into AST
|
||||
//
|
||||
void HlslParseContext::decomposeStructBufferMethods(const TSourceLoc& loc, TIntermTyped*& node, TIntermNode* arguments)
|
||||
{
|
||||
if (!node || !node->getAsOperator())
|
||||
if (node == nullptr || node->getAsOperator() == nullptr || arguments == nullptr)
|
||||
return;
|
||||
|
||||
const TOperator op = node->getAsOperator()->getOp();
|
||||
TIntermAggregate* argAggregate = arguments ? arguments->getAsAggregate() : nullptr;
|
||||
if (argAggregate == nullptr)
|
||||
return;
|
||||
|
||||
if (argAggregate->getSequence().empty())
|
||||
return;
|
||||
TIntermAggregate* argAggregate = arguments->getAsAggregate();
|
||||
|
||||
// Buffer is the object upon which method is called, so always arg 0
|
||||
TIntermTyped* bufferObj = argAggregate->getSequence()[0]->getAsTyped();
|
||||
TIntermTyped* bufferObj = nullptr;
|
||||
|
||||
// The parameters can be an aggregate, or just a the object as a symbol if there are no fn params.
|
||||
if (argAggregate) {
|
||||
if (argAggregate->getSequence().empty())
|
||||
return;
|
||||
bufferObj = argAggregate->getSequence()[0]->getAsTyped();
|
||||
} else {
|
||||
bufferObj = arguments->getAsSymbolNode();
|
||||
}
|
||||
|
||||
if (bufferObj == nullptr || bufferObj->getAsSymbolNode() == nullptr)
|
||||
return;
|
||||
|
||||
TString bufferName(bufferObj->getAsSymbolNode()->getName());
|
||||
|
||||
const auto bivIt = structBufferBuiltIn.find(bufferName);
|
||||
if (bivIt == structBufferBuiltIn.end())
|
||||
return;
|
||||
|
||||
const TBuiltInVariable builtInType = bivIt->second;
|
||||
|
||||
// Some methods require a hidden internal counter, obtained via getStructBufferCounter().
|
||||
// This lambda adds something to it and returns the old value.
|
||||
const auto incDecCounter = [&](int incval) -> TIntermTyped* {
|
||||
TIntermTyped* incrementValue = intermediate.addConstantUnion(incval, loc, true);
|
||||
TIntermTyped* counter = getStructBufferCounter(loc, bufferObj); // obtain the counter member
|
||||
|
||||
if (counter == nullptr)
|
||||
return nullptr;
|
||||
|
||||
TIntermAggregate* counterIncrement = new TIntermAggregate(EOpAtomicAdd);
|
||||
counterIncrement->setType(TType(EbtUint, EvqTemporary));
|
||||
counterIncrement->setLoc(loc);
|
||||
counterIncrement->getSequence().push_back(counter);
|
||||
counterIncrement->getSequence().push_back(incrementValue);
|
||||
|
||||
return counterIncrement;
|
||||
};
|
||||
|
||||
// Index to obtain the runtime sized array out of the buffer.
|
||||
TIntermTyped* argArray = indexStructBufferContent(loc, bufferObj);
|
||||
@@ -2469,7 +2574,9 @@ void HlslParseContext::decomposeStructBufferMethods(const TSourceLoc& loc, TInte
|
||||
|
||||
// Byte address buffers index in bytes (only multiples of 4 permitted... not so much a byte address
|
||||
// buffer then, but that's what it calls itself.
|
||||
const bool isByteAddressBuffer = (argArray->getBasicType() == EbtUint);
|
||||
const bool isByteAddressBuffer = (builtInType == EbvByteAddressBuffer ||
|
||||
builtInType == EbvRWByteAddressBuffer);
|
||||
|
||||
if (isByteAddressBuffer)
|
||||
argIndex = intermediate.addBinaryNode(EOpRightShift, argIndex, intermediate.addConstantUnion(2, loc, true),
|
||||
loc, TType(EbtInt));
|
||||
@@ -2670,6 +2777,50 @@ void HlslParseContext::decomposeStructBufferMethods(const TSourceLoc& loc, TInte
|
||||
}
|
||||
break;
|
||||
|
||||
case EOpMethodIncrementCounter:
|
||||
{
|
||||
node = incDecCounter(1);
|
||||
break;
|
||||
}
|
||||
|
||||
case EOpMethodDecrementCounter:
|
||||
{
|
||||
TIntermTyped* preIncValue = incDecCounter(-1); // result is original value
|
||||
node = intermediate.addBinaryNode(EOpAdd, preIncValue, intermediate.addConstantUnion(-1, loc, true), loc,
|
||||
preIncValue->getType());
|
||||
break;
|
||||
}
|
||||
|
||||
case EOpMethodAppend:
|
||||
{
|
||||
TIntermTyped* oldCounter = incDecCounter(1);
|
||||
|
||||
TIntermTyped* lValue = intermediate.addIndex(EOpIndexIndirect, argArray, oldCounter, loc);
|
||||
TIntermTyped* rValue = argAggregate->getSequence()[1]->getAsTyped();
|
||||
|
||||
const TType derefType(argArray->getType(), 0);
|
||||
lValue->setType(derefType);
|
||||
|
||||
node = intermediate.addAssign(EOpAssign, lValue, rValue, loc);
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case EOpMethodConsume:
|
||||
{
|
||||
TIntermTyped* oldCounter = incDecCounter(-1);
|
||||
|
||||
TIntermTyped* newCounter = intermediate.addBinaryNode(EOpAdd, oldCounter, intermediate.addConstantUnion(-1, loc, true), loc,
|
||||
oldCounter->getType());
|
||||
|
||||
node = intermediate.addIndex(EOpIndexIndirect, argArray, newCounter, loc);
|
||||
|
||||
const TType derefType(argArray->getType(), 0);
|
||||
node->setType(derefType);
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
break; // most pass through unchanged
|
||||
}
|
||||
@@ -3978,10 +4129,18 @@ TIntermTyped* HlslParseContext::handleFunctionCall(const TSourceLoc& loc, TFunct
|
||||
// TODO: this needs improvement: there's no way at present to look up a signature in
|
||||
// the symbol table for an arbitrary type. This is a temporary hack until that ability exists.
|
||||
// It will have false positives, since it doesn't check arg counts or types.
|
||||
if (arguments && arguments->getAsAggregate()) {
|
||||
const TIntermSequence& sequence = arguments->getAsAggregate()->getSequence();
|
||||
if (arguments) {
|
||||
// Check if first argument is struct buffer type. It may be an aggregate or a symbol, so we
|
||||
// look for either case.
|
||||
|
||||
if (!sequence.empty() && isStructBufferType(sequence[0]->getAsTyped()->getType())) {
|
||||
TIntermTyped* arg0 = nullptr;
|
||||
|
||||
if (arguments->getAsAggregate() && arguments->getAsAggregate()->getSequence().size() > 0)
|
||||
arg0 = arguments->getAsAggregate()->getSequence()[0]->getAsTyped();
|
||||
else if (arguments->getAsSymbolNode())
|
||||
arg0 = arguments->getAsSymbolNode();
|
||||
|
||||
if (arg0 != nullptr && isStructBufferType(arg0->getType())) {
|
||||
static const int methodPrefixSize = sizeof(BUILTIN_PREFIX)-1;
|
||||
|
||||
if (function->getName().length() > methodPrefixSize &&
|
||||
@@ -5845,8 +6004,11 @@ const TFunction* HlslParseContext::findFunction(const TSourceLoc& loc, TFunction
|
||||
// These builtin ops can accept any type, so we bypass the argument selection
|
||||
if (candidateList.size() == 1 && builtIn &&
|
||||
(candidateList[0]->getBuiltInOp() == EOpMethodAppend ||
|
||||
candidateList[0]->getBuiltInOp() == EOpMethodRestartStrip)) {
|
||||
|
||||
candidateList[0]->getBuiltInOp() == EOpMethodRestartStrip ||
|
||||
candidateList[0]->getBuiltInOp() == EOpMethodIncrementCounter ||
|
||||
candidateList[0]->getBuiltInOp() == EOpMethodDecrementCounter ||
|
||||
candidateList[0]->getBuiltInOp() == EOpMethodAppend ||
|
||||
candidateList[0]->getBuiltInOp() == EOpMethodConsume)) {
|
||||
return candidateList[0];
|
||||
}
|
||||
|
||||
@@ -6856,6 +7018,10 @@ void HlslParseContext::declareBlock(const TSourceLoc& loc, TType& type, const TS
|
||||
switch (type.getQualifier().storage) {
|
||||
case EvqUniform:
|
||||
case EvqBuffer:
|
||||
// remember pre-sanitized builtin type
|
||||
if (type.getQualifier().storage == EvqBuffer && instanceName != nullptr)
|
||||
structBufferBuiltIn[*instanceName] = type.getQualifier().builtIn;
|
||||
|
||||
correctUniform(type.getQualifier());
|
||||
break;
|
||||
case EvqVaryingIn:
|
||||
@@ -7670,7 +7836,7 @@ TIntermSymbol* HlslParseContext::findLinkageSymbol(TBuiltInVariable biType) cons
|
||||
return intermediate.addSymbol(*it->second->getAsVariable());
|
||||
}
|
||||
|
||||
// Add patch constant function invocation
|
||||
// Finalization step: Add patch constant function invocation
|
||||
void HlslParseContext::addPatchConstantInvocation()
|
||||
{
|
||||
TSourceLoc loc;
|
||||
@@ -8039,9 +8205,23 @@ void HlslParseContext::addPatchConstantInvocation()
|
||||
epBodySeq.insert(epBodySeq.end(), invocationIdTest);
|
||||
}
|
||||
|
||||
// Finalization step: remove unused buffer blocks from linkage (we don't know until the
|
||||
// shader is entirely compiled)
|
||||
void HlslParseContext::removeUnusedStructBufferCounters()
|
||||
{
|
||||
const auto endIt = std::remove_if(linkageSymbols.begin(), linkageSymbols.end(),
|
||||
[this](const TSymbol* sym) {
|
||||
const auto sbcIt = structBufferCounter.find(sym->getName());
|
||||
return sbcIt != structBufferCounter.end() && !sbcIt->second;
|
||||
});
|
||||
|
||||
linkageSymbols.erase(endIt, linkageSymbols.end());
|
||||
}
|
||||
|
||||
// post-processing
|
||||
void HlslParseContext::finish()
|
||||
{
|
||||
removeUnusedStructBufferCounters();
|
||||
addPatchConstantInvocation();
|
||||
addInterstageIoToLinkage();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user