AST/SPV: Fix #930: translate uvec4 <-> uint64 for SubgroupGeMask et. al.
On reading built-in variables SubgroupEqMask, SubgroupGeMask, SubgroupGtMask, SubgroupLeMask, and SubgroupLtMask, the AST expects 64-bit ints, while SPIR-V is defined as vectors of 32-bit ints. The declaration type has to be translated in the opposite direction.
This commit is contained in:
@@ -138,7 +138,7 @@ protected:
|
||||
spv::LoopControlMask TranslateLoopControl(const glslang::TIntermLoop&, std::vector<unsigned int>& operands) const;
|
||||
spv::StorageClass TranslateStorageClass(const glslang::TType&);
|
||||
void addIndirectionIndexCapabilities(const glslang::TType& baseType, const glslang::TType& indexType);
|
||||
spv::Id createSpvVariable(const glslang::TIntermSymbol*);
|
||||
spv::Id createSpvVariable(const glslang::TIntermSymbol*, spv::Id forcedType);
|
||||
spv::Id getSampledType(const glslang::TSampler&);
|
||||
spv::Id getInvertedSwizzleType(const glslang::TIntermTyped&);
|
||||
spv::Id createInvertedSwizzle(spv::Decoration precision, const glslang::TIntermTyped&, spv::Id parentResult);
|
||||
@@ -208,6 +208,8 @@ protected:
|
||||
if (builder.getSpvVersion() < glslang::EShTargetSpv_1_3)
|
||||
builder.addExtension(ext);
|
||||
}
|
||||
std::pair<spv::Id, spv::Id> getForcedType(spv::BuiltIn, const glslang::TType&);
|
||||
spv::Id translateForcedType(spv::Id object);
|
||||
|
||||
glslang::SpvOptions& options;
|
||||
spv::Function* shaderEntry;
|
||||
@@ -238,6 +240,10 @@ protected:
|
||||
std::unordered_map<std::string, const glslang::TIntermSymbol*> counterOriginator;
|
||||
// Map pointee types for EbtReference to their forward pointers
|
||||
std::map<const glslang::TType *, spv::Id> forwardPointers;
|
||||
// Type forcing, for when SPIR-V wants a different type than the AST,
|
||||
// requiring local translation to and from SPIR-V type on every access.
|
||||
// Maps <builtin-variable-id -> AST-required-type-id>
|
||||
std::unordered_map<spv::Id, spv::Id> forceType;
|
||||
};
|
||||
|
||||
//
|
||||
@@ -733,27 +739,27 @@ spv::BuiltIn TGlslangToSpvTraverser::TranslateBuiltInDecoration(glslang::TBuiltI
|
||||
case glslang::EbvSubGroupEqMask:
|
||||
builder.addExtension(spv::E_SPV_KHR_shader_ballot);
|
||||
builder.addCapability(spv::CapabilitySubgroupBallotKHR);
|
||||
return spv::BuiltInSubgroupEqMaskKHR;
|
||||
return spv::BuiltInSubgroupEqMask;
|
||||
|
||||
case glslang::EbvSubGroupGeMask:
|
||||
builder.addExtension(spv::E_SPV_KHR_shader_ballot);
|
||||
builder.addCapability(spv::CapabilitySubgroupBallotKHR);
|
||||
return spv::BuiltInSubgroupGeMaskKHR;
|
||||
return spv::BuiltInSubgroupGeMask;
|
||||
|
||||
case glslang::EbvSubGroupGtMask:
|
||||
builder.addExtension(spv::E_SPV_KHR_shader_ballot);
|
||||
builder.addCapability(spv::CapabilitySubgroupBallotKHR);
|
||||
return spv::BuiltInSubgroupGtMaskKHR;
|
||||
return spv::BuiltInSubgroupGtMask;
|
||||
|
||||
case glslang::EbvSubGroupLeMask:
|
||||
builder.addExtension(spv::E_SPV_KHR_shader_ballot);
|
||||
builder.addCapability(spv::CapabilitySubgroupBallotKHR);
|
||||
return spv::BuiltInSubgroupLeMaskKHR;
|
||||
return spv::BuiltInSubgroupLeMask;
|
||||
|
||||
case glslang::EbvSubGroupLtMask:
|
||||
builder.addExtension(spv::E_SPV_KHR_shader_ballot);
|
||||
builder.addCapability(spv::CapabilitySubgroupBallotKHR);
|
||||
return spv::BuiltInSubgroupLtMaskKHR;
|
||||
return spv::BuiltInSubgroupLtMask;
|
||||
|
||||
case glslang::EbvNumSubgroups:
|
||||
builder.addCapability(spv::CapabilityGroupNonUniform);
|
||||
@@ -795,6 +801,7 @@ spv::BuiltIn TGlslangToSpvTraverser::TranslateBuiltInDecoration(glslang::TBuiltI
|
||||
builder.addCapability(spv::CapabilityGroupNonUniform);
|
||||
builder.addCapability(spv::CapabilityGroupNonUniformBallot);
|
||||
return spv::BuiltInSubgroupLtMask;
|
||||
|
||||
#ifdef AMD_EXTENSIONS
|
||||
case glslang::EbvBaryCoordNoPersp:
|
||||
builder.addExtension(spv::E_SPV_AMD_shader_explicit_vertex_parameter);
|
||||
@@ -1620,8 +1627,8 @@ void TGlslangToSpvTraverser::visitSymbol(glslang::TIntermSymbol* symbol)
|
||||
// Formal function parameters were mapped during makeFunctions().
|
||||
spv::Id id = getSymbolId(symbol);
|
||||
|
||||
// Include all "static use" and "linkage only" interface variables on the OpEntryPoint instruction
|
||||
if (builder.isPointer(id)) {
|
||||
// Include all "static use" and "linkage only" interface variables on the OpEntryPoint instruction
|
||||
// Consider adding to the OpEntryPoint interface list.
|
||||
// Only looking at structures if they have at least one member.
|
||||
if (!symbol->getType().isStruct() || symbol->getType().getStruct()->size() > 0) {
|
||||
@@ -1633,6 +1640,14 @@ void TGlslangToSpvTraverser::visitSymbol(glslang::TIntermSymbol* symbol)
|
||||
iOSet.insert(id);
|
||||
}
|
||||
}
|
||||
|
||||
// If the SPIR-V type is required to be different than the AST type,
|
||||
// translate now from the SPIR-V type to the AST type, for the consuming
|
||||
// operation.
|
||||
// Note this turns it from an l-value to an r-value.
|
||||
// Currently, all symbols needing this are inputs; avoid the map lookup when non-input.
|
||||
if (symbol->getType().getQualifier().storage == glslang::EvqVaryingIn)
|
||||
id = translateForcedType(id);
|
||||
}
|
||||
|
||||
// Only process non-linkage-only nodes for generating actual static uses
|
||||
@@ -1650,8 +1665,10 @@ void TGlslangToSpvTraverser::visitSymbol(glslang::TIntermSymbol* symbol)
|
||||
// See comments in handleUserFunctionCall().
|
||||
// B) Specialization constants (normal constants don't even come in as a variable),
|
||||
// These are also pure R-values.
|
||||
// C) R-Values from type translation, see above call to translateForcedType()
|
||||
glslang::TQualifier qualifier = symbol->getQualifier();
|
||||
if (qualifier.isSpecConstant() || rValueParameters.find(symbol->getId()) != rValueParameters.end())
|
||||
if (qualifier.isSpecConstant() || rValueParameters.find(symbol->getId()) != rValueParameters.end() ||
|
||||
!builder.isPointerType(builder.getTypeId(id)))
|
||||
builder.setAccessChainRValue(id);
|
||||
else
|
||||
builder.setAccessChainLValue(id);
|
||||
@@ -1908,6 +1925,71 @@ bool TGlslangToSpvTraverser::visitBinary(glslang::TVisit /* visit */, glslang::T
|
||||
}
|
||||
}
|
||||
|
||||
// Figure out what, if any, type changes are needed when accessing a specific built-in.
|
||||
// Returns <the type SPIR-V requires for declarion, the type to translate to on use>.
|
||||
// Also see comment for 'forceType', regarding tracking SPIR-V-required types.
|
||||
std::pair<spv::Id, spv::Id> TGlslangToSpvTraverser::getForcedType(spv::BuiltIn builtIn,
|
||||
const glslang::TType& glslangType)
|
||||
{
|
||||
switch(builtIn)
|
||||
{
|
||||
case spv::BuiltInSubgroupEqMask:
|
||||
case spv::BuiltInSubgroupGeMask:
|
||||
case spv::BuiltInSubgroupGtMask:
|
||||
case spv::BuiltInSubgroupLeMask:
|
||||
case spv::BuiltInSubgroupLtMask: {
|
||||
// these require changing a 64-bit scaler -> a vector of 32-bit components
|
||||
if (glslangType.isVector())
|
||||
break;
|
||||
std::pair<spv::Id, spv::Id> ret(builder.makeVectorType(builder.makeUintType(32), 4),
|
||||
builder.makeUintType(64));
|
||||
return ret;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
std::pair<spv::Id, spv::Id> ret(spv::NoType, spv::NoType);
|
||||
return ret;
|
||||
}
|
||||
|
||||
// For an object previously identified (see getForcedType() and forceType)
|
||||
// as needing type translations, do the translation needed for a load, turning
|
||||
// an L-value into in R-value.
|
||||
spv::Id TGlslangToSpvTraverser::translateForcedType(spv::Id object)
|
||||
{
|
||||
const auto forceIt = forceType.find(object);
|
||||
if (forceIt == forceType.end())
|
||||
return object;
|
||||
|
||||
spv::Id desiredTypeId = forceIt->second;
|
||||
spv::Id objectTypeId = builder.getTypeId(object);
|
||||
assert(builder.isPointerType(objectTypeId));
|
||||
objectTypeId = builder.getContainedTypeId(objectTypeId);
|
||||
if (builder.isVectorType(objectTypeId) &&
|
||||
builder.getScalarTypeWidth(builder.getContainedTypeId(objectTypeId)) == 32) {
|
||||
if (builder.getScalarTypeWidth(desiredTypeId) == 64) {
|
||||
// handle 32-bit v.xy* -> 64-bit
|
||||
builder.clearAccessChain();
|
||||
builder.setAccessChainLValue(object);
|
||||
object = builder.accessChainLoad(spv::NoPrecision, spv::DecorationMax, objectTypeId);
|
||||
std::vector<spv::Id> components;
|
||||
components.push_back(builder.createCompositeExtract(object, builder.getContainedTypeId(objectTypeId), 0));
|
||||
components.push_back(builder.createCompositeExtract(object, builder.getContainedTypeId(objectTypeId), 1));
|
||||
|
||||
spv::Id vecType = builder.makeVectorType(builder.getContainedTypeId(objectTypeId), 2);
|
||||
return builder.createUnaryOp(spv::OpBitcast, desiredTypeId,
|
||||
builder.createCompositeConstruct(vecType, components));
|
||||
} else {
|
||||
logger->missingFunctionality("forcing 32-bit vector type to non 64-bit scalar");
|
||||
}
|
||||
} else {
|
||||
logger->missingFunctionality("forcing non 32-bit vector type");
|
||||
}
|
||||
|
||||
return object;
|
||||
}
|
||||
|
||||
bool TGlslangToSpvTraverser::visitUnary(glslang::TVisit /* visit */, glslang::TIntermUnary* node)
|
||||
{
|
||||
builder.setLine(node->getLoc().line, node->getLoc().getFilename());
|
||||
@@ -3037,7 +3119,7 @@ bool TGlslangToSpvTraverser::visitBranch(glslang::TVisit /* visit */, glslang::T
|
||||
return false;
|
||||
}
|
||||
|
||||
spv::Id TGlslangToSpvTraverser::createSpvVariable(const glslang::TIntermSymbol* node)
|
||||
spv::Id TGlslangToSpvTraverser::createSpvVariable(const glslang::TIntermSymbol* node, spv::Id forcedType)
|
||||
{
|
||||
// First, steer off constants, which are not SPIR-V variables, but
|
||||
// can still have a mapping to a SPIR-V Id.
|
||||
@@ -3050,7 +3132,8 @@ spv::Id TGlslangToSpvTraverser::createSpvVariable(const glslang::TIntermSymbol*
|
||||
|
||||
// Now, handle actual variables
|
||||
spv::StorageClass storageClass = TranslateStorageClass(node->getType());
|
||||
spv::Id spvType = convertGlslangToSpvType(node->getType());
|
||||
spv::Id spvType = forcedType == spv::NoType ? convertGlslangToSpvType(node->getType())
|
||||
: forcedType;
|
||||
|
||||
const bool contains16BitType = node->getType().containsBasicType(glslang::EbtFloat16) ||
|
||||
node->getType().containsBasicType(glslang::EbtInt16) ||
|
||||
@@ -7543,8 +7626,12 @@ spv::Id TGlslangToSpvTraverser::getSymbolId(const glslang::TIntermSymbol* symbol
|
||||
}
|
||||
|
||||
// it was not found, create it
|
||||
id = createSpvVariable(symbol);
|
||||
spv::BuiltIn builtIn = TranslateBuiltInDecoration(symbol->getQualifier().builtIn, false);
|
||||
auto forcedType = getForcedType(builtIn, symbol->getType());
|
||||
id = createSpvVariable(symbol, forcedType.first);
|
||||
symbolValues[symbol->getId()] = id;
|
||||
if (forcedType.second != spv::NoType)
|
||||
forceType[id] = forcedType.second;
|
||||
|
||||
if (symbol->getBasicType() != glslang::EbtBlock) {
|
||||
builder.addDecoration(id, TranslatePrecisionDecoration(symbol->getType()));
|
||||
@@ -7604,10 +7691,10 @@ spv::Id TGlslangToSpvTraverser::getSymbolId(const glslang::TIntermSymbol* symbol
|
||||
builder.addDecoration(id, memory[i]);
|
||||
}
|
||||
|
||||
// built-in variable decorations
|
||||
spv::BuiltIn builtIn = TranslateBuiltInDecoration(symbol->getQualifier().builtIn, false);
|
||||
if (builtIn != spv::BuiltInMax)
|
||||
// add built-in variable decoration
|
||||
if (builtIn != spv::BuiltInMax) {
|
||||
builder.addDecoration(id, spv::DecorationBuiltIn, (int)builtIn);
|
||||
}
|
||||
|
||||
// nonuniform
|
||||
builder.addDecoration(id, TranslateNonUniformDecoration(symbol->getType().getQualifier()));
|
||||
|
||||
Reference in New Issue
Block a user