SPV: Implement composite comparisons (reductions across hierchical compare).

This commit is contained in:
John Kessenich
2015-12-21 20:54:09 -07:00
parent 59420fd356
commit 2211835b4d
6 changed files with 470 additions and 71 deletions

View File

@@ -435,7 +435,7 @@ Op Builder::getMostBasicTypeClass(Id typeId) const
}
}
int Builder::getNumTypeComponents(Id typeId) const
int Builder::getNumTypeConstituents(Id typeId) const
{
Instruction* instr = module.getInstruction(typeId);
@@ -447,7 +447,10 @@ int Builder::getNumTypeComponents(Id typeId) const
return 1;
case OpTypeVector:
case OpTypeMatrix:
case OpTypeArray:
return instr->getImmediateOperand(1);
case OpTypeStruct:
return instr->getNumOperands();
default:
assert(0);
return 1;
@@ -1411,88 +1414,78 @@ Id Builder::createTextureQueryCall(Op opCode, const TextureParameters& parameter
return query->getResultId();
}
// Comments in header
Id Builder::createCompare(Decoration precision, Id value1, Id value2, bool equal)
// External comments in header.
// Operates recursively to visit the composite's hierarchy.
Id Builder::createCompositeCompare(Decoration precision, Id value1, Id value2, bool equal)
{
Id boolType = makeBoolType();
Id valueType = getTypeId(value1);
assert(valueType == getTypeId(value2));
assert(! isScalar(value1));
// Vectors
Id resultId;
if (isVectorType(valueType)) {
Id boolVectorType = makeVectorType(boolType, getNumTypeComponents(valueType));
Id boolVector;
int numConstituents = getNumTypeConstituents(valueType);
// Scalars and Vectors
if (isScalarType(valueType) || isVectorType(valueType)) {
// These just need a single comparison, just have
// to figure out what it is.
Op op;
if (getMostBasicTypeClass(valueType) == OpTypeFloat)
switch (getMostBasicTypeClass(valueType)) {
case OpTypeFloat:
op = equal ? OpFOrdEqual : OpFOrdNotEqual;
else
break;
case OpTypeInt:
op = equal ? OpIEqual : OpINotEqual;
break;
case OpTypeBool:
op = equal ? OpLogicalEqual : OpLogicalNotEqual;
precision = NoPrecision;
break;
}
boolVector = createBinOp(op, boolVectorType, value1, value2);
setPrecision(boolVector, precision);
if (isScalarType(valueType)) {
// scalar
resultId = createBinOp(op, boolType, value1, value2);
setPrecision(resultId, precision);
} else {
// vector
resultId = createBinOp(op, makeVectorType(boolType, numConstituents), value1, value2);
setPrecision(resultId, precision);
// reduce vector compares...
resultId = createUnaryOp(equal ? OpAll : OpAny, boolType, resultId);
}
// Reduce vector compares with any() and all().
op = equal ? OpAll : OpAny;
return createUnaryOp(op, boolType, boolVector);
return resultId;
}
spv::MissingFunctionality("Composite comparison of non-vectors");
// Only structs, arrays, and matrices should be left.
// They share in common the reduction operation across their constituents.
assert(isAggregateType(valueType) || isMatrixType(valueType));
return NoResult;
// Compare each pair of constituents
for (int constituent = 0; constituent < numConstituents; ++constituent) {
std::vector<unsigned> indexes(1, constituent);
Id constituentType = getContainedTypeId(valueType, constituent);
Id constituent1 = createCompositeExtract(value1, constituentType, indexes);
Id constituent2 = createCompositeExtract(value2, constituentType, indexes);
// Recursively handle aggregates, which include matrices, arrays, and structures
// and accumulate the results.
Id subResultId = createCompositeCompare(precision, constituent1, constituent2, equal);
// Matrices
if (constituent == 0)
resultId = subResultId;
else
resultId = createBinOp(equal ? OpLogicalAnd : OpLogicalOr, boolType, resultId, subResultId);
}
// Arrays
//int numElements;
//const llvm::ArrayType* arrayType = llvm::dyn_cast<llvm::ArrayType>(value1->getType());
//if (arrayType)
// numElements = (int)arrayType->getNumElements();
//else {
// // better be structure
// const llvm::StructType* structType = llvm::dyn_cast<llvm::StructType>(value1->getType());
// assert(structType);
// numElements = structType->getNumElements();
//}
//assert(numElements > 0);
//for (int element = 0; element < numElements; ++element) {
// // Get intermediate comparison values
// llvm::Value* element1 = builder.CreateExtractValue(value1, element, "element1");
// setInstructionPrecision(element1, precision);
// llvm::Value* element2 = builder.CreateExtractValue(value2, element, "element2");
// setInstructionPrecision(element2, precision);
// llvm::Value* subResult = createCompare(precision, element1, element2, equal, "comp");
// // Accumulate intermediate comparison
// if (element == 0)
// result = subResult;
// else {
// if (equal)
// result = builder.CreateAnd(result, subResult);
// else
// result = builder.CreateOr(result, subResult);
// setInstructionPrecision(result, precision);
// }
//}
//return result;
return resultId;
}
// OpCompositeConstruct
Id Builder::createCompositeConstruct(Id typeId, std::vector<Id>& constituents)
{
assert(isAggregateType(typeId) || (getNumTypeComponents(typeId) > 1 && getNumTypeComponents(typeId) == (int)constituents.size()));
assert(isAggregateType(typeId) || (getNumTypeConstituents(typeId) > 1 && getNumTypeConstituents(typeId) == (int)constituents.size()));
Instruction* op = new Instruction(getUniqueId(), typeId, OpCompositeConstruct);
for (int c = 0; c < (int)constituents.size(); ++c)