HLSL: add intrinsic function implicit promotions
This PR handles implicit promotions for intrinsics when there is no exact match, such as for example clamp(int, bool, float). In this case the int and bool will be promoted to a float, and the clamp(float, float, float) form used. These promotions can be mixed with shape conversions, e.g, clamp(int, bool2, float2). Output conversions are handled either via the existing addOutputArgumentConversion function, which this PR generalizes to handle either aggregates or unaries, or by intrinsic decomposition. If there are methods or intrinsics to be decomposed, then decomposition is responsible for any output conversions, which turns out to happen automatically in all current cases. This can be revisited once inout conversions are in place. Some cases of actual ambiguity were fixed in several tests, e.g, spv.register.autoassign.* Some intrinsics with only uint versions were expanded to signed ints natively, where the underlying AST and SPIR-V supports that. E.g, countbits. This avoids extraneous conversion nodes. A new function promoteAggregate is added, and used by findFunction. This is essentially a generalization of the "promote 1st or 2nd arg" algorithm in promoteBinary. The actual selection proceeds in three steps, as described in the comments in hlslParseContext::findFunction: 1. Attempt an exact match. If found, use it. 2. If not, obtain the operator from step 1, and promote arguments. 3. Re-select the intrinsic overload from the results of step 2.
This commit is contained in:
@@ -2555,7 +2555,7 @@ TIntermTyped* HlslParseContext::handleFunctionCall(const TSourceLoc& loc, TFunct
|
||||
//
|
||||
const TFunction* fnCandidate;
|
||||
bool builtIn;
|
||||
fnCandidate = findFunction(loc, *function, builtIn);
|
||||
fnCandidate = findFunction(loc, *function, builtIn, arguments);
|
||||
if (fnCandidate) {
|
||||
// This is a declared function that might map to
|
||||
// - a built-in operator,
|
||||
@@ -2597,21 +2597,27 @@ TIntermTyped* HlslParseContext::handleFunctionCall(const TSourceLoc& loc, TFunct
|
||||
}
|
||||
}
|
||||
|
||||
// for decompositions, since we want to operate on the function node, not the aggregate holding
|
||||
// output conversions.
|
||||
const TIntermTyped* fnNode = result;
|
||||
|
||||
decomposeIntrinsic(loc, result, arguments); // HLSL->AST intrinsic decompositions
|
||||
decomposeSampleMethods(loc, result, arguments); // HLSL->AST sample method decompositions
|
||||
decomposeGeometryMethods(loc, result, arguments); // HLSL->AST geometry method decompositions
|
||||
|
||||
// Convert 'out' arguments. If it was a constant folded built-in, it won't be an aggregate anymore.
|
||||
// Built-ins with a single argument aren't called with an aggregate, but they also don't have an output.
|
||||
// Also, build the qualifier list for user function calls, which are always called with an aggregate.
|
||||
if (result->getAsAggregate()) {
|
||||
// We don't do this is if there has been a decomposition, which will have added its own conversions
|
||||
// for output parameters.
|
||||
if (result == fnNode && result->getAsAggregate()) {
|
||||
TQualifierList& qualifierList = result->getAsAggregate()->getQualifierList();
|
||||
for (int i = 0; i < fnCandidate->getParamCount(); ++i) {
|
||||
TStorageQualifier qual = (*fnCandidate)[i].type->getQualifier().storage;
|
||||
qualifierList.push_back(qual);
|
||||
}
|
||||
result = addOutputArgumentConversions(*fnCandidate, *result->getAsAggregate());
|
||||
result = addOutputArgumentConversions(*fnCandidate, *result->getAsOperator());
|
||||
}
|
||||
|
||||
decomposeIntrinsic(loc, result, arguments); // HLSL->AST intrinsic decompositions
|
||||
decomposeSampleMethods(loc, result, arguments); // HLSL->AST sample method decompositions
|
||||
decomposeGeometryMethods(loc, result, arguments); // HLSL->AST geometry method decompositions
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2724,9 +2730,19 @@ void HlslParseContext::addInputArgumentConversions(const TFunction& function, TI
|
||||
//
|
||||
// Returns a node of a subtree that evaluates to the return value of the function.
|
||||
//
|
||||
TIntermTyped* HlslParseContext::addOutputArgumentConversions(const TFunction& function, TIntermAggregate& intermNode)
|
||||
TIntermTyped* HlslParseContext::addOutputArgumentConversions(const TFunction& function, TIntermOperator& intermNode)
|
||||
{
|
||||
TIntermSequence& arguments = intermNode.getSequence();
|
||||
assert (intermNode.getAsAggregate() != nullptr || intermNode.getAsUnaryNode() != nullptr);
|
||||
|
||||
const TSourceLoc& loc = intermNode.getLoc();
|
||||
|
||||
TIntermSequence argSequence; // temp sequence for unary node args
|
||||
|
||||
if (intermNode.getAsUnaryNode())
|
||||
argSequence.push_back(intermNode.getAsUnaryNode()->getOperand());
|
||||
|
||||
TIntermSequence& arguments = argSequence.empty() ? intermNode.getAsAggregate()->getSequence() : argSequence;
|
||||
|
||||
const auto needsConversion = [&](int argNum) {
|
||||
return function[argNum].type->getQualifier().isParamOutput() &&
|
||||
(*function[argNum].type != arguments[argNum]->getAsTyped()->getType() ||
|
||||
@@ -2759,8 +2775,8 @@ TIntermTyped* HlslParseContext::addOutputArgumentConversions(const TFunction& fu
|
||||
if (intermNode.getBasicType() != EbtVoid) {
|
||||
// do the "tempRet = function(...), " bit from above
|
||||
tempRet = makeInternalVariable("tempReturn", intermNode.getType());
|
||||
TIntermSymbol* tempRetNode = intermediate.addSymbol(*tempRet, intermNode.getLoc());
|
||||
conversionTree = intermediate.addAssign(EOpAssign, tempRetNode, &intermNode, intermNode.getLoc());
|
||||
TIntermSymbol* tempRetNode = intermediate.addSymbol(*tempRet, loc);
|
||||
conversionTree = intermediate.addAssign(EOpAssign, tempRetNode, &intermNode, loc);
|
||||
} else
|
||||
conversionTree = &intermNode;
|
||||
|
||||
@@ -2775,7 +2791,7 @@ TIntermTyped* HlslParseContext::addOutputArgumentConversions(const TFunction& fu
|
||||
// Make a temporary for what the function expects the argument to look like.
|
||||
TVariable* tempArg = makeInternalVariable("tempArg", *function[i].type);
|
||||
tempArg->getWritableType().getQualifier().makeTemporary();
|
||||
TIntermSymbol* tempArgNode = intermediate.addSymbol(*tempArg, intermNode.getLoc());
|
||||
TIntermSymbol* tempArgNode = intermediate.addSymbol(*tempArg, loc);
|
||||
|
||||
// This makes the deepest level, the member-wise copy
|
||||
TIntermTyped* tempAssign = handleAssign(arguments[i]->getLoc(), EOpAssign, arguments[i]->getAsTyped(), tempArgNode);
|
||||
@@ -2783,17 +2799,18 @@ TIntermTyped* HlslParseContext::addOutputArgumentConversions(const TFunction& fu
|
||||
conversionTree = intermediate.growAggregate(conversionTree, tempAssign, arguments[i]->getLoc());
|
||||
|
||||
// replace the argument with another node for the same tempArg variable
|
||||
arguments[i] = intermediate.addSymbol(*tempArg, intermNode.getLoc());
|
||||
arguments[i] = intermediate.addSymbol(*tempArg, loc);
|
||||
}
|
||||
}
|
||||
|
||||
// Finalize the tree topology (see bigger comment above).
|
||||
if (tempRet) {
|
||||
// do the "..., tempRet" bit from above
|
||||
TIntermSymbol* tempRetNode = intermediate.addSymbol(*tempRet, intermNode.getLoc());
|
||||
conversionTree = intermediate.growAggregate(conversionTree, tempRetNode, intermNode.getLoc());
|
||||
TIntermSymbol* tempRetNode = intermediate.addSymbol(*tempRet, loc);
|
||||
conversionTree = intermediate.growAggregate(conversionTree, tempRetNode, loc);
|
||||
}
|
||||
conversionTree = intermediate.setAggregateOperator(conversionTree, EOpComma, intermNode.getType(), intermNode.getLoc());
|
||||
|
||||
conversionTree = intermediate.setAggregateOperator(conversionTree, EOpComma, intermNode.getType(), loc);
|
||||
|
||||
return conversionTree;
|
||||
}
|
||||
@@ -4339,7 +4356,8 @@ void HlslParseContext::mergeObjectLayoutQualifiers(TQualifier& dst, const TQuali
|
||||
//
|
||||
// Return the function symbol if found, otherwise nullptr.
|
||||
//
|
||||
const TFunction* HlslParseContext::findFunction(const TSourceLoc& loc, const TFunction& call, bool& builtIn)
|
||||
const TFunction* HlslParseContext::findFunction(const TSourceLoc& loc, const TFunction& call, bool& builtIn,
|
||||
TIntermNode* args)
|
||||
{
|
||||
// const TFunction* function = nullptr;
|
||||
|
||||
@@ -4445,9 +4463,81 @@ const TFunction* HlslParseContext::findFunction(const TSourceLoc& loc, const TFu
|
||||
// send to the generic selector
|
||||
const TFunction* bestMatch = selectFunction(candidateList, call, convertible, better, tie);
|
||||
|
||||
if (bestMatch == nullptr)
|
||||
if (bestMatch == nullptr) {
|
||||
error(loc, "no matching overloaded function found", call.getName().c_str(), "");
|
||||
else if (tie)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// For builtins, we can convert across the arguments. This will happen in several steps:
|
||||
// Step 1: If there's an exact match, use it.
|
||||
// Step 2a: Otherwise, get the operator from the best match and promote arguments:
|
||||
// Step 2b: reconstruct the TFunction based on the new arg types
|
||||
// Step 3: Re-select after type promotion is applied, to find proper candidate.
|
||||
if (builtIn) {
|
||||
// Step 1: If there's an exact match, use it.
|
||||
if (call.getMangledName() == bestMatch->getMangledName())
|
||||
return bestMatch;
|
||||
|
||||
// Step 2a: Otherwise, get the operator from the best match and promote arguments as if we
|
||||
// are that kind of operator.
|
||||
if (args != nullptr) {
|
||||
// The arg list can be a unary node, or an aggregate. We have to handle both.
|
||||
// We will use the normal promote() facilities, which require an interm node.
|
||||
TIntermOperator* promote = nullptr;
|
||||
|
||||
if (call.getParamCount() == 1) {
|
||||
promote = new TIntermUnary(bestMatch->getBuiltInOp());
|
||||
promote->getAsUnaryNode()->setOperand(args->getAsTyped());
|
||||
} else {
|
||||
promote = new TIntermAggregate(bestMatch->getBuiltInOp());
|
||||
promote->getAsAggregate()->getSequence().swap(args->getAsAggregate()->getSequence());
|
||||
}
|
||||
|
||||
if (! intermediate.promote(promote))
|
||||
return nullptr;
|
||||
|
||||
// Obtain the promoted arg list.
|
||||
if (call.getParamCount() == 1) {
|
||||
args = promote->getAsUnaryNode()->getOperand();
|
||||
} else {
|
||||
promote->getAsAggregate()->getSequence().swap(args->getAsAggregate()->getSequence());
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2b: reconstruct the TFunction based on the new arg types
|
||||
TFunction convertedCall(&call.getName(), call.getType(), call.getBuiltInOp());
|
||||
|
||||
if (args->getAsAggregate()) {
|
||||
// Handle aggregates: put all args into the new function call
|
||||
for (int arg=0; arg<int(args->getAsAggregate()->getSequence().size()); ++arg) {
|
||||
// TODO: But for constness, we could avoid the new & shallowCopy, and use the pointer directly.
|
||||
TParameter param = { 0, new TType };
|
||||
param.type->shallowCopy(args->getAsAggregate()->getSequence()[arg]->getAsTyped()->getType());
|
||||
convertedCall.addParameter(param);
|
||||
}
|
||||
} else if (args->getAsUnaryNode()) {
|
||||
// Handle unaries: put all args into the new function call
|
||||
TParameter param = { 0, new TType };
|
||||
param.type->shallowCopy(args->getAsUnaryNode()->getOperand()->getAsTyped()->getType());
|
||||
convertedCall.addParameter(param);
|
||||
} else if (args->getAsTyped()) {
|
||||
// Handle bare e.g, floats, not in an aggregate.
|
||||
TParameter param = { 0, new TType };
|
||||
param.type->shallowCopy(args->getAsTyped()->getType());
|
||||
convertedCall.addParameter(param);
|
||||
} else {
|
||||
assert(0); // unknown argument list.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Step 3: Re-select after type promotion, to find proper candidate
|
||||
// send to the generic selector
|
||||
bestMatch = selectFunction(candidateList, convertedCall, convertible, better, tie);
|
||||
|
||||
// At this point, there should be no tie.
|
||||
}
|
||||
|
||||
if (tie)
|
||||
error(loc, "ambiguous best function under implicit type conversion", call.getName().c_str(), "");
|
||||
|
||||
return bestMatch;
|
||||
|
||||
@@ -84,7 +84,7 @@ public:
|
||||
void decomposeGeometryMethods(const TSourceLoc&, TIntermTyped*& node, TIntermNode* arguments);
|
||||
TIntermTyped* handleLengthMethod(const TSourceLoc&, TFunction*, TIntermNode*);
|
||||
void addInputArgumentConversions(const TFunction&, TIntermNode*&) const;
|
||||
TIntermTyped* addOutputArgumentConversions(const TFunction&, TIntermAggregate&);
|
||||
TIntermTyped* addOutputArgumentConversions(const TFunction&, TIntermOperator&);
|
||||
void builtInOpCheck(const TSourceLoc&, const TFunction&, TIntermOperator&);
|
||||
TFunction* handleConstructorCall(const TSourceLoc&, const TType&);
|
||||
void handleSemantic(TSourceLoc, TQualifier&, const TString& semantic);
|
||||
@@ -125,7 +125,7 @@ public:
|
||||
void mergeObjectLayoutQualifiers(TQualifier& dest, const TQualifier& src, bool inheritOnly);
|
||||
void checkNoShaderLayouts(const TSourceLoc&, const TShaderQualifiers&);
|
||||
|
||||
const TFunction* findFunction(const TSourceLoc& loc, const TFunction& call, bool& builtIn);
|
||||
const TFunction* findFunction(const TSourceLoc& loc, const TFunction& call, bool& builtIn, TIntermNode* args);
|
||||
void declareTypedef(const TSourceLoc&, TString& identifier, const TType&, TArraySizes* typeArray = 0);
|
||||
TIntermNode* declareVariable(const TSourceLoc&, TString& identifier, TType&, TIntermTyped* initializer = 0);
|
||||
TIntermTyped* addConstructor(const TSourceLoc&, TIntermNode*, const TType&);
|
||||
|
||||
@@ -558,8 +558,8 @@ void TBuiltInParseablesHlsl::initialize(int /*version*/, EProfile /*profile*/, c
|
||||
{ "AllMemoryBarrier", nullptr, nullptr, "-", "-", EShLangCS },
|
||||
{ "AllMemoryBarrierWithGroupSync", nullptr, nullptr, "-", "-", EShLangCS },
|
||||
{ "any", "S", "B", "SVM", "BFIU", EShLangAll },
|
||||
{ "asdouble", "S", "D", "S,", "U,", EShLangAll },
|
||||
{ "asdouble", "V2", "D", "V2,", "U,", EShLangAll },
|
||||
{ "asdouble", "S", "D", "S,", "UI,", EShLangAll },
|
||||
{ "asdouble", "V2", "D", "V2,", "UI,", EShLangAll },
|
||||
{ "asfloat", nullptr, "F", "SVM", "BFIU", EShLangAll },
|
||||
{ "asin", nullptr, nullptr, "SVM", "F", EShLangAll },
|
||||
{ "asint", nullptr, "I", "SVM", "FU", EShLangAll },
|
||||
@@ -572,7 +572,7 @@ void TBuiltInParseablesHlsl::initialize(int /*version*/, EProfile /*profile*/, c
|
||||
{ "clip", "-", "-", "SVM", "F", EShLangPS },
|
||||
{ "cos", nullptr, nullptr, "SVM", "F", EShLangAll },
|
||||
{ "cosh", nullptr, nullptr, "SVM", "F", EShLangAll },
|
||||
{ "countbits", nullptr, nullptr, "SV", "U", EShLangAll },
|
||||
{ "countbits", nullptr, nullptr, "SV", "UI", EShLangAll },
|
||||
{ "cross", nullptr, nullptr, "V3,", "F,", EShLangAll },
|
||||
{ "D3DCOLORtoUBYTE4", "V4", "I", "V4", "F", EShLangAll },
|
||||
{ "ddx", nullptr, nullptr, "SVM", "F", EShLangPS },
|
||||
@@ -636,9 +636,9 @@ void TBuiltInParseablesHlsl::initialize(int /*version*/, EProfile /*profile*/, c
|
||||
{ "log10", nullptr, nullptr, "SVM", "F", EShLangAll },
|
||||
{ "log2", nullptr, nullptr, "SVM", "F", EShLangAll },
|
||||
{ "mad", nullptr, nullptr, "SVM,,", "DFUI,,", EShLangAll },
|
||||
{ "max", nullptr, nullptr, "SVM,", "FI,", EShLangAll },
|
||||
{ "min", nullptr, nullptr, "SVM,", "FI,", EShLangAll },
|
||||
{ "modf", nullptr, nullptr, "SVM,>", "FI,", EShLangAll },
|
||||
{ "max", nullptr, nullptr, "SVM,", "FIU,", EShLangAll },
|
||||
{ "min", nullptr, nullptr, "SVM,", "FIU,", EShLangAll },
|
||||
{ "modf", nullptr, nullptr, "SVM,>", "FIU,", EShLangAll },
|
||||
{ "msad4", "V4", "U", "S,V2,V4", "U,,", EShLangAll },
|
||||
{ "mul", "S", nullptr, "S,S", "FI,", EShLangAll },
|
||||
{ "mul", "V", nullptr, "S,V", "FI,", EShLangAll },
|
||||
@@ -665,7 +665,7 @@ void TBuiltInParseablesHlsl::initialize(int /*version*/, EProfile /*profile*/, c
|
||||
{ "rcp", nullptr, nullptr, "SVM", "FD", EShLangAll },
|
||||
{ "reflect", nullptr, nullptr, "V,", "F,", EShLangAll },
|
||||
{ "refract", nullptr, nullptr, "V,V,S", "F,,", EShLangAll },
|
||||
{ "reversebits", nullptr, nullptr, "SV", "U", EShLangAll },
|
||||
{ "reversebits", nullptr, nullptr, "SV", "UI", EShLangAll },
|
||||
{ "round", nullptr, nullptr, "SVM", "F", EShLangAll },
|
||||
{ "rsqrt", nullptr, nullptr, "SVM", "F", EShLangAll },
|
||||
{ "saturate", nullptr, nullptr , "SVM", "F", EShLangAll },
|
||||
@@ -735,7 +735,7 @@ void TBuiltInParseablesHlsl::initialize(int /*version*/, EProfile /*profile*/, c
|
||||
// RWTexture loads
|
||||
{ "Load", "V4", nullptr, "!#,V", "FIU,I", EShLangAll },
|
||||
// (RW)Buffer loads
|
||||
{ "Load", "V4", nullptr, "~*1,V", "FIU,I", EShLangAll },
|
||||
{ "Load", "V4", nullptr, "~*1,V", "FIU,I", EShLangAll },
|
||||
|
||||
{ "Gather", /*!O*/ "V4", nullptr, "%@,S,V", "FIU,S,F", EShLangAll },
|
||||
{ "Gather", /* O*/ "V4", nullptr, "%@,S,V,V", "FIU,S,F,I", EShLangAll },
|
||||
|
||||
Reference in New Issue
Block a user