Merge pull request #774 from steve-lunarg/tess-ctrlpt-pcf
HLSL: support per control point patch const fn invocation
This commit is contained in:
@@ -1047,6 +1047,8 @@ TType& HlslParseContext::split(TType& type, TString name, const TType* outerStru
|
||||
if (arraySizes)
|
||||
ioVar->getWritableType().newArraySizes(*arraySizes);
|
||||
|
||||
fixBuiltInArrayType(ioVar->getWritableType());
|
||||
|
||||
interstageBuiltInIo[tInterstageIoData(memberType, *outerStructType)] = ioVar;
|
||||
|
||||
// Merge qualifier from the user structure
|
||||
@@ -1381,6 +1383,34 @@ void HlslParseContext::trackLinkage(TSymbol& symbol)
|
||||
}
|
||||
|
||||
|
||||
// Some types require fixed array sizes in SPIR-V, but can be scalars or
|
||||
// arrays of sizes SPIR-V doesn't allow. For example, tessellation factors.
|
||||
// This creates the right size. A conversion is performed when the internal
|
||||
// type is copied to or from the external type.
|
||||
void HlslParseContext::fixBuiltInArrayType(TType& type)
|
||||
{
|
||||
int requiredSize = 0;
|
||||
|
||||
switch (type.getQualifier().builtIn) {
|
||||
case EbvTessLevelOuter: requiredSize = 4; break;
|
||||
case EbvTessLevelInner: requiredSize = 2; break;
|
||||
case EbvClipDistance: // TODO: ...
|
||||
case EbvCullDistance: // TODO: ...
|
||||
default:
|
||||
return;
|
||||
}
|
||||
|
||||
if (type.isArray()) {
|
||||
// Already an array. Fix the size.
|
||||
type.changeOuterArraySize(requiredSize);
|
||||
} else {
|
||||
// it wasn't an array, but needs to be.
|
||||
TArraySizes arraySizes;
|
||||
arraySizes.addInnerSize(requiredSize);
|
||||
type.newArraySizes(arraySizes);
|
||||
}
|
||||
}
|
||||
|
||||
// Variables that correspond to the user-interface in and out of a stage
|
||||
// (not the built-in interface) are assigned locations and
|
||||
// registered as a linkage node (part of the stage's external interface).
|
||||
@@ -1389,15 +1419,24 @@ void HlslParseContext::trackLinkage(TSymbol& symbol)
|
||||
void HlslParseContext::assignLocations(TVariable& variable)
|
||||
{
|
||||
const auto assignLocation = [&](TVariable& variable) {
|
||||
const TQualifier& qualifier = variable.getType().getQualifier();
|
||||
const TType& type = variable.getType();
|
||||
const TQualifier& qualifier = type.getQualifier();
|
||||
if (qualifier.storage == EvqVaryingIn || qualifier.storage == EvqVaryingOut) {
|
||||
if (qualifier.builtIn == EbvNone) {
|
||||
// Strip off the outer array dimension for those having an extra one.
|
||||
int size;
|
||||
if (type.isArray() && qualifier.isArrayedIo(language)) {
|
||||
TType elementType(type, 0);
|
||||
size = intermediate.computeTypeLocationSize(elementType);
|
||||
} else
|
||||
size = intermediate.computeTypeLocationSize(type);
|
||||
|
||||
if (qualifier.storage == EvqVaryingIn) {
|
||||
variable.getWritableType().getQualifier().layoutLocation = nextInLocation;
|
||||
nextInLocation += intermediate.computeTypeLocationSize(variable.getType());
|
||||
nextInLocation += size;
|
||||
} else {
|
||||
variable.getWritableType().getQualifier().layoutLocation = nextOutLocation;
|
||||
nextOutLocation += intermediate.computeTypeLocationSize(variable.getType());
|
||||
nextOutLocation += size;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1559,48 +1598,10 @@ TIntermAggregate* HlslParseContext::handleFunctionDefinition(const TSourceLoc& l
|
||||
return paramNodes;
|
||||
}
|
||||
|
||||
//
|
||||
// Do all special handling for the entry point, including wrapping
|
||||
// the shader's entry point with the official entry point that will call it.
|
||||
//
|
||||
// The following:
|
||||
//
|
||||
// retType shaderEntryPoint(args...) // shader declared entry point
|
||||
// { body }
|
||||
//
|
||||
// Becomes
|
||||
//
|
||||
// out retType ret;
|
||||
// in iargs<that are input>...;
|
||||
// out oargs<that are output> ...;
|
||||
//
|
||||
// void shaderEntryPoint() // synthesized, but official, entry point
|
||||
// {
|
||||
// args<that are input> = iargs...;
|
||||
// ret = @shaderEntryPoint(args...);
|
||||
// oargs = args<that are output>...;
|
||||
// }
|
||||
//
|
||||
// The symbol table will still map the original entry point name to the
|
||||
// the modified function and it's new name:
|
||||
//
|
||||
// symbol table: shaderEntryPoint -> @shaderEntryPoint
|
||||
//
|
||||
// Returns nullptr if no entry-point tree was built, otherwise, returns
|
||||
// a subtree that creates the entry point.
|
||||
//
|
||||
TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunction& userFunction, const TAttributeMap& attributes)
|
||||
|
||||
// Handle all [attrib] attribute for the shader entry point
|
||||
void HlslParseContext::handleEntryPointAttributes(const TSourceLoc& loc, const TAttributeMap& attributes)
|
||||
{
|
||||
// if we aren't in the entry point, fix the IO as such and exit
|
||||
if (userFunction.getName().compare(intermediate.getEntryPointName().c_str()) != 0) {
|
||||
remapNonEntryPointIO(userFunction);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
entryPointFunction = &userFunction; // needed in finish()
|
||||
|
||||
// entry point logic...
|
||||
|
||||
// Handle entry-point function attributes
|
||||
const TIntermAggregate* numThreads = attributes[EatNumThreads];
|
||||
if (numThreads != nullptr) {
|
||||
@@ -1652,8 +1653,12 @@ TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunct
|
||||
error(loc, "unsupported domain type", domainStr.c_str(), "");
|
||||
}
|
||||
|
||||
if (! intermediate.setInputPrimitive(domain)) {
|
||||
error(loc, "cannot change previously set domain", TQualifier::getGeometryString(domain), "");
|
||||
if (language == EShLangTessEvaluation) {
|
||||
if (! intermediate.setInputPrimitive(domain))
|
||||
error(loc, "cannot change previously set domain", TQualifier::getGeometryString(domain), "");
|
||||
} else {
|
||||
if (! intermediate.setOutputPrimitive(domain))
|
||||
error(loc, "cannot change previously set domain", TQualifier::getGeometryString(domain), "");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1731,6 +1736,52 @@ TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunct
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Do all special handling for the entry point, including wrapping
|
||||
// the shader's entry point with the official entry point that will call it.
|
||||
//
|
||||
// The following:
|
||||
//
|
||||
// retType shaderEntryPoint(args...) // shader declared entry point
|
||||
// { body }
|
||||
//
|
||||
// Becomes
|
||||
//
|
||||
// out retType ret;
|
||||
// in iargs<that are input>...;
|
||||
// out oargs<that are output> ...;
|
||||
//
|
||||
// void shaderEntryPoint() // synthesized, but official, entry point
|
||||
// {
|
||||
// args<that are input> = iargs...;
|
||||
// ret = @shaderEntryPoint(args...);
|
||||
// oargs = args<that are output>...;
|
||||
// }
|
||||
//
|
||||
// The symbol table will still map the original entry point name to the
|
||||
// the modified function and it's new name:
|
||||
//
|
||||
// symbol table: shaderEntryPoint -> @shaderEntryPoint
|
||||
//
|
||||
// Returns nullptr if no entry-point tree was built, otherwise, returns
|
||||
// a subtree that creates the entry point.
|
||||
//
|
||||
TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunction& userFunction, const TAttributeMap& attributes)
|
||||
{
|
||||
// if we aren't in the entry point, fix the IO as such and exit
|
||||
if (userFunction.getName().compare(intermediate.getEntryPointName().c_str()) != 0) {
|
||||
remapNonEntryPointIO(userFunction);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
entryPointFunction = &userFunction; // needed in finish()
|
||||
|
||||
// Handle entry point attributes
|
||||
handleEntryPointAttributes(loc, attributes);
|
||||
|
||||
// entry point logic...
|
||||
|
||||
// Move parameters and return value to shader in/out
|
||||
TVariable* entryPointOutput; // gets created in remapEntryPointIO
|
||||
@@ -1799,10 +1850,40 @@ TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunct
|
||||
currentCaller = userFunction.getMangledName();
|
||||
|
||||
// Return value
|
||||
if (entryPointOutput)
|
||||
intermediate.growAggregate(synthBody, handleAssign(loc, EOpAssign,
|
||||
intermediate.addSymbol(*entryPointOutput), callReturn));
|
||||
else
|
||||
if (entryPointOutput) {
|
||||
TIntermTyped* returnAssign;
|
||||
|
||||
// For hull shaders, the wrapped entry point return value is written to
|
||||
// an array element as indexed by invocation ID, which we might have to make up.
|
||||
// This is required to match SPIR-V semantics.
|
||||
if (language == EShLangTessControl) {
|
||||
TIntermSymbol* invocationIdSym = findLinkageSymbol(EbvInvocationId);
|
||||
|
||||
// If there is no user declared invocation ID, we must make one.
|
||||
if (invocationIdSym == nullptr) {
|
||||
TType invocationIdType(EbtUint, EvqIn, 1);
|
||||
TString* invocationIdName = NewPoolTString("InvocationId");
|
||||
invocationIdType.getQualifier().builtIn = EbvInvocationId;
|
||||
|
||||
TVariable* variable = makeInternalVariable(*invocationIdName, invocationIdType);
|
||||
|
||||
globalQualifierFix(loc, variable->getWritableType().getQualifier());
|
||||
trackLinkage(*variable);
|
||||
|
||||
invocationIdSym = intermediate.addSymbol(*variable);
|
||||
}
|
||||
|
||||
TIntermTyped* element = intermediate.addIndex(EOpIndexIndirect, intermediate.addSymbol(*entryPointOutput),
|
||||
invocationIdSym, loc);
|
||||
element->setType(callReturn->getType());
|
||||
|
||||
returnAssign = handleAssign(loc, EOpAssign, element, callReturn);
|
||||
} else {
|
||||
returnAssign = handleAssign(loc, EOpAssign, intermediate.addSymbol(*entryPointOutput), callReturn);
|
||||
}
|
||||
|
||||
intermediate.growAggregate(synthBody, returnAssign);
|
||||
} else
|
||||
intermediate.growAggregate(synthBody, callReturn);
|
||||
|
||||
// Output copies
|
||||
@@ -1862,19 +1943,42 @@ void HlslParseContext::remapEntryPointIO(TFunction& function, TVariable*& return
|
||||
ioVariable->getWritableType().setStruct(newLists->second.output);
|
||||
}
|
||||
}
|
||||
if (storage == EvqVaryingIn)
|
||||
if (storage == EvqVaryingIn) {
|
||||
correctInput(ioVariable->getWritableType().getQualifier());
|
||||
else
|
||||
if (language == EShLangTessEvaluation)
|
||||
if (!ioVariable->getType().isArray())
|
||||
ioVariable->getWritableType().getQualifier().patch = true;
|
||||
} else {
|
||||
correctOutput(ioVariable->getWritableType().getQualifier());
|
||||
}
|
||||
ioVariable->getWritableType().getQualifier().storage = storage;
|
||||
return ioVariable;
|
||||
};
|
||||
|
||||
// return value is actually a shader-scoped output (out)
|
||||
if (function.getType().getBasicType() == EbtVoid)
|
||||
if (function.getType().getBasicType() == EbtVoid) {
|
||||
returnValue = nullptr;
|
||||
else
|
||||
returnValue = makeIoVariable("@entryPointOutput", function.getWritableType(), EvqVaryingOut);
|
||||
} else {
|
||||
if (language == EShLangTessControl) {
|
||||
// tessellation evaluation in HLSL writes a per-ctrl-pt value, but it needs to be an
|
||||
// array in SPIR-V semantics. We'll write to it indexed by invocation ID.
|
||||
|
||||
returnValue = makeIoVariable("@entryPointOutput", function.getWritableType(), EvqVaryingOut);
|
||||
|
||||
TType outputType;
|
||||
outputType.shallowCopy(function.getType());
|
||||
|
||||
// vertices has necessarily already been set when handling entry point attributes.
|
||||
TArraySizes arraySizes;
|
||||
arraySizes.addInnerSize(intermediate.getVertices());
|
||||
outputType.newArraySizes(arraySizes);
|
||||
|
||||
clearUniformInputOutput(function.getWritableType().getQualifier());
|
||||
returnValue = makeIoVariable("@entryPointOutput", outputType, EvqVaryingOut);
|
||||
} else {
|
||||
returnValue = makeIoVariable("@entryPointOutput", function.getWritableType(), EvqVaryingOut);
|
||||
}
|
||||
}
|
||||
|
||||
// parameters are actually shader-scoped inputs and outputs (in or out)
|
||||
for (int i = 0; i < function.getParamCount(); i++) {
|
||||
@@ -2031,7 +2135,11 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op
|
||||
const bool split = isLeft ? isSplitLeft : isSplitRight;
|
||||
const TIntermTyped* outer = isLeft ? outerLeft : outerRight;
|
||||
const TVector<TVariable*>& flatVariables = isLeft ? *leftVariables : *rightVariables;
|
||||
const TOperator op = node->getType().isArray() ? EOpIndexDirect : EOpIndexDirectStruct;
|
||||
|
||||
// Index operator if it's an aggregate, else EOpNull
|
||||
const TOperator op = node->getType().isArray() ? EOpIndexDirect :
|
||||
node->getType().isStruct() ? EOpIndexDirectStruct : EOpNull;
|
||||
|
||||
const TType derefType(node->getType(), member);
|
||||
|
||||
if (split && derefType.isBuiltInInterstageIO(language)) {
|
||||
@@ -2047,10 +2155,14 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op
|
||||
} else if (flattened && isFinalFlattening(derefType)) {
|
||||
subTree = intermediate.addSymbol(*flatVariables[memberIdx++]);
|
||||
} else {
|
||||
const TType splitDerefType(splitNode->getType(), splitMember);
|
||||
if (op == EOpNull) {
|
||||
subTree = splitNode;
|
||||
} else {
|
||||
const TType splitDerefType(splitNode->getType(), splitMember);
|
||||
|
||||
subTree = intermediate.addIndex(op, splitNode, intermediate.addConstantUnion(splitMember, loc), loc);
|
||||
subTree->setType(splitDerefType);
|
||||
subTree = intermediate.addIndex(op, splitNode, intermediate.addConstantUnion(splitMember, loc), loc);
|
||||
subTree->setType(splitDerefType);
|
||||
}
|
||||
}
|
||||
|
||||
return subTree;
|
||||
@@ -2069,11 +2181,15 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op
|
||||
// If we get here, we are assigning to or from a whole array or struct that must be
|
||||
// flattened, so have to do member-by-member assignment:
|
||||
|
||||
if (left->getType().isArray()) {
|
||||
const TType dereferencedType(left->getType(), 0);
|
||||
if (left->getType().isArray() || right->getType().isArray()) {
|
||||
const int elementsL = left->getType().isArray() ? left->getType().getOuterArraySize() : 1;
|
||||
const int elementsR = right->getType().isArray() ? right->getType().getOuterArraySize() : 1;
|
||||
|
||||
// The arrays may not be the same size, e.g, if the size has been forced for EbvTessLevelInner or Outer.
|
||||
const int elementsToCopy = std::min(elementsL, elementsR);
|
||||
|
||||
// array case
|
||||
for (int element=0; element < left->getType().getOuterArraySize(); ++element) {
|
||||
for (int element=0; element < elementsToCopy; ++element) {
|
||||
arrayElement.push_back(element);
|
||||
|
||||
// Add a new AST symbol node if we have a temp variable holding a complex RHS.
|
||||
@@ -2083,10 +2199,7 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op
|
||||
TIntermTyped* subSplitLeft = isSplitLeft ? getMember(true, left, element, splitLeft, element) : subLeft;
|
||||
TIntermTyped* subSplitRight = isSplitRight ? getMember(false, right, element, splitRight, element) : subRight;
|
||||
|
||||
if (isFinalFlattening(dereferencedType))
|
||||
assignList = intermediate.growAggregate(assignList, intermediate.addAssign(op, subLeft, subRight, loc), loc);
|
||||
else
|
||||
traverse(subLeft, subRight, subSplitLeft, subSplitRight);
|
||||
traverse(subLeft, subRight, subSplitLeft, subSplitRight);
|
||||
|
||||
arrayElement.pop_back();
|
||||
}
|
||||
@@ -2120,8 +2233,8 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op
|
||||
// subtree here IFF it does not itself contain any interstage built-in IO variables, so we only have to
|
||||
// recurse into it if there's something for splitting to do. That can save a lot of AST verbosity for
|
||||
// a bunch of memberwise copies.
|
||||
if (isFinalFlattening(typeL) || (!isFlattenLeft && !isFlattenRight &&
|
||||
!typeL.containsBuiltInInterstageIO(language) && !typeR.containsBuiltInInterstageIO(language))) {
|
||||
if ((!isFlattenLeft && !isFlattenRight &&
|
||||
!typeL.containsBuiltInInterstageIO(language) && !typeR.containsBuiltInInterstageIO(language))) {
|
||||
assignList = intermediate.growAggregate(assignList, intermediate.addAssign(op, subSplitLeft, subSplitRight, loc), loc);
|
||||
} else {
|
||||
traverse(subLeft, subRight, subSplitLeft, subSplitRight);
|
||||
@@ -2131,8 +2244,8 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op
|
||||
memberR += (typeR.isBuiltInInterstageIO(language) ? 0 : 1);
|
||||
}
|
||||
} else {
|
||||
assert(0); // we should never be called on a non-flattenable thing, because
|
||||
// that case bails out above to a simple copy.
|
||||
// Member copy
|
||||
assignList = intermediate.growAggregate(assignList, intermediate.addAssign(op, left, right, loc), loc);
|
||||
}
|
||||
|
||||
};
|
||||
@@ -4178,6 +4291,10 @@ void HlslParseContext::handleSemantic(TSourceLoc loc, TQualifier& qualifier, TBu
|
||||
case EbvStencilRef:
|
||||
error(loc, "unimplemented; need ARB_shader_stencil_export", "SV_STENCILREF", "");
|
||||
break;
|
||||
case EbvTessLevelInner:
|
||||
case EbvTessLevelOuter:
|
||||
qualifier.patch = true;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -7225,6 +7342,8 @@ bool HlslParseContext::isInputBuiltIn(const TQualifier& qualifier) const
|
||||
case EbvTessLevelInner:
|
||||
case EbvTessLevelOuter:
|
||||
return language == EShLangTessEvaluation;
|
||||
case EbvTessCoord:
|
||||
return language == EShLangTessEvaluation;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -7362,6 +7481,17 @@ void HlslParseContext::clearUniformInputOutput(TQualifier& qualifier)
|
||||
correctUniform(qualifier);
|
||||
}
|
||||
|
||||
|
||||
// Return a symbol for the linkage variable of the given TBuiltInVariable type
|
||||
TIntermSymbol* HlslParseContext::findLinkageSymbol(TBuiltInVariable biType) const
|
||||
{
|
||||
const auto it = builtInLinkageSymbols.find(biType);
|
||||
if (it == builtInLinkageSymbols.end()) // if it wasn't declared by the user, return nullptr
|
||||
return nullptr;
|
||||
|
||||
return intermediate.addSymbol(*it->second->getAsVariable());
|
||||
}
|
||||
|
||||
// Add patch constant function invocation
|
||||
void HlslParseContext::addPatchConstantInvocation()
|
||||
{
|
||||
@@ -7433,13 +7563,9 @@ void HlslParseContext::addPatchConstantInvocation()
|
||||
}
|
||||
};
|
||||
|
||||
// Return a symbol for the linkage variable of the given TBuiltInVariable type
|
||||
const auto findLinkageSymbol = [this](TBuiltInVariable biType) -> TIntermSymbol* {
|
||||
const auto it = builtInLinkageSymbols.find(biType);
|
||||
if (it == builtInLinkageSymbols.end()) // if it wasn't declared by the user, return nullptr
|
||||
return nullptr;
|
||||
|
||||
return intermediate.addSymbol(*it->second->getAsVariable());
|
||||
const auto isPerCtrlPt = [this](const TType& type) {
|
||||
// TODO: this is not sufficient to reject all such cases in malformed shaders.
|
||||
return type.isArray() && !type.isRuntimeSizedArray();
|
||||
};
|
||||
|
||||
// We will perform these steps. Each is in a scoped block for separation: they could
|
||||
@@ -7451,21 +7577,25 @@ void HlslParseContext::addPatchConstantInvocation()
|
||||
// 2. Synthesizes a call to the patchconstfunction using builtin variables from either main,
|
||||
// or the ones we created. Matching is based on builtin type. We may use synthesized
|
||||
// variables from (1) above.
|
||||
//
|
||||
// 2B: Synthesize per control point invocations of wrapped entry point if the PCF requires them.
|
||||
//
|
||||
// 3. Create a return sequence: copy the return value (if any) from the PCF to a
|
||||
// (non-sanitized) output variable. In case this may involve multiple copies, such as for
|
||||
// an arrayed variable, a temporary copy of the PCF output is created to avoid multiple
|
||||
// indirections into a complex R-value coming from the call to the PCF.
|
||||
//
|
||||
// 4. Add a barrier to the end of the entry point body
|
||||
//
|
||||
// 5. Call the PCF inside an if test for (invocation id == 0).
|
||||
//
|
||||
// 4. Create a barrier.
|
||||
//
|
||||
// 5/5B. Call the PCF inside an if test for (invocation id == 0).
|
||||
|
||||
TFunction& patchConstantFunction = const_cast<TFunction&>(*candidateList[0]);
|
||||
const int pcfParamCount = patchConstantFunction.getParamCount();
|
||||
TIntermSymbol* invocationIdSym = findLinkageSymbol(EbvInvocationId);
|
||||
TIntermSequence& epBodySeq = entryPointFunctionBody->getAsAggregate()->getSequence();
|
||||
|
||||
int perCtrlPtParam = -1; // -1 means there isn't one.
|
||||
|
||||
// ================ Step 1A: Union Interfaces ================
|
||||
// Our patch constant function.
|
||||
{
|
||||
@@ -7478,16 +7608,6 @@ void HlslParseContext::addPatchConstantInvocation()
|
||||
findBuiltIns(patchConstantFunction, pcfBuiltIns);
|
||||
findBuiltIns(*entryPointFunction, epfBuiltIns);
|
||||
|
||||
// Patchconstantfunction can contain only builtin qualified variables. (Technically, only HS inputs,
|
||||
// but this test is less assertive than that).
|
||||
|
||||
for (auto bi = pcfBuiltIns.begin(); bi != pcfBuiltIns.end(); ++bi) {
|
||||
if (bi->builtIn == EbvNone) {
|
||||
error(loc, "patch constant function invalid parameter", "", "");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Find the set of builtins in the PCF that are not present in the entry point.
|
||||
std::set<tInterstageIoData> notInEntryPoint;
|
||||
|
||||
@@ -7499,15 +7619,27 @@ void HlslParseContext::addPatchConstantInvocation()
|
||||
|
||||
// Now we'll add those to the entry and to the linkage.
|
||||
for (int p=0; p<pcfParamCount; ++p) {
|
||||
TType* paramType = patchConstantFunction[p].type->clone();
|
||||
const TBuiltInVariable biType = patchConstantFunction[p].declaredBuiltIn;
|
||||
const TStorageQualifier storage = patchConstantFunction[p].type->getQualifier().storage;
|
||||
|
||||
// Use the original declaration type for the linkage
|
||||
paramType->getQualifier().builtIn = biType;
|
||||
// Track whether there is any per control point input
|
||||
if (isPerCtrlPt(*patchConstantFunction[p].type)) {
|
||||
if (perCtrlPtParam >= 0) {
|
||||
// Presently we only support one per ctrl pt input. TODO: does HLSL even allow multiple?
|
||||
error(loc, "unimplemented: multiple per control point inputs to patch constant function", "", "");
|
||||
return;
|
||||
}
|
||||
perCtrlPtParam = p;
|
||||
}
|
||||
|
||||
if (notInEntryPoint.count(tInterstageIoData(biType, storage)) == 1)
|
||||
addToLinkage(*paramType, patchConstantFunction[p].name, nullptr);
|
||||
if (biType != EbvNone) {
|
||||
TType* paramType = patchConstantFunction[p].type->clone();
|
||||
// Use the original declaration type for the linkage
|
||||
paramType->getQualifier().builtIn = biType;
|
||||
|
||||
if (notInEntryPoint.count(tInterstageIoData(biType, storage)) == 1)
|
||||
addToLinkage(*paramType, patchConstantFunction[p].name, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
// If we didn't find it because the shader made one, add our own.
|
||||
@@ -7522,36 +7654,50 @@ void HlslParseContext::addPatchConstantInvocation()
|
||||
}
|
||||
|
||||
TIntermTyped* pcfArguments = nullptr;
|
||||
TVariable* perCtrlPtVar = nullptr;
|
||||
|
||||
// ================ Step 1B: Argument synthesis ================
|
||||
// Create pcfArguments for synthesis of patchconstantfunction invocation
|
||||
// TODO: handle struct or array inputs
|
||||
{
|
||||
for (int p=0; p<pcfParamCount; ++p) {
|
||||
if (patchConstantFunction[p].type->isArray() ||
|
||||
patchConstantFunction[p].type->isStruct()) {
|
||||
if ((patchConstantFunction[p].type->isArray() && !isPerCtrlPt(*patchConstantFunction[p].type)) ||
|
||||
(!patchConstantFunction[p].type->isArray() && patchConstantFunction[p].type->isStruct())) {
|
||||
error(loc, "unimplemented array or variable in patch constant function signature", "", "");
|
||||
return;
|
||||
}
|
||||
|
||||
// find which builtin it is
|
||||
const TBuiltInVariable biType = patchConstantFunction[p].declaredBuiltIn;
|
||||
TIntermSymbol* inputArg = nullptr;
|
||||
|
||||
TIntermSymbol* builtIn = findLinkageSymbol(biType);
|
||||
if (p == perCtrlPtParam) {
|
||||
if (perCtrlPtVar == nullptr) {
|
||||
perCtrlPtVar = makeInternalVariable(*patchConstantFunction[perCtrlPtParam].name,
|
||||
*patchConstantFunction[perCtrlPtParam].type);
|
||||
|
||||
perCtrlPtVar->getWritableType().getQualifier().makeTemporary();
|
||||
}
|
||||
inputArg = intermediate.addSymbol(*perCtrlPtVar, loc);
|
||||
} else {
|
||||
// find which builtin it is
|
||||
const TBuiltInVariable biType = patchConstantFunction[p].declaredBuiltIn;
|
||||
|
||||
inputArg = findLinkageSymbol(biType);
|
||||
|
||||
if (builtIn == nullptr) {
|
||||
error(loc, "unable to find patch constant function builtin variable", "", "");
|
||||
return;
|
||||
if (inputArg == nullptr) {
|
||||
error(loc, "unable to find patch constant function builtin variable", "", "");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (pcfParamCount == 1)
|
||||
pcfArguments = builtIn;
|
||||
pcfArguments = inputArg;
|
||||
else
|
||||
pcfArguments = intermediate.growAggregate(pcfArguments, builtIn);
|
||||
pcfArguments = intermediate.growAggregate(pcfArguments, inputArg);
|
||||
}
|
||||
}
|
||||
|
||||
// ================ Step 2: Synthesize call to PCF ================
|
||||
TIntermAggregate* pcfCallSequence = nullptr;
|
||||
TIntermTyped* pcfCall = nullptr;
|
||||
|
||||
{
|
||||
@@ -7563,7 +7709,8 @@ void HlslParseContext::addPatchConstantInvocation()
|
||||
pcfCall = intermediate.setAggregateOperator(pcfArguments, EOpFunctionCall, patchConstantFunction.getType(), loc);
|
||||
pcfCall->getAsAggregate()->setUserDefined();
|
||||
pcfCall->getAsAggregate()->setName(patchConstantFunction.getMangledName());
|
||||
intermediate.addToCallGraph(infoSink, entryPointFunction->getMangledName(), patchConstantFunction.getMangledName());
|
||||
intermediate.addToCallGraph(infoSink, intermediate.getEntryPointMangledName().c_str(),
|
||||
patchConstantFunction.getMangledName());
|
||||
|
||||
if (pcfCall->getAsAggregate()) {
|
||||
TQualifierList& qualifierList = pcfCall->getAsAggregate()->getQualifierList();
|
||||
@@ -7575,6 +7722,71 @@ void HlslParseContext::addPatchConstantInvocation()
|
||||
}
|
||||
}
|
||||
|
||||
// ================ Step 2B: Per Control Point synthesis ================
|
||||
// If there is per control point data, we must either emulate that with multiple
|
||||
// invocations of the entry point to build up an array, or (TODO:) use a yet
|
||||
// unavailable extension to look across the SIMD lanes. This is the former
|
||||
// as a placeholder for the latter.
|
||||
if (perCtrlPtParam >= 0) {
|
||||
// We must introduce a local temp variable of the type wanted by the PCF input.
|
||||
const int arraySize = patchConstantFunction[perCtrlPtParam].type->getOuterArraySize();
|
||||
|
||||
if (entryPointFunction->getType().getBasicType() == EbtVoid) {
|
||||
error(loc, "entry point must return a value for use with patch constant function", "", "");
|
||||
return;
|
||||
}
|
||||
|
||||
// Create calls to wrapped main to fill in the array. We will substitute fixed values
|
||||
// of invocation ID when calling the wrapped main.
|
||||
|
||||
// This is the type of the each member of the per ctrl point array.
|
||||
const TType derefType(perCtrlPtVar->getType(), 0);
|
||||
|
||||
for (int cpt = 0; cpt < arraySize; ++cpt) {
|
||||
// TODO: improve. substr(1) here is to avoid the '@' that was grafted on but isn't in the symtab
|
||||
// for this function.
|
||||
const TString origName = entryPointFunction->getName().substr(1);
|
||||
TFunction callee(&origName, TType(EbtVoid));
|
||||
TIntermTyped* callingArgs = nullptr;
|
||||
|
||||
for (int i = 0; i < entryPointFunction->getParamCount(); i++) {
|
||||
TParameter& param = (*entryPointFunction)[i];
|
||||
TType& paramType = *param.type;
|
||||
|
||||
if (paramType.getQualifier().isParamOutput()) {
|
||||
error(loc, "unimplemented: entry point outputs in patch constant function invocation", "", "");
|
||||
return;
|
||||
}
|
||||
|
||||
if (paramType.getQualifier().isParamInput()) {
|
||||
TIntermTyped* arg = nullptr;
|
||||
if ((*entryPointFunction)[i].declaredBuiltIn == EbvInvocationId) {
|
||||
// substitute invocation ID with the array element ID
|
||||
arg = intermediate.addConstantUnion(cpt, loc);
|
||||
} else {
|
||||
TVariable* argVar = makeInternalVariable(*param.name, *param.type);
|
||||
argVar->getWritableType().getQualifier().makeTemporary();
|
||||
arg = intermediate.addSymbol(*argVar);
|
||||
}
|
||||
|
||||
handleFunctionArgument(&callee, callingArgs, arg);
|
||||
}
|
||||
}
|
||||
|
||||
// Call and assign to per ctrl point variable
|
||||
currentCaller = intermediate.getEntryPointMangledName().c_str();
|
||||
TIntermTyped* callReturn = handleFunctionCall(loc, &callee, callingArgs);
|
||||
TIntermTyped* index = intermediate.addConstantUnion(cpt, loc);
|
||||
TIntermSymbol* perCtrlPtSym = intermediate.addSymbol(*perCtrlPtVar, loc);
|
||||
TIntermTyped* element = intermediate.addIndex(EOpIndexDirect, perCtrlPtSym, index, loc);
|
||||
element->setType(derefType);
|
||||
element->setLoc(loc);
|
||||
|
||||
pcfCallSequence = intermediate.growAggregate(pcfCallSequence,
|
||||
handleAssign(loc, EOpAssign, element, callReturn));
|
||||
}
|
||||
}
|
||||
|
||||
// ================ Step 3: Create return Sequence ================
|
||||
// Return sequence: copy PCF result to a temporary, then to shader output variable.
|
||||
if (pcfCall->getBasicType() != EbtVoid) {
|
||||
@@ -7591,30 +7803,31 @@ void HlslParseContext::addPatchConstantInvocation()
|
||||
if (patchConstantFunction.getDeclaredBuiltInType() != EbvNone)
|
||||
outType.getQualifier().builtIn = patchConstantFunction.getDeclaredBuiltInType();
|
||||
|
||||
outType.getQualifier().patch = true; // make it a per-patch variable
|
||||
|
||||
TVariable* pcfOutput = makeInternalVariable("@patchConstantOutput", outType);
|
||||
pcfOutput->getWritableType().getQualifier().storage = EvqVaryingOut;
|
||||
|
||||
if (pcfOutput->getType().containsBuiltInInterstageIO(language))
|
||||
split(*pcfOutput);
|
||||
|
||||
assignLocations(*pcfOutput);
|
||||
|
||||
TIntermSymbol* pcfOutputSym = intermediate.addSymbol(*pcfOutput, loc);
|
||||
|
||||
// The call to the PCF is a complex R-value: we want to store it in a temp to avoid
|
||||
// repeated calls to the PCF:
|
||||
TVariable* pcfCallResult = makeInternalVariable("@patchConstantResult", *retType);
|
||||
pcfCallResult->getWritableType().getQualifier().makeTemporary();
|
||||
TIntermSymbol* pcfResultVar = intermediate.addSymbol(*pcfCallResult, loc);
|
||||
// sanitizeType(&pcfCall->getWritableType());
|
||||
TIntermNode* pcfResultAssign = intermediate.addAssign(EOpAssign, pcfResultVar, pcfCall, loc);
|
||||
|
||||
TIntermSymbol* pcfResultVar = intermediate.addSymbol(*pcfCallResult, loc);
|
||||
TIntermNode* pcfResultAssign = handleAssign(loc, EOpAssign, pcfResultVar, pcfCall);
|
||||
TIntermNode* pcfResultToOut = handleAssign(loc, EOpAssign, pcfOutputSym, intermediate.addSymbol(*pcfCallResult, loc));
|
||||
|
||||
TIntermTyped* pcfAggregate = nullptr;
|
||||
pcfAggregate = intermediate.growAggregate(pcfAggregate, pcfResultAssign);
|
||||
pcfAggregate = intermediate.growAggregate(pcfAggregate, pcfResultToOut);
|
||||
pcfAggregate = intermediate.setAggregateOperator(pcfAggregate, EOpSequence, *retType, loc);
|
||||
|
||||
pcfCall = pcfAggregate;
|
||||
pcfCallSequence = intermediate.growAggregate(pcfCallSequence, pcfResultAssign);
|
||||
pcfCallSequence = intermediate.growAggregate(pcfCallSequence, pcfResultToOut);
|
||||
} else {
|
||||
pcfCallSequence = intermediate.growAggregate(pcfCallSequence, pcfCall);
|
||||
}
|
||||
|
||||
// ================ Step 4: Barrier ================
|
||||
@@ -7623,12 +7836,14 @@ void HlslParseContext::addPatchConstantInvocation()
|
||||
barrier->setType(TType(EbtVoid));
|
||||
epBodySeq.insert(epBodySeq.end(), barrier);
|
||||
|
||||
// ================ Step 5: Test on invocation ID ================
|
||||
// ================ Step 5: Test on invocation ID ================
|
||||
TIntermTyped* zero = intermediate.addConstantUnion(0, loc, true);
|
||||
TIntermTyped* cmp = intermediate.addBinaryNode(EOpEqual, invocationIdSym, zero, loc, TType(EbtBool));
|
||||
|
||||
// Create if statement
|
||||
TIntermTyped* invocationIdTest = new TIntermSelection(cmp, pcfCall, nullptr);
|
||||
|
||||
// ================ Step 5B: Create if statement on Invocation ID == 0 ================
|
||||
intermediate.setAggregateOperator(pcfCallSequence, EOpSequence, TType(EbtVoid), loc);
|
||||
TIntermTyped* invocationIdTest = new TIntermSelection(cmp, pcfCallSequence, nullptr);
|
||||
invocationIdTest->setLoc(loc);
|
||||
|
||||
// add our test sequence before the return.
|
||||
|
||||
Reference in New Issue
Block a user