Implement support for GL_KHR_cooperative_matrix extension
This commit is contained in:
committed by
arcady-lunarg
parent
91a97b4c69
commit
808c7ed17c
@@ -176,7 +176,7 @@ protected:
|
||||
glslang::TLayoutPacking, const glslang::TQualifier&);
|
||||
void decorateStructType(const glslang::TType&, const glslang::TTypeList* glslangStruct, glslang::TLayoutPacking,
|
||||
const glslang::TQualifier&, spv::Id, const std::vector<spv::Id>& spvMembers);
|
||||
spv::Id makeArraySizeId(const glslang::TArraySizes&, int dim);
|
||||
spv::Id makeArraySizeId(const glslang::TArraySizes&, int dim, bool allowZero = false);
|
||||
spv::Id accessChainLoad(const glslang::TType& type);
|
||||
void accessChainStore(const glslang::TType& type, spv::Id rvalue);
|
||||
void multiTypeStore(const glslang::TType&, spv::Id rValue);
|
||||
@@ -212,7 +212,7 @@ protected:
|
||||
glslang::TBasicType typeProxy);
|
||||
spv::Id createConversion(glslang::TOperator op, OpDecorations&, spv::Id destTypeId, spv::Id operand,
|
||||
glslang::TBasicType typeProxy);
|
||||
spv::Id createIntWidthConversion(glslang::TOperator op, spv::Id operand, int vectorSize);
|
||||
spv::Id createIntWidthConversion(glslang::TOperator op, spv::Id operand, int vectorSize, spv::Id destType);
|
||||
spv::Id makeSmearedConstant(spv::Id constant, int vectorSize);
|
||||
spv::Id createAtomicOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId,
|
||||
std::vector<spv::Id>& operands, glslang::TBasicType typeProxy,
|
||||
@@ -2560,12 +2560,15 @@ bool TGlslangToSpvTraverser::visitUnary(glslang::TVisit /* visit */, glslang::TI
|
||||
|
||||
spv::Id length;
|
||||
if (node->getOperand()->getType().isCoopMat()) {
|
||||
spec_constant_op_mode_setter.turnOnSpecConstantOpMode();
|
||||
|
||||
spv::Id typeId = convertGlslangToSpvType(node->getOperand()->getType());
|
||||
assert(builder.isCooperativeMatrixType(typeId));
|
||||
|
||||
length = builder.createCooperativeMatrixLength(typeId);
|
||||
if (node->getOperand()->getType().isCoopMatKHR()) {
|
||||
length = builder.createCooperativeMatrixLengthKHR(typeId);
|
||||
} else {
|
||||
spec_constant_op_mode_setter.turnOnSpecConstantOpMode();
|
||||
length = builder.createCooperativeMatrixLengthNV(typeId);
|
||||
}
|
||||
} else {
|
||||
glslang::TIntermTyped* block = node->getOperand()->getAsBinaryNode()->getLeft();
|
||||
block->traverse(this);
|
||||
@@ -3099,7 +3102,8 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
||||
case glslang::EOpConstructStruct:
|
||||
case glslang::EOpConstructTextureSampler:
|
||||
case glslang::EOpConstructReference:
|
||||
case glslang::EOpConstructCooperativeMatrix:
|
||||
case glslang::EOpConstructCooperativeMatrixNV:
|
||||
case glslang::EOpConstructCooperativeMatrixKHR:
|
||||
{
|
||||
builder.setLine(node->getLoc().line, node->getLoc().getFilename());
|
||||
std::vector<spv::Id> arguments;
|
||||
@@ -3116,7 +3120,8 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
||||
} else
|
||||
constructed = builder.createOp(spv::OpSampledImage, resultType(), arguments);
|
||||
} else if (node->getOp() == glslang::EOpConstructStruct ||
|
||||
node->getOp() == glslang::EOpConstructCooperativeMatrix ||
|
||||
node->getOp() == glslang::EOpConstructCooperativeMatrixNV ||
|
||||
node->getOp() == glslang::EOpConstructCooperativeMatrixKHR ||
|
||||
node->getType().isArray()) {
|
||||
std::vector<spv::Id> constituents;
|
||||
for (int c = 0; c < (int)arguments.size(); ++c)
|
||||
@@ -3291,6 +3296,8 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
||||
break;
|
||||
case glslang::EOpCooperativeMatrixLoad:
|
||||
case glslang::EOpCooperativeMatrixStore:
|
||||
case glslang::EOpCooperativeMatrixLoadNV:
|
||||
case glslang::EOpCooperativeMatrixStoreNV:
|
||||
noReturnValue = true;
|
||||
break;
|
||||
case glslang::EOpBeginInvocationInterlock:
|
||||
@@ -3502,10 +3509,12 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
||||
lvalue = true;
|
||||
break;
|
||||
case glslang::EOpCooperativeMatrixLoad:
|
||||
case glslang::EOpCooperativeMatrixLoadNV:
|
||||
if (arg == 0 || arg == 1)
|
||||
lvalue = true;
|
||||
break;
|
||||
case glslang::EOpCooperativeMatrixStore:
|
||||
case glslang::EOpCooperativeMatrixStoreNV:
|
||||
if (arg == 1)
|
||||
lvalue = true;
|
||||
break;
|
||||
@@ -3534,7 +3543,9 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
||||
|
||||
#ifndef GLSLANG_WEB
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixLoad ||
|
||||
node->getOp() == glslang::EOpCooperativeMatrixStore) {
|
||||
node->getOp() == glslang::EOpCooperativeMatrixStore ||
|
||||
node->getOp() == glslang::EOpCooperativeMatrixLoadNV ||
|
||||
node->getOp() == glslang::EOpCooperativeMatrixStoreNV) {
|
||||
|
||||
if (arg == 1) {
|
||||
// fold "element" parameter into the access chain
|
||||
@@ -3555,9 +3566,11 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
||||
unsigned int alignment = builder.getAccessChain().alignment;
|
||||
|
||||
int memoryAccess = TranslateMemoryAccess(coherentFlags);
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixLoad)
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixLoad ||
|
||||
node->getOp() == glslang::EOpCooperativeMatrixLoadNV)
|
||||
memoryAccess &= ~spv::MemoryAccessMakePointerAvailableKHRMask;
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixStore)
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixStore ||
|
||||
node->getOp() == glslang::EOpCooperativeMatrixStoreNV)
|
||||
memoryAccess &= ~spv::MemoryAccessMakePointerVisibleKHRMask;
|
||||
if (builder.getStorageClass(builder.getAccessChain().base) ==
|
||||
spv::StorageClassPhysicalStorageBufferEXT) {
|
||||
@@ -3655,31 +3668,48 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
||||
|
||||
builder.setLine(node->getLoc().line, node->getLoc().getFilename());
|
||||
#ifndef GLSLANG_WEB
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixLoad) {
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixLoad ||
|
||||
node->getOp() == glslang::EOpCooperativeMatrixLoadNV) {
|
||||
std::vector<spv::IdImmediate> idImmOps;
|
||||
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[1])); // buf
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[3])); // colMajor
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixLoad) {
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[3])); // matrixLayout
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
|
||||
} else {
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[3])); // colMajor
|
||||
}
|
||||
idImmOps.insert(idImmOps.end(), memoryAccessOperands.begin(), memoryAccessOperands.end());
|
||||
// get the pointee type
|
||||
spv::Id typeId = builder.getContainedTypeId(builder.getTypeId(operands[0]));
|
||||
assert(builder.isCooperativeMatrixType(typeId));
|
||||
// do the op
|
||||
spv::Id result = builder.createOp(spv::OpCooperativeMatrixLoadNV, typeId, idImmOps);
|
||||
spv::Id result = node->getOp() == glslang::EOpCooperativeMatrixLoad
|
||||
? builder.createOp(spv::OpCooperativeMatrixLoadKHR, typeId, idImmOps)
|
||||
: builder.createOp(spv::OpCooperativeMatrixLoadNV, typeId, idImmOps);
|
||||
// store the result to the pointer (out param 'm')
|
||||
builder.createStore(result, operands[0]);
|
||||
result = 0;
|
||||
} else if (node->getOp() == glslang::EOpCooperativeMatrixStore) {
|
||||
} else if (node->getOp() == glslang::EOpCooperativeMatrixStore ||
|
||||
node->getOp() == glslang::EOpCooperativeMatrixStoreNV) {
|
||||
std::vector<spv::IdImmediate> idImmOps;
|
||||
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[1])); // buf
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[0])); // object
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[3])); // colMajor
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixStore) {
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[3])); // matrixLayout
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
|
||||
} else {
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[3])); // colMajor
|
||||
}
|
||||
idImmOps.insert(idImmOps.end(), memoryAccessOperands.begin(), memoryAccessOperands.end());
|
||||
|
||||
builder.createNoResultOp(spv::OpCooperativeMatrixStoreNV, idImmOps);
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixStore)
|
||||
builder.createNoResultOp(spv::OpCooperativeMatrixStoreKHR, idImmOps);
|
||||
else
|
||||
builder.createNoResultOp(spv::OpCooperativeMatrixStoreNV, idImmOps);
|
||||
result = 0;
|
||||
} else if (node->getOp() == glslang::EOpRayQueryGetIntersectionTriangleVertexPositionsEXT) {
|
||||
std::vector<spv::IdImmediate> idImmOps;
|
||||
@@ -3694,6 +3724,32 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
||||
// store the result to the pointer (out param 'm')
|
||||
builder.createStore(result, operands[2]);
|
||||
result = 0;
|
||||
} else if (node->getOp() == glslang::EOpCooperativeMatrixMulAdd) {
|
||||
uint32_t matrixOperands = 0;
|
||||
|
||||
// If the optional operand is present, initialize matrixOperands to that value.
|
||||
if (glslangOperands.size() == 4 && glslangOperands[3]->getAsConstantUnion()) {
|
||||
matrixOperands = glslangOperands[3]->getAsConstantUnion()->getConstArray()[0].getIConst();
|
||||
}
|
||||
|
||||
// Determine Cooperative Matrix Operands bits from the signedness of the types.
|
||||
if (isTypeSignedInt(glslangOperands[0]->getAsTyped()->getBasicType()))
|
||||
matrixOperands |= spv::CooperativeMatrixOperandsMatrixASignedComponentsMask;
|
||||
if (isTypeSignedInt(glslangOperands[1]->getAsTyped()->getBasicType()))
|
||||
matrixOperands |= spv::CooperativeMatrixOperandsMatrixBSignedComponentsMask;
|
||||
if (isTypeSignedInt(glslangOperands[2]->getAsTyped()->getBasicType()))
|
||||
matrixOperands |= spv::CooperativeMatrixOperandsMatrixCSignedComponentsMask;
|
||||
if (isTypeSignedInt(node->getBasicType()))
|
||||
matrixOperands |= spv::CooperativeMatrixOperandsMatrixResultSignedComponentsMask;
|
||||
|
||||
std::vector<spv::IdImmediate> idImmOps;
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[0]));
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[1]));
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[2]));
|
||||
if (matrixOperands != 0)
|
||||
idImmOps.push_back(spv::IdImmediate(false, matrixOperands));
|
||||
|
||||
result = builder.createOp(spv::OpCooperativeMatrixMulAddKHR, resultType(), idImmOps);
|
||||
} else
|
||||
#endif
|
||||
if (atomic) {
|
||||
@@ -4586,9 +4642,10 @@ spv::Id TGlslangToSpvTraverser::convertGlslangToSpvType(const glslang::TType& ty
|
||||
spvType = builder.makeVectorType(spvType, type.getVectorSize());
|
||||
}
|
||||
|
||||
if (type.isCoopMat()) {
|
||||
if (type.isCoopMatNV()) {
|
||||
builder.addCapability(spv::CapabilityCooperativeMatrixNV);
|
||||
builder.addExtension(spv::E_SPV_NV_cooperative_matrix);
|
||||
|
||||
if (type.getBasicType() == glslang::EbtFloat16)
|
||||
builder.addCapability(spv::CapabilityFloat16);
|
||||
if (type.getBasicType() == glslang::EbtUint8 ||
|
||||
@@ -4596,11 +4653,29 @@ spv::Id TGlslangToSpvTraverser::convertGlslangToSpvType(const glslang::TType& ty
|
||||
builder.addCapability(spv::CapabilityInt8);
|
||||
}
|
||||
|
||||
spv::Id scope = makeArraySizeId(*type.getTypeParameters(), 1);
|
||||
spv::Id rows = makeArraySizeId(*type.getTypeParameters(), 2);
|
||||
spv::Id cols = makeArraySizeId(*type.getTypeParameters(), 3);
|
||||
spv::Id scope = makeArraySizeId(*type.getTypeParameters()->arraySizes, 1);
|
||||
spv::Id rows = makeArraySizeId(*type.getTypeParameters()->arraySizes, 2);
|
||||
spv::Id cols = makeArraySizeId(*type.getTypeParameters()->arraySizes, 3);
|
||||
|
||||
spvType = builder.makeCooperativeMatrixType(spvType, scope, rows, cols);
|
||||
spvType = builder.makeCooperativeMatrixTypeNV(spvType, scope, rows, cols);
|
||||
}
|
||||
|
||||
if (type.isCoopMatKHR()) {
|
||||
builder.addCapability(spv::CapabilityCooperativeMatrixKHR);
|
||||
builder.addExtension(spv::E_SPV_KHR_cooperative_matrix);
|
||||
|
||||
if (type.getBasicType() == glslang::EbtFloat16)
|
||||
builder.addCapability(spv::CapabilityFloat16);
|
||||
if (type.getBasicType() == glslang::EbtUint8 || type.getBasicType() == glslang::EbtInt8) {
|
||||
builder.addCapability(spv::CapabilityInt8);
|
||||
}
|
||||
|
||||
spv::Id scope = makeArraySizeId(*type.getTypeParameters()->arraySizes, 0);
|
||||
spv::Id rows = makeArraySizeId(*type.getTypeParameters()->arraySizes, 1);
|
||||
spv::Id cols = makeArraySizeId(*type.getTypeParameters()->arraySizes, 2);
|
||||
spv::Id use = builder.makeUintConstant(type.getCoopMatKHRuse());
|
||||
|
||||
spvType = builder.makeCooperativeMatrixTypeKHR(spvType, scope, rows, cols, use);
|
||||
}
|
||||
|
||||
if (type.isArray()) {
|
||||
@@ -4951,7 +5026,7 @@ void TGlslangToSpvTraverser::decorateStructType(const glslang::TType& type,
|
||||
// This is not quite trivial, because of specialization constants.
|
||||
// Sometimes, a raw constant is turned into an Id, and sometimes
|
||||
// a specialization constant expression is.
|
||||
spv::Id TGlslangToSpvTraverser::makeArraySizeId(const glslang::TArraySizes& arraySizes, int dim)
|
||||
spv::Id TGlslangToSpvTraverser::makeArraySizeId(const glslang::TArraySizes& arraySizes, int dim, bool allowZero)
|
||||
{
|
||||
// First, see if this is sized with a node, meaning a specialization constant:
|
||||
glslang::TIntermTyped* specNode = arraySizes.getDimNode(dim);
|
||||
@@ -4965,7 +5040,10 @@ spv::Id TGlslangToSpvTraverser::makeArraySizeId(const glslang::TArraySizes& arra
|
||||
|
||||
// Otherwise, need a compile-time (front end) size, get it:
|
||||
int size = arraySizes.getDimSize(dim);
|
||||
assert(size > 0);
|
||||
|
||||
if (!allowZero)
|
||||
assert(size > 0);
|
||||
|
||||
return builder.makeUintConstant(size);
|
||||
}
|
||||
|
||||
@@ -7287,7 +7365,9 @@ spv::Id TGlslangToSpvTraverser::createUnaryMatrixOperation(spv::Op op, OpDecorat
|
||||
// For converting integers where both the bitwidth and the signedness could
|
||||
// change, but only do the width change here. The caller is still responsible
|
||||
// for the signedness conversion.
|
||||
spv::Id TGlslangToSpvTraverser::createIntWidthConversion(glslang::TOperator op, spv::Id operand, int vectorSize)
|
||||
// destType is the final type that will be converted to, but this function
|
||||
// may only be doing part of that conversion.
|
||||
spv::Id TGlslangToSpvTraverser::createIntWidthConversion(glslang::TOperator op, spv::Id operand, int vectorSize, spv::Id destType)
|
||||
{
|
||||
// Get the result type width, based on the type to convert to.
|
||||
int width = 32;
|
||||
@@ -7358,6 +7438,11 @@ spv::Id TGlslangToSpvTraverser::createIntWidthConversion(glslang::TOperator op,
|
||||
|
||||
if (vectorSize > 0)
|
||||
type = builder.makeVectorType(type, vectorSize);
|
||||
else if (builder.getOpCode(destType) == spv::OpTypeCooperativeMatrixKHR ||
|
||||
builder.getOpCode(destType) == spv::OpTypeCooperativeMatrixNV) {
|
||||
|
||||
type = builder.makeCooperativeMatrixTypeWithSameShape(type, destType);
|
||||
}
|
||||
|
||||
return builder.createUnaryOp(convOp, type, operand);
|
||||
}
|
||||
@@ -7630,7 +7715,7 @@ spv::Id TGlslangToSpvTraverser::createConversion(glslang::TOperator op, OpDecora
|
||||
case glslang::EOpConvUint64ToInt16:
|
||||
case glslang::EOpConvUint64ToInt:
|
||||
// OpSConvert/OpUConvert + OpBitCast
|
||||
operand = createIntWidthConversion(op, operand, vectorSize);
|
||||
operand = createIntWidthConversion(op, operand, vectorSize, destType);
|
||||
|
||||
if (builder.isInSpecConstCodeGenMode()) {
|
||||
// Build zero scalar or vector for OpIAdd.
|
||||
@@ -8963,7 +9048,7 @@ spv::Id TGlslangToSpvTraverser::createMiscOperation(glslang::TOperator op, spv::
|
||||
case glslang::EOpSetMeshOutputsEXT:
|
||||
builder.createNoResultOp(spv::OpSetMeshOutputsEXT, operands);
|
||||
return 0;
|
||||
case glslang::EOpCooperativeMatrixMulAdd:
|
||||
case glslang::EOpCooperativeMatrixMulAddNV:
|
||||
opCode = spv::OpCooperativeMatrixMulAddNV;
|
||||
break;
|
||||
case glslang::EOpHitObjectTraceRayNV:
|
||||
|
||||
Reference in New Issue
Block a user