#include "iwa/util/reflect_glsl.hpp" #include #include #include #include namespace iwa { namespace { class MetaCollectingTraverser : public glslang::TIntermTraverser { private: ShaderMeta& meta; vk::ShaderStageFlagBits shaderType; public: inline MetaCollectingTraverser(ShaderMeta& meta_, vk::ShaderStageFlagBits shaderType_) : meta(meta_), shaderType(shaderType_) {} bool visitBinary(glslang::TVisit, glslang::TIntermBinary* node) override; bool visitUnary(glslang::TVisit, glslang::TIntermUnary* node) override; bool visitAggregate(glslang::TVisit, glslang::TIntermAggregate* node) override; bool visitSelection(glslang::TVisit, glslang::TIntermSelection* node) override; void visitConstantUnion(glslang::TIntermConstantUnion* node) override; void visitSymbol(glslang::TIntermSymbol* node) override; bool visitLoop(glslang::TVisit, glslang::TIntermLoop* node) override; bool visitBranch(glslang::TVisit, glslang::TIntermBranch* node) override; bool visitSwitch(glslang::TVisit, glslang::TIntermSwitch* node) override; }; vk::Format convertGlslangBaseType(const glslang::TType& type) { switch (type.getBasicType()) { case glslang::EbtInt: return vk::Format::eR32Sint; case glslang::EbtUint: return vk::Format::eR32Uint; case glslang::EbtFloat: return vk::Format::eR32Sfloat; case glslang::EbtDouble: return vk::Format::eR64Sfloat; default: break; } logAndDie("Don't know how to convert Glslang basic type :*("); } vk::Format convertGlslangLayoutFormat(glslang::TLayoutFormat layoutFormat) { switch (layoutFormat) { case glslang::TLayoutFormat::ElfNone: return vk::Format::eUndefined; // Float image case glslang::TLayoutFormat::ElfRgba32f: return vk::Format::eR32G32B32A32Sfloat; case glslang::TLayoutFormat::ElfRgba16f: return vk::Format::eR16G16B16A16Sfloat; case glslang::TLayoutFormat::ElfR32f: return vk::Format::eR32Sfloat; case glslang::TLayoutFormat::ElfRgba8: return vk::Format::eR8G8B8A8Unorm; case glslang::TLayoutFormat::ElfRgba8Snorm: return vk::Format::eR8G8B8A8Snorm; case glslang::TLayoutFormat::ElfRg32f: return vk::Format::eR32G32Sfloat; case glslang::TLayoutFormat::ElfRg16f: return vk::Format::eR16G16Sfloat; case glslang::TLayoutFormat::ElfR11fG11fB10f: return vk::Format::eB10G11R11UfloatPack32; // TODO: ? case glslang::TLayoutFormat::ElfR16f: return vk::Format::eR16Sfloat; case glslang::TLayoutFormat::ElfRgba16: return vk::Format::eR16G16B16A16Unorm; case glslang::TLayoutFormat::ElfRgb10A2: return vk::Format::eA2R10G10B10SnormPack32; // TODO: ? case glslang::TLayoutFormat::ElfRg16: return vk::Format::eR16G16Unorm; case glslang::TLayoutFormat::ElfRg8: return vk::Format::eR8G8Unorm; case glslang::TLayoutFormat::ElfR16: return vk::Format::eR16Unorm; case glslang::TLayoutFormat::ElfR8: return vk::Format::eR8Unorm; case glslang::TLayoutFormat::ElfRgba16Snorm: return vk::Format::eR16G16B16A16Snorm; case glslang::TLayoutFormat::ElfRg16Snorm: return vk::Format::eR16G16Unorm; case glslang::TLayoutFormat::ElfRg8Snorm: return vk::Format::eR8G8Snorm; case glslang::TLayoutFormat::ElfR16Snorm: return vk::Format::eR16G16Snorm; case glslang::TLayoutFormat::ElfR8Snorm: return vk::Format::eR8Snorm; // Int image case glslang::TLayoutFormat::ElfRgba32i: return vk::Format::eR32G32B32A32Sint; case glslang::TLayoutFormat::ElfRgba16i: return vk::Format::eR16G16B16A16Sint; case glslang::TLayoutFormat::ElfRgba8i: return vk::Format::eR8G8B8A8Sint; case glslang::TLayoutFormat::ElfR32i: return vk::Format::eR32Sint; case glslang::TLayoutFormat::ElfRg32i: return vk::Format::eR32G32Sint; case glslang::TLayoutFormat::ElfRg16i: return vk::Format::eR16G16Sint; case glslang::TLayoutFormat::ElfRg8i: return vk::Format::eR8G8Sint; case glslang::TLayoutFormat::ElfR16i: return vk::Format::eR16Sint; case glslang::TLayoutFormat::ElfR8i: return vk::Format::eR8Sint; case glslang::TLayoutFormat::ElfR64i: return vk::Format::eR64Sint; // Uint image case glslang::TLayoutFormat::ElfRgba32ui: return vk::Format::eR32G32B32A32Uint; case glslang::TLayoutFormat::ElfRgba16ui: return vk::Format::eR16G16B16A16Uint; case glslang::TLayoutFormat::ElfRgba8ui: return vk::Format::eR8G8B8A8Uint; case glslang::TLayoutFormat::ElfR32ui: return vk::Format::eR32Uint; case glslang::TLayoutFormat::ElfRg32ui: return vk::Format::eR32G32Uint; case glslang::TLayoutFormat::ElfRg16ui: return vk::Format::eR16G16Uint; case glslang::TLayoutFormat::ElfRgb10a2ui: return vk::Format::eA2R10G10B10UintPack32; case glslang::TLayoutFormat::ElfRg8ui: return vk::Format::eR8G8Uint; case glslang::TLayoutFormat::ElfR16ui: return vk::Format::eR16Uint; case glslang::TLayoutFormat::ElfR8ui: return vk::Format::eR8Uint; case glslang::TLayoutFormat::ElfR64ui: return vk::Format::eR64Uint; // other/unknown case glslang::TLayoutFormat::ElfSize1x8: case glslang::TLayoutFormat::ElfSize1x16: case glslang::TLayoutFormat::ElfSize1x32: case glslang::TLayoutFormat::ElfSize2x32: case glslang::TLayoutFormat::ElfSize4x32: case glslang::TLayoutFormat::ElfEsFloatGuard: case glslang::TLayoutFormat::ElfFloatGuard: case glslang::TLayoutFormat::ElfEsIntGuard: case glslang::TLayoutFormat::ElfIntGuard: case glslang::TLayoutFormat::ElfEsUintGuard: case glslang::TLayoutFormat::ElfExtSizeGuard: case glslang::TLayoutFormat::ElfCount: break; } logAndDie("Unexpected format in convertGlslangLayoutFormat()."); // : {}", layoutFormat); } vk::Format convertGlslangVectorType(glslang::TBasicType basicType, int vectorSize) { switch (basicType) { case glslang::EbtFloat: switch (vectorSize) { case 2: return vk::Format::eR32G32Sfloat; case 3: return vk::Format::eR32G32B32Sfloat; case 4: return vk::Format::eR32G32B32A32Sfloat; default: break; } break; case glslang::EbtDouble: switch (vectorSize) { case 2: return vk::Format::eR64G64Sfloat; case 3: return vk::Format::eR64G64B64Sfloat; case 4: return vk::Format::eR64G64B64A64Sfloat; default: break; } break; case glslang::EbtInt: switch (vectorSize) { case 2: return vk::Format::eR32G32Sint; case 3: return vk::Format::eR32G32B32Sint; case 4: return vk::Format::eR32G32B32A32Sint; default: break; } break; case glslang::EbtUint: switch (vectorSize) { case 2: return vk::Format::eR32G32Uint; case 3: return vk::Format::eR32G32B32Uint; case 4: return vk::Format::eR32G32B32A32Uint; default: break; } break; case glslang::EbtBool: // NOLINT(bugprone-branch-clone) TODO: ??? break; default: break; } logAndDie("Don't know how to convert Glslang vector type :("); } vk::Format convertGlslangVectorType(const glslang::TType& type) { assert(type.isVector()); return convertGlslangVectorType(type.getBasicType(), type.getVectorSize()); } ShaderVariableMatrixType convertGlslangMatrixType(const glslang::TType& type) { assert(type.isMatrix()); assert(type.getMatrixCols() == type.getMatrixRows()); // only supported types yet... switch (type.getMatrixCols()) { case 2: return ShaderVariableMatrixType::MAT2; case 3: return ShaderVariableMatrixType::MAT3; case 4: return ShaderVariableMatrixType::MAT4; default: break; } logAndDie("Don't know how to convert Glslang matrix type -.-"); } ImageDim convertGlslangSamplerDim(glslang::TSamplerDim dim) { switch (dim) { case glslang::TSamplerDim::Esd1D: return ImageDim::ONE; case glslang::TSamplerDim::Esd2D: return ImageDim::TWO; case glslang::TSamplerDim::Esd3D: return ImageDim::THREE; case glslang::TSamplerDim::EsdCube: return ImageDim::CUBE; default: break; } logAndDie("Don't know how to convert Glslang sampler dimensions ..."); } ShaderVariableType convertGlslangType(const glslang::TType& type) { ShaderVariableType result; if (type.isVector()) { result.baseType = ShaderVariableBaseType::SIMPLE; result.simple.format = convertGlslangVectorType(type); } else if (type.isMatrix()) { result.baseType = ShaderVariableBaseType::MATRIX; result.matrixType = convertGlslangMatrixType(type); } else if (type.isStruct()) { const std::size_t numMembers = type.getStruct()->size(); result.baseType = ShaderVariableBaseType::STRUCT; result.struct_.members.reserve(numMembers); std::size_t currentOffset = 0; for (const glslang::TTypeLoc& typeLoc: *type.getStruct()) { ShaderVariableStructMember& member = result.struct_.members.emplace_back(); member.name = typeLoc.type->getFieldName(); member.type = convertGlslangType(*typeLoc.type); member.offset = currentOffset; if (typeLoc.type->getQualifier().hasSemantic()) { member.semantic = typeLoc.type->getQualifier().layoutSemantic; } if (typeLoc.type->getQualifier().hasSemanticIndex()) { member.semanticIdx = typeLoc.type->getQualifier().layoutSemanticIndex; } currentOffset = member.offset + calcShaderTypeSize(member.type); // TODO: padding } } else if (type.getBasicType() == glslang::EbtSampler) { const glslang::TSampler& sampler = type.getSampler(); result.baseType = ShaderVariableBaseType::IMAGE; result.image.dimensions = convertGlslangSamplerDim(sampler.dim); result.image.format = convertGlslangLayoutFormat(type.getQualifier().layoutFormat); } else { result.baseType = ShaderVariableBaseType::SIMPLE; result.simple.format = convertGlslangBaseType(type); } if (type.isArray()) { if (type.isArrayVariablyIndexed()) { result.arraySize = 0; result.dynamicArraySize = true; } else { assert(type.getArraySizes()->getNumDims() == 1); // don't support multi dimensional arrays yet result.arraySize = type.getOuterArraySize(); } } return result; } vk::DescriptorType getGlslangDescriptorType(const glslang::TType& type) { if (type.getBasicType() == glslang::EbtSampler) { if (type.getSampler().combined) { return vk::DescriptorType::eCombinedImageSampler; } if (type.getSampler().isImage()) { return vk::DescriptorType::eStorageImage; } } else if (type.isStruct()) { if (type.getQualifier().isUniform()) { return vk::DescriptorType::eUniformBuffer; } return vk::DescriptorType::eStorageBuffer; } logAndDie("No idea what to do with this type :/"); } bool MetaCollectingTraverser::visitBinary(glslang::TVisit, glslang::TIntermBinary* node) { (void) node; return false; } bool MetaCollectingTraverser::visitUnary(glslang::TVisit, glslang::TIntermUnary* node) { (void) node; return false; } bool MetaCollectingTraverser::visitAggregate(glslang::TVisit, glslang::TIntermAggregate* node) { switch (node->getOp()) { case glslang::EOpSequence: return true; case glslang::EOpFunction: break; case glslang::EOpLinkerObjects: return true; default: break; } return false; } bool MetaCollectingTraverser::visitSelection(glslang::TVisit, glslang::TIntermSelection* node) { (void) node; return false; } void MetaCollectingTraverser::visitConstantUnion(glslang::TIntermConstantUnion* node) { (void) node; } void MetaCollectingTraverser::visitSymbol(glslang::TIntermSymbol* node) { const bool isLinkerObject = getParentNode() && getParentNode()->getAsAggregate() && getParentNode()->getAsAggregate()->getOp() == glslang::EOpLinkerObjects; if (isLinkerObject) { if (node->getQualifier().builtIn) { return; } if (node->getQualifier().isUniformOrBuffer()) { if (node->getQualifier().isPushConstant()) { ShaderPushConstantBlock pushConstantBlock; pushConstantBlock.type = convertGlslangType(node->getType()); assert(pushConstantBlock.type.baseType == ShaderVariableBaseType::STRUCT); meta.extendPushConstant(pushConstantBlock, ShaderTypeBits::make(shaderType)); return; } const unsigned setIdx = node->getQualifier().hasSet() ? node->getQualifier().layoutSet : UNSPECIFIED_INDEX; const unsigned binding = node->getQualifier().hasBinding() ? node->getQualifier().layoutBinding : UNSPECIFIED_INDEX; ShaderVariableSet& set = meta.getOrCreateInterfaceVariableSet(setIdx); assert(setIdx == UNSPECIFIED_INDEX || !set.getVariableAtBindingOpt(binding)); // multiple bindings at the same index? set.usedInStages.set(shaderType, true); ShaderVariable& var = set.variables.emplace_back(); var.binding = binding; var.name = node->getName(); if (node->getQualifier().hasSemantic()) { var.semantic = node->getQualifier().layoutSemantic; } if (node->getQualifier().hasSemanticIndex()) { var.semanticIndex = node->getQualifier().layoutSemanticIndex; } // uniform blocks are identified by the name of their type if (var.name.empty() || var.name.starts_with("anon@")) { const glslang::TString& typeName = node->getType().getTypeName(); if (!typeName.empty()) { var.name = typeName; } } var.descriptorType = getGlslangDescriptorType(node->getType()); var.type = convertGlslangType(node->getType()); } else if (node->getQualifier().storage == glslang::EvqVaryingIn) { ShaderAttribute attribute; attribute.stage = shaderType; attribute.type = convertGlslangType(node->getType()); attribute.location = node->getQualifier().hasLocation() ? node->getQualifier().layoutLocation : UNSPECIFIED_INDEX; attribute.name = node->getName(); if (node->getQualifier().hasSemantic()) { attribute.semantic = node->getQualifier().layoutSemantic; } if (node->getQualifier().hasSemanticIndex()) { attribute.semanticIndex = node->getQualifier().layoutSemanticIndex; } meta.addInputAttribute(std::move(attribute)); } else if (node->getQualifier().storage == glslang::EvqVaryingOut) { ShaderAttribute attribute; attribute.stage = shaderType; attribute.type = convertGlslangType(node->getType()); attribute.location = node->getQualifier().hasLocation() ? node->getQualifier().layoutLocation : UNSPECIFIED_INDEX; attribute.name = node->getName(); if (node->getQualifier().hasSemantic()) { attribute.semantic = node->getQualifier().layoutSemantic; } if (node->getQualifier().hasSemanticIndex()) { attribute.semanticIndex = node->getQualifier().layoutSemanticIndex; } meta.addOutputAttribute(std::move(attribute)); } } } bool MetaCollectingTraverser::visitLoop(glslang::TVisit, glslang::TIntermLoop* node) { (void) node; return false; } bool MetaCollectingTraverser::visitBranch(glslang::TVisit, glslang::TIntermBranch* node) { (void) node; return false; } bool MetaCollectingTraverser::visitSwitch(glslang::TVisit, glslang::TIntermSwitch* node) { (void) node; return false; } vk::ShaderStageFlagBits shaderStageFromGlslang(EShLanguage language) { switch (language) { case EShLangVertex: return vk::ShaderStageFlagBits::eVertex; case EShLangTessControl: return vk::ShaderStageFlagBits::eTessellationControl; case EShLangTessEvaluation: return vk::ShaderStageFlagBits::eTessellationEvaluation; case EShLangGeometry: return vk::ShaderStageFlagBits::eGeometry; case EShLangFragment: return vk::ShaderStageFlagBits::eFragment; case EShLangCompute: return vk::ShaderStageFlagBits::eCompute; case EShLangRayGen: return vk::ShaderStageFlagBits::eRaygenKHR; case EShLangIntersect: return vk::ShaderStageFlagBits::eIntersectionKHR; case EShLangAnyHit: return vk::ShaderStageFlagBits::eAnyHitKHR; case EShLangClosestHit: return vk::ShaderStageFlagBits::eClosestHitKHR; case EShLangMiss: return vk::ShaderStageFlagBits::eMissKHR; case EShLangCallable: return vk::ShaderStageFlagBits::eCallableKHR; case EShLangTask: return vk::ShaderStageFlagBits::eTaskEXT; case EShLangMesh: return vk::ShaderStageFlagBits::eMeshEXT; case EShLangCount: break; // fall through } logAndDie("Invalid value passed to shaderStageFromGlslang!"); } } ShaderMeta reflectShader(glslang::TShader& shader) { return reflectIntermediate(*shader.getIntermediate(), shaderStageFromGlslang(shader.getStage())); } ShaderMeta reflectProgram(glslang::TProgram& program) { ShaderMeta result; for (int stage = 0; stage < EShLangCount; ++stage) { glslang::TIntermediate* intermediate = program.getIntermediate(static_cast(stage)); if (intermediate == nullptr) { continue; } result.extend(reflectIntermediate(*intermediate, shaderStageFromGlslang(static_cast(stage)))); } return result; } ShaderMeta reflectIntermediate(glslang::TIntermediate& intermediate, vk::ShaderStageFlagBits stage) { ShaderMeta meta; MetaCollectingTraverser traverser(meta, stage); intermediate.getTreeRoot()->traverse(&traverser); meta.stages.set(stage, true); if (stage == vk::ShaderStageFlagBits::eCompute) { meta.localSizeX = static_cast(intermediate.getLocalSize(0)); meta.localSizeY = static_cast(intermediate.getLocalSize(1)); meta.localSizeZ = static_cast(intermediate.getLocalSize(2)); } return meta; } } // namespace iwa