Merge pull request #337 from steve-lunarg/intrinsics

HLSL: Add decompositions for some intrinsics.
This commit is contained in:
John Kessenich
2016-06-13 08:54:45 -06:00
committed by GitHub
11 changed files with 5546 additions and 3245 deletions

View File

@@ -1,5 +1,6 @@
//
//Copyright (C) 2016 Google, Inc.
//Copyright (C) 2016 LunarG, Inc.
//
//All rights reserved.
//
@@ -770,6 +771,184 @@ void HlslParseContext::handleFunctionArgument(TFunction* function, TIntermTyped*
arguments = newArg;
}
// Optionally decompose intrinsics to AST opcodes.
//
void HlslParseContext::decomposeIntrinsic(const TSourceLoc& loc, TIntermTyped*& node, TIntermNode* arguments)
{
// HLSL intrinsics can be pass through to native AST opcodes, or decomposed here to existing AST
// opcodes for compatibility with existing software stacks.
static const bool decomposeHlslIntrinsics = true;
if (!decomposeHlslIntrinsics || !node || !node->getAsOperator())
return;
const TIntermAggregate* argAggregate = arguments ? arguments->getAsAggregate() : nullptr;
TIntermUnary* fnUnary = node->getAsUnaryNode();
const TOperator op = node->getAsOperator()->getOp();
switch (op) {
case EOpGenMul:
{
// mul(a,b) -> MatrixTimesMatrix, MatrixTimesVector, MatrixTimesScalar, VectorTimesScalar, Dot, Mul
TIntermTyped* arg0 = argAggregate->getSequence()[0]->getAsTyped();
TIntermTyped* arg1 = argAggregate->getSequence()[1]->getAsTyped();
if (arg0->isVector() && arg1->isVector()) { // vec * vec
node->getAsAggregate()->setOperator(EOpDot);
} else {
node = handleBinaryMath(loc, "mul", EOpMul, arg0, arg1);
}
break;
}
case EOpRcp:
{
// rcp(a) -> 1 / a
TIntermTyped* arg0 = fnUnary->getOperand();
TBasicType type0 = arg0->getBasicType();
TIntermTyped* one = intermediate.addConstantUnion(1, type0, loc, true);
node = handleBinaryMath(loc, "rcp", EOpDiv, one, arg0);
break;
}
case EOpSaturate:
{
// saturate(a) -> clamp(a,0,1)
TIntermTyped* arg0 = fnUnary->getOperand();
TBasicType type0 = arg0->getBasicType();
TIntermAggregate* clamp = new TIntermAggregate(EOpClamp);
clamp->getSequence().push_back(arg0);
clamp->getSequence().push_back(intermediate.addConstantUnion(0, type0, loc, true));
clamp->getSequence().push_back(intermediate.addConstantUnion(1, type0, loc, true));
clamp->setLoc(loc);
clamp->setType(node->getType());
node = clamp;
break;
}
case EOpSinCos:
{
// sincos(a,b,c) -> b = sin(a), c = cos(a)
TIntermTyped* arg0 = argAggregate->getSequence()[0]->getAsTyped();
TIntermTyped* arg1 = argAggregate->getSequence()[1]->getAsTyped();
TIntermTyped* arg2 = argAggregate->getSequence()[2]->getAsTyped();
TIntermTyped* sinStatement = handleUnaryMath(loc, "sin", EOpSin, arg0);
TIntermTyped* cosStatement = handleUnaryMath(loc, "cos", EOpCos, arg0);
TIntermTyped* sinAssign = intermediate.addAssign(EOpAssign, arg1, sinStatement, loc);
TIntermTyped* cosAssign = intermediate.addAssign(EOpAssign, arg2, cosStatement, loc);
TIntermAggregate* compoundStatement = intermediate.makeAggregate(sinAssign, loc);
compoundStatement = intermediate.growAggregate(compoundStatement, cosAssign);
compoundStatement->setOperator(EOpSequence);
compoundStatement->setLoc(loc);
node = compoundStatement;
break;
}
case EOpClip:
{
// clip(a) -> if (any(a<0)) discard;
TIntermTyped* arg0 = fnUnary->getOperand();
TBasicType type0 = arg0->getBasicType();
TIntermTyped* compareNode = nullptr;
// For non-scalars: per experiment with FXC compiler, discard if any component < 0.
if (!arg0->isScalar()) {
// component-wise compare: a < 0
TIntermAggregate* less = new TIntermAggregate(EOpLessThan);
less->getSequence().push_back(arg0);
less->setLoc(loc);
// make vec or mat of bool matching dimensions of input
less->setType(TType(EbtBool, EvqTemporary,
arg0->getType().getVectorSize(),
arg0->getType().getMatrixCols(),
arg0->getType().getMatrixRows(),
arg0->getType().isVector()));
// calculate # of components for comparison const
const int constComponentCount =
std::max(arg0->getType().getVectorSize(), 1) *
std::max(arg0->getType().getMatrixCols(), 1) *
std::max(arg0->getType().getMatrixRows(), 1);
TConstUnion zero;
zero.setDConst(0.0);
TConstUnionArray zeros(constComponentCount, zero);
less->getSequence().push_back(intermediate.addConstantUnion(zeros, arg0->getType(), loc, true));
compareNode = intermediate.addBuiltInFunctionCall(loc, EOpAny, true, less, TType(EbtBool));
} else {
TIntermTyped* zero = intermediate.addConstantUnion(0, type0, loc, true);
compareNode = handleBinaryMath(loc, "clip", EOpLessThan, arg0, zero);
}
TIntermBranch* killNode = intermediate.addBranch(EOpKill, loc);
node = new TIntermSelection(compareNode, killNode, nullptr);
node->setLoc(loc);
break;
}
case EOpLog10:
{
// log10(a) -> log2(a) * 0.301029995663981 (== 1/log2(10))
TIntermTyped* arg0 = fnUnary->getOperand();
TIntermTyped* log2 = handleUnaryMath(loc, "log2", EOpLog2, arg0);
TIntermTyped* base = intermediate.addConstantUnion(0.301029995663981f, EbtFloat, loc, true);
node = handleBinaryMath(loc, "mul", EOpMul, log2, base);
break;
}
case EOpDst:
{
// dest.x = 1;
// dest.y = src0.y * src1.y;
// dest.z = src0.z;
// dest.w = src1.w;
TIntermTyped* arg0 = argAggregate->getSequence()[0]->getAsTyped();
TIntermTyped* arg1 = argAggregate->getSequence()[1]->getAsTyped();
TBasicType type0 = arg0->getBasicType();
TIntermTyped* x = intermediate.addConstantUnion(0, loc, true);
TIntermTyped* y = intermediate.addConstantUnion(1, loc, true);
TIntermTyped* z = intermediate.addConstantUnion(2, loc, true);
TIntermTyped* w = intermediate.addConstantUnion(3, loc, true);
TIntermTyped* src0y = intermediate.addIndex(EOpIndexDirect, arg0, y, loc);
TIntermTyped* src1y = intermediate.addIndex(EOpIndexDirect, arg1, y, loc);
TIntermTyped* src0z = intermediate.addIndex(EOpIndexDirect, arg0, z, loc);
TIntermTyped* src1w = intermediate.addIndex(EOpIndexDirect, arg1, w, loc);
TIntermAggregate* dst = new TIntermAggregate(EOpConstructVec4);
dst->getSequence().push_back(intermediate.addConstantUnion(1.0, EbtFloat, loc, true));
dst->getSequence().push_back(handleBinaryMath(loc, "mul", EOpMul, src0y, src1y));
dst->getSequence().push_back(src0z);
dst->getSequence().push_back(src1w);
dst->setLoc(loc);
node = dst;
break;
}
default:
break; // most pass through unchanged
}
}
//
// Handle seeing function call syntax in the grammar, which could be any of
// - .length() method
@@ -872,6 +1051,8 @@ TIntermTyped* HlslParseContext::handleFunctionCall(const TSourceLoc& loc, TFunct
}
result = addOutputArgumentConversions(*fnCandidate, *result->getAsAggregate());
}
decomposeIntrinsic(loc, result, arguments);
}
}

