SPV: Implement all matrix operators {+,-,*,/} for {matrix,scalar,vector}.

This commit is contained in:
John Kessenich
2015-12-12 12:28:14 -07:00
parent 494a02a2b0
commit 04bb8a01d6
9 changed files with 1140 additions and 76 deletions

View File

@@ -108,6 +108,7 @@ protected:
spv::Id handleUserFunctionCall(const glslang::TIntermAggregate*);
spv::Id createBinaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, spv::Id left, spv::Id right, glslang::TBasicType typeProxy, bool reduceComparison = true);
spv::Id createBinaryMatrixOperation(spv::Op, spv::Decoration precision, spv::Id typeId, spv::Id left, spv::Id right);
spv::Id createUnaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, spv::Id operand,glslang::TBasicType typeProxy);
spv::Id createConversion(glslang::TOperator op, spv::Decoration precision, spv::Id destTypeId, spv::Id operand);
spv::Id makeSmearedConstant(spv::Id constant, int vectorSize);
@@ -2122,26 +2123,17 @@ spv::Id TGlslangToSpvTraverser::createBinaryOperation(glslang::TOperator op, spv
break;
case glslang::EOpVectorTimesMatrix:
case glslang::EOpVectorTimesMatrixAssign:
assert(builder.isVector(left));
assert(builder.isMatrix(right));
binOp = spv::OpVectorTimesMatrix;
break;
case glslang::EOpMatrixTimesVector:
assert(builder.isMatrix(left));
assert(builder.isVector(right));
binOp = spv::OpMatrixTimesVector;
break;
case glslang::EOpMatrixTimesScalar:
case glslang::EOpMatrixTimesScalarAssign:
if (builder.isMatrix(right))
std::swap(left, right);
assert(builder.isScalar(right));
binOp = spv::OpMatrixTimesScalar;
break;
case glslang::EOpMatrixTimesMatrix:
case glslang::EOpMatrixTimesMatrixAssign:
assert(builder.isMatrix(left));
assert(builder.isMatrix(right));
binOp = spv::OpMatrixTimesMatrix;
break;
case glslang::EOpOuterProduct:
@@ -2220,29 +2212,8 @@ spv::Id TGlslangToSpvTraverser::createBinaryOperation(glslang::TOperator op, spv
// handle mapped binary operations (should be non-comparison)
if (binOp != spv::OpNop) {
assert(comparison == false);
if (builder.isMatrix(left) || builder.isMatrix(right)) {
switch (binOp) {
case spv::OpMatrixTimesScalar:
case spv::OpVectorTimesMatrix:
case spv::OpMatrixTimesVector:
case spv::OpMatrixTimesMatrix:
break;
case spv::OpFDiv:
// turn it into a multiply...
assert(builder.isMatrix(left) && builder.isScalar(right));
right = builder.createBinOp(spv::OpFDiv, builder.getTypeId(right), builder.makeFloatConstant(1.0F), right);
binOp = spv::OpFMul;
break;
default:
spv::MissingFunctionality("binary operation on matrix");
break;
}
spv::Id id = builder.createBinOp(binOp, typeId, left, right);
builder.setPrecision(id, precision);
return id;
}
if (builder.isMatrix(left) || builder.isMatrix(right))
return createBinaryMatrixOperation(binOp, precision, typeId, left, right);
// No matrix involved; make both operands be the same number of components, if needed
if (needMatchingVectors)
@@ -2326,6 +2297,111 @@ spv::Id TGlslangToSpvTraverser::createBinaryOperation(glslang::TOperator op, spv
return 0;
}
//
// Translate AST matrix operation to SPV operation, already having SPV-based operands/types.
// These can be any of:
//
// matrix * scalar
// scalar * matrix
// matrix * matrix linear algebraic
// matrix * vector
// vector * matrix
// matrix * matrix componentwise
// matrix op matrix op in {+, -, /}
// matrix op scalar op in {+, -, /}
// scalar op matrix op in {+, -, /}
//
spv::Id TGlslangToSpvTraverser::createBinaryMatrixOperation(spv::Op op, spv::Decoration precision, spv::Id typeId, spv::Id left, spv::Id right)
{
bool firstClass = true;
// First, handle first-class matrix operations (* and matrix/scalar)
switch (op) {
case spv::OpFDiv:
if (builder.isMatrix(left) && builder.isScalar(right)) {
// turn matrix / scalar into a multiply...
right = builder.createBinOp(spv::OpFDiv, builder.getTypeId(right), builder.makeFloatConstant(1.0F), right);
op = spv::OpMatrixTimesScalar;
} else
firstClass = false;
break;
case spv::OpMatrixTimesScalar:
if (builder.isMatrix(right))
std::swap(left, right);
assert(builder.isScalar(right));
break;
case spv::OpVectorTimesMatrix:
assert(builder.isVector(left));
assert(builder.isMatrix(right));
break;
case spv::OpMatrixTimesVector:
assert(builder.isMatrix(left));
assert(builder.isVector(right));
break;
case spv::OpMatrixTimesMatrix:
assert(builder.isMatrix(left));
assert(builder.isMatrix(right));
break;
default:
firstClass = false;
break;
}
if (firstClass) {
spv::Id id = builder.createBinOp(op, typeId, left, right);
builder.setPrecision(id, precision);
return id;
}
// Handle component-wise +, -, *, and / for all combinations of type.
// The result type of all of them is the same type as the (a) matrix operand.
// The algorithm is to:
// - break the matrix(es) into vectors
// - smear any scalar to a vector
// - do vector operations
// - make a matrix out the vector results
switch (op) {
case spv::OpFAdd:
case spv::OpFSub:
case spv::OpFDiv:
case spv::OpFMul:
{
// one time set up...
bool leftMat = builder.isMatrix(left);
bool rightMat = builder.isMatrix(right);
unsigned int numCols = leftMat ? builder.getNumColumns(left) : builder.getNumColumns(right);
int numRows = leftMat ? builder.getNumRows(left) : builder.getNumRows(right);
spv::Id scalarType = builder.getScalarTypeId(typeId);
spv::Id vecType = builder.makeVectorType(scalarType, numRows);
std::vector<spv::Id> results;
spv::Id smearVec = spv::NoResult;
if (builder.isScalar(left))
smearVec = builder.smearScalar(precision, left, vecType);
else if (builder.isScalar(right))
smearVec = builder.smearScalar(precision, right, vecType);
// do each vector op
for (unsigned int c = 0; c < numCols; ++c) {
std::vector<unsigned int> indexes;
indexes.push_back(c);
spv::Id leftVec = leftMat ? builder.createCompositeExtract( left, vecType, indexes) : smearVec;
spv::Id rightVec = rightMat ? builder.createCompositeExtract(right, vecType, indexes) : smearVec;
results.push_back(builder.createBinOp(op, vecType, leftVec, rightVec));
builder.setPrecision(results.back(), precision);
}
// put the pieces together
spv::Id id = builder.createCompositeConstruct(typeId, results);
builder.setPrecision(id, precision);
return id;
}
default:
assert(0);
return spv::NoResult;
}
}
spv::Id TGlslangToSpvTraverser::createUnaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, spv::Id operand, glslang::TBasicType typeProxy)
{
spv::Op unaryOp = spv::OpNop;