HLSL: support per control point patch const fn invocation

This PR emulates per control point inputs to patch constant functions.
Without either an extension to look across SIMD lanes or a dedicated
stage, the emulation must use separate invocations of the wrapped
entry point to obtain the per control point values.  This is provided
since shaders are wanting this functionality now, but such an extension
is not yet available.

Entry point arguments qualified as an invocation ID are replaced by the
current control point number when calling the wrapped entry point.  There
is no particular optimization for the case of the entry point not having
such an input but the PCF still accepting ctrl pt frequency data.  It'll
work, but anyway makes no so much sense.

The wrapped entry point must return the per control point data by value.
At this time it is not supported as an output parameter.
This commit is contained in:
steve-lunarg
2017-03-14 17:37:10 -06:00
parent e434ad923e
commit 9cee73e028
7 changed files with 815 additions and 68 deletions

View File

@@ -7431,6 +7431,11 @@ void HlslParseContext::addPatchConstantInvocation()
return intermediate.addSymbol(*it->second->getAsVariable());
};
const auto isPerCtrlPt = [this](const TType& type) {
// TODO: this is not sufficient to reject all such cases in malformed shaders.
return type.isArray() && !type.isRuntimeSizedArray();
};
// We will perform these steps. Each is in a scoped block for separation: they could
// become separate functions to make addPatchConstantInvocation shorter.
@@ -7441,21 +7446,25 @@ void HlslParseContext::addPatchConstantInvocation()
// 2. Synthesizes a call to the patchconstfunction using builtin variables from either main,
// or the ones we created. Matching is based on builtin type. We may use synthesized
// variables from (1) above.
//
// 2B: Synthesize per control point invocations of wrapped entry point if the PCF requires them.
//
// 3. Create a return sequence: copy the return value (if any) from the PCF to a
// (non-sanitized) output variable. In case this may involve multiple copies, such as for
// an arrayed variable, a temporary copy of the PCF output is created to avoid multiple
// indirections into a complex R-value coming from the call to the PCF.
//
// 4. Add a barrier to the end of the entry point body
//
// 5. Call the PCF inside an if test for (invocation id == 0).
//
// 4. Create a barrier.
//
// 5/5B. Call the PCF inside an if test for (invocation id == 0).
TFunction& patchConstantFunction = const_cast<TFunction&>(*candidateList[0]);
const int pcfParamCount = patchConstantFunction.getParamCount();
TIntermSymbol* invocationIdSym = findLinkageSymbol(EbvInvocationId);
TIntermSequence& epBodySeq = entryPointFunctionBody->getAsAggregate()->getSequence();
int perCtrlPtParam = -1; // -1 means there isn't one.
// ================ Step 1A: Union Interfaces ================
// Our patch constant function.
{
@@ -7468,16 +7477,6 @@ void HlslParseContext::addPatchConstantInvocation()
findBuiltIns(patchConstantFunction, pcfBuiltIns);
findBuiltIns(*entryPointFunction, epfBuiltIns);
// Patchconstantfunction can contain only builtin qualified variables. (Technically, only HS inputs,
// but this test is less assertive than that).
for (auto bi = pcfBuiltIns.begin(); bi != pcfBuiltIns.end(); ++bi) {
if (bi->builtIn == EbvNone) {
error(loc, "patch constant function invalid parameter", "", "");
return;
}
}
// Find the set of builtins in the PCF that are not present in the entry point.
std::set<tInterstageIoData> notInEntryPoint;
@@ -7489,15 +7488,27 @@ void HlslParseContext::addPatchConstantInvocation()
// Now we'll add those to the entry and to the linkage.
for (int p=0; p<pcfParamCount; ++p) {
TType* paramType = patchConstantFunction[p].type->clone();
const TBuiltInVariable biType = patchConstantFunction[p].declaredBuiltIn;
const TStorageQualifier storage = patchConstantFunction[p].type->getQualifier().storage;
// Use the original declaration type for the linkage
paramType->getQualifier().builtIn = biType;
// Track whether there is any per control point input
if (isPerCtrlPt(*patchConstantFunction[p].type)) {
if (perCtrlPtParam >= 0) {
// Presently we only support one per ctrl pt input. TODO: does HLSL even allow multiple?
error(loc, "unimplemented: multiple per control point inputs to patch constant function", "", "");
return;
}
perCtrlPtParam = p;
}
if (notInEntryPoint.count(tInterstageIoData(biType, storage)) == 1)
addToLinkage(*paramType, patchConstantFunction[p].name, nullptr);
if (biType != EbvNone) {
TType* paramType = patchConstantFunction[p].type->clone();
// Use the original declaration type for the linkage
paramType->getQualifier().builtIn = biType;
if (notInEntryPoint.count(tInterstageIoData(biType, storage)) == 1)
addToLinkage(*paramType, patchConstantFunction[p].name, nullptr);
}
}
// If we didn't find it because the shader made one, add our own.
@@ -7512,36 +7523,50 @@ void HlslParseContext::addPatchConstantInvocation()
}
TIntermTyped* pcfArguments = nullptr;
TVariable* perCtrlPtVar = nullptr;
// ================ Step 1B: Argument synthesis ================
// Create pcfArguments for synthesis of patchconstantfunction invocation
// TODO: handle struct or array inputs
{
for (int p=0; p<pcfParamCount; ++p) {
if (patchConstantFunction[p].type->isArray() ||
patchConstantFunction[p].type->isStruct()) {
if ((patchConstantFunction[p].type->isArray() && !isPerCtrlPt(*patchConstantFunction[p].type)) ||
(!patchConstantFunction[p].type->isArray() && patchConstantFunction[p].type->isStruct())) {
error(loc, "unimplemented array or variable in patch constant function signature", "", "");
return;
}
// find which builtin it is
const TBuiltInVariable biType = patchConstantFunction[p].declaredBuiltIn;
TIntermSymbol* inputArg = nullptr;
TIntermSymbol* builtIn = findLinkageSymbol(biType);
if (p == perCtrlPtParam) {
if (perCtrlPtVar == nullptr) {
perCtrlPtVar = makeInternalVariable(*patchConstantFunction[perCtrlPtParam].name,
*patchConstantFunction[perCtrlPtParam].type);
perCtrlPtVar->getWritableType().getQualifier().makeTemporary();
}
inputArg = intermediate.addSymbol(*perCtrlPtVar, loc);
} else {
// find which builtin it is
const TBuiltInVariable biType = patchConstantFunction[p].declaredBuiltIn;
inputArg = findLinkageSymbol(biType);
if (builtIn == nullptr) {
error(loc, "unable to find patch constant function builtin variable", "", "");
return;
if (inputArg == nullptr) {
error(loc, "unable to find patch constant function builtin variable", "", "");
return;
}
}
if (pcfParamCount == 1)
pcfArguments = builtIn;
pcfArguments = inputArg;
else
pcfArguments = intermediate.growAggregate(pcfArguments, builtIn);
pcfArguments = intermediate.growAggregate(pcfArguments, inputArg);
}
}
// ================ Step 2: Synthesize call to PCF ================
TIntermAggregate* pcfCallSequence = nullptr;
TIntermTyped* pcfCall = nullptr;
{
@@ -7553,7 +7578,8 @@ void HlslParseContext::addPatchConstantInvocation()
pcfCall = intermediate.setAggregateOperator(pcfArguments, EOpFunctionCall, patchConstantFunction.getType(), loc);
pcfCall->getAsAggregate()->setUserDefined();
pcfCall->getAsAggregate()->setName(patchConstantFunction.getMangledName());
intermediate.addToCallGraph(infoSink, entryPointFunction->getMangledName(), patchConstantFunction.getMangledName());
intermediate.addToCallGraph(infoSink, intermediate.getEntryPointMangledName().c_str(),
patchConstantFunction.getMangledName());
if (pcfCall->getAsAggregate()) {
TQualifierList& qualifierList = pcfCall->getAsAggregate()->getQualifierList();
@@ -7565,6 +7591,71 @@ void HlslParseContext::addPatchConstantInvocation()
}
}
// ================ Step 2B: Per Control Point synthesis ================
// If there is per control point data, we must either emulate that with multiple
// invocations of the entry point to build up an array, or (TODO:) use a yet
// unavailable extension to look across the SIMD lanes. This is the former
// as a placeholder for the latter.
if (perCtrlPtParam >= 0) {
// We must introduce a local temp variable of the type wanted by the PCF input.
const int arraySize = patchConstantFunction[perCtrlPtParam].type->getOuterArraySize();
if (entryPointFunction->getType().getBasicType() == EbtVoid) {
error(loc, "entry point must return a value for use with patch constant function", "", "");
return;
}
// Create calls to wrapped main to fill in the array. We will substitute fixed values
// of invocation ID when calling the wrapped main.
// This is the type of the each member of the per ctrl point array.
const TType derefType(perCtrlPtVar->getType(), 0);
for (int cpt = 0; cpt < arraySize; ++cpt) {
// TODO: improve. substr(1) here is to avoid the '@' that was grafted on but isn't in the symtab
// for this function.
const TString origName = entryPointFunction->getName().substr(1);
TFunction callee(&origName, TType(EbtVoid));
TIntermTyped* callingArgs = nullptr;
for (int i = 0; i < entryPointFunction->getParamCount(); i++) {
TParameter& param = (*entryPointFunction)[i];
TType& paramType = *param.type;
if (paramType.getQualifier().isParamOutput()) {
error(loc, "unimplemented: entry point outputs in patch constant function invocation", "", "");
return;
}
if (paramType.getQualifier().isParamInput()) {
TIntermTyped* arg = nullptr;
if ((*entryPointFunction)[i].declaredBuiltIn == EbvInvocationId) {
// substitute invocation ID with the array element ID
arg = intermediate.addConstantUnion(cpt, loc);
} else {
TVariable* argVar = makeInternalVariable(*param.name, *param.type);
argVar->getWritableType().getQualifier().makeTemporary();
arg = intermediate.addSymbol(*argVar);
}
handleFunctionArgument(&callee, callingArgs, arg);
}
}
// Call and assign to per ctrl point variable
currentCaller = intermediate.getEntryPointMangledName().c_str();
TIntermTyped* callReturn = handleFunctionCall(loc, &callee, callingArgs);
TIntermTyped* index = intermediate.addConstantUnion(cpt, loc);
TIntermSymbol* perCtrlPtSym = intermediate.addSymbol(*perCtrlPtVar, loc);
TIntermTyped* element = intermediate.addIndex(EOpIndexDirect, perCtrlPtSym, index, loc);
element->setType(derefType);
element->setLoc(loc);
pcfCallSequence = intermediate.growAggregate(pcfCallSequence,
handleAssign(loc, EOpAssign, element, callReturn));
}
}
// ================ Step 3: Create return Sequence ================
// Return sequence: copy PCF result to a temporary, then to shader output variable.
if (pcfCall->getBasicType() != EbtVoid) {
@@ -7581,30 +7672,31 @@ void HlslParseContext::addPatchConstantInvocation()
if (patchConstantFunction.getDeclaredBuiltInType() != EbvNone)
outType.getQualifier().builtIn = patchConstantFunction.getDeclaredBuiltInType();
outType.getQualifier().patch = true; // make it a per-patch variable
TVariable* pcfOutput = makeInternalVariable("@patchConstantOutput", outType);
pcfOutput->getWritableType().getQualifier().storage = EvqVaryingOut;
if (pcfOutput->getType().containsBuiltInInterstageIO(language))
split(*pcfOutput);
assignLocations(*pcfOutput);
TIntermSymbol* pcfOutputSym = intermediate.addSymbol(*pcfOutput, loc);
// The call to the PCF is a complex R-value: we want to store it in a temp to avoid
// repeated calls to the PCF:
TVariable* pcfCallResult = makeInternalVariable("@patchConstantResult", *retType);
pcfCallResult->getWritableType().getQualifier().makeTemporary();
TIntermSymbol* pcfResultVar = intermediate.addSymbol(*pcfCallResult, loc);
// sanitizeType(&pcfCall->getWritableType());
TIntermNode* pcfResultAssign = intermediate.addAssign(EOpAssign, pcfResultVar, pcfCall, loc);
TIntermSymbol* pcfResultVar = intermediate.addSymbol(*pcfCallResult, loc);
TIntermNode* pcfResultAssign = handleAssign(loc, EOpAssign, pcfResultVar, pcfCall);
TIntermNode* pcfResultToOut = handleAssign(loc, EOpAssign, pcfOutputSym, intermediate.addSymbol(*pcfCallResult, loc));
TIntermTyped* pcfAggregate = nullptr;
pcfAggregate = intermediate.growAggregate(pcfAggregate, pcfResultAssign);
pcfAggregate = intermediate.growAggregate(pcfAggregate, pcfResultToOut);
pcfAggregate = intermediate.setAggregateOperator(pcfAggregate, EOpSequence, *retType, loc);
pcfCall = pcfAggregate;
pcfCallSequence = intermediate.growAggregate(pcfCallSequence, pcfResultAssign);
pcfCallSequence = intermediate.growAggregate(pcfCallSequence, pcfResultToOut);
} else {
pcfCallSequence = intermediate.growAggregate(pcfCallSequence, pcfCall);
}
// ================ Step 4: Barrier ================
@@ -7613,12 +7705,14 @@ void HlslParseContext::addPatchConstantInvocation()
barrier->setType(TType(EbtVoid));
epBodySeq.insert(epBodySeq.end(), barrier);
// ================ Step 5: Test on invocation ID ================
// ================ Step 5: Test on invocation ID ================
TIntermTyped* zero = intermediate.addConstantUnion(0, loc, true);
TIntermTyped* cmp = intermediate.addBinaryNode(EOpEqual, invocationIdSym, zero, loc, TType(EbtBool));
// Create if statement
TIntermTyped* invocationIdTest = new TIntermSelection(cmp, pcfCall, nullptr);
// ================ Step 5B: Create if statement on Invocation ID == 0 ================
intermediate.setAggregateOperator(pcfCallSequence, EOpSequence, TType(EbtVoid), loc);
TIntermTyped* invocationIdTest = new TIntermSelection(cmp, pcfCallSequence, nullptr);
invocationIdTest->setLoc(loc);
// add our test sequence before the return.