blob: 970778c4568edef4500a59da3c385fd0acb59d37 [file] [log] [blame]
//
// Copyright 2018 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
// RewriteStructSamplers: Extract samplers from structs.
//
#include "compiler/translator/tree_ops/RewriteStructSamplers.h"
#include "compiler/translator/ImmutableStringBuilder.h"
#include "compiler/translator/SymbolTable.h"
#include "compiler/translator/tree_util/IntermNode_util.h"
#include "compiler/translator/tree_util/IntermTraverse.h"
namespace sh
{
namespace
{
// Used to map one structure type to another (one where the samplers are removed).
struct StructureData
{
// The structure this was replaced with. If nullptr, it means the structure is removed (because
// it had all samplers).
const TStructure *modified;
// Indexed by the field index of original structure, to get the field index of the modified
// structure. For example:
//
// struct Original
// {
// sampler2D s1;
// vec4 f1;
// sampler2D s2;
// sampler2D s3;
// vec4 f2;
// };
//
// struct Modified
// {
// vec4 f1;
// vec4 f2;
// };
//
// fieldMap:
// 0 -> Invalid
// 1 -> 0
// 2 -> Invalid
// 3 -> Invalid
// 4 -> 1
//
TVector<int> fieldMap;
};
using StructureMap = angle::HashMap<const TStructure *, StructureData>;
using StructureUniformMap = angle::HashMap<const TVariable *, const TVariable *>;
using ExtractedSamplerMap = angle::HashMap<std::string, const TVariable *>;
TIntermTyped *RewriteModifiedStructFieldSelectionExpression(
TCompiler *compiler,
TIntermBinary *node,
const StructureMap &structureMap,
const StructureUniformMap &structureUniformMap,
const ExtractedSamplerMap &extractedSamplers);
TIntermTyped *RewriteExpressionVisitBinaryHelper(TCompiler *compiler,
TIntermBinary *node,
const StructureMap &structureMap,
const StructureUniformMap &structureUniformMap,
const ExtractedSamplerMap &extractedSamplers)
{
// Only interested in EOpIndexDirectStruct binary nodes.
if (node->getOp() != EOpIndexDirectStruct)
{
return nullptr;
}
const TStructure *structure = node->getLeft()->getType().getStruct();
ASSERT(structure);
// If the result of the index is not a sampler and the struct is not replaced, there's nothing
// to do.
if (!node->getType().isSampler() && structureMap.find(structure) == structureMap.end())
{
return nullptr;
}
// Otherwise, replace the whole expression such that:
//
// - if sampler, it's indexed with whatever indices the parent structs were indexed with,
// - otherwise, the chain of field selections is rewritten by modifying the base uniform so all
// the intermediate nodes would have the correct type (and therefore fields).
ASSERT(structureMap.find(structure) != structureMap.end());
return RewriteModifiedStructFieldSelectionExpression(compiler, node, structureMap,
structureUniformMap, extractedSamplers);
}
// Given an expression, this traverser calculates a new expression where sampler-in-structs are
// replaced with their extracted ones, and field indices are adjusted for the rest of the fields.
// In particular, this is run on the right node of EOpIndexIndirect binary nodes, so that the
// expression in the index gets a chance to go through this transformation.
class RewriteExpressionTraverser final : public TIntermTraverser
{
public:
explicit RewriteExpressionTraverser(TCompiler *compiler,
const StructureMap &structureMap,
const StructureUniformMap &structureUniformMap,
const ExtractedSamplerMap &extractedSamplers)
: TIntermTraverser(true, false, false),
mCompiler(compiler),
mStructureMap(structureMap),
mStructureUniformMap(structureUniformMap),
mExtractedSamplers(extractedSamplers)
{}
bool visitBinary(Visit visit, TIntermBinary *node) override
{
TIntermTyped *rewritten = RewriteExpressionVisitBinaryHelper(
mCompiler, node, mStructureMap, mStructureUniformMap, mExtractedSamplers);
if (rewritten == nullptr)
{
return true;
}
queueReplacement(rewritten, OriginalNode::IS_DROPPED);
// Don't iterate as the expression is rewritten.
return false;
}
void visitSymbol(TIntermSymbol *node) override
{
// It's impossible to reach here with a symbol that needs replacement.
// MonomorphizeUnsupportedFunctionsInVulkanGLSL makes sure that whole structs containing
// samplers are not passed to functions, so any instance of the struct uniform is
// necessarily indexed right away. visitBinary should have already taken care of it.
ASSERT(mStructureUniformMap.find(&node->variable()) == mStructureUniformMap.end());
}
private:
TCompiler *mCompiler;
// See RewriteStructSamplersTraverser.
const StructureMap &mStructureMap;
const StructureUniformMap &mStructureUniformMap;
const ExtractedSamplerMap &mExtractedSamplers;
};
// Rewrite the index of an EOpIndexIndirect expression. The root can never need replacing, because
// it cannot be a sampler itself or of a struct type.
void RewriteIndexExpression(TCompiler *compiler,
TIntermTyped *expression,
const StructureMap &structureMap,
const StructureUniformMap &structureUniformMap,
const ExtractedSamplerMap &extractedSamplers)
{
RewriteExpressionTraverser traverser(compiler, structureMap, structureUniformMap,
extractedSamplers);
expression->traverse(&traverser);
bool valid = traverser.updateTree(compiler, expression);
ASSERT(valid);
}
// Given an expression such as the following:
//
// EOpIndexDirectStruct (sampler)
// / \
// EOpIndex* field index
// / \
// EOpIndexDirectStruct index 2
// / \
// EOpIndex* field index
// / \
// EOpIndexDirectStruct index 1
// / \
// Uniform Struct field index
//
// produces:
//
// EOpIndex*
// / \
// EOpIndex* index 2
// / \
// sampler index 1
//
// Alternatively, if the expression is as such:
//
// EOpIndexDirectStruct
// / \
// (modified struct type) EOpIndex* field index
// / \
// EOpIndexDirectStruct index 2
// / \
// EOpIndex* field index
// / \
// EOpIndexDirectStruct index 1
// / \
// Uniform Struct field index
//
// produces:
//
// EOpIndexDirectStruct
// / \
// EOpIndex* mapped field index
// / \
// EOpIndexDirectStruct index 2
// / \
// EOpIndex* mapped field index
// / \
// EOpIndexDirectStruct index 1
// / \
// Uniform Struct mapped field index
//
TIntermTyped *RewriteModifiedStructFieldSelectionExpression(
TCompiler *compiler,
TIntermBinary *node,
const StructureMap &structureMap,
const StructureUniformMap &structureUniformMap,
const ExtractedSamplerMap &extractedSamplers)
{
ASSERT(node->getOp() == EOpIndexDirectStruct);
const bool isSampler = node->getType().isSampler();
TIntermSymbol *baseUniform = nullptr;
std::string samplerName;
TVector<TIntermBinary *> indexNodeStack;
// Iterate once and build the name of the sampler.
TIntermBinary *iter = node;
while (baseUniform == nullptr)
{
indexNodeStack.push_back(iter);
baseUniform = iter->getLeft()->getAsSymbolNode();
if (isSampler)
{
if (iter->getOp() == EOpIndexDirectStruct)
{
// When indexed into a struct, get the field name instead and construct the sampler
// name.
samplerName.insert(0, iter->getIndexStructFieldName().data());
samplerName.insert(0, "_");
}
if (baseUniform)
{
// If left is a symbol, we have reached the end of the chain. Use the struct name
// to finish building the name of the sampler.
samplerName.insert(0, baseUniform->variable().name().data());
}
}
iter = iter->getLeft()->getAsBinaryNode();
}
TIntermTyped *rewritten = nullptr;
if (isSampler)
{
ASSERT(extractedSamplers.find(samplerName) != extractedSamplers.end());
rewritten = new TIntermSymbol(extractedSamplers.at(samplerName));
}
else
{
const TVariable *baseUniformVar = &baseUniform->variable();
ASSERT(structureUniformMap.find(baseUniformVar) != structureUniformMap.end());
rewritten = new TIntermSymbol(structureUniformMap.at(baseUniformVar));
}
// Iterate again and build the expression from bottom up.
for (auto it = indexNodeStack.rbegin(); it != indexNodeStack.rend(); ++it)
{
TIntermBinary *indexNode = *it;
switch (indexNode->getOp())
{
case EOpIndexDirectStruct:
if (!isSampler)
{
// Remap the field.
const TStructure *structure = indexNode->getLeft()->getType().getStruct();
ASSERT(structureMap.find(structure) != structureMap.end());
TIntermConstantUnion *asConstantUnion =
indexNode->getRight()->getAsConstantUnion();
ASSERT(asConstantUnion);
const int fieldIndex = asConstantUnion->getIConst(0);
ASSERT(fieldIndex <
static_cast<int>(structureMap.at(structure).fieldMap.size()));
const int mappedFieldIndex = structureMap.at(structure).fieldMap[fieldIndex];
rewritten = new TIntermBinary(EOpIndexDirectStruct, rewritten,
CreateIndexNode(mappedFieldIndex));
}
break;
case EOpIndexDirect:
rewritten = new TIntermBinary(EOpIndexDirect, rewritten, indexNode->getRight());
break;
case EOpIndexIndirect:
{
// Run RewriteExpressionTraverser on the right node. It may itself be an expression
// with a sampler inside that needs to be rewritten, or simply use a field of a
// struct that's remapped.
TIntermTyped *indexExpression = indexNode->getRight();
RewriteIndexExpression(compiler, indexExpression, structureMap, structureUniformMap,
extractedSamplers);
rewritten = new TIntermBinary(EOpIndexIndirect, rewritten, indexExpression);
break;
}
default:
UNREACHABLE();
break;
}
}
return rewritten;
}
class RewriteStructSamplersTraverser final : public TIntermTraverser
{
public:
explicit RewriteStructSamplersTraverser(TCompiler *compiler, TSymbolTable *symbolTable)
: TIntermTraverser(true, false, false, symbolTable),
mCompiler(compiler),
mRemovedUniformsCount(0)
{}
int removedUniformsCount() const { return mRemovedUniformsCount; }
// Each struct sampler declaration is stripped of its samplers. New uniforms are added for each
// stripped struct sampler.
bool visitDeclaration(Visit visit, TIntermDeclaration *decl) override
{
if (!mInGlobalScope)
{
return true;
}
const TIntermSequence &sequence = *(decl->getSequence());
TIntermTyped *declarator = sequence.front()->getAsTyped();
const TType &type = declarator->getType();
if (!type.isStructureContainingSamplers())
{
return false;
}
TIntermSequence newSequence;
if (type.isStructSpecifier())
{
// If this is just a struct definition (not a uniform variable declaration of a
// struct type), just remove the samplers. They are not instantiated yet.
const TStructure *structure = type.getStruct();
ASSERT(structure && mStructureMap.find(structure) == mStructureMap.end());
stripStructSpecifierSamplers(structure, &newSequence);
}
else
{
const TStructure *structure = type.getStruct();
// If the structure is defined at the same time, create the mapping to the stripped
// version first.
if (mStructureMap.find(structure) == mStructureMap.end())
{
stripStructSpecifierSamplers(structure, &newSequence);
}
// Then, extract the samplers from the struct and create global-scope variables instead.
TIntermSymbol *asSymbol = declarator->getAsSymbolNode();
ASSERT(asSymbol);
const TVariable &variable = asSymbol->variable();
ASSERT(variable.symbolType() != SymbolType::Empty);
extractStructSamplerUniforms(variable, structure, &newSequence);
}
mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), decl,
std::move(newSequence));
return false;
}
// Same implementation as in RewriteExpressionTraverser. That traverser cannot replace root.
bool visitBinary(Visit visit, TIntermBinary *node) override
{
TIntermTyped *rewritten = RewriteExpressionVisitBinaryHelper(
mCompiler, node, mStructureMap, mStructureUniformMap, mExtractedSamplers);
if (rewritten == nullptr)
{
return true;
}
queueReplacement(rewritten, OriginalNode::IS_DROPPED);
// Don't iterate as the expression is rewritten.
return false;
}
// Same implementation as in RewriteExpressionTraverser. That traverser cannot replace root.
void visitSymbol(TIntermSymbol *node) override
{
ASSERT(mStructureUniformMap.find(&node->variable()) == mStructureUniformMap.end());
}
private:
// Removes all samplers from a struct specifier.
void stripStructSpecifierSamplers(const TStructure *structure, TIntermSequence *newSequence)
{
TFieldList *newFieldList = new TFieldList;
ASSERT(structure->containsSamplers());
// Add this struct to the struct map
ASSERT(mStructureMap.find(structure) == mStructureMap.end());
StructureData *modifiedData = &mStructureMap[structure];
modifiedData->modified = nullptr;
modifiedData->fieldMap.resize(structure->fields().size(), std::numeric_limits<int>::max());
for (size_t fieldIndex = 0; fieldIndex < structure->fields().size(); ++fieldIndex)
{
const TField *field = structure->fields()[fieldIndex];
const TType &fieldType = *field->type();
// If the field is a sampler, or a struct that's entirely removed, skip it.
if (!fieldType.isSampler() && !isRemovedStructType(fieldType))
{
TType *newType = nullptr;
// Otherwise, if it's a struct that's replaced, create a new field of the replaced
// type.
if (fieldType.isStructureContainingSamplers())
{
const TStructure *fieldStruct = fieldType.getStruct();
ASSERT(mStructureMap.find(fieldStruct) != mStructureMap.end());
const TStructure *modifiedStruct = mStructureMap[fieldStruct].modified;
ASSERT(modifiedStruct);
newType = new TType(modifiedStruct, true);
if (fieldType.isArray())
{
newType->makeArrays(fieldType.getArraySizes());
}
}
else
{
// If not, duplicate the field as is.
newType = new TType(fieldType);
}
// Record the mapping of the field indices, so future EOpIndexDirectStruct's into
// this struct can be fixed up.
modifiedData->fieldMap[fieldIndex] = static_cast<int>(newFieldList->size());
TField *newField =
new TField(newType, field->name(), field->line(), field->symbolType());
newFieldList->push_back(newField);
}
}
// Prune empty structs.
if (newFieldList->empty())
{
return;
}
// Declare a new struct with the same name and the new fields.
modifiedData->modified =
new TStructure(mSymbolTable, structure->name(), newFieldList, structure->symbolType());
TType *newStructType = new TType(modifiedData->modified, true);
TVariable *newStructVar =
new TVariable(mSymbolTable, kEmptyImmutableString, newStructType, SymbolType::Empty);
TIntermSymbol *newStructRef = new TIntermSymbol(newStructVar);
TIntermDeclaration *structDecl = new TIntermDeclaration;
structDecl->appendDeclarator(newStructRef);
newSequence->push_back(structDecl);
}
// Returns true if the type is a struct that was removed because we extracted all the members.
bool isRemovedStructType(const TType &type) const
{
const TStructure *structure = type.getStruct();
if (structure == nullptr)
{
// Not a struct
return false;
}
// A struct is removed if it is in the map, but doesn't have a replacement struct.
auto iter = mStructureMap.find(structure);
return iter != mStructureMap.end() && iter->second.modified == nullptr;
}
// Removes samplers from struct uniforms. For each sampler removed also adds a new globally
// defined sampler uniform.
void extractStructSamplerUniforms(const TVariable &variable,
const TStructure *structure,
TIntermSequence *newSequence)
{
ASSERT(structure->containsSamplers());
ASSERT(mStructureMap.find(structure) != mStructureMap.end());
const TType &type = variable.getType();
enterArray(type);
for (const TField *field : structure->fields())
{
extractFieldSamplers(variable.name().data(), field, newSequence);
}
// If there's a replacement structure (because there are non-sampler fields in the struct),
// add a declaration with that type.
const TStructure *modified = mStructureMap[structure].modified;
if (modified != nullptr)
{
TType *newType = new TType(modified, false);
if (type.isArray())
{
newType->makeArrays(type.getArraySizes());
}
newType->setQualifier(EvqUniform);
const TVariable *newVariable =
new TVariable(mSymbolTable, variable.name(), newType, variable.symbolType());
TIntermDeclaration *newDecl = new TIntermDeclaration();
newDecl->appendDeclarator(new TIntermSymbol(newVariable));
newSequence->push_back(newDecl);
ASSERT(mStructureUniformMap.find(&variable) == mStructureUniformMap.end());
mStructureUniformMap[&variable] = newVariable;
}
else
{
mRemovedUniformsCount++;
}
exitArray(type);
}
// Extracts samplers from a field of a struct. Works with nested structs and arrays.
void extractFieldSamplers(const std::string &prefix,
const TField *field,
TIntermSequence *newSequence)
{
const TType &fieldType = *field->type();
if (fieldType.isSampler() || fieldType.isStructureContainingSamplers())
{
std::string newPrefix = prefix + "_" + field->name().data();
if (fieldType.isSampler())
{
extractSampler(newPrefix, fieldType, newSequence);
}
else
{
enterArray(fieldType);
const TStructure *structure = fieldType.getStruct();
for (const TField *nestedField : structure->fields())
{
extractFieldSamplers(newPrefix, nestedField, newSequence);
}
exitArray(fieldType);
}
}
}
void GenerateArraySizesFromStack(TVector<unsigned int> *sizesOut)
{
sizesOut->reserve(mArraySizeStack.size());
for (auto it = mArraySizeStack.rbegin(); it != mArraySizeStack.rend(); ++it)
{
sizesOut->push_back(*it);
}
}
// Extracts a sampler from a struct. Declares the new extracted sampler.
void extractSampler(const std::string &newName,
const TType &fieldType,
TIntermSequence *newSequence)
{
ASSERT(fieldType.isSampler());
TType *newType = new TType(fieldType);
// Add array dimensions accumulated so far due to struct arrays. Note that to support
// nested arrays, mArraySizeStack has the outermost size in the front. |makeArrays| thus
// expects this in reverse order.
TVector<unsigned int> parentArraySizes;
GenerateArraySizesFromStack(&parentArraySizes);
newType->makeArrays(parentArraySizes);
ImmutableStringBuilder nameBuilder(newName.size() + 1);
nameBuilder << newName;
newType->setQualifier(EvqUniform);
TVariable *newVariable =
new TVariable(mSymbolTable, nameBuilder, newType, SymbolType::AngleInternal);
TIntermSymbol *newSymbol = new TIntermSymbol(newVariable);
TIntermDeclaration *samplerDecl = new TIntermDeclaration;
samplerDecl->appendDeclarator(newSymbol);
newSequence->push_back(samplerDecl);
// TODO: Use a temp name instead of generating a name as currently done. There is no
// guarantee that these generated names cannot clash. Create a mapping from the previous
// name to the name assigned to the temp variable so ShaderVariable::mappedName can be
// updated post-transformation. http://anglebug.com/4301
ASSERT(mExtractedSamplers.find(newName) == mExtractedSamplers.end());
mExtractedSamplers[newName] = newVariable;
}
void enterArray(const TType &arrayType)
{
const TSpan<const unsigned int> &arraySizes = arrayType.getArraySizes();
for (auto it = arraySizes.rbegin(); it != arraySizes.rend(); ++it)
{
unsigned int arraySize = *it;
mArraySizeStack.push_back(arraySize);
}
}
void exitArray(const TType &arrayType)
{
mArraySizeStack.resize(mArraySizeStack.size() - arrayType.getNumArraySizes());
}
TCompiler *mCompiler;
int mRemovedUniformsCount;
// Map structures with samplers to ones that have their samplers removed.
StructureMap mStructureMap;
// Map uniform variables of structure type that are replaced with another variable.
StructureUniformMap mStructureUniformMap;
// Map a constructed sampler name to its variable. Used to replace an expression that uses this
// sampler with the extracted one.
ExtractedSamplerMap mExtractedSamplers;
// A stack of array sizes. Used to figure out the array dimensions of the extracted sampler,
// for example when it's nested in an array of structs in an array of structs.
TVector<unsigned int> mArraySizeStack;
};
} // anonymous namespace
bool RewriteStructSamplers(TCompiler *compiler,
TIntermBlock *root,
TSymbolTable *symbolTable,
int *removedUniformsCountOut)
{
RewriteStructSamplersTraverser traverser(compiler, symbolTable);
root->traverse(&traverser);
*removedUniformsCountOut = traverser.removedUniformsCount();
return traverser.updateTree(compiler, root);
}
} // namespace sh