iwa/source/util/shader_meta.cpp

677 lines
21 KiB
C++

#include "iwa/util/shader_meta.hpp"
#include "iwa/log.hpp"
#include "iwa/util/glsl_compiler.hpp"
#include "iwa/util/vertex_layout.hpp"
#include "iwa/util/vkutil.hpp"
namespace
{
template<typename T>
inline std::size_t calcCrcSizeAppend(T, std::size_t) noexcept
{
MIJIN_TRAP(); // TODO
return 0;
}
}
namespace iwa
{
namespace
{
vk::ShaderStageFlags typeBitsToVkStages(ShaderTypeBits bits)
{
vk::ShaderStageFlags flags = {};
if (bits.compute)
{
flags |= vk::ShaderStageFlagBits::eCompute;
}
if (bits.vertex)
{
flags |= vk::ShaderStageFlagBits::eVertex;
}
if (bits.fragment)
{
flags |= vk::ShaderStageFlagBits::eFragment;
}
if (bits.rayGeneration)
{
flags |= vk::ShaderStageFlagBits::eRaygenKHR;
}
if (bits.rayClosestHit)
{
flags |= vk::ShaderStageFlagBits::eClosestHitKHR;
}
if (bits.rayAnyHit)
{
flags |= vk::ShaderStageFlagBits::eAnyHitKHR;
}
if (bits.rayMiss)
{
flags |= vk::ShaderStageFlagBits::eMissKHR;
}
if (bits.rayIntersection)
{
flags |= vk::ShaderStageFlagBits::eIntersectionKHR;
}
if (bits.callable)
{
flags |= vk::ShaderStageFlagBits::eCallableKHR;
}
return flags;
}
void addShaderAttribute(std::vector<ShaderAttribute>& attributes, ShaderAttribute&& attribute)
{
bool doInsert = true;
for (const ShaderAttribute& myAttribute: attributes)
{
if (myAttribute.stage == attribute.stage && myAttribute.location == attribute.location && myAttribute.location != UNSPECIFIED_INDEX)
{
// same location, type must be the same
if (myAttribute.type != attribute.type)
{
logAndDie(
"Attempting to merge incompatible shader metas, attributes {} and {} are incompatible. {} != {}",
myAttribute.name, attribute.name, myAttribute.type, attribute.type);
}
doInsert = false; // member already exists, don't insert
continue;
}
}
if (!doInsert)
{
return;
}
auto it = attributes.begin();
for (; it != attributes.end(); ++it)
{
if (static_cast<unsigned>(it->stage) > static_cast<unsigned>(attribute.stage)
|| (it->stage == attribute.stage && it->location > attribute.location))
{
break; // insert here
}
}
attributes.insert(it, std::move(attribute));
}
}
ShaderVariableStructType::ShaderVariableStructType() {} // NOLINT(modernize-use-equals-default)
ShaderVariableStructType::~ShaderVariableStructType() {} // NOLINT(modernize-use-equals-default)
void ShaderMeta::extendPushConstant(ShaderPushConstantBlock pushConstantBlock_, ShaderTypeBits stages)
{
if (pushConstantBlock_.type.baseType == ShaderVariableBaseType::NONE) {
return;
}
if (pushConstantBlock.type.baseType == ShaderVariableBaseType::NONE)
{
pushConstantBlock = std::move(pushConstantBlock_);
pushConstantStages = stages;
return;
}
// now comes the actual merging
assert(pushConstantBlock.type.baseType == ShaderVariableBaseType::STRUCT);
assert(pushConstantBlock_.type.baseType == ShaderVariableBaseType::STRUCT);
assert(stages);
for (ShaderVariableStructMember& member : pushConstantBlock_.type.struct_.members)
{
bool doInsert = true;
for (const ShaderVariableStructMember& myMember : pushConstantBlock.type.struct_.members)
{
if (myMember.offset == member.offset)
{
// same offset, type must be the same
if (myMember.type != member.type)
{
logAndDie("Attempting to merge incompatible push constant blocks, members {} and {} are incompatible. {} != {}",
myMember.name, member.name, myMember.type, member.type);
}
doInsert = false; // member already exists, don't insert
continue;
}
// otherwise check for overlaps
if ((myMember.offset < member.offset && myMember.offset + calcShaderTypeSize(myMember.type) > member.offset)
|| (myMember.offset > member.offset && myMember.offset < member.offset + calcShaderTypeSize(member.type)))
{
logAndDie("Attempting to merge incompatible push constant blocks, members {} and {} are overlapping.",
myMember.name, member.name);
}
}
if (!doInsert) {
continue;
}
auto it = pushConstantBlock.type.struct_.members.begin();
for (; it != pushConstantBlock.type.struct_.members.end(); ++it)
{
if (it->offset > member.offset) {
break; // insert here
}
}
pushConstantBlock.type.struct_.members.insert(it, std::move(member));
}
pushConstantStages |= stages;
}
void ShaderMeta::addInputAttribute(ShaderAttribute attribute)
{
addShaderAttribute(inputAttributes, std::move(attribute));
}
void ShaderMeta::addOutputAttribute(ShaderAttribute attribute)
{
addShaderAttribute(outputAttributes, std::move(attribute));
}
ObjectPtr<DescriptorSetLayout> DescriptorSetMeta::createDescriptorSetLayout(Device& device) const
{
assert(bindings.size() == bindingFlags.size());
return device.createChild<DescriptorSetLayout>(DescriptorSetLayoutCreationArgs{
.bindings = bindings,
.bindingFlags = bindingFlags,
.flags = flags,
});
}
std::vector<ObjectPtr<DescriptorSet>> PipelineAndDescriptorSetLayouts::createDescriptorSets(DescriptorPool& pool) const
{
std::vector<ObjectPtr<DescriptorSet>> result;
result.reserve(descriptorSetLayouts.size());
for (const ObjectPtr<DescriptorSetLayout>& layout : descriptorSetLayouts)
{
result.push_back(pool.allocateDescriptorSet({
.layout = layout
}));
}
return result;
}
ObjectPtr<DescriptorSet> PipelineAndDescriptorSetLayouts::createDescriptorSet(DescriptorPool& pool, unsigned setIdx) const
{
MIJIN_ASSERT(setIdx < descriptorSetLayouts.size(), "Invalid set index.");
return pool.allocateDescriptorSet({
.layout = descriptorSetLayouts[setIdx]
});
}
PipelineAndDescriptorSetLayouts PipelineLayoutMeta::createPipelineLayout(Device& device) const
{
std::vector<ObjectPtr<DescriptorSetLayout>> descSetLayouts;
descSetLayouts.reserve(descriptorSets.size());
for (const DescriptorSetMeta& dslMeta : descriptorSets)
{
descSetLayouts.push_back(dslMeta.createDescriptorSetLayout(device));
}
std::vector<vk::PushConstantRange> pushConstantRanges;
if (pushConstantRange.stageFlags)
{
pushConstantRanges.push_back(pushConstantRange);
}
ObjectPtr<PipelineLayout> pipelineLayout = device.createChild<PipelineLayout>(PipelineLayoutCreationArgs{
.setLayouts = descSetLayouts,
.pushConstantRanges = std::move(pushConstantRanges)
});
return
{
.descriptorSetLayouts = std::move(descSetLayouts),
.pipelineLayout = std::move(pipelineLayout)
};
}
void ShaderVariable::verifyCompatible(const ShaderVariable& other) const
{
std::vector<std::string> errors;
if (other.binding != binding) {
errors.push_back(fmt::format("Variable bindings do not match: {} != {}.", binding, other.binding)); // NOLINT
}
if (other.descriptorType != descriptorType) {
errors.push_back(fmt::format("Descriptor types do not match: {} != {}.",
magic_enum::enum_name(descriptorType),
magic_enum::enum_name(other.descriptorType)));
}
if (other.name != name) {
logMsg("Warning: shader variable names do not match, variable will only be referrable to by one of them! ({} != {})",
name, other.name);
}
if (other.type != type) {
errors.push_back(fmt::format("Variable types do not match: {} != {}.", type, other.type));
}
if (errors.empty()) {
return;
}
logMsg("Error(s) verifying shader variable compatibility:");
for (const std::string& error : errors) {
logMsg(error);
}
std::abort();
}
std::size_t ShaderVariable::calcHash(std::size_t appendTo) const
{
(void) appendTo;
MIJIN_TRAP(); // TODO
return 0;
#if 0
std::size_t hash = appendTo;
hash = type.calcHash(hash);
hash = calcCrcSizeAppend(descriptorType, hash);
hash = calcCrcSizeAppend(binding, hash);
hash = calcCrcSizeAppend(name, hash);
return hash;
#endif
}
#if 0
ShaderSource ShaderSource::fromFile(std::string fileName, std::string name)
{
(void) fileName;
(void) name;
MIJIN_TRAP(); // TODO
return {};
std::string code = readFileText(fileName);
return {
.code = std::move(code),
.fileName = std::move(fileName),
#if !defined(KAZAN_RELEASE)
.name = std::move(name)
#endif
};
}
#endif
bool ShaderVariableSet::find(std::string_view varName, ShaderVariableFindResult& outResult) const noexcept
{
for (const ShaderVariable& var : variables)
{
if (var.name == varName)
{
outResult.setIndex = setIndex;
outResult.bindIndex = var.binding;
return true;
}
}
return false;
}
bool ShaderVariableSet::find(unsigned semantic, unsigned semanticIdx, ShaderVariableFindResult& outResult) const noexcept
{
for (const ShaderVariable& var : variables)
{
if (var.semantic == semantic && var.semanticIndex == semanticIdx)
{
outResult.setIndex = setIndex;
outResult.bindIndex = var.binding;
return true;
}
}
return false;
}
const ShaderVariable& ShaderVariableSet::getVariableAtBinding(unsigned bindingIdx) const
{
for (const ShaderVariable& var : variables)
{
if (var.binding == bindingIdx)
{
return var;
}
}
logAndDie("Could not find shader variable with binding {}!", bindingIdx);
}
const ShaderVariable* ShaderVariableSet::getVariableAtBindingOpt(unsigned bindingIdx) const
{
for (const ShaderVariable& var : variables)
{
if (var.binding == bindingIdx)
{
return &var;
}
}
return nullptr;
}
const ShaderVariable* ShaderVariableSet::getVariableAtSemanticOpt(unsigned semantic, unsigned semanticIdx) const
{
for (const ShaderVariable& var : variables)
{
if (var.semantic == semantic && var.semanticIndex == semanticIdx)
{
return &var;
}
}
return nullptr;
}
std::size_t ShaderVariableSet::calcHash(std::size_t appendTo) const
{
std::size_t hash = appendTo;
for (const ShaderVariable& var : variables) {
hash = var.calcHash(hash);
}
return hash;
}
void ShaderMeta::extend(ShaderMeta other)
{
for (ShaderVariableSet& set : other.interfaceVariableSets)
{
ShaderVariableSet& mySet = getOrCreateInterfaceVariableSet(set.setIndex);
mySet.usedInStages.bits |= set.usedInStages.bits;
for (ShaderVariable& variable : set.variables)
{
const ShaderVariable* myVariable = nullptr;
if (variable.binding != UNSPECIFIED_INDEX)
{
myVariable = mySet.getVariableAtBindingOpt(variable.binding);
}
else if (variable.semantic != UNSPECIFIED_INDEX)
{
myVariable = mySet.getVariableAtSemanticOpt(variable.semantic, variable.semanticIndex);
}
if (myVariable)
{
myVariable->verifyCompatible(variable);
continue;
}
mySet.variables.push_back(std::move(variable));
}
}
for (ShaderAttribute& attribute : other.inputAttributes)
{
addInputAttribute(std::move(attribute));
}
for (ShaderAttribute& attribute : other.outputAttributes)
{
addOutputAttribute(std::move(attribute));
}
extendPushConstant(other.pushConstantBlock, other.pushConstantStages);
stages |= other.stages;
if (localSizeX == 0 && localSizeY == 0 && localSizeZ == 0)
{
localSizeX = other.localSizeX;
localSizeY = other.localSizeY;
localSizeZ = other.localSizeZ;
}
else if ((other.localSizeX != 0 || other.localSizeY != 0 || other.localSizeZ != 0) &&
(localSizeX != other.localSizeX || localSizeY != other.localSizeY || localSizeZ != other.localSizeZ))
{
logAndDie("Error merging shader metas, conflicting local size!");
}
hash = 0;
}
bool ShaderMeta::findInterfaceVariable(std::string_view varName, ShaderVariableFindResult& outResult) const noexcept
{
for (const ShaderVariableSet& set : interfaceVariableSets)
{
if (set.find(varName, outResult)) {
return true;
}
}
return false;
}
bool ShaderMeta::findInterfaceVariable(unsigned semantic, unsigned semanticIdx, ShaderVariableFindResult& outResult) const noexcept
{
for (const ShaderVariableSet& set : interfaceVariableSets)
{
if (set.find(semantic, semanticIdx, outResult)) {
return true;
}
}
return false;
}
const ShaderVariableSet& ShaderMeta::getInterfaceVariableSet(unsigned setIdx) const
{
const ShaderVariableSet* variableSet = getInterfaceVariableSetOpt(setIdx);
MIJIN_ASSERT(variableSet != nullptr, "Could not find interface variable set.");
return *variableSet;
}
const ShaderVariableSet* ShaderMeta::getInterfaceVariableSetOpt(unsigned setIdx) const
{
for (const ShaderVariableSet& set : interfaceVariableSets)
{
if (set.setIndex == setIdx) {
return &set;
}
}
return nullptr;
}
const ShaderVariableType& ShaderMeta::getInterfaceVariableType(unsigned setIdx, unsigned bindingIdx) const
{
return getInterfaceVariableSet(setIdx).getVariableAtBinding(bindingIdx).type;
}
VertexInput ShaderMeta::generateVertexInput(const NamedVertexInput& namedInput) const noexcept
{
VertexInput result{
.bindings = namedInput.bindings
};
for (const ShaderAttribute& attribute : inputAttributes)
{
if (attribute.stage != vk::ShaderStageFlagBits::eVertex) {
continue;
}
MIJIN_ASSERT_FATAL(attribute.type.baseType == ShaderVariableBaseType::SIMPLE, "Vertex shader input must be a simple type.");
auto itAttribute = namedInput.attributes.find(attribute.name);
MIJIN_ASSERT_FATAL(itAttribute != namedInput.attributes.end(), "Missing attribute in input.");
result.attributes.push_back(vk::VertexInputAttributeDescription{
.location = attribute.location,
.binding = itAttribute->second.binding,
.format = attribute.type.simple.format,
.offset = itAttribute->second.offset
});
}
return result;
}
VertexInput ShaderMeta::generateVertexInputFromLayout(const VertexLayout& layout) const noexcept
{
VertexInput result{
.bindings = {
vk::VertexInputBindingDescription{
.binding = 0,
.stride = layout.stride,
.inputRate = vk::VertexInputRate::eVertex
}
}
};
for (const ShaderAttribute& attribute : inputAttributes)
{
if (attribute.stage != vk::ShaderStageFlagBits::eVertex) {
continue;
}
if (attribute.semantic == UNSPECIFIED_INDEX) {
continue;
}
MIJIN_ASSERT_FATAL(attribute.type.baseType == ShaderVariableBaseType::SIMPLE, "Vertex shader input must be a simple type.");
auto itAttribute = std::ranges::find_if(layout.attributes, [&attribute](const VertexAttribute& attrib) {
return static_cast<unsigned>(attrib.semantic) == attribute.semantic && attrib.semanticIdx == attribute.semanticIndex;
});
MIJIN_ASSERT_FATAL(itAttribute != layout.attributes.end(), "Missing attribute in vertex layout.");
result.attributes.push_back(vk::VertexInputAttributeDescription{
.location = attribute.location,
.binding = 0,
.format = attribute.type.simple.format,
.offset = itAttribute->offset
});
}
return result;
}
DescriptorSetMeta ShaderMeta::generateDescriptorSetLayout(const ShaderVariableSet& set, const GenerateDescriptorSetLayoutArgs& args) const
{
DescriptorSetMeta setInfo{
.flags = args.flags
};
for (const ShaderVariable& var : set.variables)
{
auto itVar = std::ranges::find_if(setInfo.bindings, [&](const vk::DescriptorSetLayoutBinding& binding) {
return binding.binding == var.binding;
});
assert(itVar == setInfo.bindings.end()); // should have been merged!
if (itVar != setInfo.bindings.end())
{
itVar->stageFlags |= typeBitsToVkStages(set.usedInStages);
continue; // TODO: verify the bindings are compatible
}
vk::DescriptorSetLayoutBinding& binding = setInfo.bindings.emplace_back();
vk::DescriptorBindingFlags& flags = setInfo.bindingFlags.emplace_back();
binding.binding = var.binding;
binding.descriptorType = var.descriptorType;
binding.descriptorCount = 1;
binding.stageFlags = typeBitsToVkStages(set.usedInStages);
// support for dynamically sized descriptors
auto itCounts = args.descriptorCounts.find(var.binding);
if (itCounts != args.descriptorCounts.end() && itCounts->second > 0)
{
binding.descriptorCount = itCounts->second;
flags |= vk::DescriptorBindingFlagBits::ePartiallyBound;
}
if (setInfo.descriptorTypes.size() <= var.binding) {
setInfo.descriptorTypes.resize(var.binding + 1);
}
setInfo.descriptorTypes[var.binding] = var.descriptorType;
}
return setInfo;
}
PipelineLayoutMeta ShaderMeta::generatePipelineLayout(const GeneratePipelineLayoutArgs& args) const
{
static const std::vector<std::uint32_t> NO_DESCRIPTOR_COUNTS = {};
static const GenerateDescriptorSetLayoutArgs NO_DESCRIPTOR_SET_ARGS = {};
PipelineLayoutMeta result;
for (const ShaderVariableSet& set : interfaceVariableSets)
{
if (set.setIndex >= result.descriptorSets.size()) {
result.descriptorSets.resize(set.setIndex + 1);
}
auto itSet = args.descriptorSets.find(set.setIndex);
const GenerateDescriptorSetLayoutArgs setArgs =
itSet != args.descriptorSets.end()
? itSet->second
: NO_DESCRIPTOR_SET_ARGS;
result.descriptorSets[set.setIndex] = generateDescriptorSetLayout(set, setArgs);
}
if (pushConstantBlock.type.baseType != ShaderVariableBaseType::NONE)
{
assert(pushConstantStages);
result.pushConstantRange.stageFlags = typeBitsToVkStages(pushConstantStages);
result.pushConstantRange.size = pushConstantBlock.offset + calcShaderTypeSize(pushConstantBlock.type);
}
return result;
}
bool ShaderMeta::empty() const
{
static_assert(ShaderMeta::STRUCT_VERSION == 1, "Update me");
return interfaceVariableSets.empty()
&& inputAttributes.empty()
&& outputAttributes.empty()
&& pushConstantStages == ShaderTypeBits()
&& pushConstantBlock.type.baseType == ShaderVariableBaseType::NONE
&& localSizeX == 0
&& localSizeY == 0
&& localSizeZ == 0;
}
std::size_t ShaderMeta::getHash() const
{
if (hash == 0)
{
hash = 1; // TODO
MIJIN_TRAP();
#if 0
for (const ShaderVariableSet& variableSet : interfaceVariableSets) {
hash = variableSet.calcHash(hash);
}
hash = calcCrcSizeAppend(pushConstantStages.bits, hash);
hash = pushConstantBlock.type.calcHash(hash);
hash = calcCrcSizeAppend(pushConstantBlock.offset, hash);
hash = calcCrcSizeAppend(localSizeX, hash);
hash = calcCrcSizeAppend(localSizeY, hash);
hash = calcCrcSizeAppend(localSizeZ, hash);
#endif
}
return hash;
}
unsigned calcShaderTypeSize(const ShaderVariableType& type, bool ignoreArraySize) noexcept
{
unsigned size = 0;
switch (type.baseType)
{
case ShaderVariableBaseType::SIMPLE:
size = vkFormatSize(type.simple.format);
break;
case ShaderVariableBaseType::MATRIX:
switch (type.matrixType)
{
case ShaderVariableMatrixType::MAT2:
size = 16;
break;
case ShaderVariableMatrixType::MAT3:
size = 36;
break;
case ShaderVariableMatrixType::MAT4:
size = 64;
break;
default:
logAndDie("Lol, what's this?");
}
break;
case ShaderVariableBaseType::STRUCT:
assert(!type.struct_.members.empty());
size = static_cast<unsigned>(type.struct_.members.back().offset + calcShaderTypeSize(type.struct_.members.back().type));
break;
default:
logAndDie("How would I know?");
}
if (!ignoreArraySize) {
size *= type.arraySize;
}
return size;
}
} // namespace iwa