Implement support for GL_KHR_cooperative_matrix extension

This commit is contained in:
Boris Zanin
2023-03-16 13:01:01 +01:00
committed by arcady-lunarg
parent 91a97b4c69
commit 808c7ed17c
40 changed files with 8227 additions and 5733 deletions

View File

@@ -481,15 +481,41 @@ Id Builder::makeMatrixType(Id component, int cols, int rows)
return type->getResultId();
}
Id Builder::makeCooperativeMatrixType(Id component, Id scope, Id rows, Id cols)
Id Builder::makeCooperativeMatrixTypeKHR(Id component, Id scope, Id rows, Id cols, Id use)
{
// try to find it
Instruction* type;
for (int t = 0; t < (int)groupedTypes[OpTypeCooperativeMatrixKHR].size(); ++t) {
type = groupedTypes[OpTypeCooperativeMatrixKHR][t];
if (type->getIdOperand(0) == component &&
type->getIdOperand(1) == scope &&
type->getIdOperand(2) == rows &&
type->getIdOperand(3) == cols &&
type->getIdOperand(4) == use)
return type->getResultId();
}
// not found, make it
type = new Instruction(getUniqueId(), NoType, OpTypeCooperativeMatrixKHR);
type->addIdOperand(component);
type->addIdOperand(scope);
type->addIdOperand(rows);
type->addIdOperand(cols);
type->addIdOperand(use);
groupedTypes[OpTypeCooperativeMatrixKHR].push_back(type);
constantsTypesGlobals.push_back(std::unique_ptr<Instruction>(type));
module.mapInstruction(type);
return type->getResultId();
}
Id Builder::makeCooperativeMatrixTypeNV(Id component, Id scope, Id rows, Id cols)
{
// try to find it
Instruction* type;
for (int t = 0; t < (int)groupedTypes[OpTypeCooperativeMatrixNV].size(); ++t) {
type = groupedTypes[OpTypeCooperativeMatrixNV][t];
if (type->getIdOperand(0) == component &&
type->getIdOperand(1) == scope &&
type->getIdOperand(2) == rows &&
if (type->getIdOperand(0) == component && type->getIdOperand(1) == scope && type->getIdOperand(2) == rows &&
type->getIdOperand(3) == cols)
return type->getResultId();
}
@@ -507,6 +533,17 @@ Id Builder::makeCooperativeMatrixType(Id component, Id scope, Id rows, Id cols)
return type->getResultId();
}
Id Builder::makeCooperativeMatrixTypeWithSameShape(Id component, Id otherType)
{
Instruction* instr = module.getInstruction(otherType);
if (instr->getOpCode() == OpTypeCooperativeMatrixNV) {
return makeCooperativeMatrixTypeNV(component, instr->getIdOperand(1), instr->getIdOperand(2), instr->getIdOperand(3));
} else {
assert(instr->getOpCode() == OpTypeCooperativeMatrixKHR);
return makeCooperativeMatrixTypeKHR(component, instr->getIdOperand(1), instr->getIdOperand(2), instr->getIdOperand(3), instr->getIdOperand(4));
}
}
Id Builder::makeGenericType(spv::Op opcode, std::vector<spv::IdImmediate>& operands)
{
// try to find it
@@ -1254,6 +1291,7 @@ int Builder::getNumTypeConstituents(Id typeId) const
}
case OpTypeStruct:
return instr->getNumOperands();
case OpTypeCooperativeMatrixKHR:
case OpTypeCooperativeMatrixNV:
// has only one constituent when used with OpCompositeConstruct.
return 1;
@@ -1303,6 +1341,7 @@ Id Builder::getContainedTypeId(Id typeId, int member) const
case OpTypeMatrix:
case OpTypeArray:
case OpTypeRuntimeArray:
case OpTypeCooperativeMatrixKHR:
case OpTypeCooperativeMatrixNV:
return instr->getIdOperand(0);
case OpTypePointer:
@@ -1769,6 +1808,7 @@ Id Builder::makeCompositeConstant(Id typeId, const std::vector<Id>& members, boo
case OpTypeVector:
case OpTypeArray:
case OpTypeMatrix:
case OpTypeCooperativeMatrixKHR:
case OpTypeCooperativeMatrixNV:
if (! specConstant) {
Id existing = findCompositeConstant(typeClass, typeId, members);
@@ -2405,7 +2445,24 @@ Id Builder::createArrayLength(Id base, unsigned int member)
return length->getResultId();
}
Id Builder::createCooperativeMatrixLength(Id type)
Id Builder::createCooperativeMatrixLengthKHR(Id type)
{
spv::Id intType = makeUintType(32);
// Generate code for spec constants if in spec constant operation
// generation mode.
if (generatingOpCodeForSpecConst) {
return createSpecConstantOp(OpCooperativeMatrixLengthKHR, intType, std::vector<Id>(1, type), std::vector<Id>());
}
Instruction* length = new Instruction(getUniqueId(), intType, OpCooperativeMatrixLengthKHR);
length->addIdOperand(type);
buildPoint->addInstruction(std::unique_ptr<Instruction>(length));
return length->getResultId();
}
Id Builder::createCooperativeMatrixLengthNV(Id type)
{
spv::Id intType = makeUintType(32);