SPV: Implement composite comparisons (reductions across hierchical compare).
This commit is contained in:
@@ -435,7 +435,7 @@ Op Builder::getMostBasicTypeClass(Id typeId) const
|
||||
}
|
||||
}
|
||||
|
||||
int Builder::getNumTypeComponents(Id typeId) const
|
||||
int Builder::getNumTypeConstituents(Id typeId) const
|
||||
{
|
||||
Instruction* instr = module.getInstruction(typeId);
|
||||
|
||||
@@ -447,7 +447,10 @@ int Builder::getNumTypeComponents(Id typeId) const
|
||||
return 1;
|
||||
case OpTypeVector:
|
||||
case OpTypeMatrix:
|
||||
case OpTypeArray:
|
||||
return instr->getImmediateOperand(1);
|
||||
case OpTypeStruct:
|
||||
return instr->getNumOperands();
|
||||
default:
|
||||
assert(0);
|
||||
return 1;
|
||||
@@ -1411,88 +1414,78 @@ Id Builder::createTextureQueryCall(Op opCode, const TextureParameters& parameter
|
||||
return query->getResultId();
|
||||
}
|
||||
|
||||
// Comments in header
|
||||
Id Builder::createCompare(Decoration precision, Id value1, Id value2, bool equal)
|
||||
// External comments in header.
|
||||
// Operates recursively to visit the composite's hierarchy.
|
||||
Id Builder::createCompositeCompare(Decoration precision, Id value1, Id value2, bool equal)
|
||||
{
|
||||
Id boolType = makeBoolType();
|
||||
Id valueType = getTypeId(value1);
|
||||
|
||||
assert(valueType == getTypeId(value2));
|
||||
assert(! isScalar(value1));
|
||||
|
||||
// Vectors
|
||||
Id resultId;
|
||||
|
||||
if (isVectorType(valueType)) {
|
||||
Id boolVectorType = makeVectorType(boolType, getNumTypeComponents(valueType));
|
||||
Id boolVector;
|
||||
int numConstituents = getNumTypeConstituents(valueType);
|
||||
|
||||
// Scalars and Vectors
|
||||
|
||||
if (isScalarType(valueType) || isVectorType(valueType)) {
|
||||
// These just need a single comparison, just have
|
||||
// to figure out what it is.
|
||||
Op op;
|
||||
if (getMostBasicTypeClass(valueType) == OpTypeFloat)
|
||||
switch (getMostBasicTypeClass(valueType)) {
|
||||
case OpTypeFloat:
|
||||
op = equal ? OpFOrdEqual : OpFOrdNotEqual;
|
||||
else
|
||||
break;
|
||||
case OpTypeInt:
|
||||
op = equal ? OpIEqual : OpINotEqual;
|
||||
break;
|
||||
case OpTypeBool:
|
||||
op = equal ? OpLogicalEqual : OpLogicalNotEqual;
|
||||
precision = NoPrecision;
|
||||
break;
|
||||
}
|
||||
|
||||
boolVector = createBinOp(op, boolVectorType, value1, value2);
|
||||
setPrecision(boolVector, precision);
|
||||
if (isScalarType(valueType)) {
|
||||
// scalar
|
||||
resultId = createBinOp(op, boolType, value1, value2);
|
||||
setPrecision(resultId, precision);
|
||||
} else {
|
||||
// vector
|
||||
resultId = createBinOp(op, makeVectorType(boolType, numConstituents), value1, value2);
|
||||
setPrecision(resultId, precision);
|
||||
// reduce vector compares...
|
||||
resultId = createUnaryOp(equal ? OpAll : OpAny, boolType, resultId);
|
||||
}
|
||||
|
||||
// Reduce vector compares with any() and all().
|
||||
|
||||
op = equal ? OpAll : OpAny;
|
||||
|
||||
return createUnaryOp(op, boolType, boolVector);
|
||||
return resultId;
|
||||
}
|
||||
|
||||
spv::MissingFunctionality("Composite comparison of non-vectors");
|
||||
// Only structs, arrays, and matrices should be left.
|
||||
// They share in common the reduction operation across their constituents.
|
||||
assert(isAggregateType(valueType) || isMatrixType(valueType));
|
||||
|
||||
return NoResult;
|
||||
// Compare each pair of constituents
|
||||
for (int constituent = 0; constituent < numConstituents; ++constituent) {
|
||||
std::vector<unsigned> indexes(1, constituent);
|
||||
Id constituentType = getContainedTypeId(valueType, constituent);
|
||||
Id constituent1 = createCompositeExtract(value1, constituentType, indexes);
|
||||
Id constituent2 = createCompositeExtract(value2, constituentType, indexes);
|
||||
|
||||
// Recursively handle aggregates, which include matrices, arrays, and structures
|
||||
// and accumulate the results.
|
||||
Id subResultId = createCompositeCompare(precision, constituent1, constituent2, equal);
|
||||
|
||||
// Matrices
|
||||
if (constituent == 0)
|
||||
resultId = subResultId;
|
||||
else
|
||||
resultId = createBinOp(equal ? OpLogicalAnd : OpLogicalOr, boolType, resultId, subResultId);
|
||||
}
|
||||
|
||||
// Arrays
|
||||
|
||||
//int numElements;
|
||||
//const llvm::ArrayType* arrayType = llvm::dyn_cast<llvm::ArrayType>(value1->getType());
|
||||
//if (arrayType)
|
||||
// numElements = (int)arrayType->getNumElements();
|
||||
//else {
|
||||
// // better be structure
|
||||
// const llvm::StructType* structType = llvm::dyn_cast<llvm::StructType>(value1->getType());
|
||||
// assert(structType);
|
||||
// numElements = structType->getNumElements();
|
||||
//}
|
||||
|
||||
//assert(numElements > 0);
|
||||
|
||||
//for (int element = 0; element < numElements; ++element) {
|
||||
// // Get intermediate comparison values
|
||||
// llvm::Value* element1 = builder.CreateExtractValue(value1, element, "element1");
|
||||
// setInstructionPrecision(element1, precision);
|
||||
// llvm::Value* element2 = builder.CreateExtractValue(value2, element, "element2");
|
||||
// setInstructionPrecision(element2, precision);
|
||||
|
||||
// llvm::Value* subResult = createCompare(precision, element1, element2, equal, "comp");
|
||||
|
||||
// // Accumulate intermediate comparison
|
||||
// if (element == 0)
|
||||
// result = subResult;
|
||||
// else {
|
||||
// if (equal)
|
||||
// result = builder.CreateAnd(result, subResult);
|
||||
// else
|
||||
// result = builder.CreateOr(result, subResult);
|
||||
// setInstructionPrecision(result, precision);
|
||||
// }
|
||||
//}
|
||||
|
||||
//return result;
|
||||
return resultId;
|
||||
}
|
||||
|
||||
// OpCompositeConstruct
|
||||
Id Builder::createCompositeConstruct(Id typeId, std::vector<Id>& constituents)
|
||||
{
|
||||
assert(isAggregateType(typeId) || (getNumTypeComponents(typeId) > 1 && getNumTypeComponents(typeId) == (int)constituents.size()));
|
||||
assert(isAggregateType(typeId) || (getNumTypeConstituents(typeId) > 1 && getNumTypeConstituents(typeId) == (int)constituents.size()));
|
||||
|
||||
Instruction* op = new Instruction(getUniqueId(), typeId, OpCompositeConstruct);
|
||||
for (int c = 0; c < (int)constituents.size(); ++c)
|
||||
|
||||
Reference in New Issue
Block a user