Implement support for GL_KHR_cooperative_matrix extension

This commit is contained in:
Boris Zanin
2023-03-16 13:01:01 +01:00
committed by arcady-lunarg
parent 91a97b4c69
commit 808c7ed17c
40 changed files with 8227 additions and 5733 deletions

View File

@@ -176,7 +176,7 @@ protected:
glslang::TLayoutPacking, const glslang::TQualifier&);
void decorateStructType(const glslang::TType&, const glslang::TTypeList* glslangStruct, glslang::TLayoutPacking,
const glslang::TQualifier&, spv::Id, const std::vector<spv::Id>& spvMembers);
spv::Id makeArraySizeId(const glslang::TArraySizes&, int dim);
spv::Id makeArraySizeId(const glslang::TArraySizes&, int dim, bool allowZero = false);
spv::Id accessChainLoad(const glslang::TType& type);
void accessChainStore(const glslang::TType& type, spv::Id rvalue);
void multiTypeStore(const glslang::TType&, spv::Id rValue);
@@ -212,7 +212,7 @@ protected:
glslang::TBasicType typeProxy);
spv::Id createConversion(glslang::TOperator op, OpDecorations&, spv::Id destTypeId, spv::Id operand,
glslang::TBasicType typeProxy);
spv::Id createIntWidthConversion(glslang::TOperator op, spv::Id operand, int vectorSize);
spv::Id createIntWidthConversion(glslang::TOperator op, spv::Id operand, int vectorSize, spv::Id destType);
spv::Id makeSmearedConstant(spv::Id constant, int vectorSize);
spv::Id createAtomicOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId,
std::vector<spv::Id>& operands, glslang::TBasicType typeProxy,
@@ -2560,12 +2560,15 @@ bool TGlslangToSpvTraverser::visitUnary(glslang::TVisit /* visit */, glslang::TI
spv::Id length;
if (node->getOperand()->getType().isCoopMat()) {
spec_constant_op_mode_setter.turnOnSpecConstantOpMode();
spv::Id typeId = convertGlslangToSpvType(node->getOperand()->getType());
assert(builder.isCooperativeMatrixType(typeId));
length = builder.createCooperativeMatrixLength(typeId);
if (node->getOperand()->getType().isCoopMatKHR()) {
length = builder.createCooperativeMatrixLengthKHR(typeId);
} else {
spec_constant_op_mode_setter.turnOnSpecConstantOpMode();
length = builder.createCooperativeMatrixLengthNV(typeId);
}
} else {
glslang::TIntermTyped* block = node->getOperand()->getAsBinaryNode()->getLeft();
block->traverse(this);
@@ -3099,7 +3102,8 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
case glslang::EOpConstructStruct:
case glslang::EOpConstructTextureSampler:
case glslang::EOpConstructReference:
case glslang::EOpConstructCooperativeMatrix:
case glslang::EOpConstructCooperativeMatrixNV:
case glslang::EOpConstructCooperativeMatrixKHR:
{
builder.setLine(node->getLoc().line, node->getLoc().getFilename());
std::vector<spv::Id> arguments;
@@ -3116,7 +3120,8 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
} else
constructed = builder.createOp(spv::OpSampledImage, resultType(), arguments);
} else if (node->getOp() == glslang::EOpConstructStruct ||
node->getOp() == glslang::EOpConstructCooperativeMatrix ||
node->getOp() == glslang::EOpConstructCooperativeMatrixNV ||
node->getOp() == glslang::EOpConstructCooperativeMatrixKHR ||
node->getType().isArray()) {
std::vector<spv::Id> constituents;
for (int c = 0; c < (int)arguments.size(); ++c)
@@ -3291,6 +3296,8 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
break;
case glslang::EOpCooperativeMatrixLoad:
case glslang::EOpCooperativeMatrixStore:
case glslang::EOpCooperativeMatrixLoadNV:
case glslang::EOpCooperativeMatrixStoreNV:
noReturnValue = true;
break;
case glslang::EOpBeginInvocationInterlock:
@@ -3502,10 +3509,12 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
lvalue = true;
break;
case glslang::EOpCooperativeMatrixLoad:
case glslang::EOpCooperativeMatrixLoadNV:
if (arg == 0 || arg == 1)
lvalue = true;
break;
case glslang::EOpCooperativeMatrixStore:
case glslang::EOpCooperativeMatrixStoreNV:
if (arg == 1)
lvalue = true;
break;
@@ -3534,7 +3543,9 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
#ifndef GLSLANG_WEB
if (node->getOp() == glslang::EOpCooperativeMatrixLoad ||
node->getOp() == glslang::EOpCooperativeMatrixStore) {
node->getOp() == glslang::EOpCooperativeMatrixStore ||
node->getOp() == glslang::EOpCooperativeMatrixLoadNV ||
node->getOp() == glslang::EOpCooperativeMatrixStoreNV) {
if (arg == 1) {
// fold "element" parameter into the access chain
@@ -3555,9 +3566,11 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
unsigned int alignment = builder.getAccessChain().alignment;
int memoryAccess = TranslateMemoryAccess(coherentFlags);
if (node->getOp() == glslang::EOpCooperativeMatrixLoad)
if (node->getOp() == glslang::EOpCooperativeMatrixLoad ||
node->getOp() == glslang::EOpCooperativeMatrixLoadNV)
memoryAccess &= ~spv::MemoryAccessMakePointerAvailableKHRMask;
if (node->getOp() == glslang::EOpCooperativeMatrixStore)
if (node->getOp() == glslang::EOpCooperativeMatrixStore ||
node->getOp() == glslang::EOpCooperativeMatrixStoreNV)
memoryAccess &= ~spv::MemoryAccessMakePointerVisibleKHRMask;
if (builder.getStorageClass(builder.getAccessChain().base) ==
spv::StorageClassPhysicalStorageBufferEXT) {
@@ -3655,31 +3668,48 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
builder.setLine(node->getLoc().line, node->getLoc().getFilename());
#ifndef GLSLANG_WEB
if (node->getOp() == glslang::EOpCooperativeMatrixLoad) {
if (node->getOp() == glslang::EOpCooperativeMatrixLoad ||
node->getOp() == glslang::EOpCooperativeMatrixLoadNV) {
std::vector<spv::IdImmediate> idImmOps;
idImmOps.push_back(spv::IdImmediate(true, operands[1])); // buf
idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
idImmOps.push_back(spv::IdImmediate(true, operands[3])); // colMajor
if (node->getOp() == glslang::EOpCooperativeMatrixLoad) {
idImmOps.push_back(spv::IdImmediate(true, operands[3])); // matrixLayout
idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
} else {
idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
idImmOps.push_back(spv::IdImmediate(true, operands[3])); // colMajor
}
idImmOps.insert(idImmOps.end(), memoryAccessOperands.begin(), memoryAccessOperands.end());
// get the pointee type
spv::Id typeId = builder.getContainedTypeId(builder.getTypeId(operands[0]));
assert(builder.isCooperativeMatrixType(typeId));
// do the op
spv::Id result = builder.createOp(spv::OpCooperativeMatrixLoadNV, typeId, idImmOps);
spv::Id result = node->getOp() == glslang::EOpCooperativeMatrixLoad
? builder.createOp(spv::OpCooperativeMatrixLoadKHR, typeId, idImmOps)
: builder.createOp(spv::OpCooperativeMatrixLoadNV, typeId, idImmOps);
// store the result to the pointer (out param 'm')
builder.createStore(result, operands[0]);
result = 0;
} else if (node->getOp() == glslang::EOpCooperativeMatrixStore) {
} else if (node->getOp() == glslang::EOpCooperativeMatrixStore ||
node->getOp() == glslang::EOpCooperativeMatrixStoreNV) {
std::vector<spv::IdImmediate> idImmOps;
idImmOps.push_back(spv::IdImmediate(true, operands[1])); // buf
idImmOps.push_back(spv::IdImmediate(true, operands[0])); // object
idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
idImmOps.push_back(spv::IdImmediate(true, operands[3])); // colMajor
if (node->getOp() == glslang::EOpCooperativeMatrixStore) {
idImmOps.push_back(spv::IdImmediate(true, operands[3])); // matrixLayout
idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
} else {
idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
idImmOps.push_back(spv::IdImmediate(true, operands[3])); // colMajor
}
idImmOps.insert(idImmOps.end(), memoryAccessOperands.begin(), memoryAccessOperands.end());
builder.createNoResultOp(spv::OpCooperativeMatrixStoreNV, idImmOps);
if (node->getOp() == glslang::EOpCooperativeMatrixStore)
builder.createNoResultOp(spv::OpCooperativeMatrixStoreKHR, idImmOps);
else
builder.createNoResultOp(spv::OpCooperativeMatrixStoreNV, idImmOps);
result = 0;
} else if (node->getOp() == glslang::EOpRayQueryGetIntersectionTriangleVertexPositionsEXT) {
std::vector<spv::IdImmediate> idImmOps;
@@ -3694,6 +3724,32 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
// store the result to the pointer (out param 'm')
builder.createStore(result, operands[2]);
result = 0;
} else if (node->getOp() == glslang::EOpCooperativeMatrixMulAdd) {
uint32_t matrixOperands = 0;
// If the optional operand is present, initialize matrixOperands to that value.
if (glslangOperands.size() == 4 && glslangOperands[3]->getAsConstantUnion()) {
matrixOperands = glslangOperands[3]->getAsConstantUnion()->getConstArray()[0].getIConst();
}
// Determine Cooperative Matrix Operands bits from the signedness of the types.
if (isTypeSignedInt(glslangOperands[0]->getAsTyped()->getBasicType()))
matrixOperands |= spv::CooperativeMatrixOperandsMatrixASignedComponentsMask;
if (isTypeSignedInt(glslangOperands[1]->getAsTyped()->getBasicType()))
matrixOperands |= spv::CooperativeMatrixOperandsMatrixBSignedComponentsMask;
if (isTypeSignedInt(glslangOperands[2]->getAsTyped()->getBasicType()))
matrixOperands |= spv::CooperativeMatrixOperandsMatrixCSignedComponentsMask;
if (isTypeSignedInt(node->getBasicType()))
matrixOperands |= spv::CooperativeMatrixOperandsMatrixResultSignedComponentsMask;
std::vector<spv::IdImmediate> idImmOps;
idImmOps.push_back(spv::IdImmediate(true, operands[0]));
idImmOps.push_back(spv::IdImmediate(true, operands[1]));
idImmOps.push_back(spv::IdImmediate(true, operands[2]));
if (matrixOperands != 0)
idImmOps.push_back(spv::IdImmediate(false, matrixOperands));
result = builder.createOp(spv::OpCooperativeMatrixMulAddKHR, resultType(), idImmOps);
} else
#endif
if (atomic) {
@@ -4586,9 +4642,10 @@ spv::Id TGlslangToSpvTraverser::convertGlslangToSpvType(const glslang::TType& ty
spvType = builder.makeVectorType(spvType, type.getVectorSize());
}
if (type.isCoopMat()) {
if (type.isCoopMatNV()) {
builder.addCapability(spv::CapabilityCooperativeMatrixNV);
builder.addExtension(spv::E_SPV_NV_cooperative_matrix);
if (type.getBasicType() == glslang::EbtFloat16)
builder.addCapability(spv::CapabilityFloat16);
if (type.getBasicType() == glslang::EbtUint8 ||
@@ -4596,11 +4653,29 @@ spv::Id TGlslangToSpvTraverser::convertGlslangToSpvType(const glslang::TType& ty
builder.addCapability(spv::CapabilityInt8);
}
spv::Id scope = makeArraySizeId(*type.getTypeParameters(), 1);
spv::Id rows = makeArraySizeId(*type.getTypeParameters(), 2);
spv::Id cols = makeArraySizeId(*type.getTypeParameters(), 3);
spv::Id scope = makeArraySizeId(*type.getTypeParameters()->arraySizes, 1);
spv::Id rows = makeArraySizeId(*type.getTypeParameters()->arraySizes, 2);
spv::Id cols = makeArraySizeId(*type.getTypeParameters()->arraySizes, 3);
spvType = builder.makeCooperativeMatrixType(spvType, scope, rows, cols);
spvType = builder.makeCooperativeMatrixTypeNV(spvType, scope, rows, cols);
}
if (type.isCoopMatKHR()) {
builder.addCapability(spv::CapabilityCooperativeMatrixKHR);
builder.addExtension(spv::E_SPV_KHR_cooperative_matrix);
if (type.getBasicType() == glslang::EbtFloat16)
builder.addCapability(spv::CapabilityFloat16);
if (type.getBasicType() == glslang::EbtUint8 || type.getBasicType() == glslang::EbtInt8) {
builder.addCapability(spv::CapabilityInt8);
}
spv::Id scope = makeArraySizeId(*type.getTypeParameters()->arraySizes, 0);
spv::Id rows = makeArraySizeId(*type.getTypeParameters()->arraySizes, 1);
spv::Id cols = makeArraySizeId(*type.getTypeParameters()->arraySizes, 2);
spv::Id use = builder.makeUintConstant(type.getCoopMatKHRuse());
spvType = builder.makeCooperativeMatrixTypeKHR(spvType, scope, rows, cols, use);
}
if (type.isArray()) {
@@ -4951,7 +5026,7 @@ void TGlslangToSpvTraverser::decorateStructType(const glslang::TType& type,
// This is not quite trivial, because of specialization constants.
// Sometimes, a raw constant is turned into an Id, and sometimes
// a specialization constant expression is.
spv::Id TGlslangToSpvTraverser::makeArraySizeId(const glslang::TArraySizes& arraySizes, int dim)
spv::Id TGlslangToSpvTraverser::makeArraySizeId(const glslang::TArraySizes& arraySizes, int dim, bool allowZero)
{
// First, see if this is sized with a node, meaning a specialization constant:
glslang::TIntermTyped* specNode = arraySizes.getDimNode(dim);
@@ -4965,7 +5040,10 @@ spv::Id TGlslangToSpvTraverser::makeArraySizeId(const glslang::TArraySizes& arra
// Otherwise, need a compile-time (front end) size, get it:
int size = arraySizes.getDimSize(dim);
assert(size > 0);
if (!allowZero)
assert(size > 0);
return builder.makeUintConstant(size);
}
@@ -7287,7 +7365,9 @@ spv::Id TGlslangToSpvTraverser::createUnaryMatrixOperation(spv::Op op, OpDecorat
// For converting integers where both the bitwidth and the signedness could
// change, but only do the width change here. The caller is still responsible
// for the signedness conversion.
spv::Id TGlslangToSpvTraverser::createIntWidthConversion(glslang::TOperator op, spv::Id operand, int vectorSize)
// destType is the final type that will be converted to, but this function
// may only be doing part of that conversion.
spv::Id TGlslangToSpvTraverser::createIntWidthConversion(glslang::TOperator op, spv::Id operand, int vectorSize, spv::Id destType)
{
// Get the result type width, based on the type to convert to.
int width = 32;
@@ -7358,6 +7438,11 @@ spv::Id TGlslangToSpvTraverser::createIntWidthConversion(glslang::TOperator op,
if (vectorSize > 0)
type = builder.makeVectorType(type, vectorSize);
else if (builder.getOpCode(destType) == spv::OpTypeCooperativeMatrixKHR ||
builder.getOpCode(destType) == spv::OpTypeCooperativeMatrixNV) {
type = builder.makeCooperativeMatrixTypeWithSameShape(type, destType);
}
return builder.createUnaryOp(convOp, type, operand);
}
@@ -7630,7 +7715,7 @@ spv::Id TGlslangToSpvTraverser::createConversion(glslang::TOperator op, OpDecora
case glslang::EOpConvUint64ToInt16:
case glslang::EOpConvUint64ToInt:
// OpSConvert/OpUConvert + OpBitCast
operand = createIntWidthConversion(op, operand, vectorSize);
operand = createIntWidthConversion(op, operand, vectorSize, destType);
if (builder.isInSpecConstCodeGenMode()) {
// Build zero scalar or vector for OpIAdd.
@@ -8963,7 +9048,7 @@ spv::Id TGlslangToSpvTraverser::createMiscOperation(glslang::TOperator op, spv::
case glslang::EOpSetMeshOutputsEXT:
builder.createNoResultOp(spv::OpSetMeshOutputsEXT, operands);
return 0;
case glslang::EOpCooperativeMatrixMulAdd:
case glslang::EOpCooperativeMatrixMulAddNV:
opCode = spv::OpCooperativeMatrixMulAddNV;
break;
case glslang::EOpHitObjectTraceRayNV: