HLSL: support per control point patch const fn invocation
This PR emulates per control point inputs to patch constant functions. Without either an extension to look across SIMD lanes or a dedicated stage, the emulation must use separate invocations of the wrapped entry point to obtain the per control point values. This is provided since shaders are wanting this functionality now, but such an extension is not yet available. Entry point arguments qualified as an invocation ID are replaced by the current control point number when calling the wrapped entry point. There is no particular optimization for the case of the entry point not having such an input but the PCF still accepting ctrl pt frequency data. It'll work, but anyway makes no so much sense. The wrapped entry point must return the per control point data by value. At this time it is not supported as an output parameter.
This commit is contained in:
@@ -7431,6 +7431,11 @@ void HlslParseContext::addPatchConstantInvocation()
|
||||
|
||||
return intermediate.addSymbol(*it->second->getAsVariable());
|
||||
};
|
||||
|
||||
const auto isPerCtrlPt = [this](const TType& type) {
|
||||
// TODO: this is not sufficient to reject all such cases in malformed shaders.
|
||||
return type.isArray() && !type.isRuntimeSizedArray();
|
||||
};
|
||||
|
||||
// We will perform these steps. Each is in a scoped block for separation: they could
|
||||
// become separate functions to make addPatchConstantInvocation shorter.
|
||||
@@ -7441,21 +7446,25 @@ void HlslParseContext::addPatchConstantInvocation()
|
||||
// 2. Synthesizes a call to the patchconstfunction using builtin variables from either main,
|
||||
// or the ones we created. Matching is based on builtin type. We may use synthesized
|
||||
// variables from (1) above.
|
||||
//
|
||||
// 2B: Synthesize per control point invocations of wrapped entry point if the PCF requires them.
|
||||
//
|
||||
// 3. Create a return sequence: copy the return value (if any) from the PCF to a
|
||||
// (non-sanitized) output variable. In case this may involve multiple copies, such as for
|
||||
// an arrayed variable, a temporary copy of the PCF output is created to avoid multiple
|
||||
// indirections into a complex R-value coming from the call to the PCF.
|
||||
//
|
||||
// 4. Add a barrier to the end of the entry point body
|
||||
//
|
||||
// 5. Call the PCF inside an if test for (invocation id == 0).
|
||||
//
|
||||
// 4. Create a barrier.
|
||||
//
|
||||
// 5/5B. Call the PCF inside an if test for (invocation id == 0).
|
||||
|
||||
TFunction& patchConstantFunction = const_cast<TFunction&>(*candidateList[0]);
|
||||
const int pcfParamCount = patchConstantFunction.getParamCount();
|
||||
TIntermSymbol* invocationIdSym = findLinkageSymbol(EbvInvocationId);
|
||||
TIntermSequence& epBodySeq = entryPointFunctionBody->getAsAggregate()->getSequence();
|
||||
|
||||
int perCtrlPtParam = -1; // -1 means there isn't one.
|
||||
|
||||
// ================ Step 1A: Union Interfaces ================
|
||||
// Our patch constant function.
|
||||
{
|
||||
@@ -7468,16 +7477,6 @@ void HlslParseContext::addPatchConstantInvocation()
|
||||
findBuiltIns(patchConstantFunction, pcfBuiltIns);
|
||||
findBuiltIns(*entryPointFunction, epfBuiltIns);
|
||||
|
||||
// Patchconstantfunction can contain only builtin qualified variables. (Technically, only HS inputs,
|
||||
// but this test is less assertive than that).
|
||||
|
||||
for (auto bi = pcfBuiltIns.begin(); bi != pcfBuiltIns.end(); ++bi) {
|
||||
if (bi->builtIn == EbvNone) {
|
||||
error(loc, "patch constant function invalid parameter", "", "");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Find the set of builtins in the PCF that are not present in the entry point.
|
||||
std::set<tInterstageIoData> notInEntryPoint;
|
||||
|
||||
@@ -7489,15 +7488,27 @@ void HlslParseContext::addPatchConstantInvocation()
|
||||
|
||||
// Now we'll add those to the entry and to the linkage.
|
||||
for (int p=0; p<pcfParamCount; ++p) {
|
||||
TType* paramType = patchConstantFunction[p].type->clone();
|
||||
const TBuiltInVariable biType = patchConstantFunction[p].declaredBuiltIn;
|
||||
const TStorageQualifier storage = patchConstantFunction[p].type->getQualifier().storage;
|
||||
|
||||
// Use the original declaration type for the linkage
|
||||
paramType->getQualifier().builtIn = biType;
|
||||
// Track whether there is any per control point input
|
||||
if (isPerCtrlPt(*patchConstantFunction[p].type)) {
|
||||
if (perCtrlPtParam >= 0) {
|
||||
// Presently we only support one per ctrl pt input. TODO: does HLSL even allow multiple?
|
||||
error(loc, "unimplemented: multiple per control point inputs to patch constant function", "", "");
|
||||
return;
|
||||
}
|
||||
perCtrlPtParam = p;
|
||||
}
|
||||
|
||||
if (notInEntryPoint.count(tInterstageIoData(biType, storage)) == 1)
|
||||
addToLinkage(*paramType, patchConstantFunction[p].name, nullptr);
|
||||
if (biType != EbvNone) {
|
||||
TType* paramType = patchConstantFunction[p].type->clone();
|
||||
// Use the original declaration type for the linkage
|
||||
paramType->getQualifier().builtIn = biType;
|
||||
|
||||
if (notInEntryPoint.count(tInterstageIoData(biType, storage)) == 1)
|
||||
addToLinkage(*paramType, patchConstantFunction[p].name, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
// If we didn't find it because the shader made one, add our own.
|
||||
@@ -7512,36 +7523,50 @@ void HlslParseContext::addPatchConstantInvocation()
|
||||
}
|
||||
|
||||
TIntermTyped* pcfArguments = nullptr;
|
||||
TVariable* perCtrlPtVar = nullptr;
|
||||
|
||||
// ================ Step 1B: Argument synthesis ================
|
||||
// Create pcfArguments for synthesis of patchconstantfunction invocation
|
||||
// TODO: handle struct or array inputs
|
||||
{
|
||||
for (int p=0; p<pcfParamCount; ++p) {
|
||||
if (patchConstantFunction[p].type->isArray() ||
|
||||
patchConstantFunction[p].type->isStruct()) {
|
||||
if ((patchConstantFunction[p].type->isArray() && !isPerCtrlPt(*patchConstantFunction[p].type)) ||
|
||||
(!patchConstantFunction[p].type->isArray() && patchConstantFunction[p].type->isStruct())) {
|
||||
error(loc, "unimplemented array or variable in patch constant function signature", "", "");
|
||||
return;
|
||||
}
|
||||
|
||||
// find which builtin it is
|
||||
const TBuiltInVariable biType = patchConstantFunction[p].declaredBuiltIn;
|
||||
TIntermSymbol* inputArg = nullptr;
|
||||
|
||||
TIntermSymbol* builtIn = findLinkageSymbol(biType);
|
||||
if (p == perCtrlPtParam) {
|
||||
if (perCtrlPtVar == nullptr) {
|
||||
perCtrlPtVar = makeInternalVariable(*patchConstantFunction[perCtrlPtParam].name,
|
||||
*patchConstantFunction[perCtrlPtParam].type);
|
||||
|
||||
perCtrlPtVar->getWritableType().getQualifier().makeTemporary();
|
||||
}
|
||||
inputArg = intermediate.addSymbol(*perCtrlPtVar, loc);
|
||||
} else {
|
||||
// find which builtin it is
|
||||
const TBuiltInVariable biType = patchConstantFunction[p].declaredBuiltIn;
|
||||
|
||||
inputArg = findLinkageSymbol(biType);
|
||||
|
||||
if (builtIn == nullptr) {
|
||||
error(loc, "unable to find patch constant function builtin variable", "", "");
|
||||
return;
|
||||
if (inputArg == nullptr) {
|
||||
error(loc, "unable to find patch constant function builtin variable", "", "");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (pcfParamCount == 1)
|
||||
pcfArguments = builtIn;
|
||||
pcfArguments = inputArg;
|
||||
else
|
||||
pcfArguments = intermediate.growAggregate(pcfArguments, builtIn);
|
||||
pcfArguments = intermediate.growAggregate(pcfArguments, inputArg);
|
||||
}
|
||||
}
|
||||
|
||||
// ================ Step 2: Synthesize call to PCF ================
|
||||
TIntermAggregate* pcfCallSequence = nullptr;
|
||||
TIntermTyped* pcfCall = nullptr;
|
||||
|
||||
{
|
||||
@@ -7553,7 +7578,8 @@ void HlslParseContext::addPatchConstantInvocation()
|
||||
pcfCall = intermediate.setAggregateOperator(pcfArguments, EOpFunctionCall, patchConstantFunction.getType(), loc);
|
||||
pcfCall->getAsAggregate()->setUserDefined();
|
||||
pcfCall->getAsAggregate()->setName(patchConstantFunction.getMangledName());
|
||||
intermediate.addToCallGraph(infoSink, entryPointFunction->getMangledName(), patchConstantFunction.getMangledName());
|
||||
intermediate.addToCallGraph(infoSink, intermediate.getEntryPointMangledName().c_str(),
|
||||
patchConstantFunction.getMangledName());
|
||||
|
||||
if (pcfCall->getAsAggregate()) {
|
||||
TQualifierList& qualifierList = pcfCall->getAsAggregate()->getQualifierList();
|
||||
@@ -7565,6 +7591,71 @@ void HlslParseContext::addPatchConstantInvocation()
|
||||
}
|
||||
}
|
||||
|
||||
// ================ Step 2B: Per Control Point synthesis ================
|
||||
// If there is per control point data, we must either emulate that with multiple
|
||||
// invocations of the entry point to build up an array, or (TODO:) use a yet
|
||||
// unavailable extension to look across the SIMD lanes. This is the former
|
||||
// as a placeholder for the latter.
|
||||
if (perCtrlPtParam >= 0) {
|
||||
// We must introduce a local temp variable of the type wanted by the PCF input.
|
||||
const int arraySize = patchConstantFunction[perCtrlPtParam].type->getOuterArraySize();
|
||||
|
||||
if (entryPointFunction->getType().getBasicType() == EbtVoid) {
|
||||
error(loc, "entry point must return a value for use with patch constant function", "", "");
|
||||
return;
|
||||
}
|
||||
|
||||
// Create calls to wrapped main to fill in the array. We will substitute fixed values
|
||||
// of invocation ID when calling the wrapped main.
|
||||
|
||||
// This is the type of the each member of the per ctrl point array.
|
||||
const TType derefType(perCtrlPtVar->getType(), 0);
|
||||
|
||||
for (int cpt = 0; cpt < arraySize; ++cpt) {
|
||||
// TODO: improve. substr(1) here is to avoid the '@' that was grafted on but isn't in the symtab
|
||||
// for this function.
|
||||
const TString origName = entryPointFunction->getName().substr(1);
|
||||
TFunction callee(&origName, TType(EbtVoid));
|
||||
TIntermTyped* callingArgs = nullptr;
|
||||
|
||||
for (int i = 0; i < entryPointFunction->getParamCount(); i++) {
|
||||
TParameter& param = (*entryPointFunction)[i];
|
||||
TType& paramType = *param.type;
|
||||
|
||||
if (paramType.getQualifier().isParamOutput()) {
|
||||
error(loc, "unimplemented: entry point outputs in patch constant function invocation", "", "");
|
||||
return;
|
||||
}
|
||||
|
||||
if (paramType.getQualifier().isParamInput()) {
|
||||
TIntermTyped* arg = nullptr;
|
||||
if ((*entryPointFunction)[i].declaredBuiltIn == EbvInvocationId) {
|
||||
// substitute invocation ID with the array element ID
|
||||
arg = intermediate.addConstantUnion(cpt, loc);
|
||||
} else {
|
||||
TVariable* argVar = makeInternalVariable(*param.name, *param.type);
|
||||
argVar->getWritableType().getQualifier().makeTemporary();
|
||||
arg = intermediate.addSymbol(*argVar);
|
||||
}
|
||||
|
||||
handleFunctionArgument(&callee, callingArgs, arg);
|
||||
}
|
||||
}
|
||||
|
||||
// Call and assign to per ctrl point variable
|
||||
currentCaller = intermediate.getEntryPointMangledName().c_str();
|
||||
TIntermTyped* callReturn = handleFunctionCall(loc, &callee, callingArgs);
|
||||
TIntermTyped* index = intermediate.addConstantUnion(cpt, loc);
|
||||
TIntermSymbol* perCtrlPtSym = intermediate.addSymbol(*perCtrlPtVar, loc);
|
||||
TIntermTyped* element = intermediate.addIndex(EOpIndexDirect, perCtrlPtSym, index, loc);
|
||||
element->setType(derefType);
|
||||
element->setLoc(loc);
|
||||
|
||||
pcfCallSequence = intermediate.growAggregate(pcfCallSequence,
|
||||
handleAssign(loc, EOpAssign, element, callReturn));
|
||||
}
|
||||
}
|
||||
|
||||
// ================ Step 3: Create return Sequence ================
|
||||
// Return sequence: copy PCF result to a temporary, then to shader output variable.
|
||||
if (pcfCall->getBasicType() != EbtVoid) {
|
||||
@@ -7581,30 +7672,31 @@ void HlslParseContext::addPatchConstantInvocation()
|
||||
if (patchConstantFunction.getDeclaredBuiltInType() != EbvNone)
|
||||
outType.getQualifier().builtIn = patchConstantFunction.getDeclaredBuiltInType();
|
||||
|
||||
outType.getQualifier().patch = true; // make it a per-patch variable
|
||||
|
||||
TVariable* pcfOutput = makeInternalVariable("@patchConstantOutput", outType);
|
||||
pcfOutput->getWritableType().getQualifier().storage = EvqVaryingOut;
|
||||
|
||||
if (pcfOutput->getType().containsBuiltInInterstageIO(language))
|
||||
split(*pcfOutput);
|
||||
|
||||
assignLocations(*pcfOutput);
|
||||
|
||||
TIntermSymbol* pcfOutputSym = intermediate.addSymbol(*pcfOutput, loc);
|
||||
|
||||
// The call to the PCF is a complex R-value: we want to store it in a temp to avoid
|
||||
// repeated calls to the PCF:
|
||||
TVariable* pcfCallResult = makeInternalVariable("@patchConstantResult", *retType);
|
||||
pcfCallResult->getWritableType().getQualifier().makeTemporary();
|
||||
TIntermSymbol* pcfResultVar = intermediate.addSymbol(*pcfCallResult, loc);
|
||||
// sanitizeType(&pcfCall->getWritableType());
|
||||
TIntermNode* pcfResultAssign = intermediate.addAssign(EOpAssign, pcfResultVar, pcfCall, loc);
|
||||
|
||||
TIntermSymbol* pcfResultVar = intermediate.addSymbol(*pcfCallResult, loc);
|
||||
TIntermNode* pcfResultAssign = handleAssign(loc, EOpAssign, pcfResultVar, pcfCall);
|
||||
TIntermNode* pcfResultToOut = handleAssign(loc, EOpAssign, pcfOutputSym, intermediate.addSymbol(*pcfCallResult, loc));
|
||||
|
||||
TIntermTyped* pcfAggregate = nullptr;
|
||||
pcfAggregate = intermediate.growAggregate(pcfAggregate, pcfResultAssign);
|
||||
pcfAggregate = intermediate.growAggregate(pcfAggregate, pcfResultToOut);
|
||||
pcfAggregate = intermediate.setAggregateOperator(pcfAggregate, EOpSequence, *retType, loc);
|
||||
|
||||
pcfCall = pcfAggregate;
|
||||
pcfCallSequence = intermediate.growAggregate(pcfCallSequence, pcfResultAssign);
|
||||
pcfCallSequence = intermediate.growAggregate(pcfCallSequence, pcfResultToOut);
|
||||
} else {
|
||||
pcfCallSequence = intermediate.growAggregate(pcfCallSequence, pcfCall);
|
||||
}
|
||||
|
||||
// ================ Step 4: Barrier ================
|
||||
@@ -7613,12 +7705,14 @@ void HlslParseContext::addPatchConstantInvocation()
|
||||
barrier->setType(TType(EbtVoid));
|
||||
epBodySeq.insert(epBodySeq.end(), barrier);
|
||||
|
||||
// ================ Step 5: Test on invocation ID ================
|
||||
// ================ Step 5: Test on invocation ID ================
|
||||
TIntermTyped* zero = intermediate.addConstantUnion(0, loc, true);
|
||||
TIntermTyped* cmp = intermediate.addBinaryNode(EOpEqual, invocationIdSym, zero, loc, TType(EbtBool));
|
||||
|
||||
// Create if statement
|
||||
TIntermTyped* invocationIdTest = new TIntermSelection(cmp, pcfCall, nullptr);
|
||||
|
||||
// ================ Step 5B: Create if statement on Invocation ID == 0 ================
|
||||
intermediate.setAggregateOperator(pcfCallSequence, EOpSequence, TType(EbtVoid), loc);
|
||||
TIntermTyped* invocationIdTest = new TIntermSelection(cmp, pcfCallSequence, nullptr);
|
||||
invocationIdTest->setLoc(loc);
|
||||
|
||||
// add our test sequence before the return.
|
||||
|
||||
Reference in New Issue
Block a user