#include "iwa/util/shader_meta.hpp" #include "iwa/log.hpp" #include "iwa/util/glsl_compiler.hpp" #include "iwa/util/vertex_layout.hpp" #include "iwa/util/vkutil.hpp" namespace { template inline std::size_t calcCrcSizeAppend(T, std::size_t) noexcept { MIJIN_TRAP(); // TODO return 0; } } namespace iwa { namespace { vk::ShaderStageFlags typeBitsToVkStages(ShaderTypeBits bits) { vk::ShaderStageFlags flags = {}; if (bits.compute) { flags |= vk::ShaderStageFlagBits::eCompute; } if (bits.vertex) { flags |= vk::ShaderStageFlagBits::eVertex; } if (bits.fragment) { flags |= vk::ShaderStageFlagBits::eFragment; } if (bits.rayGeneration) { flags |= vk::ShaderStageFlagBits::eRaygenKHR; } if (bits.rayClosestHit) { flags |= vk::ShaderStageFlagBits::eClosestHitKHR; } if (bits.rayAnyHit) { flags |= vk::ShaderStageFlagBits::eAnyHitKHR; } if (bits.rayMiss) { flags |= vk::ShaderStageFlagBits::eMissKHR; } if (bits.rayIntersection) { flags |= vk::ShaderStageFlagBits::eIntersectionKHR; } if (bits.callable) { flags |= vk::ShaderStageFlagBits::eCallableKHR; } return flags; } void addShaderAttribute(std::vector& attributes, ShaderAttribute&& attribute) { bool doInsert = true; for (const ShaderAttribute& myAttribute: attributes) { if (myAttribute.stage == attribute.stage && myAttribute.location == attribute.location && myAttribute.location != UNSPECIFIED_INDEX) { // same location, type must be the same if (myAttribute.type != attribute.type) { logAndDie( "Attempting to merge incompatible shader metas, attributes {} and {} are incompatible. {} != {}", myAttribute.name, attribute.name, myAttribute.type, attribute.type); } doInsert = false; // member already exists, don't insert continue; } } if (!doInsert) { return; } auto it = attributes.begin(); for (; it != attributes.end(); ++it) { if (static_cast(it->stage) > static_cast(attribute.stage) || (it->stage == attribute.stage && it->location > attribute.location)) { break; // insert here } } attributes.insert(it, std::move(attribute)); } } ShaderVariableStructType::ShaderVariableStructType() {} // NOLINT(modernize-use-equals-default) ShaderVariableStructType::~ShaderVariableStructType() {} // NOLINT(modernize-use-equals-default) void ShaderMeta::extendPushConstant(ShaderPushConstantBlock pushConstantBlock_, ShaderTypeBits pushConstantStages_) { if (pushConstantBlock_.type.baseType == ShaderVariableBaseType::NONE) { return; } if (pushConstantBlock.type.baseType == ShaderVariableBaseType::NONE) { pushConstantBlock = std::move(pushConstantBlock_); pushConstantStages = pushConstantStages_; return; } // now comes the actual merging assert(pushConstantBlock.type.baseType == ShaderVariableBaseType::STRUCT); assert(pushConstantBlock_.type.baseType == ShaderVariableBaseType::STRUCT); assert(pushConstantStages_); for (ShaderVariableStructMember& member : pushConstantBlock_.type.struct_.members) { bool doInsert = true; for (const ShaderVariableStructMember& myMember : pushConstantBlock.type.struct_.members) { if (myMember.offset == member.offset) { // same offset, type must be the same if (myMember.type != member.type) { logAndDie("Attempting to merge incompatible push constant blocks, members {} and {} are incompatible. {} != {}", myMember.name, member.name, myMember.type, member.type); } doInsert = false; // member already exists, don't insert continue; } // otherwise check for overlaps if ((myMember.offset < member.offset && myMember.offset + calcShaderTypeSize(myMember.type) > member.offset) || (myMember.offset > member.offset && myMember.offset < member.offset + calcShaderTypeSize(member.type))) { logAndDie("Attempting to merge incompatible push constant blocks, members {} and {} are overlapping.", myMember.name, member.name); } } if (!doInsert) { continue; } auto it = pushConstantBlock.type.struct_.members.begin(); for (; it != pushConstantBlock.type.struct_.members.end(); ++it) { if (it->offset > member.offset) { break; // insert here } } pushConstantBlock.type.struct_.members.insert(it, std::move(member)); } pushConstantStages |= pushConstantStages_; } void ShaderMeta::addInputAttribute(ShaderAttribute attribute) { addShaderAttribute(inputAttributes, std::move(attribute)); } void ShaderMeta::addOutputAttribute(ShaderAttribute attribute) { addShaderAttribute(outputAttributes, std::move(attribute)); } ObjectPtr DescriptorSetMeta::createDescriptorSetLayout(Device& device) const { assert(bindings.size() == bindingFlags.size()); return device.createChild(DescriptorSetLayoutCreationArgs{ .bindings = bindings, .bindingFlags = bindingFlags, .flags = flags, }); } std::vector> PipelineAndDescriptorSetLayouts::createDescriptorSets(DescriptorPool& pool) const { std::vector> result; result.reserve(descriptorSetLayouts.size()); for (const ObjectPtr& layout : descriptorSetLayouts) { result.push_back(pool.allocateDescriptorSet({ .layout = layout })); } return result; } ObjectPtr PipelineAndDescriptorSetLayouts::createDescriptorSet(DescriptorPool& pool, unsigned setIdx) const { MIJIN_ASSERT(setIdx < descriptorSetLayouts.size(), "Invalid set index."); return pool.allocateDescriptorSet({ .layout = descriptorSetLayouts[setIdx] }); } PipelineAndDescriptorSetLayouts PipelineLayoutMeta::createPipelineLayout(Device& device) const { std::vector> descSetLayouts; descSetLayouts.reserve(descriptorSets.size()); for (const DescriptorSetMeta& dslMeta : descriptorSets) { descSetLayouts.push_back(dslMeta.createDescriptorSetLayout(device)); } std::vector pushConstantRanges; if (pushConstantRange.stageFlags) { pushConstantRanges.push_back(pushConstantRange); } ObjectPtr pipelineLayout = device.createChild(PipelineLayoutCreationArgs{ .setLayouts = descSetLayouts, .pushConstantRanges = std::move(pushConstantRanges) }); return { .descriptorSetLayouts = std::move(descSetLayouts), .pipelineLayout = std::move(pipelineLayout) }; } void ShaderVariable::verifyCompatible(const ShaderVariable& other) const { std::vector errors; if (other.binding != binding) { errors.push_back(fmt::format("Variable bindings do not match: {} != {}.", binding, other.binding)); // NOLINT } if (other.descriptorType != descriptorType) { errors.push_back(fmt::format("Descriptor types do not match: {} != {}.", magic_enum::enum_name(descriptorType), magic_enum::enum_name(other.descriptorType))); } if (other.name != name) { logMsg("Warning: shader variable names do not match, variable will only be referrable to by one of them! ({} != {})", name, other.name); } if (other.type != type) { errors.push_back(fmt::format("Variable types do not match: {} != {}.", type, other.type)); } if (errors.empty()) { return; } logMsg("Error(s) verifying shader variable compatibility:"); for (const std::string& error : errors) { logMsg(error); } std::abort(); } std::size_t ShaderVariable::calcHash(std::size_t appendTo) const { (void) appendTo; MIJIN_TRAP(); // TODO return 0; #if 0 std::size_t hash = appendTo; hash = type.calcHash(hash); hash = calcCrcSizeAppend(descriptorType, hash); hash = calcCrcSizeAppend(binding, hash); hash = calcCrcSizeAppend(name, hash); return hash; #endif } #if 0 ShaderSource ShaderSource::fromFile(std::string fileName, std::string name) { (void) fileName; (void) name; MIJIN_TRAP(); // TODO return {}; std::string code = readFileText(fileName); return { .code = std::move(code), .fileName = std::move(fileName), #if !defined(KAZAN_RELEASE) .name = std::move(name) #endif }; } #endif bool ShaderVariableSet::find(std::string_view varName, ShaderVariableFindResult& outResult) const noexcept { for (const ShaderVariable& var : variables) { if (var.name == varName) { outResult.setIndex = setIndex; outResult.bindIndex = var.binding; return true; } } return false; } bool ShaderVariableSet::find(unsigned semantic, unsigned semanticIdx, ShaderVariableFindResult& outResult) const noexcept { for (const ShaderVariable& var : variables) { if (var.semantic == semantic && var.semanticIndex == semanticIdx) { outResult.setIndex = setIndex; outResult.bindIndex = var.binding; return true; } } return false; } const ShaderVariable& ShaderVariableSet::getVariableAtBinding(unsigned bindingIdx) const { for (const ShaderVariable& var : variables) { if (var.binding == bindingIdx) { return var; } } logAndDie("Could not find shader variable with binding {}!", bindingIdx); } const ShaderVariable* ShaderVariableSet::getVariableAtBindingOpt(unsigned bindingIdx) const { for (const ShaderVariable& var : variables) { if (var.binding == bindingIdx) { return &var; } } return nullptr; } const ShaderVariable* ShaderVariableSet::getVariableAtSemanticOpt(unsigned semantic, unsigned semanticIdx) const { for (const ShaderVariable& var : variables) { if (var.semantic == semantic && var.semanticIndex == semanticIdx) { return &var; } } return nullptr; } std::size_t ShaderVariableSet::calcHash(std::size_t appendTo) const { std::size_t hash = appendTo; for (const ShaderVariable& var : variables) { hash = var.calcHash(hash); } return hash; } void ShaderMeta::extend(ShaderMeta other) { for (ShaderVariableSet& set : other.interfaceVariableSets) { ShaderVariableSet& mySet = getOrCreateInterfaceVariableSet(set.setIndex); mySet.usedInStages.bits |= set.usedInStages.bits; for (ShaderVariable& variable : set.variables) { const ShaderVariable* myVariable = nullptr; if (variable.binding != UNSPECIFIED_INDEX) { myVariable = mySet.getVariableAtBindingOpt(variable.binding); } else if (variable.semantic != UNSPECIFIED_INDEX) { myVariable = mySet.getVariableAtSemanticOpt(variable.semantic, variable.semanticIndex); } if (myVariable) { myVariable->verifyCompatible(variable); continue; } mySet.variables.push_back(std::move(variable)); } } for (ShaderAttribute& attribute : other.inputAttributes) { addInputAttribute(std::move(attribute)); } for (ShaderAttribute& attribute : other.outputAttributes) { addOutputAttribute(std::move(attribute)); } extendPushConstant(other.pushConstantBlock, other.pushConstantStages); stages |= other.stages; if (localSizeX == 0 && localSizeY == 0 && localSizeZ == 0) { localSizeX = other.localSizeX; localSizeY = other.localSizeY; localSizeZ = other.localSizeZ; } else if ((other.localSizeX != 0 || other.localSizeY != 0 || other.localSizeZ != 0) && (localSizeX != other.localSizeX || localSizeY != other.localSizeY || localSizeZ != other.localSizeZ)) { logAndDie("Error merging shader metas, conflicting local size!"); } hash = 0; } bool ShaderMeta::findInterfaceVariable(std::string_view varName, ShaderVariableFindResult& outResult) const noexcept { for (const ShaderVariableSet& set : interfaceVariableSets) { if (set.find(varName, outResult)) { return true; } } return false; } bool ShaderMeta::findInterfaceVariable(unsigned semantic, unsigned semanticIdx, ShaderVariableFindResult& outResult) const noexcept { for (const ShaderVariableSet& set : interfaceVariableSets) { if (set.find(semantic, semanticIdx, outResult)) { return true; } } return false; } const ShaderVariableSet& ShaderMeta::getInterfaceVariableSet(unsigned setIdx) const { const ShaderVariableSet* variableSet = getInterfaceVariableSetOpt(setIdx); MIJIN_ASSERT(variableSet != nullptr, "Could not find interface variable set."); return *variableSet; } const ShaderVariableSet* ShaderMeta::getInterfaceVariableSetOpt(unsigned setIdx) const { for (const ShaderVariableSet& set : interfaceVariableSets) { if (set.setIndex == setIdx) { return &set; } } return nullptr; } const ShaderVariableType& ShaderMeta::getInterfaceVariableType(unsigned setIdx, unsigned bindingIdx) const { return getInterfaceVariableSet(setIdx).getVariableAtBinding(bindingIdx).type; } VertexInput ShaderMeta::generateVertexInput(const NamedVertexInput& namedInput) const noexcept { VertexInput result{ .bindings = namedInput.bindings }; for (const ShaderAttribute& attribute : inputAttributes) { if (attribute.stage != vk::ShaderStageFlagBits::eVertex) { continue; } MIJIN_ASSERT_FATAL(attribute.type.baseType == ShaderVariableBaseType::SIMPLE, "Vertex shader input must be a simple type."); auto itAttribute = namedInput.attributes.find(attribute.name); MIJIN_ASSERT_FATAL(itAttribute != namedInput.attributes.end(), "Missing attribute in input."); result.attributes.push_back(vk::VertexInputAttributeDescription{ .location = attribute.location, .binding = itAttribute->second.binding, .format = attribute.type.simple.format, .offset = itAttribute->second.offset }); } return result; } VertexInput ShaderMeta::generateVertexInputFromLayout(const VertexLayout& layout) const noexcept { VertexInput result{ .bindings = { vk::VertexInputBindingDescription{ .binding = 0, .stride = layout.stride, .inputRate = vk::VertexInputRate::eVertex } } }; for (const ShaderAttribute& attribute : inputAttributes) { if (attribute.stage != vk::ShaderStageFlagBits::eVertex) { continue; } if (attribute.semantic == UNSPECIFIED_INDEX) { continue; } MIJIN_ASSERT_FATAL(attribute.type.baseType == ShaderVariableBaseType::SIMPLE, "Vertex shader input must be a simple type."); auto itAttribute = std::ranges::find_if(layout.attributes, [&attribute](const VertexAttribute& attrib) { return static_cast(attrib.semantic) == attribute.semantic && attrib.semanticIdx == attribute.semanticIndex; }); MIJIN_ASSERT_FATAL(itAttribute != layout.attributes.end(), "Missing attribute in vertex layout."); result.attributes.push_back(vk::VertexInputAttributeDescription{ .location = attribute.location, .binding = 0, .format = attribute.type.simple.format, .offset = itAttribute->offset }); } return result; } DescriptorSetMeta ShaderMeta::generateDescriptorSetLayout(const ShaderVariableSet& set, const GenerateDescriptorSetLayoutArgs& args) const { DescriptorSetMeta setInfo{ .flags = args.flags }; for (const ShaderVariable& var : set.variables) { auto itVar = std::ranges::find_if(setInfo.bindings, [&](const vk::DescriptorSetLayoutBinding& binding) { return binding.binding == var.binding; }); assert(itVar == setInfo.bindings.end()); // should have been merged! if (itVar != setInfo.bindings.end()) { itVar->stageFlags |= typeBitsToVkStages(set.usedInStages); continue; // TODO: verify the bindings are compatible } vk::DescriptorSetLayoutBinding& binding = setInfo.bindings.emplace_back(); vk::DescriptorBindingFlags& flags = setInfo.bindingFlags.emplace_back(); binding.binding = var.binding; binding.descriptorType = var.descriptorType; binding.descriptorCount = 1; binding.stageFlags = typeBitsToVkStages(set.usedInStages); // support for dynamically sized descriptors auto itCounts = args.descriptorCounts.find(var.binding); if (itCounts != args.descriptorCounts.end() && itCounts->second > 0) { binding.descriptorCount = itCounts->second; flags |= vk::DescriptorBindingFlagBits::ePartiallyBound; } if (setInfo.descriptorTypes.size() <= var.binding) { setInfo.descriptorTypes.resize(var.binding + 1); } setInfo.descriptorTypes[var.binding] = var.descriptorType; } return setInfo; } PipelineLayoutMeta ShaderMeta::generatePipelineLayout(const GeneratePipelineLayoutArgs& args) const { static const std::vector NO_DESCRIPTOR_COUNTS = {}; static const GenerateDescriptorSetLayoutArgs NO_DESCRIPTOR_SET_ARGS = {}; PipelineLayoutMeta result; for (const ShaderVariableSet& set : interfaceVariableSets) { if (set.setIndex >= result.descriptorSets.size()) { result.descriptorSets.resize(set.setIndex + 1); } auto itSet = args.descriptorSets.find(set.setIndex); const GenerateDescriptorSetLayoutArgs setArgs = itSet != args.descriptorSets.end() ? itSet->second : NO_DESCRIPTOR_SET_ARGS; result.descriptorSets[set.setIndex] = generateDescriptorSetLayout(set, setArgs); } if (pushConstantBlock.type.baseType != ShaderVariableBaseType::NONE) { assert(pushConstantStages); result.pushConstantRange.stageFlags = typeBitsToVkStages(pushConstantStages); result.pushConstantRange.size = pushConstantBlock.offset + calcShaderTypeSize(pushConstantBlock.type); } return result; } bool ShaderMeta::empty() const { static_assert(ShaderMeta::STRUCT_VERSION == 1, "Update me"); return interfaceVariableSets.empty() && inputAttributes.empty() && outputAttributes.empty() && pushConstantStages == ShaderTypeBits() && pushConstantBlock.type.baseType == ShaderVariableBaseType::NONE && localSizeX == 0 && localSizeY == 0 && localSizeZ == 0; } std::size_t ShaderMeta::getHash() const { if (hash == 0) { hash = 1; // TODO MIJIN_TRAP(); #if 0 for (const ShaderVariableSet& variableSet : interfaceVariableSets) { hash = variableSet.calcHash(hash); } hash = calcCrcSizeAppend(pushConstantStages.bits, hash); hash = pushConstantBlock.type.calcHash(hash); hash = calcCrcSizeAppend(pushConstantBlock.offset, hash); hash = calcCrcSizeAppend(localSizeX, hash); hash = calcCrcSizeAppend(localSizeY, hash); hash = calcCrcSizeAppend(localSizeZ, hash); #endif } return hash; } unsigned calcShaderTypeSize(const ShaderVariableType& type, bool ignoreArraySize) noexcept { unsigned size = 0; switch (type.baseType) { case ShaderVariableBaseType::SIMPLE: size = vkFormatSize(type.simple.format); break; case ShaderVariableBaseType::MATRIX: switch (type.matrixType) { case ShaderVariableMatrixType::MAT2: size = 16; break; case ShaderVariableMatrixType::MAT3: size = 36; break; case ShaderVariableMatrixType::MAT4: size = 64; break; default: logAndDie("Lol, what's this?"); } break; case ShaderVariableBaseType::STRUCT: assert(!type.struct_.members.empty()); size = static_cast(type.struct_.members.back().offset + calcShaderTypeSize(type.struct_.members.back().type)); break; default: logAndDie("How would I know?"); } if (!ignoreArraySize) { size *= type.arraySize; } return size; } } // namespace iwa