View File

@@ -1,5 +1,6 @@
//
//Copyright (C) 2016 Google, Inc.
//Copyright (C) 2016 LunarG, Inc.
//
//All rights reserved.
//
@@ -85,6 +86,7 @@ public:
TIntermAggregate* handleFunctionDefinition(const TSourceLoc&, TFunction&);
void handleFunctionArgument(TFunction*, TIntermTyped*& arguments, TIntermTyped* newArg);
TIntermTyped* handleFunctionCall(const TSourceLoc&, TFunction*, TIntermNode*);
void decomposeIntrinsic(const TSourceLoc&, TIntermTyped*& node, TIntermNode* arguments);
TIntermTyped* handleLengthMethod(const TSourceLoc&, TFunction*, TIntermNode*);
void addInputArgumentConversions(const TFunction&, TIntermNode*&) const;
TIntermTyped* addOutputArgumentConversions(const TFunction&, TIntermAggregate&) const;

View File

@@ -279,7 +279,7 @@ void TBuiltInParseablesHlsl::initialize(int version, EProfile profile, int spv,
{ "DeviceMemoryBarrierWithGroupSync", nullptr, nullptr, "-", "-", EShLangComputeMask },
{ "distance", "S", "F", "V,", "F,", EShLangAll },
{ "dot", "S", nullptr, "V,", "FI,", EShLangAll },
{ "dst", nullptr, nullptr, "V,", "F,", EShLangAll },
{ "dst", nullptr, nullptr, "V4,V4", "F,", EShLangAll },
// { "errorf", "-", "-", "", "", EShLangAll }, TODO: varargs
{ "EvaluateAttributeAtCentroid", nullptr, nullptr, "SVM", "F", EShLangFragmentMask },
{ "EvaluateAttributeAtSample", nullptr, nullptr, "SVM,S", "F,U", EShLangFragmentMask },
@@ -324,6 +324,7 @@ void TBuiltInParseablesHlsl::initialize(int version, EProfile profile, int spv,
{ "min", nullptr, nullptr, "SVM,", "FI,", EShLangAll },
{ "modf", nullptr, nullptr, "SVM,>", "FI,", EShLangAll },
{ "msad4", "V4", "U", "S,V2,V4", "U,,", EShLangAll },
// TODO: fix matrix return size for non-square mats used with mul opcode
{ "mul", "S", nullptr, "S,S", "FI,", EShLangAll },
{ "mul", "V", nullptr, "S,V", "FI,", EShLangAll },
{ "mul", "M", nullptr, "S,M", "FI,", EShLangAll },
@@ -508,7 +509,7 @@ void TBuiltInParseablesHlsl::initialize(const TBuiltInResource &resources, int v
void TBuiltInParseablesHlsl::identifyBuiltIns(int version, EProfile profile, int spv, int vulkan, EShLanguage language,
TSymbolTable& symbolTable)
{
// symbolTable.relateToOperator("abort");
// symbolTable.relateToOperator("abort", EOpAbort);
symbolTable.relateToOperator("abs", EOpAbs);
symbolTable.relateToOperator("acos", EOpAcos);
symbolTable.relateToOperator("all", EOpAll);
@@ -525,12 +526,12 @@ void TBuiltInParseablesHlsl::identifyBuiltIns(int version, EProfile profile, int
symbolTable.relateToOperator("ceil", EOpCeil);
// symbolTable.relateToOperator("CheckAccessFullyMapped");
symbolTable.relateToOperator("clamp", EOpClamp);
// symbolTable.relateToOperator("clip");
symbolTable.relateToOperator("clip", EOpClip);
symbolTable.relateToOperator("cos", EOpCos);
symbolTable.relateToOperator("cosh", EOpCosh);
symbolTable.relateToOperator("countbits", EOpBitCount);
symbolTable.relateToOperator("cross", EOpCross);
// symbolTable.relateToOperator("D3DCOLORtoUBYTE4");
// symbolTable.relateToOperator("D3DCOLORtoUBYTE4", EOpD3DCOLORtoUBYTE4);
symbolTable.relateToOperator("ddx", EOpDPdx);
symbolTable.relateToOperator("ddx_coarse", EOpDPdxCoarse);
symbolTable.relateToOperator("ddx_fine", EOpDPdxFine);
@@ -543,7 +544,7 @@ void TBuiltInParseablesHlsl::identifyBuiltIns(int version, EProfile profile, int
// symbolTable.relateToOperator("DeviceMemoryBarrierWithGroupSync");
symbolTable.relateToOperator("distance", EOpDistance);
symbolTable.relateToOperator("dot", EOpDot);
// symbolTable.relateToOperator("dst");
symbolTable.relateToOperator("dst", EOpDst);
// symbolTable.relateToOperator("errorf");
symbolTable.relateToOperator("EvaluateAttributeAtCentroid", EOpInterpolateAtCentroid);
symbolTable.relateToOperator("EvaluateAttributeAtSample", EOpInterpolateAtSample);
@@ -557,7 +558,7 @@ void TBuiltInParseablesHlsl::identifyBuiltIns(int version, EProfile profile, int
symbolTable.relateToOperator("firstbitlow", EOpFindLSB);
symbolTable.relateToOperator("floor", EOpFloor);
symbolTable.relateToOperator("fma", EOpFma);
// symbolTable.relateToOperator("fmod");
symbolTable.relateToOperator("fmod", EOpMod);
symbolTable.relateToOperator("frac", EOpFract);
symbolTable.relateToOperator("frexp", EOpFrexp);
symbolTable.relateToOperator("fwidth", EOpFwidth);
@@ -574,21 +575,21 @@ void TBuiltInParseablesHlsl::identifyBuiltIns(int version, EProfile profile, int
// symbolTable.relateToOperator("InterlockedMin");
// symbolTable.relateToOperator("InterlockedOr");
// symbolTable.relateToOperator("InterlockedXor");
// symbolTable.relateToOperator("isfinite");
symbolTable.relateToOperator("isfinite", EOpIsFinite);
symbolTable.relateToOperator("isinf", EOpIsInf);
symbolTable.relateToOperator("isnan", EOpIsNan);
symbolTable.relateToOperator("ldexp", EOpLdexp);
symbolTable.relateToOperator("length", EOpLength);
// symbolTable.relateToOperator("lit");
symbolTable.relateToOperator("log", EOpLog);
// symbolTable.relateToOperator("log10");
symbolTable.relateToOperator("log10", EOpLog10);
symbolTable.relateToOperator("log2", EOpLog2);
// symbolTable.relateToOperator("mad");
symbolTable.relateToOperator("max", EOpMax);
symbolTable.relateToOperator("min", EOpMin);
symbolTable.relateToOperator("modf", EOpModf);
// symbolTable.relateToOperator("msad4");
// symbolTable.relateToOperator("mul");
// symbolTable.relateToOperator("msad4", EOpMsad4);
symbolTable.relateToOperator("mul", EOpGenMul);
// symbolTable.relateToOperator("noise", EOpNoise); // TODO: check return type
symbolTable.relateToOperator("normalize", EOpNormalize);
symbolTable.relateToOperator("pow", EOpPow);
@@ -604,16 +605,16 @@ void TBuiltInParseablesHlsl::identifyBuiltIns(int version, EProfile profile, int
// symbolTable.relateToOperator("ProcessTriTessFactorsMax");
// symbolTable.relateToOperator("ProcessTriTessFactorsMin");
symbolTable.relateToOperator("radians", EOpRadians);
// symbolTable.relateToOperator("rcp");
symbolTable.relateToOperator("rcp", EOpRcp);
symbolTable.relateToOperator("reflect", EOpReflect);
symbolTable.relateToOperator("refract", EOpRefract);
symbolTable.relateToOperator("reversebits", EOpBitFieldReverse);
symbolTable.relateToOperator("round", EOpRoundEven);
symbolTable.relateToOperator("rsqrt", EOpInverseSqrt);
// symbolTable.relateToOperator("saturate");
symbolTable.relateToOperator("saturate", EOpSaturate);
symbolTable.relateToOperator("sign", EOpSign);
symbolTable.relateToOperator("sin", EOpSin);
// symbolTable.relateToOperator("sincos");
symbolTable.relateToOperator("sincos", EOpSinCos);
symbolTable.relateToOperator("sinh", EOpSinh);
symbolTable.relateToOperator("smoothstep", EOpSmoothStep);
symbolTable.relateToOperator("sqrt", EOpSqrt);