SPV: Implement all matrix operators {+,-,*,/} for {matrix,scalar,vector}.
This commit is contained in:
@@ -108,6 +108,7 @@ protected:
|
||||
spv::Id handleUserFunctionCall(const glslang::TIntermAggregate*);
|
||||
|
||||
spv::Id createBinaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, spv::Id left, spv::Id right, glslang::TBasicType typeProxy, bool reduceComparison = true);
|
||||
spv::Id createBinaryMatrixOperation(spv::Op, spv::Decoration precision, spv::Id typeId, spv::Id left, spv::Id right);
|
||||
spv::Id createUnaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, spv::Id operand,glslang::TBasicType typeProxy);
|
||||
spv::Id createConversion(glslang::TOperator op, spv::Decoration precision, spv::Id destTypeId, spv::Id operand);
|
||||
spv::Id makeSmearedConstant(spv::Id constant, int vectorSize);
|
||||
@@ -2122,26 +2123,17 @@ spv::Id TGlslangToSpvTraverser::createBinaryOperation(glslang::TOperator op, spv
|
||||
break;
|
||||
case glslang::EOpVectorTimesMatrix:
|
||||
case glslang::EOpVectorTimesMatrixAssign:
|
||||
assert(builder.isVector(left));
|
||||
assert(builder.isMatrix(right));
|
||||
binOp = spv::OpVectorTimesMatrix;
|
||||
break;
|
||||
case glslang::EOpMatrixTimesVector:
|
||||
assert(builder.isMatrix(left));
|
||||
assert(builder.isVector(right));
|
||||
binOp = spv::OpMatrixTimesVector;
|
||||
break;
|
||||
case glslang::EOpMatrixTimesScalar:
|
||||
case glslang::EOpMatrixTimesScalarAssign:
|
||||
if (builder.isMatrix(right))
|
||||
std::swap(left, right);
|
||||
assert(builder.isScalar(right));
|
||||
binOp = spv::OpMatrixTimesScalar;
|
||||
break;
|
||||
case glslang::EOpMatrixTimesMatrix:
|
||||
case glslang::EOpMatrixTimesMatrixAssign:
|
||||
assert(builder.isMatrix(left));
|
||||
assert(builder.isMatrix(right));
|
||||
binOp = spv::OpMatrixTimesMatrix;
|
||||
break;
|
||||
case glslang::EOpOuterProduct:
|
||||
@@ -2220,29 +2212,8 @@ spv::Id TGlslangToSpvTraverser::createBinaryOperation(glslang::TOperator op, spv
|
||||
// handle mapped binary operations (should be non-comparison)
|
||||
if (binOp != spv::OpNop) {
|
||||
assert(comparison == false);
|
||||
if (builder.isMatrix(left) || builder.isMatrix(right)) {
|
||||
switch (binOp) {
|
||||
case spv::OpMatrixTimesScalar:
|
||||
case spv::OpVectorTimesMatrix:
|
||||
case spv::OpMatrixTimesVector:
|
||||
case spv::OpMatrixTimesMatrix:
|
||||
break;
|
||||
case spv::OpFDiv:
|
||||
// turn it into a multiply...
|
||||
assert(builder.isMatrix(left) && builder.isScalar(right));
|
||||
right = builder.createBinOp(spv::OpFDiv, builder.getTypeId(right), builder.makeFloatConstant(1.0F), right);
|
||||
binOp = spv::OpFMul;
|
||||
break;
|
||||
default:
|
||||
spv::MissingFunctionality("binary operation on matrix");
|
||||
break;
|
||||
}
|
||||
|
||||
spv::Id id = builder.createBinOp(binOp, typeId, left, right);
|
||||
builder.setPrecision(id, precision);
|
||||
|
||||
return id;
|
||||
}
|
||||
if (builder.isMatrix(left) || builder.isMatrix(right))
|
||||
return createBinaryMatrixOperation(binOp, precision, typeId, left, right);
|
||||
|
||||
// No matrix involved; make both operands be the same number of components, if needed
|
||||
if (needMatchingVectors)
|
||||
@@ -2326,6 +2297,111 @@ spv::Id TGlslangToSpvTraverser::createBinaryOperation(glslang::TOperator op, spv
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Translate AST matrix operation to SPV operation, already having SPV-based operands/types.
|
||||
// These can be any of:
|
||||
//
|
||||
// matrix * scalar
|
||||
// scalar * matrix
|
||||
// matrix * matrix linear algebraic
|
||||
// matrix * vector
|
||||
// vector * matrix
|
||||
// matrix * matrix componentwise
|
||||
// matrix op matrix op in {+, -, /}
|
||||
// matrix op scalar op in {+, -, /}
|
||||
// scalar op matrix op in {+, -, /}
|
||||
//
|
||||
spv::Id TGlslangToSpvTraverser::createBinaryMatrixOperation(spv::Op op, spv::Decoration precision, spv::Id typeId, spv::Id left, spv::Id right)
|
||||
{
|
||||
bool firstClass = true;
|
||||
|
||||
// First, handle first-class matrix operations (* and matrix/scalar)
|
||||
switch (op) {
|
||||
case spv::OpFDiv:
|
||||
if (builder.isMatrix(left) && builder.isScalar(right)) {
|
||||
// turn matrix / scalar into a multiply...
|
||||
right = builder.createBinOp(spv::OpFDiv, builder.getTypeId(right), builder.makeFloatConstant(1.0F), right);
|
||||
op = spv::OpMatrixTimesScalar;
|
||||
} else
|
||||
firstClass = false;
|
||||
break;
|
||||
case spv::OpMatrixTimesScalar:
|
||||
if (builder.isMatrix(right))
|
||||
std::swap(left, right);
|
||||
assert(builder.isScalar(right));
|
||||
break;
|
||||
case spv::OpVectorTimesMatrix:
|
||||
assert(builder.isVector(left));
|
||||
assert(builder.isMatrix(right));
|
||||
break;
|
||||
case spv::OpMatrixTimesVector:
|
||||
assert(builder.isMatrix(left));
|
||||
assert(builder.isVector(right));
|
||||
break;
|
||||
case spv::OpMatrixTimesMatrix:
|
||||
assert(builder.isMatrix(left));
|
||||
assert(builder.isMatrix(right));
|
||||
break;
|
||||
default:
|
||||
firstClass = false;
|
||||
break;
|
||||
}
|
||||
|
||||
if (firstClass) {
|
||||
spv::Id id = builder.createBinOp(op, typeId, left, right);
|
||||
builder.setPrecision(id, precision);
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
// Handle component-wise +, -, *, and / for all combinations of type.
|
||||
// The result type of all of them is the same type as the (a) matrix operand.
|
||||
// The algorithm is to:
|
||||
// - break the matrix(es) into vectors
|
||||
// - smear any scalar to a vector
|
||||
// - do vector operations
|
||||
// - make a matrix out the vector results
|
||||
switch (op) {
|
||||
case spv::OpFAdd:
|
||||
case spv::OpFSub:
|
||||
case spv::OpFDiv:
|
||||
case spv::OpFMul:
|
||||
{
|
||||
// one time set up...
|
||||
bool leftMat = builder.isMatrix(left);
|
||||
bool rightMat = builder.isMatrix(right);
|
||||
unsigned int numCols = leftMat ? builder.getNumColumns(left) : builder.getNumColumns(right);
|
||||
int numRows = leftMat ? builder.getNumRows(left) : builder.getNumRows(right);
|
||||
spv::Id scalarType = builder.getScalarTypeId(typeId);
|
||||
spv::Id vecType = builder.makeVectorType(scalarType, numRows);
|
||||
std::vector<spv::Id> results;
|
||||
spv::Id smearVec = spv::NoResult;
|
||||
if (builder.isScalar(left))
|
||||
smearVec = builder.smearScalar(precision, left, vecType);
|
||||
else if (builder.isScalar(right))
|
||||
smearVec = builder.smearScalar(precision, right, vecType);
|
||||
|
||||
// do each vector op
|
||||
for (unsigned int c = 0; c < numCols; ++c) {
|
||||
std::vector<unsigned int> indexes;
|
||||
indexes.push_back(c);
|
||||
spv::Id leftVec = leftMat ? builder.createCompositeExtract( left, vecType, indexes) : smearVec;
|
||||
spv::Id rightVec = rightMat ? builder.createCompositeExtract(right, vecType, indexes) : smearVec;
|
||||
results.push_back(builder.createBinOp(op, vecType, leftVec, rightVec));
|
||||
builder.setPrecision(results.back(), precision);
|
||||
}
|
||||
|
||||
// put the pieces together
|
||||
spv::Id id = builder.createCompositeConstruct(typeId, results);
|
||||
builder.setPrecision(id, precision);
|
||||
return id;
|
||||
}
|
||||
default:
|
||||
assert(0);
|
||||
return spv::NoResult;
|
||||
}
|
||||
}
|
||||
|
||||
spv::Id TGlslangToSpvTraverser::createUnaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, spv::Id operand, glslang::TBasicType typeProxy)
|
||||
{
|
||||
spv::Op unaryOp = spv::OpNop;
|
||||
|
||||
Reference in New Issue
Block a user