HLSL: Recursive composite flattening

This PR implements recursive type flattening.  For example, an array of structs of other structs
can be flattened to individual member variables at the shader interface.

This is sufficient for many purposes, e.g, uniforms containing opaque types, but is not sufficient
for geometry shader arrayed inputs.  That will be handled separately with structure splitting,
 which is not implemented by this PR.  In the meantime, that case is detected and triggers an error.

The recursive flattening extends the following three aspects of single-level flattening:

- Flattening of structures to individual members with names such as "foo[0].samp[1]";

- Turning constant references to the nested composite type into a reference to a particular
  flattened member.

- Shadow copies between arrays of flattened members and the nested composite type.

Previous single-level flattening only flattened at the shader interface, and that is unchanged by this PR.
Internally, shadow copies are, such as if the type is passed to a function.

Also, the reasons for flattening are unchanged.  Uniforms containing opaque types, and interface struct
types are flattened.  (The latter will change with structure splitting).

One existing test changes: hlsl.structin.vert, which did in fact contain a nested composite type to be
flattened.

Two new tests are added: hlsl.structarray.flatten.frag, and hlsl.structarray.flatten.geom (currently
issues an error until type splitting is online).

The process of arriving at the individual member from chained postfix expressions is more complex than
it was with one level.  See large-ish comment above HlslParseContext::flatten() for details.
This commit is contained in:
steve-lunarg
2016-11-28 17:09:54 -07:00
parent b56f4ac72c
commit a2b01a0da8
10 changed files with 724 additions and 204 deletions

View File

