| // |
| // Copyright 2021 The ANGLE Project Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| // |
| // MonomorphizeUnsupportedFunctions: Monomorphize functions that are called with |
| // parameters that are incompatible with both Vulkan GLSL and Metal. |
| // |
| |
| #include "compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.h" |
| |
| #include "compiler/translator/ImmutableStringBuilder.h" |
| #include "compiler/translator/SymbolTable.h" |
| #include "compiler/translator/tree_util/IntermNode_util.h" |
| #include "compiler/translator/tree_util/IntermTraverse.h" |
| #include "compiler/translator/tree_util/ReplaceVariable.h" |
| |
| namespace sh |
| { |
| namespace |
| { |
| struct Argument |
| { |
| size_t argumentIndex; |
| TIntermTyped *argument; |
| }; |
| |
| struct FunctionData |
| { |
| // Whether the original function is used. If this is false, the function can be removed because |
| // all callers have been modified. |
| bool isOriginalUsed; |
| // The original definition of the function, used to create the monomorphized version. |
| TIntermFunctionDefinition *originalDefinition; |
| // List of monomorphized versions of this function. They will be added next to the original |
| // version (or replace it). |
| TVector<TIntermFunctionDefinition *> monomorphizedDefinitions; |
| }; |
| |
| using FunctionMap = angle::HashMap<const TFunction *, FunctionData>; |
| |
| // Traverse the function definitions and initialize the map. Allows visitAggregate to have access |
| // to TIntermFunctionDefinition even when the function is only forward declared at that point. |
| void InitializeFunctionMap(TIntermBlock *root, FunctionMap *functionMapOut) |
| { |
| TIntermSequence &sequence = *root->getSequence(); |
| |
| for (TIntermNode *node : sequence) |
| { |
| TIntermFunctionDefinition *asFuncDef = node->getAsFunctionDefinition(); |
| if (asFuncDef != nullptr) |
| { |
| const TFunction *function = asFuncDef->getFunction(); |
| ASSERT(function && functionMapOut->find(function) == functionMapOut->end()); |
| (*functionMapOut)[function] = FunctionData{false, asFuncDef, {}}; |
| } |
| } |
| } |
| |
| const TVariable *GetBaseUniform(TIntermTyped *node, bool *isSamplerInStructOut) |
| { |
| *isSamplerInStructOut = false; |
| |
| while (node->getAsBinaryNode()) |
| { |
| TIntermBinary *asBinary = node->getAsBinaryNode(); |
| |
| TOperator op = asBinary->getOp(); |
| |
| // No opaque uniform can be inside an interface block. |
| if (op == EOpIndexDirectInterfaceBlock) |
| { |
| return nullptr; |
| } |
| |
| if (op == EOpIndexDirectStruct) |
| { |
| *isSamplerInStructOut = true; |
| } |
| |
| node = asBinary->getLeft(); |
| } |
| |
| // Only interested in uniform opaque types. If a function call within another function uses |
| // opaque uniforms in an unsupported way, it will be replaced in a follow up pass after the |
| // calling function is monomorphized. |
| if (node->getType().getQualifier() != EvqUniform) |
| { |
| return nullptr; |
| } |
| |
| ASSERT(IsOpaqueType(node->getType().getBasicType()) || |
| node->getType().isStructureContainingSamplers()); |
| |
| TIntermSymbol *asSymbol = node->getAsSymbolNode(); |
| ASSERT(asSymbol); |
| |
| return &asSymbol->variable(); |
| } |
| |
| TIntermTyped *ExtractSideEffects(TSymbolTable *symbolTable, |
| TIntermTyped *node, |
| TIntermSequence *replacementIndices) |
| { |
| TIntermTyped *withoutSideEffects = node->deepCopy(); |
| |
| for (TIntermBinary *asBinary = withoutSideEffects->getAsBinaryNode(); asBinary; |
| asBinary = asBinary->getLeft()->getAsBinaryNode()) |
| { |
| TOperator op = asBinary->getOp(); |
| TIntermTyped *index = asBinary->getRight(); |
| |
| if (op == EOpIndexDirectStruct) |
| { |
| break; |
| } |
| |
| // No side effects with constant expressions. |
| if (op == EOpIndexDirect) |
| { |
| ASSERT(index->getAsConstantUnion()); |
| continue; |
| } |
| |
| ASSERT(op == EOpIndexIndirect); |
| |
| // If the index is a symbol, there's no side effect, so leave it as-is. |
| if (index->getAsSymbolNode()) |
| { |
| continue; |
| } |
| |
| // Otherwise create a temp variable initialized with the index and use that temp variable as |
| // the index. |
| TIntermDeclaration *tempDecl = nullptr; |
| TVariable *tempVar = DeclareTempVariable(symbolTable, index, EvqTemporary, &tempDecl); |
| |
| replacementIndices->push_back(tempDecl); |
| asBinary->replaceChildNode(index, new TIntermSymbol(tempVar)); |
| } |
| |
| return withoutSideEffects; |
| } |
| |
| void CreateMonomorphizedFunctionCallArgs(const TIntermSequence &originalCallArguments, |
| const TVector<Argument> &replacedArguments, |
| TIntermSequence *substituteArgsOut) |
| { |
| size_t nextReplacedArg = 0; |
| for (size_t argIndex = 0; argIndex < originalCallArguments.size(); ++argIndex) |
| { |
| if (nextReplacedArg >= replacedArguments.size() || |
| argIndex != replacedArguments[nextReplacedArg].argumentIndex) |
| { |
| // Not replaced, keep argument as is. |
| substituteArgsOut->push_back(originalCallArguments[argIndex]); |
| } |
| else |
| { |
| TIntermTyped *argument = replacedArguments[nextReplacedArg].argument; |
| |
| // Iterate over indices of the argument and create a new arg for every non-const |
| // index. Note that the index itself may be an expression, and it may require further |
| // substitution in the next pass. |
| while (argument->getAsBinaryNode()) |
| { |
| TIntermBinary *asBinary = argument->getAsBinaryNode(); |
| if (asBinary->getOp() == EOpIndexIndirect) |
| { |
| TIntermTyped *index = asBinary->getRight(); |
| substituteArgsOut->push_back(index->deepCopy()); |
| } |
| argument = asBinary->getLeft(); |
| } |
| |
| ++nextReplacedArg; |
| } |
| } |
| } |
| |
| const TFunction *MonomorphizeFunction(TSymbolTable *symbolTable, |
| const TFunction *original, |
| TVector<Argument> *replacedArguments, |
| VariableReplacementMap *argumentMapOut) |
| { |
| TFunction *substituteFunction = |
| new TFunction(symbolTable, kEmptyImmutableString, SymbolType::AngleInternal, |
| &original->getReturnType(), original->isKnownToNotHaveSideEffects()); |
| |
| size_t nextReplacedArg = 0; |
| for (size_t paramIndex = 0; paramIndex < original->getParamCount(); ++paramIndex) |
| { |
| const TVariable *originalParam = original->getParam(paramIndex); |
| |
| if (nextReplacedArg >= replacedArguments->size() || |
| paramIndex != (*replacedArguments)[nextReplacedArg].argumentIndex) |
| { |
| TVariable *substituteArgument = |
| new TVariable(symbolTable, originalParam->name(), &originalParam->getType(), |
| originalParam->symbolType()); |
| // Not replaced, add an identical parameter. |
| substituteFunction->addParameter(substituteArgument); |
| (*argumentMapOut)[originalParam] = new TIntermSymbol(substituteArgument); |
| } |
| else |
| { |
| TIntermTyped *substituteArgument = (*replacedArguments)[nextReplacedArg].argument; |
| (*argumentMapOut)[originalParam] = substituteArgument; |
| |
| // Iterate over indices of the argument and create a new parameter for every non-const |
| // index (which may be an expression). Replace the symbol in the argument with a |
| // variable of the index type. This is later used to replace the parameter in the |
| // function body. |
| while (substituteArgument->getAsBinaryNode()) |
| { |
| TIntermBinary *asBinary = substituteArgument->getAsBinaryNode(); |
| if (asBinary->getOp() == EOpIndexIndirect) |
| { |
| TIntermTyped *index = asBinary->getRight(); |
| TType *indexType = new TType(index->getType()); |
| indexType->setQualifier(EvqParamIn); |
| |
| TVariable *param = new TVariable(symbolTable, kEmptyImmutableString, indexType, |
| SymbolType::AngleInternal); |
| substituteFunction->addParameter(param); |
| |
| // The argument now uses the function parameters as indices. |
| asBinary->replaceChildNode(asBinary->getRight(), new TIntermSymbol(param)); |
| } |
| substituteArgument = asBinary->getLeft(); |
| } |
| |
| ++nextReplacedArg; |
| } |
| } |
| |
| return substituteFunction; |
| } |
| |
| class MonomorphizeTraverser final : public TIntermTraverser |
| { |
| public: |
| explicit MonomorphizeTraverser(TCompiler *compiler, |
| TSymbolTable *symbolTable, |
| const ShCompileOptions &compileOptions, |
| UnsupportedFunctionArgsBitSet unsupportedFunctionArgs, |
| FunctionMap *functionMap) |
| : TIntermTraverser(true, false, false, symbolTable), |
| mCompiler(compiler), |
| mCompileOptions(compileOptions), |
| mUnsupportedFunctionArgs(unsupportedFunctionArgs), |
| mFunctionMap(functionMap) |
| {} |
| |
| bool visitAggregate(Visit visit, TIntermAggregate *node) override |
| { |
| if (node->getOp() != EOpCallFunctionInAST) |
| { |
| return true; |
| } |
| |
| const TFunction *function = node->getFunction(); |
| ASSERT(function && mFunctionMap->find(function) != mFunctionMap->end()); |
| |
| FunctionData &data = (*mFunctionMap)[function]; |
| |
| TIntermFunctionDefinition *monomorphized = |
| processFunctionCall(node, data.originalDefinition, &data.isOriginalUsed); |
| if (monomorphized) |
| { |
| data.monomorphizedDefinitions.push_back(monomorphized); |
| } |
| |
| return true; |
| } |
| |
| bool getAnyMonomorphized() const { return mAnyMonomorphized; } |
| |
| private: |
| bool isUnsupportedArgument(TIntermTyped *callArgument, const TVariable *funcArgument) const |
| { |
| // Only interested in opaque uniforms and structs that contain samplers. |
| const bool isOpaqueType = IsOpaqueType(funcArgument->getType().getBasicType()); |
| const bool isStructContainingSamplers = |
| funcArgument->getType().isStructureContainingSamplers(); |
| if (!isOpaqueType && !isStructContainingSamplers) |
| { |
| return false; |
| } |
| |
| // If not uniform (the variable was itself a function parameter), don't process it in |
| // this pass, as we don't know which actual uniform it corresponds to. |
| bool isSamplerInStruct = false; |
| const TVariable *uniform = GetBaseUniform(callArgument, &isSamplerInStruct); |
| if (uniform == nullptr) |
| { |
| return false; |
| } |
| |
| const TType &type = uniform->getType(); |
| |
| if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::StructContainingSamplers]) |
| { |
| // Monomorphize if the parameter is a structure that contains samplers (so in |
| // RewriteStructSamplers we don't need to rewrite the functions to accept multiple |
| // parameters split from the struct). |
| if (isStructContainingSamplers) |
| { |
| return true; |
| } |
| } |
| |
| if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::ArrayOfArrayOfSamplerOrImage]) |
| { |
| // Monomorphize if: |
| // |
| // - The opaque uniform is a sampler in a struct (which can create an array-of-array |
| // situation), and the function expects an array of samplers, or |
| // |
| // - The opaque uniform is an array of array of sampler or image, and it's partially |
| // subscripted (i.e. the function itself expects an array) |
| // |
| const bool isParameterArrayOfOpaqueType = funcArgument->getType().isArray(); |
| const bool isArrayOfArrayOfSamplerOrImage = |
| (type.isSampler() || type.isImage()) && type.isArrayOfArrays(); |
| if (isSamplerInStruct && isParameterArrayOfOpaqueType) |
| { |
| return true; |
| } |
| if (isArrayOfArrayOfSamplerOrImage && isParameterArrayOfOpaqueType) |
| { |
| return true; |
| } |
| } |
| |
| if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::AtomicCounter]) |
| { |
| if (type.isAtomicCounter()) |
| { |
| return true; |
| } |
| } |
| |
| if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::SamplerCubeEmulation]) |
| { |
| // Monomorphize if the opaque uniform is a samplerCube and ES2's cube sampling emulation |
| // is requested. |
| if (type.isSamplerCube() && mCompileOptions.emulateSeamfulCubeMapSampling) |
| { |
| return true; |
| } |
| } |
| |
| if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::Image]) |
| { |
| if (type.isImage()) |
| { |
| return true; |
| } |
| } |
| |
| if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::PixelLocalStorage]) |
| { |
| if (type.isPixelLocal()) |
| { |
| return true; |
| } |
| } |
| |
| return false; |
| } |
| |
| TIntermFunctionDefinition *processFunctionCall(TIntermAggregate *functionCall, |
| TIntermFunctionDefinition *originalDefinition, |
| bool *isOriginalUsedOut) |
| { |
| const TFunction *function = functionCall->getFunction(); |
| const TIntermSequence &callArguments = *functionCall->getSequence(); |
| |
| TVector<Argument> replacedArguments; |
| TIntermSequence replacementIndices; |
| |
| // Go through function call arguments, and see if any is used in an unsupported way. |
| for (size_t argIndex = 0; argIndex < callArguments.size(); ++argIndex) |
| { |
| TIntermTyped *callArgument = callArguments[argIndex]->getAsTyped(); |
| const TVariable *funcArgument = function->getParam(argIndex); |
| if (isUnsupportedArgument(callArgument, funcArgument)) |
| { |
| // Copy the argument and extract the side effects. |
| TIntermTyped *argument = |
| ExtractSideEffects(mSymbolTable, callArgument, &replacementIndices); |
| |
| replacedArguments.push_back({argIndex, argument}); |
| } |
| } |
| |
| if (replacedArguments.empty()) |
| { |
| *isOriginalUsedOut = true; |
| return nullptr; |
| } |
| |
| mAnyMonomorphized = true; |
| |
| insertStatementsInParentBlock(replacementIndices); |
| |
| // Create the arguments for the substitute function call. Done before monomorphizing the |
| // function, which transforms the arguments to what needs to be replaced in the function |
| // body. |
| TIntermSequence newCallArgs; |
| CreateMonomorphizedFunctionCallArgs(callArguments, replacedArguments, &newCallArgs); |
| |
| // Duplicate the function and substitute the replaced arguments with only the non-const |
| // indices. Additionally, substitute the non-const indices of arguments with the new |
| // function parameters. |
| VariableReplacementMap argumentMap; |
| const TFunction *monomorphized = |
| MonomorphizeFunction(mSymbolTable, function, &replacedArguments, &argumentMap); |
| |
| // Replace this function call with a call to the new one. |
| queueReplacement(TIntermAggregate::CreateFunctionCall(*monomorphized, &newCallArgs), |
| OriginalNode::IS_DROPPED); |
| |
| // Create a new function definition, with the body of the old function but with the replaced |
| // parameters substituted with the calling expressions. |
| TIntermFunctionPrototype *substitutePrototype = new TIntermFunctionPrototype(monomorphized); |
| TIntermBlock *substituteBlock = originalDefinition->getBody()->deepCopy(); |
| GetDeclaratorReplacements(mSymbolTable, substituteBlock, &argumentMap); |
| bool valid = ReplaceVariables(mCompiler, substituteBlock, argumentMap); |
| ASSERT(valid); |
| |
| return new TIntermFunctionDefinition(substitutePrototype, substituteBlock); |
| } |
| |
| TCompiler *mCompiler; |
| const ShCompileOptions &mCompileOptions; |
| UnsupportedFunctionArgsBitSet mUnsupportedFunctionArgs; |
| bool mAnyMonomorphized = false; |
| |
| // Map of original to monomorphized functions. |
| FunctionMap *mFunctionMap; |
| }; |
| |
| class UpdateFunctionsDefinitionsTraverser final : public TIntermTraverser |
| { |
| public: |
| explicit UpdateFunctionsDefinitionsTraverser(TSymbolTable *symbolTable, |
| const FunctionMap &functionMap) |
| : TIntermTraverser(true, false, false, symbolTable), mFunctionMap(functionMap) |
| {} |
| |
| void visitFunctionPrototype(TIntermFunctionPrototype *node) override |
| { |
| const bool isInFunctionDefinition = getParentNode()->getAsFunctionDefinition() != nullptr; |
| if (isInFunctionDefinition) |
| { |
| return; |
| } |
| |
| // Add to and possibly replace the function prototype with replacement prototypes. |
| const TFunction *function = node->getFunction(); |
| ASSERT(function && mFunctionMap.find(function) != mFunctionMap.end()); |
| |
| const FunctionData &data = mFunctionMap.at(function); |
| |
| // If nothing to do, leave it be. |
| if (data.monomorphizedDefinitions.empty()) |
| { |
| ASSERT(data.isOriginalUsed); |
| return; |
| } |
| |
| // Replace the prototype with itself (if function is still used) as well as any |
| // monomorphized versions. |
| TIntermSequence replacement; |
| if (data.isOriginalUsed) |
| { |
| replacement.push_back(node); |
| } |
| for (TIntermFunctionDefinition *monomorphizedDefinition : data.monomorphizedDefinitions) |
| { |
| replacement.push_back(new TIntermFunctionPrototype( |
| monomorphizedDefinition->getFunctionPrototype()->getFunction())); |
| } |
| mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node, |
| std::move(replacement)); |
| } |
| |
| bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override |
| { |
| // Add to and possibly replace the function definition with replacement definitions. |
| const TFunction *function = node->getFunction(); |
| ASSERT(function && mFunctionMap.find(function) != mFunctionMap.end()); |
| |
| const FunctionData &data = mFunctionMap.at(function); |
| |
| // If nothing to do, leave it be. |
| if (data.monomorphizedDefinitions.empty()) |
| { |
| ASSERT(data.isOriginalUsed || function->name() == "main"); |
| return false; |
| } |
| |
| // Replace the definition with itself (if function is still used) as well as any |
| // monomorphized versions. |
| TIntermSequence replacement; |
| if (data.isOriginalUsed) |
| { |
| replacement.push_back(node); |
| } |
| for (TIntermFunctionDefinition *monomorphizedDefinition : data.monomorphizedDefinitions) |
| { |
| replacement.push_back(monomorphizedDefinition); |
| } |
| mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node, |
| std::move(replacement)); |
| |
| return false; |
| } |
| |
| private: |
| const FunctionMap &mFunctionMap; |
| }; |
| |
| void SortDeclarations(TIntermBlock *root) |
| { |
| TIntermSequence *original = root->getSequence(); |
| |
| TIntermSequence replacement; |
| TIntermSequence functionDefs; |
| |
| // Accumulate non-function-definition declarations in |replacement| and function definitions in |
| // |functionDefs|. |
| for (TIntermNode *node : *original) |
| { |
| if (node->getAsFunctionDefinition() || node->getAsFunctionPrototypeNode()) |
| { |
| functionDefs.push_back(node); |
| } |
| else |
| { |
| replacement.push_back(node); |
| } |
| } |
| |
| // Append function definitions to |replacement|. |
| replacement.insert(replacement.end(), functionDefs.begin(), functionDefs.end()); |
| |
| // Replace root's sequence with |replacement|. |
| root->replaceAllChildren(replacement); |
| } |
| |
| bool MonomorphizeUnsupportedFunctionsImpl(TCompiler *compiler, |
| TIntermBlock *root, |
| TSymbolTable *symbolTable, |
| const ShCompileOptions &compileOptions, |
| UnsupportedFunctionArgsBitSet unsupportedFunctionArgs) |
| { |
| // First, sort out the declarations such that all non-function declarations are placed before |
| // function definitions. This way when the function is replaced with one that references said |
| // declarations (i.e. uniforms), the uniform declaration is already present above it. |
| SortDeclarations(root); |
| |
| while (true) |
| { |
| FunctionMap functionMap; |
| InitializeFunctionMap(root, &functionMap); |
| |
| MonomorphizeTraverser monomorphizer(compiler, symbolTable, compileOptions, |
| unsupportedFunctionArgs, &functionMap); |
| root->traverse(&monomorphizer); |
| |
| if (!monomorphizer.getAnyMonomorphized()) |
| { |
| break; |
| } |
| |
| if (!monomorphizer.updateTree(compiler, root)) |
| { |
| return false; |
| } |
| |
| UpdateFunctionsDefinitionsTraverser functionUpdater(symbolTable, functionMap); |
| root->traverse(&functionUpdater); |
| |
| if (!functionUpdater.updateTree(compiler, root)) |
| { |
| return false; |
| } |
| } |
| |
| return true; |
| } |
| } // anonymous namespace |
| |
| bool MonomorphizeUnsupportedFunctions(TCompiler *compiler, |
| TIntermBlock *root, |
| TSymbolTable *symbolTable, |
| const ShCompileOptions &compileOptions, |
| UnsupportedFunctionArgsBitSet unsupportedFunctionArgs) |
| { |
| // This function actually applies multiple transformation, and the AST may not be valid until |
| // the transformations are entirely done. Some validation is momentarily disabled. |
| bool enableValidateFunctionCall = compiler->disableValidateFunctionCall(); |
| |
| bool result = MonomorphizeUnsupportedFunctionsImpl(compiler, root, symbolTable, compileOptions, |
| unsupportedFunctionArgs); |
| |
| compiler->restoreValidateFunctionCall(enableValidateFunctionCall); |
| return result && compiler->validateAST(root); |
| } |
| } // namespace sh |