HLSL default function parameters
This PR adds support for default function parameters in the following cases:
1. Simple constants, such as void fn(int x, float myparam = 3)
2. Expressions that can be const folded, such a ... myparam = sin(some_const)
3. Initializer lists that can be const folded, such as ... float2 myparam = {1,2}
New tests are added: hlsl.params.default.frag and hlsl.params.default.err.frag
(for testing error situations, such as ambiguity or non-const-foldable).
In order to avoid sampler method ambiguity, the hlsl better() lambda now
considers sampler matches. Previously, all sampler types looked identical
since only the basic type of EbtSampler was considered.
This commit is contained in:
@@ -1776,9 +1776,55 @@ bool HlslGrammar::acceptFunctionParameters(TFunction& function)
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
// default_parameter_declaration
|
||||
// : EQUAL conditional_expression
|
||||
// : EQUAL initializer
|
||||
bool HlslGrammar::acceptDefaultParameterDeclaration(const TType& type, TIntermTyped*& node)
|
||||
{
|
||||
node = nullptr;
|
||||
|
||||
// Valid not to have a default_parameter_declaration
|
||||
if (!acceptTokenClass(EHTokAssign))
|
||||
return true;
|
||||
|
||||
if (!acceptConditionalExpression(node)) {
|
||||
if (!acceptInitializer(node))
|
||||
return false;
|
||||
|
||||
// For initializer lists, we have to const-fold into a constructor for the type, so build
|
||||
// that.
|
||||
TFunction* constructor = parseContext.handleConstructorCall(token.loc, type);
|
||||
if (constructor == nullptr) // cannot construct
|
||||
return false;
|
||||
|
||||
TIntermTyped* arguments = nullptr;
|
||||
for (int i=0; i<int(node->getAsAggregate()->getSequence().size()); i++)
|
||||
parseContext.handleFunctionArgument(constructor, arguments, node->getAsAggregate()->getSequence()[i]->getAsTyped());
|
||||
|
||||
node = parseContext.handleFunctionCall(token.loc, constructor, node);
|
||||
}
|
||||
|
||||
// If this is simply a constant, we can use it directly.
|
||||
if (node->getAsConstantUnion())
|
||||
return true;
|
||||
|
||||
// Otherwise, it has to be const-foldable.
|
||||
TIntermTyped* origNode = node;
|
||||
|
||||
node = intermediate.fold(node->getAsAggregate());
|
||||
|
||||
if (node != nullptr && origNode != node)
|
||||
return true;
|
||||
|
||||
parseContext.error(token.loc, "invalid default parameter value", "", "");
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// parameter_declaration
|
||||
// : fully_specified_type post_decls
|
||||
// | fully_specified_type identifier array_specifier post_decls
|
||||
// : fully_specified_type post_decls [ = default_parameter_declaration ]
|
||||
// | fully_specified_type identifier array_specifier post_decls [ = default_parameter_declaration ]
|
||||
//
|
||||
bool HlslGrammar::acceptParameterDeclaration(TFunction& function)
|
||||
{
|
||||
@@ -1806,9 +1852,19 @@ bool HlslGrammar::acceptParameterDeclaration(TFunction& function)
|
||||
// post_decls
|
||||
acceptPostDecls(type->getQualifier());
|
||||
|
||||
TIntermTyped* defaultValue;
|
||||
if (!acceptDefaultParameterDeclaration(*type, defaultValue))
|
||||
return false;
|
||||
|
||||
parseContext.paramFix(*type);
|
||||
|
||||
TParameter param = { idToken.string, type };
|
||||
// If any prior parameters have default values, all the parameters after that must as well.
|
||||
if (defaultValue == nullptr && function.getDefaultParamCount() > 0) {
|
||||
parseContext.error(idToken.loc, "invalid parameter after default value parameters", idToken.string->c_str(), "");
|
||||
return false;
|
||||
}
|
||||
|
||||
TParameter param = { idToken.string, type, defaultValue };
|
||||
function.addParameter(param);
|
||||
|
||||
return true;
|
||||
|
||||
@@ -111,6 +111,7 @@ namespace glslang {
|
||||
bool acceptDefaultLabel(TIntermNode*&);
|
||||
void acceptArraySpecifier(TArraySizes*&);
|
||||
void acceptPostDecls(TQualifier&);
|
||||
bool acceptDefaultParameterDeclaration(const TType&, TIntermTyped*&);
|
||||
|
||||
HlslParseContext& parseContext; // state of parsing and helper functions for building the intermediate
|
||||
TIntermediate& intermediate; // the final product, the intermediate representation, includes the AST
|
||||
|
||||
@@ -1375,7 +1375,7 @@ TIntermNode* HlslParseContext::handleReturnValue(const TSourceLoc& loc, TIntermT
|
||||
|
||||
void HlslParseContext::handleFunctionArgument(TFunction* function, TIntermTyped*& arguments, TIntermTyped* newArg)
|
||||
{
|
||||
TParameter param = { 0, new TType };
|
||||
TParameter param = { 0, new TType, nullptr };
|
||||
param.type->shallowCopy(newArg->getType());
|
||||
function->addParameter(param);
|
||||
if (arguments)
|
||||
@@ -2643,7 +2643,7 @@ void HlslParseContext::decomposeIntrinsic(const TSourceLoc& loc, TIntermTyped*&
|
||||
// - user function
|
||||
// - subroutine call (not implemented yet)
|
||||
//
|
||||
TIntermTyped* HlslParseContext::handleFunctionCall(const TSourceLoc& loc, TFunction* function, TIntermNode* arguments)
|
||||
TIntermTyped* HlslParseContext::handleFunctionCall(const TSourceLoc& loc, TFunction* function, TIntermTyped* arguments)
|
||||
{
|
||||
TIntermTyped* result = nullptr;
|
||||
|
||||
@@ -2783,10 +2783,10 @@ TIntermTyped* HlslParseContext::handleLengthMethod(const TSourceLoc& loc, TFunct
|
||||
//
|
||||
// Add any needed implicit conversions for function-call arguments to input parameters.
|
||||
//
|
||||
void HlslParseContext::addInputArgumentConversions(const TFunction& function, TIntermNode*& arguments)
|
||||
void HlslParseContext::addInputArgumentConversions(const TFunction& function, TIntermTyped*& arguments)
|
||||
{
|
||||
TIntermAggregate* aggregate = arguments->getAsAggregate();
|
||||
const auto setArg = [&](int argNum, TIntermNode* arg) {
|
||||
const auto setArg = [&](int argNum, TIntermTyped* arg) {
|
||||
if (function.getParamCount() == 1)
|
||||
arguments = arg;
|
||||
else {
|
||||
@@ -4483,8 +4483,8 @@ void HlslParseContext::mergeObjectLayoutQualifiers(TQualifier& dst, const TQuali
|
||||
//
|
||||
// Return the function symbol if found, otherwise nullptr.
|
||||
//
|
||||
const TFunction* HlslParseContext::findFunction(const TSourceLoc& loc, const TFunction& call, bool& builtIn,
|
||||
TIntermNode* args)
|
||||
const TFunction* HlslParseContext::findFunction(const TSourceLoc& loc, TFunction& call, bool& builtIn,
|
||||
TIntermTyped*& args)
|
||||
{
|
||||
// const TFunction* function = nullptr;
|
||||
|
||||
@@ -4583,6 +4583,22 @@ const TFunction* HlslParseContext::findFunction(const TSourceLoc& loc, const TFu
|
||||
return false;
|
||||
}
|
||||
|
||||
// Handle sampler betterness: An exact sampler match beats a non-exact match.
|
||||
// (If we just looked at basic type, all EbtSamplers would look the same).
|
||||
// If any type is not a sampler, just use the linearize function below.
|
||||
if (from.getBasicType() == EbtSampler && to1.getBasicType() == EbtSampler && to2.getBasicType() == EbtSampler) {
|
||||
// We can ignore the vector size in the comparison.
|
||||
TSampler to1Sampler = to1.getSampler();
|
||||
TSampler to2Sampler = to2.getSampler();
|
||||
|
||||
to1Sampler.vectorSize = to2Sampler.vectorSize = from.getSampler().vectorSize;
|
||||
|
||||
if (from.getSampler() == to2Sampler)
|
||||
return from.getSampler() != to1Sampler;
|
||||
if (from.getSampler() == to1Sampler)
|
||||
return false;
|
||||
}
|
||||
|
||||
// Might or might not be changing shape, which means basic type might
|
||||
// or might not match, so within that, the question is how big a
|
||||
// basic-type conversion is being done.
|
||||
@@ -4672,18 +4688,18 @@ const TFunction* HlslParseContext::findFunction(const TSourceLoc& loc, const TFu
|
||||
// Handle aggregates: put all args into the new function call
|
||||
for (int arg=0; arg<int(args->getAsAggregate()->getSequence().size()); ++arg) {
|
||||
// TODO: But for constness, we could avoid the new & shallowCopy, and use the pointer directly.
|
||||
TParameter param = { 0, new TType };
|
||||
TParameter param = { 0, new TType, nullptr };
|
||||
param.type->shallowCopy(args->getAsAggregate()->getSequence()[arg]->getAsTyped()->getType());
|
||||
convertedCall.addParameter(param);
|
||||
}
|
||||
} else if (args->getAsUnaryNode()) {
|
||||
// Handle unaries: put all args into the new function call
|
||||
TParameter param = { 0, new TType };
|
||||
TParameter param = { 0, new TType, nullptr };
|
||||
param.type->shallowCopy(args->getAsUnaryNode()->getOperand()->getAsTyped()->getType());
|
||||
convertedCall.addParameter(param);
|
||||
} else if (args->getAsTyped()) {
|
||||
// Handle bare e.g, floats, not in an aggregate.
|
||||
TParameter param = { 0, new TType };
|
||||
TParameter param = { 0, new TType, nullptr };
|
||||
param.type->shallowCopy(args->getAsTyped()->getType());
|
||||
convertedCall.addParameter(param);
|
||||
} else {
|
||||
@@ -4701,6 +4717,13 @@ const TFunction* HlslParseContext::findFunction(const TSourceLoc& loc, const TFu
|
||||
if (tie)
|
||||
error(loc, "ambiguous best function under implicit type conversion", call.getName().c_str(), "");
|
||||
|
||||
// Append default parameter values if needed
|
||||
if (!tie && bestMatch != nullptr) {
|
||||
for (int defParam = call.getParamCount(); defParam < bestMatch->getParamCount(); ++defParam) {
|
||||
handleFunctionArgument(&call, args, (*bestMatch)[defParam].defaultValue);
|
||||
}
|
||||
}
|
||||
|
||||
return bestMatch;
|
||||
}
|
||||
|
||||
|
||||
@@ -79,12 +79,12 @@ public:
|
||||
TIntermNode* handleReturnValue(const TSourceLoc&, TIntermTyped*);
|
||||
void handleFunctionArgument(TFunction*, TIntermTyped*& arguments, TIntermTyped* newArg);
|
||||
TIntermTyped* handleAssign(const TSourceLoc&, TOperator, TIntermTyped* left, TIntermTyped* right) const;
|
||||
TIntermTyped* handleFunctionCall(const TSourceLoc&, TFunction*, TIntermNode*);
|
||||
TIntermTyped* handleFunctionCall(const TSourceLoc&, TFunction*, TIntermTyped*);
|
||||
void decomposeIntrinsic(const TSourceLoc&, TIntermTyped*& node, TIntermNode* arguments);
|
||||
void decomposeSampleMethods(const TSourceLoc&, TIntermTyped*& node, TIntermNode* arguments);
|
||||
void decomposeGeometryMethods(const TSourceLoc&, TIntermTyped*& node, TIntermNode* arguments);
|
||||
TIntermTyped* handleLengthMethod(const TSourceLoc&, TFunction*, TIntermNode*);
|
||||
void addInputArgumentConversions(const TFunction&, TIntermNode*&);
|
||||
void addInputArgumentConversions(const TFunction&, TIntermTyped*&);
|
||||
TIntermTyped* addOutputArgumentConversions(const TFunction&, TIntermOperator&);
|
||||
void builtInOpCheck(const TSourceLoc&, const TFunction&, TIntermOperator&);
|
||||
TFunction* handleConstructorCall(const TSourceLoc&, const TType&);
|
||||
@@ -126,7 +126,7 @@ public:
|
||||
void mergeObjectLayoutQualifiers(TQualifier& dest, const TQualifier& src, bool inheritOnly);
|
||||
void checkNoShaderLayouts(const TSourceLoc&, const TShaderQualifiers&);
|
||||
|
||||
const TFunction* findFunction(const TSourceLoc& loc, const TFunction& call, bool& builtIn, TIntermNode* args);
|
||||
const TFunction* findFunction(const TSourceLoc& loc, TFunction& call, bool& builtIn, TIntermTyped*& args);
|
||||
void declareTypedef(const TSourceLoc&, TString& identifier, const TType&, TArraySizes* typeArray = 0);
|
||||
TIntermNode* declareVariable(const TSourceLoc&, TString& identifier, TType&, TIntermTyped* initializer = 0);
|
||||
void lengthenList(const TSourceLoc&, TIntermSequence& list, int size);
|
||||
|
||||
@@ -586,7 +586,7 @@ void TBuiltInParseablesHlsl::initialize(int /*version*/, EProfile /*profile*/, c
|
||||
{ "DeviceMemoryBarrier", nullptr, nullptr, "-", "-", EShLangPSCS },
|
||||
{ "DeviceMemoryBarrierWithGroupSync", nullptr, nullptr, "-", "-", EShLangCS },
|
||||
{ "distance", "S", "F", "V,", "F,", EShLangAll },
|
||||
{ "dot", "S", nullptr, "V,", "FI,", EShLangAll },
|
||||
{ "dot", "S", nullptr, "SV,", "FI,", EShLangAll },
|
||||
{ "dst", nullptr, nullptr, "V4,", "F,", EShLangAll },
|
||||
// { "errorf", "-", "-", "", "", EShLangAll }, TODO: varargs
|
||||
{ "EvaluateAttributeAtCentroid", nullptr, nullptr, "SVM", "F", EShLangPS },
|
||||
|
||||
Reference in New Issue
Block a user