@@ -45,6 +45,7 @@
#include "../glslang/OSDependent/osinclude.h"
#include <algorithm>
#include <functional>
#include <cctype>
namespace glslang {
@@ -651,11 +652,11 @@ TIntermTyped* HlslParseContext::handleBracketDereference(const TSourceLoc& loc,
else {
// at least one of base and index is variable...
if (base->getAsSymbolNode() && shouldFlatten(base->getType())) {
if (base->getAsSymbolNode() && (wasFlattened(base) || shouldFlatten(base->getType()))) {
if (index->getQualifier().storage != EvqConst)
error(loc, "Invalid variable index to flattened uniform array", base->getAsSymbolNode()->getName().c_str(), "");
result = flattenAccess(base, indexValue);
result = flattenAccess(loc, base, indexValue);
flattened = (result != base);
} else {
if (index->getQualifier().storage == EvqConst) {
@@ -831,8 +832,8 @@ TIntermTyped* HlslParseContext::handleDotDereference(const TSourceLoc& loc, TInt
}
}
if (fieldFound) {
if (base->getAsSymbolNode() && shouldFlatten(base->getType()))
result = flattenAccess(base, member);
if (base->getAsSymbolNode() && (wasFlattened(base) || shouldFlatten(base->getType())))
result = flattenAccess(loc, base, member);
else {
if (base->getType().getQualifier().storage == EvqConst)
result = intermediate.foldDereference(base, member, loc);
@@ -850,6 +851,12 @@ TIntermTyped* HlslParseContext::handleDotDereference(const TSourceLoc& loc, TInt
return result;
}
// Determine whether we should flatten an arbitrary type.
bool HlslParseContext::shouldFlatten(const TType& type) const
{
return shouldFlattenIO(type) || shouldFlattenUniform(type);
}
// Is this an IO variable that can't be passed down the stack?
// E.g., pipeline inputs to the vertex stage and outputs from the fragment stage.
bool HlslParseContext::shouldFlattenIO(const TType& type) const
@@ -869,27 +876,98 @@ bool HlslParseContext::shouldFlattenUniform(const TType& type) const
{
const TStorageQualifier qualifier = type.getQualifier().storage;
return type.isArray() &&
intermediate.getFlattenUniformArrays() &&
return ((type.isArray() && intermediate.getFlattenUniformArrays()) || type.isStruct()) &&
qualifier == EvqUniform &&
type.isOpaque();
type.containsOpaque();
}
// Top level variable flattening: construct data
void HlslParseContext::flatten(const TSourceLoc& loc, const TVariable& variable)
{
const TType& type = variable.getType();
// Presently, flattening of structure arrays is unimplemented.
// We handle one, or the other.
if (type.isArray() && type.isStruct()) {
error(loc, "cannot flatten structure array", variable.getName().c_str(), "");
// emplace gives back a pair whose .first is an iterator to the item...
auto entry = flattenMap.emplace(variable.getUniqueId(),
TFlattenData(type.getQualifier().layoutBinding));
// ... and the item is a map pair, so first->second is the TFlattenData itself.
flatten(loc, variable, type, entry.first->second, "");
}
// Recursively flatten the given variable at the provided type, building the flattenData as we go.
//
// This is mutually recursive with flattenStruct and flattenArray.
// We are going to flatten an arbitrarily nested composite structure into a linear sequence of
// members, and later on, we want to turn a path through the tree structure into a final
// location in this linear sequence.
//
// If the tree was N-ary, that can be directly calculated. However, we are dealing with
// arbitrary numbers - peraps a struct of 7 members containing an array of 3. Thus, we must
// build a data structure to allow the sequence of bracket and dot operators on arrays and
// structs to arrive at the proper member.
//
// To avoid storing a tree with pointers, we are going to flatten the tree into a vector of integers.
// The leaves are the indexes into the flattened member array.
// Each level will have the next location for the Nth item stored sequentially, so for instance:
//
// struct { float2 a[2]; int b; float4 c[3] };
//
// This will produce the following flattened tree:
// Pos: 0 1 2 3 4 5 6 7 8 9 10 11 12 13
// (3, 7, 8, 5, 6, 0, 1, 2, 11, 12, 13, 3, 4, 5}
//
// Given a reference to mystruct.c[1], the access chain is (2,1), so we traverse:
// (0+2) = 8 --> (8+1) = 12 --> 12 = 4
//
// so the 4th flattened member in traversal order is ours.
//
int HlslParseContext::flatten(const TSourceLoc& loc, const TVariable& variable, const TType& type,
TFlattenData& flattenData, TString name)
{
// TODO: when struct splitting is in place we can remove this restriction.
if (language == EShLangGeometry) {
const TType derefType(type, 0);
if (!isFinalFlattening(derefType) && type.getQualifier().storage == EvqVaryingIn)
error(loc, "recursive type not yet supported in GS input", variable.getName().c_str(), "");
}
if (type.isStruct())
flattenStruct(variable);
// If something is an arrayed struct, the array flattener will recursively call flatten()
// to then flatten the struct, so this is an "if else": we don't do both.
if (type.isArray())
flattenArray(loc, variable);
return flattenArray(loc, variable, type, flattenData, name);
else if (type.isStruct())
return flattenStruct(loc, variable, type, flattenData, name);
else {
assert(0); // should never happen
return -1;
}
}
// Add a single flattened member to the flattened data being tracked for the composite
// Returns true for the final flattening level.
int HlslParseContext::addFlattenedMember(const TSourceLoc& loc,
const TVariable& variable, const TType& type, TFlattenData& flattenData,
const TString& memberName, bool track)
{
if (isFinalFlattening(type)) {
// This is as far as we flatten. Insert the variable.
TVariable* memberVariable = makeInternalVariable(memberName.c_str(), type);
mergeQualifiers(memberVariable->getWritableType().getQualifier(), variable.getType().getQualifier());
if (flattenData.nextBinding != TQualifier::layoutBindingEnd)
memberVariable->getWritableType().getQualifier().layoutBinding = flattenData.nextBinding++;
flattenData.offsets.push_back(flattenData.members.size());
flattenData.members.push_back(memberVariable);
if (track)
trackLinkageDeferred(*memberVariable);
return flattenData.offsets.size()-1; // location of the member reference
} else {
// Further recursion required
return flatten(loc, variable, type, flattenData, memberName);
}
}
// Figure out the mapping between an aggregate's top members and an
@@ -899,84 +977,103 @@ void HlslParseContext::flatten(const TSourceLoc& loc, const TVariable& variable)
// effecting a transfer of this information to the flattened variable form.
//
// Assumes shouldFlatten() or equivalent was called first.
//
// TODO: generalize this to arbitrary nesting?
void HlslParseContext::flattenStruct(const TVariable& variable)
int HlslParseContext::flattenStruct(const TSourceLoc& loc, const TVariable& variable, const TType& type,
TFlattenData& flattenData, TString name)
{
TVector<TVariable*> memberVariables;
assert(type.isStruct());
auto members = *type.getStruct();
// Reserve space for this tree level.
int start = flattenData.offsets.size();
int pos = start;
flattenData.offsets.resize(int(pos + members.size()), -1);
auto members = *variable.getType().getStruct();
for (int member = 0; member < (int)members.size(); ++member) {
TVariable* memberVariable = makeInternalVariable(members[member].type->getFieldName().c_str(),
*members[member].type);
mergeQualifiers(memberVariable->getWritableType().getQualifier(), variable.getType().getQualifier());
memberVariables.push_back(memberVariable);
TType& dereferencedType = *members[member].type;
const TString memberName = name + (name.empty() ? "" : ".") + dereferencedType.getFieldName();
const int mpos = addFlattenedMember(loc, variable, dereferencedType, flattenData, memberName, false);
flattenData.offsets[pos++] = mpos;
// N.B. Erase I/O-related annotations from the source-type member.
members[member].type->getQualifier().makeTemporary();
dereferencedType.getQualifier().makeTemporary();
}
flattenMap[variable.getUniqueId()] = memberVariables;
return start;
}
// Figure out mapping between an array's members and an
// equivalent set of individual variables.
//
// Assumes shouldFlatten() or equivalent was called first.
void HlslParseContext::flattenArray(const TSourceLoc& loc, const TVariable& variable)
int HlslParseContext::flattenArray(const TSourceLoc& loc, const TVariable& variable, const TType& type,
TFlattenData& flattenData, TString name)
{
const TType& type = variable.getType();
assert(type.isArray());
if (type.isImplicitlySizedArray())
error(loc, "cannot flatten implicitly sized array", variable.getName().c_str(), "");
if (type.getArraySizes()->getNumDims() != 1)
error(loc, "cannot flatten multi-dimensional array", variable.getName().c_str(), "");
const int size = type.getCumulativeArraySize();
TVector<TVariable*> memberVariables;
const int size = type.getOuterArraySize();
const TType dereferencedType(type, 0);
int binding = type.getQualifier().layoutBinding;
if (dereferencedType.isStruct() || dereferencedType.isArray()) {
error(loc, "cannot flatten array of aggregate types", variable.getName().c_str(), "");
}
if (name.empty())
name = variable.getName();
for (int element=0; element < size; ++element) {
// Reserve space for this tree level.
int start = flattenData.offsets.size();
int pos = start;
flattenData.offsets.resize(int(pos + size), -1);
for (int element=0; element < size; ++element) {
char elementNumBuf[20]; // sufficient for MAXINT
snprintf(elementNumBuf, sizeof(elementNumBuf)-1, "[%d]", element);
const TString memberName = variable.getName() + elementNumBuf;
const int mpos = addFlattenedMember(loc, variable, dereferencedType, flattenData,
name + elementNumBuf, true);
TVariable* memberVariable = makeInternalVariable(memberName.c_str(), dereferencedType);
memberVariable->getWritableType().getQualifier() = variable.getType().getQualifier();
memberVariable->getWritableType().getQualifier().layoutBinding = binding;
if (binding != TQualifier::layoutBindingEnd)
++binding;
memberVariables.push_back(memberVariable);
trackLinkageDeferred(*memberVariable);
flattenData.offsets[pos++] = mpos;
}
flattenMap[variable.getUniqueId()] = memberVariables;
return start;
}
// Return true if we have flattened this node.
bool HlslParseContext::wasFlattened(const TIntermTyped* node) const
{
return node != nullptr &&
node->getAsSymbolNode() != nullptr &&
wasFlattened(node->getAsSymbolNode()->getId());
}
// Turn an access into an aggregate that was flattened to instead be
// an access to the individual variable the member was flattened to.
// Assumes shouldFlatten() or equivalent was called first.
TIntermTyped* HlslParseContext::flattenAccess(TIntermTyped* base, int member)
TIntermTyped* HlslParseContext::flattenAccess(const TSourceLoc&, TIntermTyped* base, int member)
{
const TType dereferencedType(base->getType(), member); // dereferenced type
const TIntermSymbol& symbolNode = *base->getAsSymbolNode();
if (flattenMap.find(symbolNode.getId()) == flattenMap.end())
const auto flattenData = flattenMap.find(symbolNode.getId());
if (flattenData == flattenMap.end())
return base;
const TVariable* memberVariable = flattenMap[symbolNode.getId()][member];
return intermediate.addSymbol(*memberVariable);
// Calculate new cumulative offset from the packed tree
flattenOffset.back() = flattenData->second.offsets[flattenOffset.back() + member];
if (isFinalFlattening(dereferencedType)) {
// Finished flattening: create symbol for variable
member = flattenData->second.offsets[flattenOffset.back()];
const TVariable* memberVariable = flattenData->second.members[member];
return intermediate.addSymbol(*memberVariable);
} else {
// If this is not the final flattening, accumulate the position and return
// an object of the partially dereferenced type.
return new TIntermSymbol(symbolNode.getId(), "flattenShadow", dereferencedType);
}
}
// Variables that correspond to the user-interface in and out of a stage
@@ -1002,8 +1099,8 @@ void HlslParseContext::assignLocations(TVariable& variable)
}
};
if (shouldFlatten(variable.getType())) {
auto& memberList = flattenMap[variable.getUniqueId()];
if (wasFlattened(variable.getUniqueId())) {
auto& memberList = flattenMap[variable.getUniqueId()].members;
for (auto member = memberList.begin(); member != memberList.end(); ++member)
assignLocation(**member);
} else
@@ -1294,7 +1391,7 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op
return nullptr;
const auto mustFlatten = [&](const TIntermTyped& node) {
return shouldFlatten(node.getType()) && node.getAsSymbolNode() &&
return wasFlattened(&node) && node.getAsSymbolNode() &&
flattenMap.find(node.getAsSymbolNode()->getId()) != flattenMap.end();
};
@@ -1327,10 +1424,10 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op
memberCount = left->getType().getCumulativeArraySize();
if (flattenLeft)
leftVariables = &flattenMap.find(left->getAsSymbolNode()->getId())->second;
leftVariables = &flattenMap.find(left->getAsSymbolNode()->getId())->second.members;
if (flattenRight) {
rightVariables = &flattenMap.find(right->getAsSymbolNode()->getId())->second;
rightVariables = &flattenMap.find(right->getAsSymbolNode()->getId())->second.members;
} else {
// The RHS is not flattened. There are several cases:
// 1. 1 item to copy: Use the RHS directly.
@@ -1355,13 +1452,15 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op
}
}
int memberIdx = 0;
const auto getMember = [&](bool flatten, TIntermTyped* node,
const TVector<TVariable*>& memberVariables, int member,
TOperator op, const TType& memberType) -> TIntermTyped * {
TIntermTyped* subTree;
if (flatten)
subTree = intermediate.addSymbol(*memberVariables[member]);
else {
if (flatten && isFinalFlattening(memberType)) {
subTree = intermediate.addSymbol(*memberVariables[memberIdx++]);
} else {
subTree = intermediate.addIndex(op, node, intermediate.addConstantUnion(member, loc), loc);
subTree->setType(memberType);
}
@@ -1369,46 +1468,59 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op
return subTree;
};
// Return the proper RHS node: a new symbol from a TVariable, copy
// of an TIntermSymbol node, or sometimes the right node directly.
const auto getRHS = [&]() {
return rhsTempVar ? intermediate.addSymbol(*rhsTempVar, loc) :
cloneSymNode ? intermediate.addSymbol(*cloneSymNode) :
right;
// Cannot use auto here, because this is recursive, and auto can't work out the type without seeing the
// whole thing. So, we'll resort to an explicit type via std::function.
const std::function<void(TIntermTyped* left, TIntermTyped* right)>
traverse = [&](TIntermTyped* left, TIntermTyped* right) -> void {
// If we get here, we are assigning to or from a whole array or struct that must be
// flattened, so have to do member-by-member assignment:
if (left->getType().isArray()) {
// array case
const TType dereferencedType(left->getType(), 0);
for (int element=0; element < left->getType().getOuterArraySize(); ++element) {
// Add a new AST symbol node if we have a temp variable holding a complex RHS.
TIntermTyped* subRight = getMember(flattenRight, right, *rightVariables, element,
EOpIndexDirect, dereferencedType);
TIntermTyped* subLeft = getMember(flattenLeft, left, *leftVariables, element,
EOpIndexDirect, dereferencedType);
if (isFinalFlattening(dereferencedType))
assignList = intermediate.growAggregate(assignList, intermediate.addAssign(op, subLeft, subRight, loc), loc);
else
traverse(subLeft, subRight);
}
} else if (left->getType().isStruct()) {
// struct case
const auto& members = *left->getType().getStruct();
for (int member = 0; member < (int)members.size(); ++member) {
TIntermTyped* subRight = getMember(flattenRight, right, *rightVariables, member,
EOpIndexDirectStruct, *members[member].type);
TIntermTyped* subLeft = getMember(flattenLeft, left, *leftVariables, member,
EOpIndexDirectStruct, *members[member].type);
if (isFinalFlattening(*members[member].type))
assignList = intermediate.growAggregate(assignList, intermediate.addAssign(op, subLeft, subRight, loc), loc);
else
traverse(subLeft, subRight);
}
} else {
assert(0); // we should never be called on a non-flattenable thing, because
// that case bails out above to a simple copy.
}
};
// Handle struct assignment
if (left->getType().isStruct()) {
// If we get here, we are assigning to or from a whole struct that must be
// flattened, so have to do member-by-member assignment:
const auto& members = *left->getType().getStruct();
// Use the proper RHS node: a new symbol from a TVariable, copy
// of an TIntermSymbol node, or sometimes the right node directly.
right = rhsTempVar ? intermediate.addSymbol(*rhsTempVar, loc) :
cloneSymNode ? intermediate.addSymbol(*cloneSymNode) :
right;
for (int member = 0; member < (int)members.size(); ++member) {
TIntermTyped* subRight = getMember(flattenRight, getRHS(), *rightVariables, member,
EOpIndexDirectStruct, *members[member].type);
TIntermTyped* subLeft = getMember(flattenLeft, left, *leftVariables, member,
EOpIndexDirectStruct, *members[member].type);
assignList = intermediate.growAggregate(assignList, intermediate.addAssign(op, subLeft, subRight, loc), loc);
}
}
// Handle array assignment
if (left->getType().isArray()) {
// If we get here, we are assigning to or from a whole array that must be
// flattened, so have to do member-by-member assignment:
const TType dereferencedType(left->getType(), 0);
for (int element=0; element < memberCount; ++element) {
// Add a new AST symbol node if we have a temp variable holding a complex RHS.
TIntermTyped* subRight = getMember(flattenRight, getRHS(), *rightVariables, element,
EOpIndexDirect, dereferencedType);
TIntermTyped* subLeft = getMember(flattenLeft, left, *leftVariables, element,
EOpIndexDirect, dereferencedType);
assignList = intermediate.growAggregate(assignList, intermediate.addAssign(op, subLeft, subRight, loc), loc);
}
}
// This makes the whole assignment, recursing through subtypes as needed.
traverse(left, right);
assert(assignList != nullptr);
assignList->setOperator(EOpSequence);
@@ -2701,7 +2813,7 @@ void HlslParseContext::addInputArgumentConversions(const TFunction& function, TI
arg = intermediate.addShapeConversion(EOpFunctionCall, *function[i].type, arg);
setArg(i, arg);
} else {
if (shouldFlatten(arg->getType())) {
if (wasFlattened(arg)) {
// Will make a two-level subtree.
// The deepest will copy member-by-member to build the structure to pass.
// The level above that will be a two-operand EOpComma sequence that follows the copy by the
@@ -2749,7 +2861,7 @@ TIntermTyped* HlslParseContext::addOutputArgumentConversions(const TFunction& fu
return function[argNum].type->getQualifier().isParamOutput() &&
(*function[argNum].type != arguments[argNum]->getAsTyped()->getType() ||
shouldConvertLValue(arguments[argNum]) ||
shouldFlatten(arguments[argNum]->getAsTyped()->getType()));
wasFlattened(arguments[argNum]->getAsTyped()));
};
// Will there be any output conversions?
@@ -4589,23 +4701,23 @@ TIntermNode* HlslParseContext::declareVariable(const TSourceLoc& loc, TString& i
inheritGlobalDefaults(type.getQualifier());
bool flattenVar = false;
const bool flattenVar = shouldFlatten(type);
// Declare the variable
if (type.isArray()) {
// array case
flattenVar = shouldFlatten(type);
declareArray(loc, identifier, type, symbol, !flattenVar);
if (flattenVar)
flatten(loc, *symbol->getAsVariable());
} else {
// non-array case
if (! symbol)
symbol = declareNonArray(loc, identifier, type);
symbol = declareNonArray(loc, identifier, type, !flattenVar);
else if (type != symbol->getType())
error(loc, "cannot change the type of", "redeclaration", symbol->getName().c_str());
}
if (flattenVar)
flatten(loc, *symbol->getAsVariable());
if (! symbol)
return nullptr;
@@ -4658,14 +4770,14 @@ TVariable* HlslParseContext::makeInternalVariable(const char* name, const TType&
//
// Return the successfully declared variable.
//
TVariable* HlslParseContext::declareNonArray(const TSourceLoc& loc, TString& identifier, TType& type)
TVariable* HlslParseContext::declareNonArray(const TSourceLoc& loc, TString& identifier, TType& type, bool track)
{
// make a new variable
TVariable* variable = new TVariable(&identifier, type);
// add variable to symbol table
if (symbolTable.insert(*variable)) {
if (symbolTable.atGlobalLevel())
if (track && symbolTable.atGlobalLevel())
trackLinkageDeferred(*variable);
return variable;
}