iwa/source/util/reflect_glsl.cpp
2024-04-06 14:11:26 +02:00

588 lines
20 KiB
C++

#include "iwa/util/reflect_glsl.hpp"
#include <glslang/Include/InfoSink.h>
#include <glslang/Public/ShaderLang.h>
#include <glslang/MachineIndependent/localintermediate.h>
#include <glslang/Public/ResourceLimits.h>
namespace iwa
{
namespace
{
class MetaCollectingTraverser : public glslang::TIntermTraverser
{
private:
ShaderMeta& meta;
vk::ShaderStageFlagBits shaderType;
public:
inline MetaCollectingTraverser(ShaderMeta& meta_, vk::ShaderStageFlagBits shaderType_) : meta(meta_), shaderType(shaderType_)
{}
bool visitBinary(glslang::TVisit, glslang::TIntermBinary* node) override;
bool visitUnary(glslang::TVisit, glslang::TIntermUnary* node) override;
bool visitAggregate(glslang::TVisit, glslang::TIntermAggregate* node) override;
bool visitSelection(glslang::TVisit, glslang::TIntermSelection* node) override;
void visitConstantUnion(glslang::TIntermConstantUnion* node) override;
void visitSymbol(glslang::TIntermSymbol* node) override;
bool visitLoop(glslang::TVisit, glslang::TIntermLoop* node) override;
bool visitBranch(glslang::TVisit, glslang::TIntermBranch* node) override;
bool visitSwitch(glslang::TVisit, glslang::TIntermSwitch* node) override;
};
vk::Format convertGlslangBaseType(const glslang::TType& type)
{
switch (type.getBasicType())
{
case glslang::EbtInt:
return vk::Format::eR32Sint;
case glslang::EbtUint:
return vk::Format::eR32Uint;
case glslang::EbtFloat:
return vk::Format::eR32Sfloat;
case glslang::EbtDouble:
return vk::Format::eR64Sfloat;
default:
break;
}
logAndDie("Don't know how to convert Glslang basic type :*(");
}
vk::Format convertGlslangLayoutFormat(glslang::TLayoutFormat layoutFormat)
{
switch (layoutFormat)
{
case glslang::TLayoutFormat::ElfNone:
return vk::Format::eUndefined;
// Float image
case glslang::TLayoutFormat::ElfRgba32f:
return vk::Format::eR32G32B32A32Sfloat;
case glslang::TLayoutFormat::ElfRgba16f:
return vk::Format::eR16G16B16A16Sfloat;
case glslang::TLayoutFormat::ElfR32f:
return vk::Format::eR32Sfloat;
case glslang::TLayoutFormat::ElfRgba8:
return vk::Format::eR8G8B8A8Unorm;
case glslang::TLayoutFormat::ElfRgba8Snorm:
return vk::Format::eR8G8B8A8Snorm;
case glslang::TLayoutFormat::ElfRg32f:
return vk::Format::eR32G32Sfloat;
case glslang::TLayoutFormat::ElfRg16f:
return vk::Format::eR16G16Sfloat;
case glslang::TLayoutFormat::ElfR11fG11fB10f:
return vk::Format::eB10G11R11UfloatPack32; // TODO: ?
case glslang::TLayoutFormat::ElfR16f:
return vk::Format::eR16Sfloat;
case glslang::TLayoutFormat::ElfRgba16:
return vk::Format::eR16G16B16A16Unorm;
case glslang::TLayoutFormat::ElfRgb10A2:
return vk::Format::eA2R10G10B10SnormPack32; // TODO: ?
case glslang::TLayoutFormat::ElfRg16:
return vk::Format::eR16G16Unorm;
case glslang::TLayoutFormat::ElfRg8:
return vk::Format::eR8G8Unorm;
case glslang::TLayoutFormat::ElfR16:
return vk::Format::eR16Unorm;
case glslang::TLayoutFormat::ElfR8:
return vk::Format::eR8Unorm;
case glslang::TLayoutFormat::ElfRgba16Snorm:
return vk::Format::eR16G16B16A16Snorm;
case glslang::TLayoutFormat::ElfRg16Snorm:
return vk::Format::eR16G16Unorm;
case glslang::TLayoutFormat::ElfRg8Snorm:
return vk::Format::eR8G8Snorm;
case glslang::TLayoutFormat::ElfR16Snorm:
return vk::Format::eR16G16Snorm;
case glslang::TLayoutFormat::ElfR8Snorm:
return vk::Format::eR8Snorm;
// Int image
case glslang::TLayoutFormat::ElfRgba32i:
return vk::Format::eR32G32B32A32Sint;
case glslang::TLayoutFormat::ElfRgba16i:
return vk::Format::eR16G16B16A16Sint;
case glslang::TLayoutFormat::ElfRgba8i:
return vk::Format::eR8G8B8A8Sint;
case glslang::TLayoutFormat::ElfR32i:
return vk::Format::eR32Sint;
case glslang::TLayoutFormat::ElfRg32i:
return vk::Format::eR32G32Sint;
case glslang::TLayoutFormat::ElfRg16i:
return vk::Format::eR16G16Sint;
case glslang::TLayoutFormat::ElfRg8i:
return vk::Format::eR8G8Sint;
case glslang::TLayoutFormat::ElfR16i:
return vk::Format::eR16Sint;
case glslang::TLayoutFormat::ElfR8i:
return vk::Format::eR8Sint;
case glslang::TLayoutFormat::ElfR64i:
return vk::Format::eR64Sint;
// Uint image
case glslang::TLayoutFormat::ElfRgba32ui:
return vk::Format::eR32G32B32A32Uint;
case glslang::TLayoutFormat::ElfRgba16ui:
return vk::Format::eR16G16B16A16Uint;
case glslang::TLayoutFormat::ElfRgba8ui:
return vk::Format::eR8G8B8A8Uint;
case glslang::TLayoutFormat::ElfR32ui:
return vk::Format::eR32Uint;
case glslang::TLayoutFormat::ElfRg32ui:
return vk::Format::eR32G32Uint;
case glslang::TLayoutFormat::ElfRg16ui:
return vk::Format::eR16G16Uint;
case glslang::TLayoutFormat::ElfRgb10a2ui:
return vk::Format::eA2R10G10B10UintPack32;
case glslang::TLayoutFormat::ElfRg8ui:
return vk::Format::eR8G8Uint;
case glslang::TLayoutFormat::ElfR16ui:
return vk::Format::eR16Uint;
case glslang::TLayoutFormat::ElfR8ui:
return vk::Format::eR8Uint;
case glslang::TLayoutFormat::ElfR64ui:
return vk::Format::eR64Uint;
// other/unknown
case glslang::TLayoutFormat::ElfSize1x8:
case glslang::TLayoutFormat::ElfSize1x16:
case glslang::TLayoutFormat::ElfSize1x32:
case glslang::TLayoutFormat::ElfSize2x32:
case glslang::TLayoutFormat::ElfSize4x32:
case glslang::TLayoutFormat::ElfEsFloatGuard:
case glslang::TLayoutFormat::ElfFloatGuard:
case glslang::TLayoutFormat::ElfEsIntGuard:
case glslang::TLayoutFormat::ElfIntGuard:
case glslang::TLayoutFormat::ElfEsUintGuard:
case glslang::TLayoutFormat::ElfExtSizeGuard:
case glslang::TLayoutFormat::ElfCount:
break;
}
logAndDie("Unexpected format in convertGlslangLayoutFormat()."); // : {}", layoutFormat);
}
vk::Format convertGlslangVectorType(glslang::TBasicType basicType, int vectorSize)
{
switch (basicType)
{
case glslang::EbtFloat:
switch (vectorSize)
{
case 2:
return vk::Format::eR32G32Sfloat;
case 3:
return vk::Format::eR32G32B32Sfloat;
case 4:
return vk::Format::eR32G32B32A32Sfloat;
default:
break;
}
break;
case glslang::EbtDouble:
switch (vectorSize)
{
case 2:
return vk::Format::eR64G64Sfloat;
case 3:
return vk::Format::eR64G64B64Sfloat;
case 4:
return vk::Format::eR64G64B64A64Sfloat;
default:
break;
}
break;
case glslang::EbtInt:
switch (vectorSize)
{
case 2:
return vk::Format::eR32G32Sint;
case 3:
return vk::Format::eR32G32B32Sint;
case 4:
return vk::Format::eR32G32B32A32Sint;
default:
break;
}
break;
case glslang::EbtUint:
switch (vectorSize)
{
case 2:
return vk::Format::eR32G32Uint;
case 3:
return vk::Format::eR32G32B32Uint;
case 4:
return vk::Format::eR32G32B32A32Uint;
default:
break;
}
break;
case glslang::EbtBool: // NOLINT(bugprone-branch-clone) TODO: ???
break;
default:
break;
}
logAndDie("Don't know how to convert Glslang vector type :(");
}
vk::Format convertGlslangVectorType(const glslang::TType& type)
{
assert(type.isVector());
return convertGlslangVectorType(type.getBasicType(), type.getVectorSize());
}
ShaderVariableMatrixType convertGlslangMatrixType(const glslang::TType& type)
{
assert(type.isMatrix());
assert(type.getMatrixCols() == type.getMatrixRows()); // only supported types yet...
switch (type.getMatrixCols())
{
case 2:
return ShaderVariableMatrixType::MAT2;
case 3:
return ShaderVariableMatrixType::MAT3;
case 4:
return ShaderVariableMatrixType::MAT4;
default:
break;
}
logAndDie("Don't know how to convert Glslang matrix type -.-");
}
ImageDim convertGlslangSamplerDim(glslang::TSamplerDim dim)
{
switch (dim)
{
case glslang::TSamplerDim::Esd1D:
return ImageDim::ONE;
case glslang::TSamplerDim::Esd2D:
return ImageDim::TWO;
case glslang::TSamplerDim::Esd3D:
return ImageDim::THREE;
case glslang::TSamplerDim::EsdCube:
return ImageDim::CUBE;
default:
break;
}
logAndDie("Don't know how to convert Glslang sampler dimensions ...");
}
ShaderVariableType convertGlslangType(const glslang::TType& type)
{
ShaderVariableType result;
if (type.isVector())
{
result.baseType = ShaderVariableBaseType::SIMPLE;
result.simple.format = convertGlslangVectorType(type);
} else if (type.isMatrix())
{
result.baseType = ShaderVariableBaseType::MATRIX;
result.matrixType = convertGlslangMatrixType(type);
} else if (type.isStruct())
{
const std::size_t numMembers = type.getStruct()->size();
result.baseType = ShaderVariableBaseType::STRUCT;
result.struct_.members.reserve(numMembers);
std::size_t currentOffset = 0;
for (const glslang::TTypeLoc& typeLoc: *type.getStruct())
{
ShaderVariableStructMember& member = result.struct_.members.emplace_back();
member.name = typeLoc.type->getFieldName();
member.type = convertGlslangType(*typeLoc.type);
member.offset = currentOffset;
if (typeLoc.type->getQualifier().hasSemantic())
{
member.semantic = typeLoc.type->getQualifier().layoutSemantic;
}
if (typeLoc.type->getQualifier().hasSemanticIndex())
{
member.semanticIdx = typeLoc.type->getQualifier().layoutSemanticIndex;
}
currentOffset = member.offset + calcShaderTypeSize(member.type); // TODO: padding
}
} else if (type.getBasicType() == glslang::EbtSampler)
{
const glslang::TSampler& sampler = type.getSampler();
result.baseType = ShaderVariableBaseType::IMAGE;
result.image.dimensions = convertGlslangSamplerDim(sampler.dim);
result.image.format = convertGlslangLayoutFormat(type.getQualifier().layoutFormat);
} else
{
result.baseType = ShaderVariableBaseType::SIMPLE;
result.simple.format = convertGlslangBaseType(type);
}
if (type.isArray())
{
if (type.isArrayVariablyIndexed())
{
result.arraySize = 0;
result.dynamicArraySize = true;
} else
{
assert(type.getArraySizes()->getNumDims() == 1); // don't support multi dimensional arrays yet
result.arraySize = type.getOuterArraySize();
}
}
return result;
}
vk::DescriptorType getGlslangDescriptorType(const glslang::TType& type)
{
if (type.getBasicType() == glslang::EbtSampler)
{
if (type.getSampler().combined)
{
return vk::DescriptorType::eCombinedImageSampler;
}
if (type.getSampler().isImage())
{
return vk::DescriptorType::eStorageImage;
}
} else if (type.isStruct())
{
if (type.getQualifier().isUniform())
{
return vk::DescriptorType::eUniformBuffer;
}
return vk::DescriptorType::eStorageBuffer;
}
logAndDie("No idea what to do with this type :/");
}
bool MetaCollectingTraverser::visitBinary(glslang::TVisit, glslang::TIntermBinary* node)
{
(void) node;
return false;
}
bool MetaCollectingTraverser::visitUnary(glslang::TVisit, glslang::TIntermUnary* node)
{
(void) node;
return false;
}
bool MetaCollectingTraverser::visitAggregate(glslang::TVisit, glslang::TIntermAggregate* node)
{
switch (node->getOp())
{
case glslang::EOpSequence:
return true;
case glslang::EOpFunction:
break;
case glslang::EOpLinkerObjects:
return true;
default:
break;
}
return false;
}
bool MetaCollectingTraverser::visitSelection(glslang::TVisit, glslang::TIntermSelection* node)
{
(void) node;
return false;
}
void MetaCollectingTraverser::visitConstantUnion(glslang::TIntermConstantUnion* node)
{
(void) node;
}
void MetaCollectingTraverser::visitSymbol(glslang::TIntermSymbol* node)
{
const bool isLinkerObject = getParentNode()
&& getParentNode()->getAsAggregate()
&& getParentNode()->getAsAggregate()->getOp() == glslang::EOpLinkerObjects;
if (isLinkerObject)
{
if (node->getQualifier().builtIn)
{
return;
}
if (node->getQualifier().isUniformOrBuffer())
{
if (node->getQualifier().isPushConstant())
{
ShaderPushConstantBlock pushConstantBlock;
pushConstantBlock.type = convertGlslangType(node->getType());
assert(pushConstantBlock.type.baseType == ShaderVariableBaseType::STRUCT);
meta.extendPushConstant(pushConstantBlock, ShaderTypeBits::make(shaderType));
return;
}
const unsigned setIdx = node->getQualifier().hasSet() ? node->getQualifier().layoutSet : UNSPECIFIED_INDEX;
const unsigned binding = node->getQualifier().hasBinding() ? node->getQualifier().layoutBinding : UNSPECIFIED_INDEX;
ShaderVariableSet& set = meta.getOrCreateInterfaceVariableSet(setIdx);
assert(setIdx == UNSPECIFIED_INDEX || !set.getVariableAtBindingOpt(binding)); // multiple bindings at the same index?
set.usedInStages.set(shaderType, true);
ShaderVariable& var = set.variables.emplace_back();
var.binding = binding;
var.name = node->getName();
if (node->getQualifier().hasSemantic())
{
var.semantic = node->getQualifier().layoutSemantic;
}
if (node->getQualifier().hasSemanticIndex())
{
var.semanticIndex = node->getQualifier().layoutSemanticIndex;
}
// uniform blocks are identified by the name of their type
if (var.name.empty() || var.name.starts_with("anon@"))
{
const glslang::TString& typeName = node->getType().getTypeName();
if (!typeName.empty())
{
var.name = typeName;
}
}
var.descriptorType = getGlslangDescriptorType(node->getType());
var.type = convertGlslangType(node->getType());
}
else if (node->getQualifier().storage == glslang::EvqVaryingIn)
{
ShaderAttribute attribute;
attribute.stage = shaderType;
attribute.type = convertGlslangType(node->getType());
attribute.location = node->getQualifier().hasLocation() ? node->getQualifier().layoutLocation : UNSPECIFIED_INDEX;
attribute.name = node->getName();
if (node->getQualifier().hasSemantic())
{
attribute.semantic = node->getQualifier().layoutSemantic;
}
if (node->getQualifier().hasSemanticIndex())
{
attribute.semanticIndex = node->getQualifier().layoutSemanticIndex;
}
meta.addInputAttribute(std::move(attribute));
}
else if (node->getQualifier().storage == glslang::EvqVaryingOut)
{
ShaderAttribute attribute;
attribute.stage = shaderType;
attribute.type = convertGlslangType(node->getType());
attribute.location = node->getQualifier().hasLocation() ? node->getQualifier().layoutLocation : UNSPECIFIED_INDEX;
attribute.name = node->getName();
if (node->getQualifier().hasSemantic())
{
attribute.semantic = node->getQualifier().layoutSemantic;
}
if (node->getQualifier().hasSemanticIndex())
{
attribute.semanticIndex = node->getQualifier().layoutSemanticIndex;
}
meta.addOutputAttribute(std::move(attribute));
}
}
}
bool MetaCollectingTraverser::visitLoop(glslang::TVisit, glslang::TIntermLoop* node)
{
(void) node;
return false;
}
bool MetaCollectingTraverser::visitBranch(glslang::TVisit, glslang::TIntermBranch* node)
{
(void) node;
return false;
}
bool MetaCollectingTraverser::visitSwitch(glslang::TVisit, glslang::TIntermSwitch* node)
{
(void) node;
return false;
}
vk::ShaderStageFlagBits shaderStageFromGlslang(EShLanguage language)
{
switch (language)
{
case EShLangVertex:
return vk::ShaderStageFlagBits::eVertex;
case EShLangTessControl:
return vk::ShaderStageFlagBits::eTessellationControl;
case EShLangTessEvaluation:
return vk::ShaderStageFlagBits::eTessellationEvaluation;
case EShLangGeometry:
return vk::ShaderStageFlagBits::eGeometry;
case EShLangFragment:
return vk::ShaderStageFlagBits::eFragment;
case EShLangCompute:
return vk::ShaderStageFlagBits::eCompute;
case EShLangRayGen:
return vk::ShaderStageFlagBits::eRaygenKHR;
case EShLangIntersect:
return vk::ShaderStageFlagBits::eIntersectionKHR;
case EShLangAnyHit:
return vk::ShaderStageFlagBits::eAnyHitKHR;
case EShLangClosestHit:
return vk::ShaderStageFlagBits::eClosestHitKHR;
case EShLangMiss:
return vk::ShaderStageFlagBits::eMissKHR;
case EShLangCallable:
return vk::ShaderStageFlagBits::eCallableKHR;
case EShLangTask:
return vk::ShaderStageFlagBits::eTaskEXT;
case EShLangMesh:
return vk::ShaderStageFlagBits::eMeshEXT;
case EShLangCount:
break; // fall through
}
logAndDie("Invalid value passed to shaderStageFromGlslang!");
}
}
ShaderMeta reflectShader(glslang::TShader& shader)
{
return reflectIntermediate(*shader.getIntermediate(), shaderStageFromGlslang(shader.getStage()));
}
ShaderMeta reflectProgram(glslang::TProgram& program)
{
ShaderMeta result;
for (int stage = 0; stage < EShLangCount; ++stage)
{
glslang::TIntermediate* intermediate = program.getIntermediate(static_cast<EShLanguage>(stage));
if (intermediate == nullptr) {
continue;
}
result.extend(reflectIntermediate(*intermediate, shaderStageFromGlslang(static_cast<EShLanguage>(stage))));
}
return result;
}
ShaderMeta reflectIntermediate(glslang::TIntermediate& intermediate, vk::ShaderStageFlagBits stage)
{
ShaderMeta meta;
MetaCollectingTraverser traverser(meta, stage);
intermediate.getTreeRoot()->traverse(&traverser);
meta.stages.set(stage, true);
if (stage == vk::ShaderStageFlagBits::eCompute)
{
meta.localSizeX = static_cast<unsigned>(intermediate.getLocalSize(0));
meta.localSizeY = static_cast<unsigned>(intermediate.getLocalSize(1));
meta.localSizeZ = static_cast<unsigned>(intermediate.getLocalSize(2));
}
return meta;
}
} // namespace iwa