Vulkan: support dynamic indices in array of arrays

Expands existing struct-sampler rewrite to flatten arrays of arrays.
This allows us to support dynamically-uniform array indexing, which is
core in ES 3.2.

Samplers inside (possibly nested) structs are broken apart as before,
and then if the type resulting from merging the array sizes of the field
and its containing structs is an array of array, the array is flattened.

Also adds an offset parameter to functions taking in arrays to account
for this translation.

As a result of outer array sizes leaking into function signatures,
functions taking arrays of different sizes are duplicated according to
how the function is invoked.

Bug: angleproject:3604
Change-Id: Ic9373fd12a38f19bd811eac92e281055a63c1901
Reviewed-on: https://chromium-review.googlesource.com/c/angle/angle/+/1744177
Commit-Queue: James Dong <dongja@google.com>
Reviewed-by: Shahbaz Youssefi <syoussefi@chromium.org>
diff --git a/include/GLSLANG/ShaderLang.h b/include/GLSLANG/ShaderLang.h
index e8da567..284f1e6 100644
--- a/include/GLSLANG/ShaderLang.h
+++ b/include/GLSLANG/ShaderLang.h
@@ -26,7 +26,7 @@
 
 // Version number for shader translation API.
 // It is incremented every time the API changes.
-#define ANGLE_SH_VERSION 213
+#define ANGLE_SH_VERSION 214
 
 enum ShShaderSpec
 {
@@ -300,6 +300,10 @@
 // If requested, validates the AST after every transformation.  Useful for debugging.
 const ShCompileOptions SH_VALIDATE_AST = UINT64_C(1) << 46;
 
+// Use old version of RewriteStructSamplers, which doesn't produce as many
+// sampler arrays in parameters. This causes a few tests to pass on Android.
+const ShCompileOptions SH_USE_OLD_REWRITE_STRUCT_SAMPLERS = UINT64_C(1) << 47;
+
 // Defines alternate strategies for implementing array index clamping.
 enum ShArrayIndexClampingStrategy
 {
diff --git a/include/platform/FeaturesVk.h b/include/platform/FeaturesVk.h
index 9ca5c39..4fd61ae 100644
--- a/include/platform/FeaturesVk.h
+++ b/include/platform/FeaturesVk.h
@@ -200,6 +200,14 @@
         "disallow_seamful_cube_map_emulation", FeatureCategory::VulkanWorkarounds,
         "Seamful cube map emulation misbehaves on the AMD windows driver, so it's disallowed",
         &members, "http://anglebug.com/3243"};
+
+    // Qualcomm shader compiler doesn't support sampler arrays as parameters, so
+    // revert to old RewriteStructSamplers behavior, which produces fewer.
+    Feature forceOldRewriteStructSamplers = {
+        "force_old_rewrite_struct_samplers", FeatureCategory::VulkanWorkarounds,
+        "Qualcomm shader compiler doesn't support sampler arrays as parameters, so "
+        "revert to old RewriteStructSamplers behavior, which produces fewer.",
+        &members, "http://anglebug.com/2703"};
 };
 
 inline FeaturesVk::FeaturesVk()  = default;
diff --git a/src/compiler.gni b/src/compiler.gni
index 929a081..089fe2b 100644
--- a/src/compiler.gni
+++ b/src/compiler.gni
@@ -168,6 +168,7 @@
   "src/compiler/translator/tree_ops/RewriteExpressionsWithShaderStorageBlock.h",
   "src/compiler/translator/tree_ops/RewriteStructSamplers.cpp",
   "src/compiler/translator/tree_ops/RewriteStructSamplers.h",
+  "src/compiler/translator/tree_ops/RewriteStructSamplersOld.cpp",
   "src/compiler/translator/tree_ops/RewriteRepeatedAssignToSwizzled.cpp",
   "src/compiler/translator/tree_ops/RewriteRepeatedAssignToSwizzled.h",
   "src/compiler/translator/tree_ops/RewriteRowMajorMatrices.cpp",
diff --git a/src/compiler/translator/IntermNode.cpp b/src/compiler/translator/IntermNode.cpp
index 17c23f0..c551615 100644
--- a/src/compiler/translator/IntermNode.cpp
+++ b/src/compiler/translator/IntermNode.cpp
@@ -268,6 +268,10 @@
     return false;
 }
 
+TIntermBranch::TIntermBranch(const TIntermBranch &node)
+    : TIntermBranch(node.mFlowOp, node.mExpression->deepCopy())
+{}
+
 size_t TIntermBranch::getChildCount() const
 {
     return (mExpression ? 1 : 0);
@@ -401,6 +405,14 @@
     return replaceChildNodeInternal(original, replacement);
 }
 
+TIntermBlock::TIntermBlock(const TIntermBlock &node)
+{
+    for (TIntermNode *node : node.mStatements)
+    {
+        mStatements.push_back(node->deepCopy());
+    }
+}
+
 size_t TIntermBlock::getChildCount() const
 {
     return mStatements.size();
@@ -954,6 +966,8 @@
     return false;
 }
 
+TIntermCase::TIntermCase(const TIntermCase &node) : TIntermCase(node.mCondition->deepCopy()) {}
+
 size_t TIntermCase::getChildCount() const
 {
     return (mCondition ? 1 : 0);
@@ -1326,6 +1340,11 @@
     setLine(line);
 }
 
+TIntermInvariantDeclaration::TIntermInvariantDeclaration(const TIntermInvariantDeclaration &node)
+    : TIntermInvariantDeclaration(static_cast<TIntermSymbol *>(node.mSymbol->deepCopy()),
+                                  node.mLine)
+{}
+
 TIntermTernary::TIntermTernary(TIntermTyped *cond,
                                TIntermTyped *trueExpression,
                                TIntermTyped *falseExpression)
@@ -1357,6 +1376,14 @@
     }
 }
 
+TIntermLoop::TIntermLoop(const TIntermLoop &node)
+    : TIntermLoop(node.mType,
+                  node.mInit->deepCopy(),
+                  node.mCond->deepCopy(),
+                  node.mExpr->deepCopy(),
+                  node.mBody->deepCopy())
+{}
+
 TIntermIfElse::TIntermIfElse(TIntermTyped *cond, TIntermBlock *trueB, TIntermBlock *falseB)
     : TIntermNode(), mCondition(cond), mTrueBlock(trueB), mFalseBlock(falseB)
 {
@@ -1368,6 +1395,12 @@
     }
 }
 
+TIntermIfElse::TIntermIfElse(const TIntermIfElse &node)
+    : TIntermIfElse(node.mCondition->deepCopy(),
+                    node.mTrueBlock->deepCopy(),
+                    node.mFalseBlock ? node.mFalseBlock->deepCopy() : nullptr)
+{}
+
 TIntermSwitch::TIntermSwitch(TIntermTyped *init, TIntermBlock *statementList)
     : TIntermNode(), mInit(init), mStatementList(statementList)
 {
@@ -1375,6 +1408,10 @@
     ASSERT(mStatementList);
 }
 
+TIntermSwitch::TIntermSwitch(const TIntermSwitch &node)
+    : TIntermSwitch(node.mInit->deepCopy(), node.mStatementList->deepCopy())
+{}
+
 void TIntermSwitch::setStatementList(TIntermBlock *statementList)
 {
     ASSERT(statementList);
@@ -3772,6 +3809,10 @@
     : mDirective(directive), mCommand(std::move(command))
 {}
 
+TIntermPreprocessorDirective::TIntermPreprocessorDirective(const TIntermPreprocessorDirective &node)
+    : TIntermPreprocessorDirective(node.mDirective, node.mCommand)
+{}
+
 TIntermPreprocessorDirective::~TIntermPreprocessorDirective() = default;
 
 size_t TIntermPreprocessorDirective::getChildCount() const
diff --git a/src/compiler/translator/IntermNode.h b/src/compiler/translator/IntermNode.h
index 33a8a6d..561b2f1 100644
--- a/src/compiler/translator/IntermNode.h
+++ b/src/compiler/translator/IntermNode.h
@@ -104,6 +104,8 @@
     virtual TIntermBranch *getAsBranchNode() { return nullptr; }
     virtual TIntermPreprocessorDirective *getAsPreprocessorDirective() { return nullptr; }
 
+    virtual TIntermNode *deepCopy() const = 0;
+
     virtual size_t getChildCount() const                  = 0;
     virtual TIntermNode *getChildNode(size_t index) const = 0;
     // Replace a child node. Return true if |original| is a child
@@ -131,7 +133,7 @@
   public:
     TIntermTyped() {}
 
-    virtual TIntermTyped *deepCopy() const = 0;
+    virtual TIntermTyped *deepCopy() const override = 0;
 
     TIntermTyped *getAsTyped() override { return this; }
 
@@ -211,12 +213,17 @@
     void setExpression(TIntermTyped *expression) { mExpr = expression; }
     void setBody(TIntermBlock *body) { mBody = body; }
 
+    virtual TIntermLoop *deepCopy() const override { return new TIntermLoop(*this); }
+
   protected:
     TLoopType mType;
     TIntermNode *mInit;   // for-loop initialization
     TIntermTyped *mCond;  // loop exit condition
     TIntermTyped *mExpr;  // for-loop expression
     TIntermBlock *mBody;  // loop body
+
+  private:
+    TIntermLoop(const TIntermLoop &);
 };
 
 //
@@ -237,9 +244,14 @@
     TOperator getFlowOp() { return mFlowOp; }
     TIntermTyped *getExpression() { return mExpression; }
 
+    virtual TIntermBranch *deepCopy() const override { return new TIntermBranch(*this); }
+
   protected:
     TOperator mFlowOp;
     TIntermTyped *mExpression;  // zero except for "return exp;" statements
+
+  private:
+    TIntermBranch(const TIntermBranch &);
 };
 
 // Nodes that correspond to variable symbols in the source code. These may be regular variables or
@@ -676,8 +688,13 @@
     TIntermSequence *getSequence() override { return &mStatements; }
     const TIntermSequence *getSequence() const override { return &mStatements; }
 
+    TIntermBlock *deepCopy() const override { return new TIntermBlock(*this); }
+
   protected:
     TIntermSequence mStatements;
+
+  private:
+    TIntermBlock(const TIntermBlock &);
 };
 
 // Function prototype. May be in the AST either as a function prototype declaration or as a part of
@@ -740,6 +757,12 @@
 
     const TFunction *getFunction() const { return mPrototype->getFunction(); }
 
+    TIntermNode *deepCopy() const override
+    {
+        UNREACHABLE();
+        return nullptr;
+    }
+
   private:
     TIntermFunctionPrototype *mPrototype;
     TIntermBlock *mBody;
@@ -767,6 +790,12 @@
     TIntermSequence *getSequence() override { return &mDeclarators; }
     const TIntermSequence *getSequence() const override { return &mDeclarators; }
 
+    TIntermNode *deepCopy() const override
+    {
+        UNREACHABLE();
+        return nullptr;
+    }
+
   protected:
     TIntermSequence mDeclarators;
 };
@@ -786,8 +815,15 @@
     TIntermNode *getChildNode(size_t index) const final;
     bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
 
+    TIntermInvariantDeclaration *deepCopy() const override
+    {
+        return new TIntermInvariantDeclaration(*this);
+    }
+
   private:
     TIntermSymbol *mSymbol;
+
+    TIntermInvariantDeclaration(const TIntermInvariantDeclaration &);
 };
 
 // For ternary operators like a ? b : c.
@@ -845,10 +881,15 @@
     TIntermBlock *getTrueBlock() const { return mTrueBlock; }
     TIntermBlock *getFalseBlock() const { return mFalseBlock; }
 
+    TIntermIfElse *deepCopy() const override { return new TIntermIfElse(*this); }
+
   protected:
     TIntermTyped *mCondition;
     TIntermBlock *mTrueBlock;
     TIntermBlock *mFalseBlock;
+
+  private:
+    TIntermIfElse(const TIntermIfElse &);
 };
 
 //
@@ -872,9 +913,14 @@
     // Must be called with a non-null statementList.
     void setStatementList(TIntermBlock *statementList);
 
+    TIntermSwitch *deepCopy() const override { return new TIntermSwitch(*this); }
+
   protected:
     TIntermTyped *mInit;
     TIntermBlock *mStatementList;
+
+  private:
+    TIntermSwitch(const TIntermSwitch &);
 };
 
 //
@@ -895,8 +941,13 @@
     bool hasCondition() const { return mCondition != nullptr; }
     TIntermTyped *getCondition() const { return mCondition; }
 
+    TIntermCase *deepCopy() const override { return new TIntermCase(*this); }
+
   protected:
     TIntermTyped *mCondition;
+
+  private:
+    TIntermCase(const TIntermCase &);
 };
 
 //
@@ -930,9 +981,16 @@
     PreprocessorDirective getDirective() const { return mDirective; }
     const ImmutableString &getCommand() const { return mCommand; }
 
+    TIntermPreprocessorDirective *deepCopy() const override
+    {
+        return new TIntermPreprocessorDirective(*this);
+    }
+
   private:
     PreprocessorDirective mDirective;
     ImmutableString mCommand;
+
+    TIntermPreprocessorDirective(const TIntermPreprocessorDirective &);
 };
 
 }  // namespace sh
diff --git a/src/compiler/translator/Symbol.cpp b/src/compiler/translator/Symbol.cpp
index a6d865b..d0389af 100644
--- a/src/compiler/translator/Symbol.cpp
+++ b/src/compiler/translator/Symbol.cpp
@@ -222,12 +222,28 @@
     return SymbolType() == SymbolType::BuiltIn && name().beginsWith(kAtomicCounterName);
 }
 
-bool TFunction::hasSamplerInStructParams() const
+bool TFunction::hasSamplerInStructOrArrayParams() const
 {
     for (size_t paramIndex = 0; paramIndex < mParamCount; ++paramIndex)
     {
         const TVariable *param = getParam(paramIndex);
-        if (param->getType().isStructureContainingSamplers())
+        if (param->getType().isStructureContainingSamplers() ||
+            (param->getType().isArray() && param->getType().isSampler()))
+        {
+            return true;
+        }
+    }
+
+    return false;
+}
+
+bool TFunction::hasSamplerInStructOrArrayOfArrayParams() const
+{
+    for (size_t paramIndex = 0; paramIndex < mParamCount; ++paramIndex)
+    {
+        const TVariable *param = getParam(paramIndex);
+        if (param->getType().isStructureContainingSamplers() ||
+            (param->getType().isArrayOfArrays() && param->getType().isSampler()))
         {
             return true;
         }
diff --git a/src/compiler/translator/Symbol.h b/src/compiler/translator/Symbol.h
index e317250..b9cf908 100644
--- a/src/compiler/translator/Symbol.h
+++ b/src/compiler/translator/Symbol.h
@@ -234,7 +234,8 @@
     bool isMain() const;
     bool isImageFunction() const;
     bool isAtomicCounterFunction() const;
-    bool hasSamplerInStructParams() const;
+    bool hasSamplerInStructOrArrayParams() const;
+    bool hasSamplerInStructOrArrayOfArrayParams() const;
 
     // Note: Only to be used for static built-in functions!
     constexpr TFunction(const TSymbolUniqueId &id,
diff --git a/src/compiler/translator/TranslatorVulkan.cpp b/src/compiler/translator/TranslatorVulkan.cpp
index 5a99428..1770fb4 100644
--- a/src/compiler/translator/TranslatorVulkan.cpp
+++ b/src/compiler/translator/TranslatorVulkan.cpp
@@ -671,9 +671,9 @@
     }
 
     // Write out default uniforms into a uniform block assigned to a specific set/binding.
-    int defaultUniformCount        = 0;
-    int structTypesUsedForUniforms = 0;
-    int atomicCounterCount         = 0;
+    int defaultUniformCount           = 0;
+    int aggregateTypesUsedForUniforms = 0;
+    int atomicCounterCount            = 0;
     for (const auto &uniform : getUniforms())
     {
         if (!uniform.isBuiltIn() && uniform.staticUse && !gl::IsOpaqueType(uniform.type))
@@ -681,9 +681,9 @@
             ++defaultUniformCount;
         }
 
-        if (uniform.isStruct())
+        if (uniform.isStruct() || uniform.isArrayOfArrays())
         {
-            ++structTypesUsedForUniforms;
+            ++aggregateTypesUsedForUniforms;
         }
 
         if (gl::IsAtomicCounterType(uniform.type))
@@ -694,15 +694,28 @@
 
     // TODO(lucferron): Refactor this function to do fewer tree traversals.
     // http://anglebug.com/2461
-    if (structTypesUsedForUniforms > 0)
+    if (aggregateTypesUsedForUniforms > 0)
     {
         if (!NameEmbeddedStructUniforms(this, root, &getSymbolTable()))
         {
             return false;
         }
 
-        int removedUniformsCount = 0;
-        if (!RewriteStructSamplers(this, root, &getSymbolTable(), &removedUniformsCount))
+        bool rewriteStructSamplersResult;
+        int removedUniformsCount;
+
+        if (compileOptions & SH_USE_OLD_REWRITE_STRUCT_SAMPLERS)
+        {
+            rewriteStructSamplersResult =
+                RewriteStructSamplersOld(this, root, &getSymbolTable(), &removedUniformsCount);
+        }
+        else
+        {
+            rewriteStructSamplersResult =
+                RewriteStructSamplers(this, root, &getSymbolTable(), &removedUniformsCount);
+        }
+
+        if (!rewriteStructSamplersResult)
         {
             return false;
         }
diff --git a/src/compiler/translator/Types.cpp b/src/compiler/translator/Types.cpp
index 879dee5..a6d7ba9 100644
--- a/src/compiler/translator/Types.cpp
+++ b/src/compiler/translator/Types.cpp
@@ -716,6 +716,19 @@
     }
 }
 
+void TType::toArrayBaseType()
+{
+    if (mArraySizes == nullptr)
+    {
+        return;
+    }
+    if (mArraySizes->size() > 0)
+    {
+        mArraySizes->clear();
+    }
+    invalidateMangledName();
+}
+
 void TType::setInterfaceBlock(const TInterfaceBlock *interfaceBlockIn)
 {
     if (mInterfaceBlock != interfaceBlockIn)
diff --git a/src/compiler/translator/Types.h b/src/compiler/translator/Types.h
index be718e4..22a475e 100644
--- a/src/compiler/translator/Types.h
+++ b/src/compiler/translator/Types.h
@@ -214,6 +214,8 @@
 
     // Note that the array element type might still be an array type in GLSL ES version >= 3.10.
     void toArrayElementType();
+    // Removes all array sizes.
+    void toArrayBaseType();
 
     const TInterfaceBlock *getInterfaceBlock() const { return mInterfaceBlock; }
     void setInterfaceBlock(const TInterfaceBlock *interfaceBlockIn);
diff --git a/src/compiler/translator/blocklayout.cpp b/src/compiler/translator/blocklayout.cpp
index 70cf436..019a0e3 100644
--- a/src/compiler/translator/blocklayout.cpp
+++ b/src/compiler/translator/blocklayout.cpp
@@ -406,6 +406,7 @@
         mNameStack.push_back(arrayVar.name);
         mMappedNameStack.push_back(arrayVar.mappedName);
     }
+    mArraySizeStack.push_back(arrayVar.getOutermostArraySize());
 }
 
 void VariableNameVisitor::exitArray(const ShaderVariable &arrayVar)
@@ -415,6 +416,7 @@
         mNameStack.pop_back();
         mMappedNameStack.pop_back();
     }
+    mArraySizeStack.pop_back();
 }
 
 void VariableNameVisitor::enterArrayElement(const ShaderVariable &arrayVar,
@@ -461,7 +463,7 @@
         mMappedNameStack.pop_back();
     }
 
-    visitNamedSampler(sampler, name, mappedName);
+    visitNamedSampler(sampler, name, mappedName, mArraySizeStack);
 }
 
 void VariableNameVisitor::visitVariable(const ShaderVariable &variable, bool isRowMajor)
@@ -481,7 +483,7 @@
         mMappedNameStack.pop_back();
     }
 
-    visitNamedVariable(variable, isRowMajor, name, mappedName);
+    visitNamedVariable(variable, isRowMajor, name, mappedName, mArraySizeStack);
 }
 
 // BlockEncoderVisitor implementation.
@@ -554,7 +556,8 @@
 void BlockEncoderVisitor::visitNamedVariable(const ShaderVariable &variable,
                                              bool isRowMajor,
                                              const std::string &name,
-                                             const std::string &mappedName)
+                                             const std::string &mappedName,
+                                             const std::vector<unsigned int> &arraySizes)
 {
     std::vector<unsigned int> innermostArraySize;
 
diff --git a/src/compiler/translator/blocklayout.h b/src/compiler/translator/blocklayout.h
index 1b38ea5..2003c49 100644
--- a/src/compiler/translator/blocklayout.h
+++ b/src/compiler/translator/blocklayout.h
@@ -230,12 +230,14 @@
   protected:
     virtual void visitNamedSampler(const sh::ShaderVariable &sampler,
                                    const std::string &name,
-                                   const std::string &mappedName)
+                                   const std::string &mappedName,
+                                   const std::vector<unsigned int> &arraySizes)
     {}
     virtual void visitNamedVariable(const ShaderVariable &variable,
                                     bool isRowMajor,
                                     const std::string &name,
-                                    const std::string &mappedName) = 0;
+                                    const std::string &mappedName,
+                                    const std::vector<unsigned int> &arraySizes) = 0;
 
     std::string collapseNameStack() const;
     std::string collapseMappedNameStack() const;
@@ -246,6 +248,7 @@
 
     std::vector<std::string> mNameStack;
     std::vector<std::string> mMappedNameStack;
+    std::vector<unsigned int> mArraySizeStack;
 };
 
 class BlockEncoderVisitor : public VariableNameVisitor
@@ -264,7 +267,8 @@
     void visitNamedVariable(const ShaderVariable &variable,
                             bool isRowMajor,
                             const std::string &name,
-                            const std::string &mappedName) override;
+                            const std::string &mappedName,
+                            const std::vector<unsigned int> &arraySizes) override;
 
     virtual void encodeVariable(const ShaderVariable &variable,
                                 const BlockMemberInfo &variableInfo,
diff --git a/src/compiler/translator/tree_ops/RewriteStructSamplers.cpp b/src/compiler/translator/tree_ops/RewriteStructSamplers.cpp
index 7ca3cce..e862d3d 100644
--- a/src/compiler/translator/tree_ops/RewriteStructSamplers.cpp
+++ b/src/compiler/translator/tree_ops/RewriteStructSamplers.cpp
@@ -9,7 +9,9 @@
 #include "compiler/translator/tree_ops/RewriteStructSamplers.h"
 
 #include "compiler/translator/ImmutableStringBuilder.h"
+#include "compiler/translator/StaticType.h"
 #include "compiler/translator/SymbolTable.h"
+#include "compiler/translator/tree_util/IntermNode_util.h"
 #include "compiler/translator/tree_util/IntermTraverse.h"
 
 namespace sh
@@ -81,10 +83,228 @@
     return nullptr;
 }
 
-// Maximum string size of a hex unsigned int.
-constexpr size_t kHexSize = ImmutableStringBuilder::GetHexCharCount<unsigned int>();
+void GenerateArrayStrides(const std::vector<size_t> &arraySizes,
+                          std::vector<size_t> *arrayStridesOut)
+{
+    auto &strides = *arrayStridesOut;
 
-class Traverser final : public TIntermTraverser
+    ASSERT(strides.empty());
+    strides.reserve(arraySizes.size() + 1);
+
+    size_t currentStride = 1;
+    strides.push_back(1);
+    for (auto it = arraySizes.rbegin(); it != arraySizes.rend(); ++it)
+    {
+        currentStride *= *it;
+        strides.push_back(currentStride);
+    }
+}
+
+// This returns an expression representing the correct index using the array
+// index operations in node.
+static TIntermTyped *GetIndexExpressionFromTypedNode(TIntermTyped *node,
+                                                     const std::vector<size_t> &strides,
+                                                     TIntermTyped *offset)
+{
+    TIntermTyped *result      = offset;
+    TIntermTyped *currentNode = node;
+
+    auto it = strides.end();
+    --it;
+    // If this is being used as an argument, not all indices may be present;
+    // count how many indices are there.
+    while (currentNode->getAsBinaryNode())
+    {
+        TIntermBinary *asBinary = currentNode->getAsBinaryNode();
+
+        switch (asBinary->getOp())
+        {
+            case EOpIndexDirectStruct:
+                break;
+
+            case EOpIndexDirect:
+            case EOpIndexIndirect:
+                --it;
+                break;
+
+            default:
+                UNREACHABLE();
+                break;
+        }
+
+        currentNode = asBinary->getLeft();
+    }
+
+    currentNode = node;
+
+    while (currentNode->getAsBinaryNode())
+    {
+        TIntermBinary *asBinary = currentNode->getAsBinaryNode();
+
+        switch (asBinary->getOp())
+        {
+            case EOpIndexDirectStruct:
+                break;
+
+            case EOpIndexDirect:
+            case EOpIndexIndirect:
+            {
+                TIntermBinary *multiply =
+                    new TIntermBinary(EOpMul, CreateIndexNode(static_cast<int>(*it++)),
+                                      asBinary->getRight()->deepCopy());
+                result = new TIntermBinary(EOpAdd, result, multiply);
+                break;
+            }
+
+            default:
+                UNREACHABLE();
+                break;
+        }
+
+        currentNode = asBinary->getLeft();
+    }
+
+    return result;
+}
+
+// Structures for keeping track of function instantiations.
+
+// An instantiation is keyed by the flattened sizes of the sampler arrays.
+typedef std::vector<size_t> Instantiation;
+
+struct InstantiationHash
+{
+    size_t operator()(const Instantiation &v) const noexcept
+    {
+        std::hash<size_t> hasher;
+        size_t seed = 0;
+        for (size_t x : v)
+        {
+            seed ^= hasher(x) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
+        }
+        return seed;
+    }
+};
+
+// Map from each function to a "set" of instantiations.
+// We store a TFunction for each instantiation as its value.
+typedef std::map<ImmutableString, std::unordered_map<Instantiation, TFunction *, InstantiationHash>>
+    FunctionInstantiations;
+
+typedef std::unordered_map<const TFunction *, const TFunction *> FunctionMap;
+
+// Generates a new function from the given function using the given
+// instantiation; generatedInstantiations can be null.
+TFunction *GenerateFunctionFromArguments(const TFunction *function,
+                                         const TIntermSequence *arguments,
+                                         TSymbolTable *symbolTable,
+                                         FunctionInstantiations *functionInstantiations,
+                                         FunctionMap *functionMap,
+                                         const FunctionInstantiations *generatedInstantiations)
+{
+    // Collect sizes of array arguments.
+    Instantiation instantiation;
+    for (TIntermNode *node : *arguments)
+    {
+        const TType &type = node->getAsTyped()->getType();
+        if (type.isArray() && type.isSampler())
+        {
+            ASSERT(type.getNumArraySizes() == 1);
+            instantiation.push_back((*type.getArraySizes())[0]);
+        }
+    }
+
+    if (generatedInstantiations)
+    {
+        auto it1 = generatedInstantiations->find(function->name());
+        if (it1 != generatedInstantiations->end())
+        {
+            const auto &map = it1->second;
+            auto it2        = map.find(instantiation);
+            if (it2 != map.end())
+            {
+                return it2->second;
+            }
+        }
+    }
+
+    TFunction **newFunction = &(*functionInstantiations)[function->name()][instantiation];
+
+    if (!*newFunction)
+    {
+        *newFunction =
+            new TFunction(symbolTable, kEmptyImmutableString, SymbolType::AngleInternal,
+                          &function->getReturnType(), function->isKnownToNotHaveSideEffects());
+        (*functionMap)[*newFunction] = function;
+        // Insert parameters from updated function.
+        TFunction *updatedFunction = symbolTable->findUserDefinedFunction(function->name());
+        size_t paramCount          = updatedFunction->getParamCount();
+        auto it                    = instantiation.begin();
+        for (size_t paramIndex = 0; paramIndex < paramCount; ++paramIndex)
+        {
+            const TVariable *param = updatedFunction->getParam(paramIndex);
+            const TType &paramType = param->getType();
+            if (paramType.isArray() && paramType.isSampler())
+            {
+                TType *replacementType = new TType(paramType);
+                size_t arraySize       = *it++;
+                replacementType->setArraySize(0, static_cast<unsigned int>(arraySize));
+                param =
+                    new TVariable(symbolTable, param->name(), replacementType, param->symbolType());
+            }
+            (*newFunction)->addParameter(param);
+        }
+    }
+    return *newFunction;
+}
+
+class ArrayTraverser
+{
+  public:
+    ArrayTraverser() { mCumulativeArraySizeStack.push_back(1); }
+
+    void enterArray(const TType &arrayType)
+    {
+        if (!arrayType.isArray())
+            return;
+        size_t currentArraySize = mCumulativeArraySizeStack.back();
+        const auto &arraySizes  = *arrayType.getArraySizes();
+        for (auto it = arraySizes.rbegin(); it != arraySizes.rend(); ++it)
+        {
+            unsigned int arraySize = *it;
+            currentArraySize *= arraySize;
+            mArraySizeStack.push_back(arraySize);
+            mCumulativeArraySizeStack.push_back(currentArraySize);
+        }
+    }
+
+    void exitArray(const TType &arrayType)
+    {
+        if (!arrayType.isArray())
+            return;
+        mArraySizeStack.resize(mArraySizeStack.size() - arrayType.getNumArraySizes());
+        mCumulativeArraySizeStack.resize(mCumulativeArraySizeStack.size() -
+                                         arrayType.getNumArraySizes());
+    }
+
+  protected:
+    std::vector<size_t> mArraySizeStack;
+    // The first element is 1; each successive element is the previous
+    // multiplied by the size of the next nested array in the current sampler.
+    // For example, with sampler2D foo[3][6], we would have {1, 3, 18}.
+    std::vector<size_t> mCumulativeArraySizeStack;
+};
+
+struct VariableExtraData
+{
+    // The value consists of strides, starting from the outermost array.
+    // For example, with sampler2D foo[3][6], we would have {1, 6, 18}.
+    std::unordered_map<const TVariable *, std::vector<size_t>> arrayStrideMap;
+    // For each generated array parameter, holds the offset parameter.
+    std::unordered_map<const TVariable *, const TVariable *> paramOffsetMap;
+};
+
+class Traverser final : public TIntermTraverser, public ArrayTraverser
 {
   public:
     explicit Traverser(TSymbolTable *symbolTable)
@@ -98,7 +318,7 @@
     int removedUniformsCount() const { return mRemovedUniformsCount; }
 
     // Each struct sampler declaration is stripped of its samplers. New uniforms are added for each
-    // stripped struct sampler.
+    // stripped struct sampler. Flattens all arrays, including default uniforms.
     bool visitDeclaration(Visit visit, TIntermDeclaration *decl) override
     {
         if (visit != PreVisit)
@@ -133,6 +353,17 @@
             mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), decl, *newSequence);
         }
 
+        if (type.isSampler() && type.isArray())
+        {
+            TIntermSequence *newSequence = new TIntermSequence;
+            TIntermSymbol *asSymbol      = declarator->getAsSymbolNode();
+            ASSERT(asSymbol);
+            const TVariable &variable = asSymbol->variable();
+            ASSERT(variable.symbolType() != SymbolType::Empty);
+            extractSampler(variable.name(), variable.getType(), newSequence, 0);
+            mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), decl, *newSequence);
+        }
+
         return true;
     }
 
@@ -141,15 +372,41 @@
     {
         if (visit != PreVisit)
             return true;
+        // If the node isn't a sampler or if this isn't the outermost access,
+        // continue.
+        if (!node->getType().isSampler() || node->getType().isArray())
+        {
+            return true;
+        }
 
-        if (node->getOp() == EOpIndexDirectStruct && node->getType().isSampler())
+        if (node->getOp() == EOpIndexDirect || node->getOp() == EOpIndexIndirect ||
+            node->getOp() == EOpIndexDirectStruct)
         {
             ImmutableString newName = GetStructSamplerNameFromTypedNode(node);
             const TVariable *samplerReplacement =
                 static_cast<const TVariable *>(mSymbolTable->findUserDefined(newName));
             ASSERT(samplerReplacement);
 
-            TIntermSymbol *replacement = new TIntermSymbol(samplerReplacement);
+            TIntermTyped *replacement = new TIntermSymbol(samplerReplacement);
+
+            if (replacement->isArray())
+            {
+                // Add in an indirect index if contained in an array
+                const auto &strides = mVariableExtraData.arrayStrideMap[samplerReplacement];
+                ASSERT(!strides.empty());
+                if (strides.size() > 1)
+                {
+                    auto it = mVariableExtraData.paramOffsetMap.find(samplerReplacement);
+
+                    TIntermTyped *offset =
+                        it == mVariableExtraData.paramOffsetMap.end()
+                            ? static_cast<TIntermTyped *>(CreateIndexNode(0))
+                            : static_cast<TIntermTyped *>(new TIntermSymbol(it->second));
+
+                    TIntermTyped *index = GetIndexExpressionFromTypedNode(node, strides, offset);
+                    replacement         = new TIntermBinary(EOpIndexIndirect, replacement, index);
+                }
+            }
 
             queueReplacement(replacement, OriginalNode::IS_DROPPED);
             return true;
@@ -165,7 +422,7 @@
     {
         const TFunction *function = node->getFunction();
 
-        if (!function->hasSamplerInStructParams())
+        if (!function->hasSamplerInStructOrArrayParams())
         {
             return;
         }
@@ -183,7 +440,7 @@
             function = newFunction;
         }
 
-        ASSERT(!function->hasSamplerInStructParams());
+        ASSERT(!function->hasSamplerInStructOrArrayOfArrayParams());
         TIntermFunctionPrototype *newProto = new TIntermFunctionPrototype(function);
         queueReplacement(newProto, OriginalNode::IS_DROPPED);
     }
@@ -213,19 +470,28 @@
             return true;
 
         const TFunction *function = node->getFunction();
-        if (!function->hasSamplerInStructParams())
+        if (!function->hasSamplerInStructOrArrayParams())
             return true;
 
         ASSERT(node->getOp() == EOpCallFunctionInAST);
-        TFunction *newFunction        = mSymbolTable->findUserDefinedFunction(function->name());
         TIntermSequence *newArguments = getStructSamplerArguments(function, node->getSequence());
 
+        TFunction *newFunction = GenerateFunctionFromArguments(
+            function, newArguments, mSymbolTable, &mFunctionInstantiations, &mFunctionMap, nullptr);
+
         TIntermAggregate *newCall =
             TIntermAggregate::CreateFunctionCall(*newFunction, newArguments);
         queueReplacement(newCall, OriginalNode::IS_DROPPED);
         return true;
     }
 
+    FunctionInstantiations *getFunctionInstantiations() { return &mFunctionInstantiations; }
+
+    std::unordered_map<const TFunction *, const TFunction *> *getFunctionMap()
+    {
+        return &mFunctionMap;
+    }
+
   private:
     // This returns the name of a struct sampler reference. References are always TIntermBinary.
     static ImmutableString GetStructSamplerNameFromTypedNode(TIntermTyped *node)
@@ -239,14 +505,6 @@
 
             switch (asBinary->getOp())
             {
-                case EOpIndexDirect:
-                {
-                    const int index = asBinary->getRight()->getAsConstantUnion()->getIConst(0);
-                    const std::string strInt = Str(index);
-                    stringBuilder.insert(0, strInt);
-                    stringBuilder.insert(0, "_");
-                    break;
-                }
                 case EOpIndexDirectStruct:
                 {
                     stringBuilder.insert(0, asBinary->getIndexStructFieldName().data());
@@ -254,6 +512,10 @@
                     break;
                 }
 
+                case EOpIndexDirect:
+                case EOpIndexIndirect:
+                    break;
+
                 default:
                     UNREACHABLE();
                     break;
@@ -344,6 +606,8 @@
 
         size_t nonSamplerCount = 0;
 
+        enterArray(variable.getType());
+
         for (const TField *field : structure->fields())
         {
             nonSamplerCount +=
@@ -359,6 +623,8 @@
         {
             mRemovedUniformsCount++;
         }
+
+        exitArray(variable.getType());
     }
 
     // Extracts samplers from a field of a struct. Works with nested structs and arrays.
@@ -367,23 +633,6 @@
                                 const TType &containingType,
                                 TIntermSequence *newSequence)
     {
-        if (containingType.isArray())
-        {
-            size_t nonSamplerCount = 0;
-
-            // Name the samplers internally as varName_<index>_fieldName
-            const TVector<unsigned int> &arraySizes = *containingType.getArraySizes();
-            for (unsigned int arrayElement = 0; arrayElement < arraySizes[0]; ++arrayElement)
-            {
-                ImmutableStringBuilder stringBuilder(prefix.length() + kHexSize + 1);
-                stringBuilder << prefix << "_";
-                stringBuilder.appendHex(arrayElement);
-                nonSamplerCount = extractFieldSamplersImpl(stringBuilder, field, newSequence);
-            }
-
-            return nonSamplerCount;
-        }
-
         return extractFieldSamplersImpl(prefix, field, newSequence);
     }
 
@@ -403,16 +652,18 @@
 
             if (fieldType.isSampler())
             {
-                extractSampler(newPrefix, fieldType, newSequence);
+                extractSampler(newPrefix, fieldType, newSequence, 0);
             }
             else
             {
+                enterArray(fieldType);
                 const TStructure *structure = fieldType.getStruct();
                 for (const TField *nestedField : structure->fields())
                 {
                     nonSamplerCount +=
                         extractFieldSamplers(newPrefix, nestedField, fieldType, newSequence);
                 }
+                exitArray(fieldType);
             }
         }
         else
@@ -426,9 +677,20 @@
     // Extracts a sampler from a struct. Declares the new extracted sampler.
     void extractSampler(const ImmutableString &newName,
                         const TType &fieldType,
-                        TIntermSequence *newSequence) const
+                        TIntermSequence *newSequence,
+                        size_t arrayLevel)
     {
+        enterArray(fieldType);
+
         TType *newType = new TType(fieldType);
+        while (newType->isArray())
+        {
+            newType->toArrayElementType();
+        }
+        if (!mArraySizeStack.empty())
+        {
+            newType->makeArray(static_cast<unsigned int>(mCumulativeArraySizeStack.back()));
+        }
         newType->setQualifier(EvqUniform);
         TVariable *newVariable =
             new TVariable(mSymbolTable, newName, newType, SymbolType::AngleInternal);
@@ -440,22 +702,17 @@
         newSequence->push_back(samplerDecl);
 
         mSymbolTable->declareInternal(newVariable);
+
+        GenerateArrayStrides(mArraySizeStack, &mVariableExtraData.arrayStrideMap[newVariable]);
+
+        exitArray(fieldType);
     }
 
     // Returns the chained name of a sampler uniform field.
-    static ImmutableString GetFieldName(const ImmutableString &paramName,
-                                        const TField *field,
-                                        unsigned arrayIndex)
+    static ImmutableString GetFieldName(const ImmutableString &paramName, const TField *field)
     {
-        ImmutableStringBuilder nameBuilder(paramName.length() + kHexSize + 2 +
-                                           field->name().length());
+        ImmutableStringBuilder nameBuilder(paramName.length() + 1 + field->name().length());
         nameBuilder << paramName << "_";
-
-        if (arrayIndex < std::numeric_limits<unsigned>::max())
-        {
-            nameBuilder.appendHex(arrayIndex);
-            nameBuilder << "_";
-        }
         nameBuilder << field->name();
 
         return nameBuilder;
@@ -463,7 +720,7 @@
 
     // A pattern that visits every parameter of a function call. Uses different handlers for struct
     // parameters, struct sampler parameters, and non-struct parameters.
-    class StructSamplerFunctionVisitor : angle::NonCopyable
+    class StructSamplerFunctionVisitor : angle::NonCopyable, public ArrayTraverser
     {
       public:
         StructSamplerFunctionVisitor()          = default;
@@ -481,11 +738,16 @@
                 if (paramType.isStructureContainingSamplers())
                 {
                     const ImmutableString &baseName = getNameFromIndex(function, paramIndex);
-                    if (traverseStructContainingSamplers(baseName, paramType))
+                    if (traverseStructContainingSamplers(baseName, paramType, paramIndex))
                     {
                         visitStructParam(function, paramIndex);
                     }
                 }
+                else if (paramType.isArray() && paramType.isSampler())
+                {
+                    const ImmutableString &paramName = getNameFromIndex(function, paramIndex);
+                    traverseLeafSampler(paramName, paramType, paramIndex);
+                }
                 else
                 {
                     visitNonStructParam(function, paramIndex);
@@ -494,22 +756,26 @@
         }
 
         virtual ImmutableString getNameFromIndex(const TFunction *function, size_t paramIndex) = 0;
+        // Also includes samplers in arrays of arrays.
         virtual void visitSamplerInStructParam(const ImmutableString &name,
-                                               const TField *field)                            = 0;
+                                               const TType *type,
+                                               size_t paramIndex)                              = 0;
         virtual void visitStructParam(const TFunction *function, size_t paramIndex)            = 0;
         virtual void visitNonStructParam(const TFunction *function, size_t paramIndex)         = 0;
 
       private:
         bool traverseStructContainingSamplers(const ImmutableString &baseName,
-                                              const TType &structType)
+                                              const TType &structType,
+                                              size_t paramIndex)
         {
             bool hasNonSamplerFields    = false;
             const TStructure *structure = structType.getStruct();
+            enterArray(structType);
             for (const TField *field : structure->fields())
             {
                 if (field->type()->isStructureContainingSamplers() || field->type()->isSampler())
                 {
-                    if (traverseSamplerInStruct(baseName, structType, field))
+                    if (traverseSamplerInStruct(baseName, structType, field, paramIndex))
                     {
                         hasNonSamplerFields = true;
                     }
@@ -519,54 +785,42 @@
                     hasNonSamplerFields = true;
                 }
             }
+            exitArray(structType);
             return hasNonSamplerFields;
         }
 
         bool traverseSamplerInStruct(const ImmutableString &baseName,
                                      const TType &baseType,
-                                     const TField *field)
+                                     const TField *field,
+                                     size_t paramIndex)
         {
             bool hasNonSamplerParams = false;
 
-            if (baseType.isArray())
+            if (field->type()->isStructureContainingSamplers())
             {
-                const TVector<unsigned int> &arraySizes = *baseType.getArraySizes();
-                ASSERT(arraySizes.size() == 1);
-
-                for (unsigned int arrayIndex = 0; arrayIndex < arraySizes[0]; ++arrayIndex)
-                {
-                    ImmutableString name = GetFieldName(baseName, field, arrayIndex);
-
-                    if (field->type()->isStructureContainingSamplers())
-                    {
-                        if (traverseStructContainingSamplers(name, *field->type()))
-                        {
-                            hasNonSamplerParams = true;
-                        }
-                    }
-                    else
-                    {
-                        ASSERT(field->type()->isSampler());
-                        visitSamplerInStructParam(name, field);
-                    }
-                }
-            }
-            else if (field->type()->isStructureContainingSamplers())
-            {
-                ImmutableString name =
-                    GetFieldName(baseName, field, std::numeric_limits<unsigned>::max());
-                hasNonSamplerParams = traverseStructContainingSamplers(name, *field->type());
+                ImmutableString name = GetFieldName(baseName, field);
+                hasNonSamplerParams =
+                    traverseStructContainingSamplers(name, *field->type(), paramIndex);
             }
             else
             {
                 ASSERT(field->type()->isSampler());
-                ImmutableString name =
-                    GetFieldName(baseName, field, std::numeric_limits<unsigned>::max());
-                visitSamplerInStructParam(name, field);
+                ImmutableString name = GetFieldName(baseName, field);
+                traverseLeafSampler(name, *field->type(), paramIndex);
             }
 
             return hasNonSamplerParams;
         }
+
+        void traverseLeafSampler(const ImmutableString &samplerName,
+                                 const TType &samplerType,
+                                 size_t paramIndex)
+        {
+            enterArray(samplerType);
+            visitSamplerInStructParam(samplerName, &samplerType, paramIndex);
+            exitArray(samplerType);
+            return;
+        }
     };
 
     // A visitor that replaces functions with struct sampler references. The struct sampler
@@ -574,8 +828,8 @@
     class CreateStructSamplerFunctionVisitor final : public StructSamplerFunctionVisitor
     {
       public:
-        CreateStructSamplerFunctionVisitor(TSymbolTable *symbolTable)
-            : mSymbolTable(symbolTable), mNewFunction(nullptr)
+        CreateStructSamplerFunctionVisitor(TSymbolTable *symbolTable, VariableExtraData *extraData)
+            : mSymbolTable(symbolTable), mNewFunction(nullptr), mExtraData(extraData)
         {}
 
         ImmutableString getNameFromIndex(const TFunction *function, size_t paramIndex) override
@@ -593,12 +847,31 @@
             StructSamplerFunctionVisitor::traverse(function);
         }
 
-        void visitSamplerInStructParam(const ImmutableString &name, const TField *field) override
+        void visitSamplerInStructParam(const ImmutableString &name,
+                                       const TType *type,
+                                       size_t paramIndex) override
         {
+            if (mArraySizeStack.size() > 0)
+            {
+                TType *newType = new TType(*type);
+                newType->toArrayBaseType();
+                newType->makeArray(static_cast<unsigned int>(mCumulativeArraySizeStack.back()));
+                type = newType;
+            }
             TVariable *fieldSampler =
-                new TVariable(mSymbolTable, name, field->type(), SymbolType::AngleInternal);
+                new TVariable(mSymbolTable, name, type, SymbolType::AngleInternal);
             mNewFunction->addParameter(fieldSampler);
             mSymbolTable->declareInternal(fieldSampler);
+            if (mArraySizeStack.size() > 0)
+            {
+                // Also declare an offset parameter.
+                const TType *intType     = StaticType::GetBasic<EbtInt>();
+                TVariable *samplerOffset = new TVariable(mSymbolTable, kEmptyImmutableString,
+                                                         intType, SymbolType::AngleInternal);
+                mNewFunction->addParameter(samplerOffset);
+                GenerateArrayStrides(mArraySizeStack, &mExtraData->arrayStrideMap[fieldSampler]);
+                mExtraData->paramOffsetMap[fieldSampler] = samplerOffset;
+            }
         }
 
         void visitStructParam(const TFunction *function, size_t paramIndex) override
@@ -621,11 +894,12 @@
       private:
         TSymbolTable *mSymbolTable;
         TFunction *mNewFunction;
+        VariableExtraData *mExtraData;
     };
 
-    TFunction *createStructSamplerFunction(const TFunction *function) const
+    TFunction *createStructSamplerFunction(const TFunction *function)
     {
-        CreateStructSamplerFunctionVisitor visitor(mSymbolTable);
+        CreateStructSamplerFunctionVisitor visitor(mSymbolTable, &mVariableExtraData);
         visitor.traverse(function);
         return visitor.getNewFunction();
     }
@@ -634,8 +908,13 @@
     class GetSamplerArgumentsVisitor final : public StructSamplerFunctionVisitor
     {
       public:
-        GetSamplerArgumentsVisitor(TSymbolTable *symbolTable, const TIntermSequence *arguments)
-            : mSymbolTable(symbolTable), mArguments(arguments), mNewArguments(new TIntermSequence)
+        GetSamplerArgumentsVisitor(TSymbolTable *symbolTable,
+                                   const TIntermSequence *arguments,
+                                   VariableExtraData *extraData)
+            : mSymbolTable(symbolTable),
+              mArguments(arguments),
+              mNewArguments(new TIntermSequence),
+              mExtraData(extraData)
         {}
 
         ImmutableString getNameFromIndex(const TFunction *function, size_t paramIndex) override
@@ -644,12 +923,49 @@
             return GetStructSamplerNameFromTypedNode(argument);
         }
 
-        void visitSamplerInStructParam(const ImmutableString &name, const TField *field) override
+        void visitSamplerInStructParam(const ImmutableString &name,
+                                       const TType *type,
+                                       size_t paramIndex) override
         {
-            TVariable *argSampler =
-                new TVariable(mSymbolTable, name, field->type(), SymbolType::AngleInternal);
+            const TVariable *argSampler =
+                static_cast<const TVariable *>(mSymbolTable->findUserDefined(name));
+            ASSERT(argSampler);
+
+            TIntermTyped *argument = (*mArguments)[paramIndex]->getAsTyped();
+
+            auto it = mExtraData->paramOffsetMap.find(argSampler);
+            TIntermTyped *argOffset =
+                it == mExtraData->paramOffsetMap.end()
+                    ? static_cast<TIntermTyped *>(CreateIndexNode(0))
+                    : static_cast<TIntermTyped *>(new TIntermSymbol(it->second));
+
+            TIntermTyped *finalOffset = GetIndexExpressionFromTypedNode(
+                argument, mExtraData->arrayStrideMap[argSampler], argOffset);
+
             TIntermSymbol *argSymbol = new TIntermSymbol(argSampler);
+
+            // If we have a regular sampler inside a struct (possibly an array
+            // of structs), handle this case separately.
+            if (!type->isArray() && mArraySizeStack.size() == 0)
+            {
+                if (argSampler->getType().isArray())
+                {
+                    TIntermTyped *argIndex =
+                        new TIntermBinary(EOpIndexIndirect, argSymbol, finalOffset);
+                    mNewArguments->push_back(argIndex);
+                }
+                else
+                {
+                    mNewArguments->push_back(argSymbol);
+                }
+                return;
+            }
+
             mNewArguments->push_back(argSymbol);
+
+            mNewArguments->push_back(finalOffset);
+            // If array, we need to calculate the offset based on what indices
+            // are present in the argument.
         }
 
         void visitStructParam(const TFunction *function, size_t paramIndex) override
@@ -673,18 +989,203 @@
         TSymbolTable *mSymbolTable;
         const TIntermSequence *mArguments;
         TIntermSequence *mNewArguments;
+        VariableExtraData *mExtraData;
     };
 
     TIntermSequence *getStructSamplerArguments(const TFunction *function,
-                                               const TIntermSequence *arguments) const
+                                               const TIntermSequence *arguments)
     {
-        GetSamplerArgumentsVisitor visitor(mSymbolTable, arguments);
+        GetSamplerArgumentsVisitor visitor(mSymbolTable, arguments, &mVariableExtraData);
         visitor.traverse(function);
         return visitor.getNewArguments();
     }
 
     int mRemovedUniformsCount;
     std::set<ImmutableString> mRemovedStructs;
+    FunctionInstantiations mFunctionInstantiations;
+    FunctionMap mFunctionMap;
+    VariableExtraData mVariableExtraData;
+};
+
+class MonomorphizeTraverser final : public TIntermTraverser
+{
+  public:
+    typedef std::unordered_map<const TVariable *, const TVariable *> VariableReplacementMap;
+
+    explicit MonomorphizeTraverser(
+        TCompiler *compiler,
+        TSymbolTable *symbolTable,
+        FunctionInstantiations *functionInstantiations,
+        std::unordered_map<const TFunction *, const TFunction *> *functionMap)
+        : TIntermTraverser(true, false, true, symbolTable),
+          mFunctionInstantiations(*functionInstantiations),
+          mFunctionMap(functionMap),
+          mCompiler(compiler),
+          mSubpassesSucceeded(true)
+    {}
+
+    void switchToPending()
+    {
+        mFunctionInstantiations.clear();
+        mFunctionInstantiations.swap(mPendingInstantiations);
+    }
+
+    bool hasPending()
+    {
+        if (mPendingInstantiations.empty())
+            return false;
+        for (auto &entry : mPendingInstantiations)
+        {
+            if (!entry.second.empty())
+            {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    bool subpassesSucceeded() { return mSubpassesSucceeded; }
+
+    void visitFunctionPrototype(TIntermFunctionPrototype *node) override
+    {
+        mReplacementPrototypes.clear();
+        const TFunction *function = node->getFunction();
+
+        auto &generatedMap = mGeneratedInstantiations[function->name()];
+
+        auto it = mFunctionInstantiations.find(function->name());
+        if (it == mFunctionInstantiations.end())
+            return;
+        for (const auto &instantiation : it->second)
+        {
+            TFunction *replacementFunction = instantiation.second;
+            mReplacementPrototypes.push_back(new TIntermFunctionPrototype(replacementFunction));
+            generatedMap[instantiation.first] = replacementFunction;
+        }
+        if (!mInFunctionDefinition)
+        {
+            insertStatementsInParentBlock(mReplacementPrototypes);
+        }
+    }
+
+    bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
+    {
+        mInFunctionDefinition = visit == PreVisit;
+        if (visit != PostVisit)
+            return true;
+        TIntermSequence replacements;
+        const TFunction *function = node->getFunction();
+        size_t numParameters      = function->getParamCount();
+
+        for (TIntermNode *replacementNode : mReplacementPrototypes)
+        {
+            TIntermFunctionPrototype *replacementPrototype =
+                replacementNode->getAsFunctionPrototypeNode();
+            const TFunction *replacementFunction = replacementPrototype->getFunction();
+
+            // Replace function parameters with correct array sizes.
+            VariableReplacementMap variableReplacementMap;
+            ASSERT(replacementPrototype->getFunction()->getParamCount() == numParameters);
+            for (size_t i = 0; i < numParameters; i++)
+            {
+                const TVariable *origParam = function->getParam(i);
+                const TVariable *newParam  = replacementFunction->getParam(i);
+                if (origParam != newParam)
+                {
+                    variableReplacementMap[origParam] = newParam;
+                }
+            }
+
+            TIntermBlock *body = node->getBody()->deepCopy();
+            ReplaceVariablesTraverser replaceVariables(mSymbolTable, &variableReplacementMap);
+            body->traverse(&replaceVariables);
+            mSubpassesSucceeded &= replaceVariables.updateTree(mCompiler, body);
+            CollectNewInstantiationsTraverser collectNewInstantiations(
+                mSymbolTable, &mPendingInstantiations, &mGeneratedInstantiations, mFunctionMap);
+            body->traverse(&collectNewInstantiations);
+            mSubpassesSucceeded &= collectNewInstantiations.updateTree(mCompiler, body);
+            replacements.push_back(new TIntermFunctionDefinition(replacementPrototype, body));
+        }
+        insertStatementsInParentBlock(replacements);
+        return true;
+    }
+
+  private:
+    bool mInFunctionDefinition;
+    FunctionInstantiations mFunctionInstantiations;
+    // Set of already-generated instantiations.
+    FunctionInstantiations mGeneratedInstantiations;
+    // New instantiations caused by other instantiations.
+    FunctionInstantiations mPendingInstantiations;
+    std::unordered_map<const TFunction *, const TFunction *> *mFunctionMap;
+    TIntermSequence mReplacementPrototypes;
+    TCompiler *mCompiler;
+    bool mSubpassesSucceeded;
+
+    class ReplaceVariablesTraverser : public TIntermTraverser
+    {
+      public:
+        explicit ReplaceVariablesTraverser(TSymbolTable *symbolTable,
+                                           VariableReplacementMap *variableReplacementMap)
+            : TIntermTraverser(true, false, false, symbolTable),
+              mVariableReplacementMap(variableReplacementMap)
+        {}
+
+        void visitSymbol(TIntermSymbol *node) override
+        {
+            const TVariable *variable = &node->variable();
+            auto it                   = mVariableReplacementMap->find(variable);
+            if (it != mVariableReplacementMap->end())
+            {
+                queueReplacement(new TIntermSymbol(it->second), OriginalNode::IS_DROPPED);
+            }
+        }
+
+      private:
+        VariableReplacementMap *mVariableReplacementMap;
+    };
+
+    class CollectNewInstantiationsTraverser : public TIntermTraverser
+    {
+      public:
+        explicit CollectNewInstantiationsTraverser(
+            TSymbolTable *symbolTable,
+            FunctionInstantiations *pendingInstantiations,
+            FunctionInstantiations *generatedInstantiations,
+            std::unordered_map<const TFunction *, const TFunction *> *functionMap)
+            : TIntermTraverser(true, false, false, symbolTable),
+              mPendingInstantiations(pendingInstantiations),
+              mGeneratedInstantiations(generatedInstantiations),
+              mFunctionMap(functionMap)
+        {}
+
+        bool visitAggregate(Visit visit, TIntermAggregate *node) override
+        {
+            if (!node->isFunctionCall())
+                return true;
+            const TFunction *function = node->getFunction();
+            const TFunction *oldFunction;
+            {
+                auto it = mFunctionMap->find(function);
+                if (it == mFunctionMap->end())
+                    return true;
+                oldFunction = it->second;
+            }
+            ASSERT(node->getOp() == EOpCallFunctionInAST);
+            TIntermSequence *arguments = node->getSequence();
+            TFunction *newFunction     = GenerateFunctionFromArguments(
+                oldFunction, arguments, mSymbolTable, mPendingInstantiations, mFunctionMap,
+                mGeneratedInstantiations);
+            queueReplacement(TIntermAggregate::CreateFunctionCall(*newFunction, arguments),
+                             OriginalNode::IS_DROPPED);
+            return true;
+        }
+
+      private:
+        FunctionInstantiations *mPendingInstantiations;
+        FunctionInstantiations *mGeneratedInstantiations;
+        std::unordered_map<const TFunction *, const TFunction *> *mFunctionMap;
+    };
 };
 }  // anonymous namespace
 
@@ -695,8 +1196,45 @@
 {
     Traverser rewriteStructSamplers(symbolTable);
     root->traverse(&rewriteStructSamplers);
+    if (!rewriteStructSamplers.updateTree(compiler, root))
+    {
+        return false;
+    }
     *removedUniformsCountOut = rewriteStructSamplers.removedUniformsCount();
 
-    return rewriteStructSamplers.updateTree(compiler, root);
+    if (rewriteStructSamplers.getFunctionInstantiations()->empty())
+    {
+        return true;
+    }
+
+    MonomorphizeTraverser monomorphizeFunctions(compiler, symbolTable,
+                                                rewriteStructSamplers.getFunctionInstantiations(),
+                                                rewriteStructSamplers.getFunctionMap());
+    root->traverse(&monomorphizeFunctions);
+    if (!monomorphizeFunctions.subpassesSucceeded())
+    {
+        return false;
+    }
+    if (!monomorphizeFunctions.updateTree(compiler, root))
+    {
+        return false;
+    }
+
+    // Generate instantiations caused by other instantiations.
+    while (monomorphizeFunctions.hasPending())
+    {
+        monomorphizeFunctions.switchToPending();
+        root->traverse(&monomorphizeFunctions);
+        if (!monomorphizeFunctions.subpassesSucceeded())
+        {
+            return false;
+        }
+        if (!monomorphizeFunctions.updateTree(compiler, root))
+        {
+            return false;
+        }
+    }
+
+    return true;
 }
 }  // namespace sh
diff --git a/src/compiler/translator/tree_ops/RewriteStructSamplers.h b/src/compiler/translator/tree_ops/RewriteStructSamplers.h
index 62ba6dd..a7c0ccc 100644
--- a/src/compiler/translator/tree_ops/RewriteStructSamplers.h
+++ b/src/compiler/translator/tree_ops/RewriteStructSamplers.h
@@ -33,6 +33,10 @@
                                             TIntermBlock *root,
                                             TSymbolTable *symbolTable,
                                             int *removedUniformsCountOut);
+ANGLE_NO_DISCARD bool RewriteStructSamplersOld(TCompiler *compier,
+                                               TIntermBlock *root,
+                                               TSymbolTable *symbolTable,
+                                               int *removedUniformsCountOut);
 }  // namespace sh
 
 #endif  // COMPILER_TRANSLATOR_TREEOPS_REWRITESTRUCTSAMPLERS_H_
diff --git a/src/compiler/translator/tree_ops/RewriteStructSamplersOld.cpp b/src/compiler/translator/tree_ops/RewriteStructSamplersOld.cpp
new file mode 100644
index 0000000..5d54a19
--- /dev/null
+++ b/src/compiler/translator/tree_ops/RewriteStructSamplersOld.cpp
@@ -0,0 +1,705 @@
+//
+// Copyright 2018 The ANGLE Project Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// RewriteStructSamplers: Extract structs from samplers.
+//
+
+#include "compiler/translator/tree_ops/RewriteStructSamplers.h"
+
+#include "compiler/translator/ImmutableStringBuilder.h"
+#include "compiler/translator/SymbolTable.h"
+#include "compiler/translator/tree_util/IntermTraverse.h"
+
+namespace sh
+{
+namespace
+{
+// Helper method to get the sampler extracted struct type of a parameter.
+TType *GetStructSamplerParameterType(TSymbolTable *symbolTable, const TVariable &param)
+{
+    const TStructure *structure = param.getType().getStruct();
+    const TSymbol *structSymbol = symbolTable->findUserDefined(structure->name());
+    ASSERT(structSymbol && structSymbol->isStruct());
+    const TStructure *structVar = static_cast<const TStructure *>(structSymbol);
+    TType *structType           = new TType(structVar, false);
+
+    if (param.getType().isArray())
+    {
+        structType->makeArrays(*param.getType().getArraySizes());
+    }
+
+    ASSERT(!structType->isStructureContainingSamplers());
+
+    return structType;
+}
+
+TIntermSymbol *ReplaceTypeOfSymbolNode(TIntermSymbol *symbolNode, TSymbolTable *symbolTable)
+{
+    const TVariable &oldVariable = symbolNode->variable();
+
+    TType *newType = GetStructSamplerParameterType(symbolTable, oldVariable);
+
+    TVariable *newVariable =
+        new TVariable(oldVariable.uniqueId(), oldVariable.name(), oldVariable.symbolType(),
+                      oldVariable.extension(), newType);
+    return new TIntermSymbol(newVariable);
+}
+
+TIntermTyped *ReplaceTypeOfTypedStructNode(TIntermTyped *argument, TSymbolTable *symbolTable)
+{
+    TIntermSymbol *asSymbol = argument->getAsSymbolNode();
+    if (asSymbol)
+    {
+        ASSERT(asSymbol->getType().getStruct());
+        return ReplaceTypeOfSymbolNode(asSymbol, symbolTable);
+    }
+
+    TIntermTyped *replacement = argument->deepCopy();
+    TIntermBinary *binary     = replacement->getAsBinaryNode();
+    ASSERT(binary);
+
+    while (binary)
+    {
+        ASSERT(binary->getOp() == EOpIndexDirectStruct || binary->getOp() == EOpIndexDirect);
+
+        asSymbol = binary->getLeft()->getAsSymbolNode();
+
+        if (asSymbol)
+        {
+            ASSERT(asSymbol->getType().getStruct());
+            TIntermSymbol *newSymbol = ReplaceTypeOfSymbolNode(asSymbol, symbolTable);
+            binary->replaceChildNode(binary->getLeft(), newSymbol);
+            return replacement;
+        }
+
+        binary = binary->getLeft()->getAsBinaryNode();
+    }
+
+    UNREACHABLE();
+    return nullptr;
+}
+
+// Maximum string size of a hex unsigned int.
+constexpr size_t kHexSize = ImmutableStringBuilder::GetHexCharCount<unsigned int>();
+
+class Traverser final : public TIntermTraverser
+{
+  public:
+    explicit Traverser(TSymbolTable *symbolTable)
+        : TIntermTraverser(true, false, true, symbolTable), mRemovedUniformsCount(0)
+    {
+        mSymbolTable->push();
+    }
+
+    ~Traverser() override { mSymbolTable->pop(); }
+
+    int removedUniformsCount() const { return mRemovedUniformsCount; }
+
+    // Each struct sampler declaration is stripped of its samplers. New uniforms are added for each
+    // stripped struct sampler.
+    bool visitDeclaration(Visit visit, TIntermDeclaration *decl) override
+    {
+        if (visit != PreVisit)
+            return true;
+
+        if (!mInGlobalScope)
+        {
+            return true;
+        }
+
+        const TIntermSequence &sequence = *(decl->getSequence());
+        TIntermTyped *declarator        = sequence.front()->getAsTyped();
+        const TType &type               = declarator->getType();
+
+        if (type.isStructureContainingSamplers())
+        {
+            TIntermSequence *newSequence = new TIntermSequence;
+
+            if (type.isStructSpecifier())
+            {
+                stripStructSpecifierSamplers(type.getStruct(), newSequence);
+            }
+            else
+            {
+                TIntermSymbol *asSymbol = declarator->getAsSymbolNode();
+                ASSERT(asSymbol);
+                const TVariable &variable = asSymbol->variable();
+                ASSERT(variable.symbolType() != SymbolType::Empty);
+                extractStructSamplerUniforms(decl, variable, type.getStruct(), newSequence);
+            }
+
+            mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), decl, *newSequence);
+        }
+
+        return true;
+    }
+
+    // Each struct sampler reference is replaced with a reference to the new extracted sampler.
+    bool visitBinary(Visit visit, TIntermBinary *node) override
+    {
+        if (visit != PreVisit)
+            return true;
+
+        if (node->getOp() == EOpIndexDirectStruct && node->getType().isSampler())
+        {
+            ImmutableString newName = GetStructSamplerNameFromTypedNode(node);
+            const TVariable *samplerReplacement =
+                static_cast<const TVariable *>(mSymbolTable->findUserDefined(newName));
+            ASSERT(samplerReplacement);
+
+            TIntermSymbol *replacement = new TIntermSymbol(samplerReplacement);
+
+            queueReplacement(replacement, OriginalNode::IS_DROPPED);
+            return true;
+        }
+
+        return true;
+    }
+
+    // In we are passing references to structs containing samplers we must new additional
+    // arguments. For each extracted struct sampler a new argument is added. This chains to nested
+    // structs.
+    void visitFunctionPrototype(TIntermFunctionPrototype *node) override
+    {
+        const TFunction *function = node->getFunction();
+
+        if (!function->hasSamplerInStructOrArrayOfArrayParams())
+        {
+            return;
+        }
+
+        const TSymbol *foundFunction = mSymbolTable->findUserDefined(function->name());
+        if (foundFunction)
+        {
+            ASSERT(foundFunction->isFunction());
+            function = static_cast<const TFunction *>(foundFunction);
+        }
+        else
+        {
+            TFunction *newFunction = createStructSamplerFunction(function);
+            mSymbolTable->declareUserDefinedFunction(newFunction, true);
+            function = newFunction;
+        }
+
+        ASSERT(!function->hasSamplerInStructOrArrayOfArrayParams());
+        TIntermFunctionPrototype *newProto = new TIntermFunctionPrototype(function);
+        queueReplacement(newProto, OriginalNode::IS_DROPPED);
+    }
+
+    // We insert a new scope for each function definition so we can track the new parameters.
+    bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
+    {
+        if (visit == PreVisit)
+        {
+            mSymbolTable->push();
+        }
+        else
+        {
+            ASSERT(visit == PostVisit);
+            mSymbolTable->pop();
+        }
+        return true;
+    }
+
+    // For function call nodes we pass references to the extracted struct samplers in that scope.
+    bool visitAggregate(Visit visit, TIntermAggregate *node) override
+    {
+        if (visit != PreVisit)
+            return true;
+
+        if (!node->isFunctionCall())
+            return true;
+
+        const TFunction *function = node->getFunction();
+        if (!function->hasSamplerInStructOrArrayOfArrayParams())
+            return true;
+
+        ASSERT(node->getOp() == EOpCallFunctionInAST);
+        TFunction *newFunction        = mSymbolTable->findUserDefinedFunction(function->name());
+        TIntermSequence *newArguments = getStructSamplerArguments(function, node->getSequence());
+
+        TIntermAggregate *newCall =
+            TIntermAggregate::CreateFunctionCall(*newFunction, newArguments);
+        queueReplacement(newCall, OriginalNode::IS_DROPPED);
+        return true;
+    }
+
+  private:
+    // This returns the name of a struct sampler reference. References are always TIntermBinary.
+    static ImmutableString GetStructSamplerNameFromTypedNode(TIntermTyped *node)
+    {
+        std::string stringBuilder;
+
+        TIntermTyped *currentNode = node;
+        while (currentNode->getAsBinaryNode())
+        {
+            TIntermBinary *asBinary = currentNode->getAsBinaryNode();
+
+            switch (asBinary->getOp())
+            {
+                case EOpIndexDirect:
+                {
+                    const int index = asBinary->getRight()->getAsConstantUnion()->getIConst(0);
+                    const std::string strInt = Str(index);
+                    stringBuilder.insert(0, strInt);
+                    stringBuilder.insert(0, "_");
+                    break;
+                }
+                case EOpIndexDirectStruct:
+                {
+                    stringBuilder.insert(0, asBinary->getIndexStructFieldName().data());
+                    stringBuilder.insert(0, "_");
+                    break;
+                }
+
+                default:
+                    UNREACHABLE();
+                    break;
+            }
+
+            currentNode = asBinary->getLeft();
+        }
+
+        const ImmutableString &variableName = currentNode->getAsSymbolNode()->variable().name();
+        stringBuilder.insert(0, variableName.data());
+
+        return stringBuilder;
+    }
+
+    // Removes all the struct samplers from a struct specifier.
+    void stripStructSpecifierSamplers(const TStructure *structure, TIntermSequence *newSequence)
+    {
+        TFieldList *newFieldList = new TFieldList;
+        ASSERT(structure->containsSamplers());
+
+        for (const TField *field : structure->fields())
+        {
+            const TType &fieldType = *field->type();
+            if (!fieldType.isSampler() && !isRemovedStructType(fieldType))
+            {
+                TType *newType = nullptr;
+
+                if (fieldType.isStructureContainingSamplers())
+                {
+                    const TSymbol *structSymbol =
+                        mSymbolTable->findUserDefined(fieldType.getStruct()->name());
+                    ASSERT(structSymbol && structSymbol->isStruct());
+                    const TStructure *fieldStruct = static_cast<const TStructure *>(structSymbol);
+                    newType                       = new TType(fieldStruct, true);
+                    if (fieldType.isArray())
+                    {
+                        newType->makeArrays(*fieldType.getArraySizes());
+                    }
+                }
+                else
+                {
+                    newType = new TType(fieldType);
+                }
+
+                TField *newField =
+                    new TField(newType, field->name(), field->line(), field->symbolType());
+                newFieldList->push_back(newField);
+            }
+        }
+
+        // Prune empty structs.
+        if (newFieldList->empty())
+        {
+            mRemovedStructs.insert(structure->name());
+            return;
+        }
+
+        TStructure *newStruct =
+            new TStructure(mSymbolTable, structure->name(), newFieldList, structure->symbolType());
+        TType *newStructType = new TType(newStruct, true);
+        TVariable *newStructVar =
+            new TVariable(mSymbolTable, kEmptyImmutableString, newStructType, SymbolType::Empty);
+        TIntermSymbol *newStructRef = new TIntermSymbol(newStructVar);
+
+        TIntermDeclaration *structDecl = new TIntermDeclaration;
+        structDecl->appendDeclarator(newStructRef);
+
+        newSequence->push_back(structDecl);
+
+        mSymbolTable->declare(newStruct);
+    }
+
+    // Returns true if the type is a struct that was removed because we extracted all the members.
+    bool isRemovedStructType(const TType &type) const
+    {
+        const TStructure *structure = type.getStruct();
+        return (structure && (mRemovedStructs.count(structure->name()) > 0));
+    }
+
+    // Removes samplers from struct uniforms. For each sampler removed also adds a new globally
+    // defined sampler uniform.
+    void extractStructSamplerUniforms(TIntermDeclaration *oldDeclaration,
+                                      const TVariable &variable,
+                                      const TStructure *structure,
+                                      TIntermSequence *newSequence)
+    {
+        ASSERT(structure->containsSamplers());
+
+        size_t nonSamplerCount = 0;
+
+        for (const TField *field : structure->fields())
+        {
+            nonSamplerCount +=
+                extractFieldSamplers(variable.name(), field, variable.getType(), newSequence);
+        }
+
+        if (nonSamplerCount > 0)
+        {
+            // Keep the old declaration around if it has other members.
+            newSequence->push_back(oldDeclaration);
+        }
+        else
+        {
+            mRemovedUniformsCount++;
+        }
+    }
+
+    // Extracts samplers from a field of a struct. Works with nested structs and arrays.
+    size_t extractFieldSamplers(const ImmutableString &prefix,
+                                const TField *field,
+                                const TType &containingType,
+                                TIntermSequence *newSequence)
+    {
+        if (containingType.isArray())
+        {
+            size_t nonSamplerCount = 0;
+
+            // Name the samplers internally as varName_<index>_fieldName
+            const TVector<unsigned int> &arraySizes = *containingType.getArraySizes();
+            for (unsigned int arrayElement = 0; arrayElement < arraySizes[0]; ++arrayElement)
+            {
+                ImmutableStringBuilder stringBuilder(prefix.length() + kHexSize + 1);
+                stringBuilder << prefix << "_";
+                stringBuilder.appendHex(arrayElement);
+                nonSamplerCount = extractFieldSamplersImpl(stringBuilder, field, newSequence);
+            }
+
+            return nonSamplerCount;
+        }
+
+        return extractFieldSamplersImpl(prefix, field, newSequence);
+    }
+
+    // Extracts samplers from a field of a struct. Works with nested structs and arrays.
+    size_t extractFieldSamplersImpl(const ImmutableString &prefix,
+                                    const TField *field,
+                                    TIntermSequence *newSequence)
+    {
+        size_t nonSamplerCount = 0;
+
+        const TType &fieldType = *field->type();
+        if (fieldType.isSampler() || fieldType.isStructureContainingSamplers())
+        {
+            ImmutableStringBuilder stringBuilder(prefix.length() + field->name().length() + 1);
+            stringBuilder << prefix << "_" << field->name();
+            ImmutableString newPrefix(stringBuilder);
+
+            if (fieldType.isSampler())
+            {
+                extractSampler(newPrefix, fieldType, newSequence);
+            }
+            else
+            {
+                const TStructure *structure = fieldType.getStruct();
+                for (const TField *nestedField : structure->fields())
+                {
+                    nonSamplerCount +=
+                        extractFieldSamplers(newPrefix, nestedField, fieldType, newSequence);
+                }
+            }
+        }
+        else
+        {
+            nonSamplerCount++;
+        }
+
+        return nonSamplerCount;
+    }
+
+    // Extracts a sampler from a struct. Declares the new extracted sampler.
+    void extractSampler(const ImmutableString &newName,
+                        const TType &fieldType,
+                        TIntermSequence *newSequence) const
+    {
+        TType *newType = new TType(fieldType);
+        newType->setQualifier(EvqUniform);
+        TVariable *newVariable =
+            new TVariable(mSymbolTable, newName, newType, SymbolType::AngleInternal);
+        TIntermSymbol *newRef = new TIntermSymbol(newVariable);
+
+        TIntermDeclaration *samplerDecl = new TIntermDeclaration;
+        samplerDecl->appendDeclarator(newRef);
+
+        newSequence->push_back(samplerDecl);
+
+        mSymbolTable->declareInternal(newVariable);
+    }
+
+    // Returns the chained name of a sampler uniform field.
+    static ImmutableString GetFieldName(const ImmutableString &paramName,
+                                        const TField *field,
+                                        unsigned arrayIndex)
+    {
+        ImmutableStringBuilder nameBuilder(paramName.length() + kHexSize + 2 +
+                                           field->name().length());
+        nameBuilder << paramName << "_";
+
+        if (arrayIndex < std::numeric_limits<unsigned>::max())
+        {
+            nameBuilder.appendHex(arrayIndex);
+            nameBuilder << "_";
+        }
+        nameBuilder << field->name();
+
+        return nameBuilder;
+    }
+
+    // A pattern that visits every parameter of a function call. Uses different handlers for struct
+    // parameters, struct sampler parameters, and non-struct parameters.
+    class StructSamplerFunctionVisitor : angle::NonCopyable
+    {
+      public:
+        StructSamplerFunctionVisitor()          = default;
+        virtual ~StructSamplerFunctionVisitor() = default;
+
+        virtual void traverse(const TFunction *function)
+        {
+            size_t paramCount = function->getParamCount();
+
+            for (size_t paramIndex = 0; paramIndex < paramCount; ++paramIndex)
+            {
+                const TVariable *param = function->getParam(paramIndex);
+                const TType &paramType = param->getType();
+
+                if (paramType.isStructureContainingSamplers())
+                {
+                    const ImmutableString &baseName = getNameFromIndex(function, paramIndex);
+                    if (traverseStructContainingSamplers(baseName, paramType))
+                    {
+                        visitStructParam(function, paramIndex);
+                    }
+                }
+                else
+                {
+                    visitNonStructParam(function, paramIndex);
+                }
+            }
+        }
+
+        virtual ImmutableString getNameFromIndex(const TFunction *function, size_t paramIndex) = 0;
+        virtual void visitSamplerInStructParam(const ImmutableString &name,
+                                               const TField *field)                            = 0;
+        virtual void visitStructParam(const TFunction *function, size_t paramIndex)            = 0;
+        virtual void visitNonStructParam(const TFunction *function, size_t paramIndex)         = 0;
+
+      private:
+        bool traverseStructContainingSamplers(const ImmutableString &baseName,
+                                              const TType &structType)
+        {
+            bool hasNonSamplerFields    = false;
+            const TStructure *structure = structType.getStruct();
+            for (const TField *field : structure->fields())
+            {
+                if (field->type()->isStructureContainingSamplers() || field->type()->isSampler())
+                {
+                    if (traverseSamplerInStruct(baseName, structType, field))
+                    {
+                        hasNonSamplerFields = true;
+                    }
+                }
+                else
+                {
+                    hasNonSamplerFields = true;
+                }
+            }
+            return hasNonSamplerFields;
+        }
+
+        bool traverseSamplerInStruct(const ImmutableString &baseName,
+                                     const TType &baseType,
+                                     const TField *field)
+        {
+            bool hasNonSamplerParams = false;
+
+            if (baseType.isArray())
+            {
+                const TVector<unsigned int> &arraySizes = *baseType.getArraySizes();
+                ASSERT(arraySizes.size() == 1);
+
+                for (unsigned int arrayIndex = 0; arrayIndex < arraySizes[0]; ++arrayIndex)
+                {
+                    ImmutableString name = GetFieldName(baseName, field, arrayIndex);
+
+                    if (field->type()->isStructureContainingSamplers())
+                    {
+                        if (traverseStructContainingSamplers(name, *field->type()))
+                        {
+                            hasNonSamplerParams = true;
+                        }
+                    }
+                    else
+                    {
+                        ASSERT(field->type()->isSampler());
+                        visitSamplerInStructParam(name, field);
+                    }
+                }
+            }
+            else if (field->type()->isStructureContainingSamplers())
+            {
+                ImmutableString name =
+                    GetFieldName(baseName, field, std::numeric_limits<unsigned>::max());
+                hasNonSamplerParams = traverseStructContainingSamplers(name, *field->type());
+            }
+            else
+            {
+                ASSERT(field->type()->isSampler());
+                ImmutableString name =
+                    GetFieldName(baseName, field, std::numeric_limits<unsigned>::max());
+                visitSamplerInStructParam(name, field);
+            }
+
+            return hasNonSamplerParams;
+        }
+    };
+
+    // A visitor that replaces functions with struct sampler references. The struct sampler
+    // references are expanded to include new fields for the structs.
+    class CreateStructSamplerFunctionVisitor final : public StructSamplerFunctionVisitor
+    {
+      public:
+        CreateStructSamplerFunctionVisitor(TSymbolTable *symbolTable)
+            : mSymbolTable(symbolTable), mNewFunction(nullptr)
+        {}
+
+        ImmutableString getNameFromIndex(const TFunction *function, size_t paramIndex) override
+        {
+            const TVariable *param = function->getParam(paramIndex);
+            return param->name();
+        }
+
+        void traverse(const TFunction *function) override
+        {
+            mNewFunction =
+                new TFunction(mSymbolTable, function->name(), function->symbolType(),
+                              &function->getReturnType(), function->isKnownToNotHaveSideEffects());
+
+            StructSamplerFunctionVisitor::traverse(function);
+        }
+
+        void visitSamplerInStructParam(const ImmutableString &name, const TField *field) override
+        {
+            TVariable *fieldSampler =
+                new TVariable(mSymbolTable, name, field->type(), SymbolType::AngleInternal);
+            mNewFunction->addParameter(fieldSampler);
+            mSymbolTable->declareInternal(fieldSampler);
+        }
+
+        void visitStructParam(const TFunction *function, size_t paramIndex) override
+        {
+            const TVariable *param = function->getParam(paramIndex);
+            TType *structType      = GetStructSamplerParameterType(mSymbolTable, *param);
+            TVariable *newParam =
+                new TVariable(mSymbolTable, param->name(), structType, param->symbolType());
+            mNewFunction->addParameter(newParam);
+        }
+
+        void visitNonStructParam(const TFunction *function, size_t paramIndex) override
+        {
+            const TVariable *param = function->getParam(paramIndex);
+            mNewFunction->addParameter(param);
+        }
+
+        TFunction *getNewFunction() const { return mNewFunction; }
+
+      private:
+        TSymbolTable *mSymbolTable;
+        TFunction *mNewFunction;
+    };
+
+    TFunction *createStructSamplerFunction(const TFunction *function) const
+    {
+        CreateStructSamplerFunctionVisitor visitor(mSymbolTable);
+        visitor.traverse(function);
+        return visitor.getNewFunction();
+    }
+
+    // A visitor that replaces function calls with expanded struct sampler parameters.
+    class GetSamplerArgumentsVisitor final : public StructSamplerFunctionVisitor
+    {
+      public:
+        GetSamplerArgumentsVisitor(TSymbolTable *symbolTable, const TIntermSequence *arguments)
+            : mSymbolTable(symbolTable), mArguments(arguments), mNewArguments(new TIntermSequence)
+        {}
+
+        ImmutableString getNameFromIndex(const TFunction *function, size_t paramIndex) override
+        {
+            TIntermTyped *argument = (*mArguments)[paramIndex]->getAsTyped();
+            return GetStructSamplerNameFromTypedNode(argument);
+        }
+
+        void visitSamplerInStructParam(const ImmutableString &name, const TField *field) override
+        {
+            TVariable *argSampler =
+                new TVariable(mSymbolTable, name, field->type(), SymbolType::AngleInternal);
+            TIntermSymbol *argSymbol = new TIntermSymbol(argSampler);
+            mNewArguments->push_back(argSymbol);
+        }
+
+        void visitStructParam(const TFunction *function, size_t paramIndex) override
+        {
+            // The tree structure of the parameter is modified to point to the new type. This leaves
+            // the tree in a consistent state.
+            TIntermTyped *argument    = (*mArguments)[paramIndex]->getAsTyped();
+            TIntermTyped *replacement = ReplaceTypeOfTypedStructNode(argument, mSymbolTable);
+            mNewArguments->push_back(replacement);
+        }
+
+        void visitNonStructParam(const TFunction *function, size_t paramIndex) override
+        {
+            TIntermTyped *argument = (*mArguments)[paramIndex]->getAsTyped();
+            mNewArguments->push_back(argument);
+        }
+
+        TIntermSequence *getNewArguments() const { return mNewArguments; }
+
+      private:
+        TSymbolTable *mSymbolTable;
+        const TIntermSequence *mArguments;
+        TIntermSequence *mNewArguments;
+    };
+
+    TIntermSequence *getStructSamplerArguments(const TFunction *function,
+                                               const TIntermSequence *arguments) const
+    {
+        GetSamplerArgumentsVisitor visitor(mSymbolTable, arguments);
+        visitor.traverse(function);
+        return visitor.getNewArguments();
+    }
+
+    int mRemovedUniformsCount;
+    std::set<ImmutableString> mRemovedStructs;
+};
+}  // anonymous namespace
+
+bool RewriteStructSamplersOld(TCompiler *compiler,
+                              TIntermBlock *root,
+                              TSymbolTable *symbolTable,
+                              int *removedUniformsCountOut)
+{
+    Traverser rewriteStructSamplers(symbolTable);
+    root->traverse(&rewriteStructSamplers);
+    if (!rewriteStructSamplers.updateTree(compiler, root))
+    {
+        return false;
+    }
+    *removedUniformsCountOut = rewriteStructSamplers.removedUniformsCount();
+    return true;
+}
+}  // namespace sh
diff --git a/src/libANGLE/ProgramLinkedResources.cpp b/src/libANGLE/ProgramLinkedResources.cpp
index 17d2974..106f4af 100644
--- a/src/libANGLE/ProgramLinkedResources.cpp
+++ b/src/libANGLE/ProgramLinkedResources.cpp
@@ -248,7 +248,8 @@
     void visitNamedVariable(const sh::ShaderVariable &variable,
                             bool isRowMajor,
                             const std::string &name,
-                            const std::string &mappedName) override
+                            const std::string &mappedName,
+                            const std::vector<unsigned int> &arraySizes) override
     {
         // If getBlockMemberInfo returns false, the variable is optimized out.
         sh::BlockMemberInfo variableInfo;
@@ -308,7 +309,8 @@
     void visitNamedVariable(const sh::ShaderVariable &variable,
                             bool isRowMajor,
                             const std::string &name,
-                            const std::string &mappedName) override
+                            const std::string &mappedName,
+                            const std::vector<unsigned int> &arraySizes) override
     {
         if (mSkipEnabled)
             return;
@@ -397,15 +399,17 @@
 
     void visitNamedSampler(const sh::ShaderVariable &sampler,
                            const std::string &name,
-                           const std::string &mappedName) override
+                           const std::string &mappedName,
+                           const std::vector<unsigned int> &arraySizes) override
     {
-        visitNamedVariable(sampler, false, name, mappedName);
+        visitNamedVariable(sampler, false, name, mappedName, arraySizes);
     }
 
     void visitNamedVariable(const sh::ShaderVariable &variable,
                             bool isRowMajor,
                             const std::string &name,
-                            const std::string &mappedName) override
+                            const std::string &mappedName,
+                            const std::vector<unsigned int> &arraySizes) override
     {
         bool isSampler                          = IsSamplerType(variable.type);
         bool isImage                            = IsImageType(variable.type);
@@ -468,6 +472,7 @@
             linkedUniform.mappedName = fullMappedNameWithArrayIndex;
             linkedUniform.active     = mMarkActive;
             linkedUniform.staticUse  = mMarkStaticUse;
+            linkedUniform.outerArraySizes = arraySizes;
             if (variable.hasParentArrayIndex())
             {
                 linkedUniform.setParentArrayIndex(variable.parentArrayIndex());
diff --git a/src/libANGLE/Shader.cpp b/src/libANGLE/Shader.cpp
index 8627a60..09f0039 100644
--- a/src/libANGLE/Shader.cpp
+++ b/src/libANGLE/Shader.cpp
@@ -406,7 +406,15 @@
         // Remove null characters from the source line
         line.erase(std::remove(line.begin(), line.end(), '\0'), line.end());
 
-        shaderStream << "// " << line << std::endl;
+        shaderStream << "// " << line;
+
+        // glslang complains if a comment ends with backslash
+        if (!line.empty() && line.back() == '\\')
+        {
+            shaderStream << "\\";
+        }
+
+        shaderStream << std::endl;
     }
     shaderStream << "\n\n";
     shaderStream << mState.mTranslatedSource;
diff --git a/src/libANGLE/Uniform.cpp b/src/libANGLE/Uniform.cpp
index cc010dd..e724059 100644
--- a/src/libANGLE/Uniform.cpp
+++ b/src/libANGLE/Uniform.cpp
@@ -80,7 +80,8 @@
       ActiveVariable(uniform),
       typeInfo(uniform.typeInfo),
       bufferIndex(uniform.bufferIndex),
-      blockInfo(uniform.blockInfo)
+      blockInfo(uniform.blockInfo),
+      outerArraySizes(uniform.outerArraySizes)
 {}
 
 LinkedUniform &LinkedUniform::operator=(const LinkedUniform &uniform)
@@ -90,6 +91,7 @@
     typeInfo                = uniform.typeInfo;
     bufferIndex             = uniform.bufferIndex;
     blockInfo               = uniform.blockInfo;
+    outerArraySizes         = uniform.outerArraySizes;
     return *this;
 }
 
diff --git a/src/libANGLE/Uniform.h b/src/libANGLE/Uniform.h
index 3e16101..a6411b7 100644
--- a/src/libANGLE/Uniform.h
+++ b/src/libANGLE/Uniform.h
@@ -75,6 +75,7 @@
     // Identifies the containing buffer backed resource -- interface block or atomic counter buffer.
     int bufferIndex;
     sh::BlockMemberInfo blockInfo;
+    std::vector<unsigned int> outerArraySizes;
 };
 
 struct BufferVariable : public sh::ShaderVariable, public ActiveVariable
diff --git a/src/libANGLE/renderer/d3d/ProgramD3D.cpp b/src/libANGLE/renderer/d3d/ProgramD3D.cpp
index b711fdb..2b8b750 100644
--- a/src/libANGLE/renderer/d3d/ProgramD3D.cpp
+++ b/src/libANGLE/renderer/d3d/ProgramD3D.cpp
@@ -217,7 +217,8 @@
 
     void visitNamedSampler(const sh::ShaderVariable &sampler,
                            const std::string &name,
-                           const std::string &mappedName) override
+                           const std::string &mappedName,
+                           const std::vector<unsigned int> &arraySizes) override
     {
         auto uniformMapEntry = mUniformMapOut->find(name);
         if (uniformMapEntry == mUniformMapOut->end())
diff --git a/src/libANGLE/renderer/vulkan/ContextVk.cpp b/src/libANGLE/renderer/vulkan/ContextVk.cpp
index a485321..7318a0e 100644
--- a/src/libANGLE/renderer/vulkan/ContextVk.cpp
+++ b/src/libANGLE/renderer/vulkan/ContextVk.cpp
@@ -240,6 +240,7 @@
       mIsAnyHostVisibleBufferWritten(false),
       mEmulateSeamfulCubeMapSampling(false),
       mEmulateSeamfulCubeMapSamplingWithSubgroupOps(false),
+      mUseOldRewriteStructSamplers(false),
       mLastCompletedQueueSerial(renderer->nextSerial()),
       mCurrentQueueSerial(renderer->nextSerial()),
       mPoolAllocator(kDefaultPoolAllocatorPageSize, 1),
@@ -446,6 +447,8 @@
     mEmulateSeamfulCubeMapSampling =
         shouldEmulateSeamfulCubeMapSampling(&mEmulateSeamfulCubeMapSamplingWithSubgroupOps);
 
+    mUseOldRewriteStructSamplers = shouldUseOldRewriteStructSamplers();
+
     return angle::Result::Continue;
 }
 
@@ -2947,4 +2950,9 @@
 
     return true;
 }
+
+bool ContextVk::shouldUseOldRewriteStructSamplers() const
+{
+    return mRenderer->getFeatures().forceOldRewriteStructSamplers.enabled;
+}
 }  // namespace rx
diff --git a/src/libANGLE/renderer/vulkan/ContextVk.h b/src/libANGLE/renderer/vulkan/ContextVk.h
index ae82f8d..a0cde83 100644
--- a/src/libANGLE/renderer/vulkan/ContextVk.h
+++ b/src/libANGLE/renderer/vulkan/ContextVk.h
@@ -333,6 +333,8 @@
         return mEmulateSeamfulCubeMapSampling;
     }
 
+    bool useOldRewriteStructSamplers() const { return mUseOldRewriteStructSamplers; }
+
   private:
     // Dirty bits.
     enum DirtyBitType : size_t
@@ -492,6 +494,8 @@
 
     bool shouldEmulateSeamfulCubeMapSampling(bool *useSubgroupOpsOut) const;
 
+    bool shouldUseOldRewriteStructSamplers() const;
+
     vk::PipelineHelper *mCurrentGraphicsPipeline;
     vk::PipelineAndSerial *mCurrentComputePipeline;
     gl::PrimitiveMode mCurrentDrawMode;
@@ -558,6 +562,10 @@
     bool mEmulateSeamfulCubeMapSampling;
     bool mEmulateSeamfulCubeMapSamplingWithSubgroupOps;
 
+    // Whether this context should use the old version of the
+    // RewriteStructSamplers pass.
+    bool mUseOldRewriteStructSamplers;
+
     struct DriverUniformsDescriptorSet
     {
         vk::DynamicBuffer dynamicBuffer;
diff --git a/src/libANGLE/renderer/vulkan/GlslangWrapper.cpp b/src/libANGLE/renderer/vulkan/GlslangWrapper.cpp
index 01c6c1a..8dac687 100644
--- a/src/libANGLE/renderer/vulkan/GlslangWrapper.cpp
+++ b/src/libANGLE/renderer/vulkan/GlslangWrapper.cpp
@@ -366,7 +366,7 @@
     return shaderSource;
 }
 
-std::string GetMappedSamplerName(const std::string &originalName)
+std::string GetMappedSamplerNameOld(const std::string &originalName)
 {
     std::string samplerName = gl::ParseResourceName(originalName, nullptr);
 
@@ -777,7 +777,8 @@
                                                      bindingStart, shaderSources);
 }
 
-void AssignTextureBindings(const gl::ProgramState &programState,
+void AssignTextureBindings(bool useOldRewriteStructSamplers,
+                           const gl::ProgramState &programState,
                            gl::ShaderMap<IntermediateShaderSource> *shaderSources)
 {
     const std::string texturesDescriptorSet = "set = " + Str(kTextureDescriptorSetIndex);
@@ -789,18 +790,28 @@
     for (unsigned int uniformIndex : programState.getSamplerUniformRange())
     {
         const gl::LinkedUniform &samplerUniform = uniforms[uniformIndex];
+
+        if (!useOldRewriteStructSamplers &&
+            vk::SamplerNameContainsNonZeroArrayElement(samplerUniform.name))
+        {
+            continue;
+        }
+
         const std::string bindingString =
             texturesDescriptorSet + ", binding = " + Str(bindingIndex++);
 
         // Samplers in structs are extracted and renamed.
-        const std::string samplerName = GetMappedSamplerName(samplerUniform.name);
+        const std::string samplerName = useOldRewriteStructSamplers
+                                            ? GetMappedSamplerNameOld(samplerUniform.name)
+                                            : vk::GetMappedSamplerName(samplerUniform.name);
 
         AssignResourceBinding(samplerUniform.activeShaders(), samplerName, bindingString,
                               kUniformQualifier, kUnusedUniformSubstitution, shaderSources);
     }
 }
 
-void CleanupUnusedEntities(const gl::ProgramState &programState,
+void CleanupUnusedEntities(bool useOldRewriteStructSamplers,
+                           const gl::ProgramState &programState,
                            const gl::ProgramLinkedResources &resources,
                            gl::Shader *glVertexShader,
                            gl::ShaderMap<IntermediateShaderSource> *shaderSources)
@@ -847,8 +858,11 @@
     // uniforms to a single line.
     for (const gl::UnusedUniform &unusedUniform : resources.unusedUniforms)
     {
-        std::string uniformName =
-            unusedUniform.isSampler ? GetMappedSamplerName(unusedUniform.name) : unusedUniform.name;
+        std::string uniformName = unusedUniform.isSampler
+                                      ? useOldRewriteStructSamplers
+                                            ? GetMappedSamplerNameOld(unusedUniform.name)
+                                            : vk::GetMappedSamplerName(unusedUniform.name)
+                                      : unusedUniform.name;
 
         for (IntermediateShaderSource &shaderSource : *shaderSources)
         {
@@ -880,7 +894,8 @@
 }
 
 // static
-void GlslangWrapper::GetShaderSource(const gl::ProgramState &programState,
+void GlslangWrapper::GetShaderSource(bool useOldRewriteStructSamplers,
+                                     const gl::ProgramState &programState,
                                      const gl::ProgramLinkedResources &resources,
                                      gl::ShaderMap<std::string> *shaderSourcesOut)
 {
@@ -906,9 +921,9 @@
     }
     AssignUniformBindings(&intermediateSources);
     AssignBufferBindings(programState, &intermediateSources);
-    AssignTextureBindings(programState, &intermediateSources);
+    AssignTextureBindings(useOldRewriteStructSamplers, programState, &intermediateSources);
 
-    CleanupUnusedEntities(programState, resources,
+    CleanupUnusedEntities(useOldRewriteStructSamplers, programState, resources,
                           programState.getAttachedShader(gl::ShaderType::Vertex),
                           &intermediateSources);
 
diff --git a/src/libANGLE/renderer/vulkan/GlslangWrapper.h b/src/libANGLE/renderer/vulkan/GlslangWrapper.h
index 8a69987..de599a4 100644
--- a/src/libANGLE/renderer/vulkan/GlslangWrapper.h
+++ b/src/libANGLE/renderer/vulkan/GlslangWrapper.h
@@ -22,7 +22,8 @@
     static void Initialize();
     static void Release();
 
-    static void GetShaderSource(const gl::ProgramState &programState,
+    static void GetShaderSource(bool useOldRewriteStructSamplers,
+                                const gl::ProgramState &programState,
                                 const gl::ProgramLinkedResources &resources,
                                 gl::ShaderMap<std::string> *shaderSourcesOut);
 
diff --git a/src/libANGLE/renderer/vulkan/ProgramVk.cpp b/src/libANGLE/renderer/vulkan/ProgramVk.cpp
index 291710c..81ba731 100644
--- a/src/libANGLE/renderer/vulkan/ProgramVk.cpp
+++ b/src/libANGLE/renderer/vulkan/ProgramVk.cpp
@@ -503,7 +503,8 @@
     // assignment done in that function.
     linkResources(resources);
 
-    GlslangWrapper::GetShaderSource(mState, resources, &mShaderSources);
+    GlslangWrapper::GetShaderSource(contextVk->useOldRewriteStructSamplers(), mState, resources,
+                                    &mShaderSources);
 
     reset(contextVk);
 
@@ -565,6 +566,7 @@
 
     // Textures:
     vk::DescriptorSetLayoutDesc texturesSetDesc;
+    uint32_t bindingIndex = 0;
 
     for (uint32_t textureIndex = 0; textureIndex < mState.getSamplerBindings().size();
          ++textureIndex)
@@ -575,11 +577,27 @@
         const gl::LinkedUniform &samplerUniform = mState.getUniforms()[uniformIndex];
 
         // The front-end always binds array sampler units sequentially.
-        const uint32_t arraySize = static_cast<uint32_t>(samplerBinding.boundTextureUnits.size());
+        uint32_t arraySize = static_cast<uint32_t>(samplerBinding.boundTextureUnits.size());
         VkShaderStageFlags activeStages =
             gl_vk::GetShaderStageFlags(samplerUniform.activeShaders());
 
-        texturesSetDesc.update(textureIndex, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, arraySize,
+        if (!contextVk->useOldRewriteStructSamplers())
+        {
+            // 2D arrays are split into multiple 1D arrays when generating
+            // LinkedUniforms. Since they are flattened into one array, ignore the
+            // nonzero elements and expand the array to the total array size.
+            if (vk::SamplerNameContainsNonZeroArrayElement(samplerUniform.name))
+            {
+                continue;
+            }
+
+            for (unsigned int outerArraySize : samplerUniform.outerArraySizes)
+            {
+                arraySize *= outerArraySize;
+            }
+        }
+
+        texturesSetDesc.update(bindingIndex++, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, arraySize,
                                activeStages);
     }
 
@@ -1459,6 +1477,12 @@
 
     bool useSubgroupOps                = false;
     bool emulateSeamfulCubeMapSampling = contextVk->emulateSeamfulCubeMapSampling(&useSubgroupOps);
+    bool useOldRewriteStructSamplers   = contextVk->useOldRewriteStructSamplers();
+
+    std::unordered_map<std::string, uint32_t> mappedSamplerNameToBindingIndex;
+    std::unordered_map<std::string, uint32_t> mappedSamplerNameToArrayOffset;
+
+    uint32_t currentBindingIndex = 0;
 
     for (uint32_t textureIndex = 0; textureIndex < mState.getSamplerBindings().size();
          ++textureIndex)
@@ -1467,8 +1491,30 @@
 
         ASSERT(!samplerBinding.unreferenced);
 
-        for (uint32_t arrayElement = 0; arrayElement < samplerBinding.boundTextureUnits.size();
-             ++arrayElement)
+        uint32_t uniformIndex = mState.getUniformIndexFromSamplerIndex(textureIndex);
+        const gl::LinkedUniform &samplerUniform = mState.getUniforms()[uniformIndex];
+        std::string mappedSamplerName           = vk::GetMappedSamplerName(samplerUniform.name);
+
+        if (useOldRewriteStructSamplers ||
+            mappedSamplerNameToBindingIndex.emplace(mappedSamplerName, currentBindingIndex).second)
+        {
+            currentBindingIndex++;
+        }
+
+        uint32_t bindingIndex = textureIndex;
+        uint32_t arrayOffset  = 0;
+        uint32_t arraySize    = static_cast<uint32_t>(samplerBinding.boundTextureUnits.size());
+
+        if (!useOldRewriteStructSamplers)
+        {
+            bindingIndex = mappedSamplerNameToBindingIndex[mappedSamplerName];
+            arrayOffset  = mappedSamplerNameToArrayOffset[mappedSamplerName];
+            // Front-end generates array elements in order, so we can just increment
+            // the offset each time we process a nested array.
+            mappedSamplerNameToArrayOffset[mappedSamplerName] += arraySize;
+        }
+
+        for (uint32_t arrayElement = 0; arrayElement < arraySize; ++arrayElement)
         {
             GLuint textureUnit   = samplerBinding.boundTextureUnits[arrayElement];
             TextureVk *textureVk = activeTextures[textureUnit].texture;
@@ -1496,8 +1542,8 @@
             writeInfo.sType            = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
             writeInfo.pNext            = nullptr;
             writeInfo.dstSet           = descriptorSet;
-            writeInfo.dstBinding       = textureIndex;
-            writeInfo.dstArrayElement  = arrayElement;
+            writeInfo.dstBinding       = bindingIndex;
+            writeInfo.dstArrayElement  = arrayOffset + arrayElement;
             writeInfo.descriptorCount  = 1;
             writeInfo.descriptorType   = VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER;
             writeInfo.pImageInfo       = &imageInfo;
diff --git a/src/libANGLE/renderer/vulkan/RendererVk.cpp b/src/libANGLE/renderer/vulkan/RendererVk.cpp
index d4e7ef8..dbdc2d5 100644
--- a/src/libANGLE/renderer/vulkan/RendererVk.cpp
+++ b/src/libANGLE/renderer/vulkan/RendererVk.cpp
@@ -1266,6 +1266,8 @@
         }
 
         mFeatures.bindEmptyForUnusedDescriptorSets.enabled = true;
+
+        mFeatures.forceOldRewriteStructSamplers.enabled = true;
     }
 
     if (IsWindows() && IsIntel(mPhysicalDeviceProperties.vendorID))
diff --git a/src/libANGLE/renderer/vulkan/ShaderVk.cpp b/src/libANGLE/renderer/vulkan/ShaderVk.cpp
index cf59679..ee5f569 100644
--- a/src/libANGLE/renderer/vulkan/ShaderVk.cpp
+++ b/src/libANGLE/renderer/vulkan/ShaderVk.cpp
@@ -53,6 +53,11 @@
         }
     }
 
+    if (contextVk->useOldRewriteStructSamplers())
+    {
+        compileOptions |= SH_USE_OLD_REWRITE_STRUCT_SAMPLERS;
+    }
+
     return compileImpl(context, compilerInstance, mData.getSource(), compileOptions | options);
 }
 
diff --git a/src/libANGLE/renderer/vulkan/vk_utils.cpp b/src/libANGLE/renderer/vulkan/vk_utils.cpp
index e3a248c..ba48388 100644
--- a/src/libANGLE/renderer/vulkan/vk_utils.cpp
+++ b/src/libANGLE/renderer/vulkan/vk_utils.cpp
@@ -526,6 +526,57 @@
     return false;
 }
 
+bool SamplerNameContainsNonZeroArrayElement(const std::string &name)
+{
+    constexpr char kZERO_ELEMENT[] = "[0]";
+
+    size_t start = 0;
+    while (true)
+    {
+        start = name.find(kZERO_ELEMENT[0], start);
+        if (start == std::string::npos)
+        {
+            break;
+        }
+        if (name.compare(start, strlen(kZERO_ELEMENT), kZERO_ELEMENT) != 0)
+        {
+            return true;
+        }
+        start++;
+    }
+    return false;
+}
+
+std::string GetMappedSamplerName(const std::string &originalName)
+{
+    std::string samplerName = originalName;
+
+    // Samplers in structs are extracted.
+    std::replace(samplerName.begin(), samplerName.end(), '.', '_');
+
+    // Remove array elements
+    auto out = samplerName.begin();
+    for (auto in = samplerName.begin(); in != samplerName.end(); in++)
+    {
+        if (*in == '[')
+        {
+            while (*in != ']')
+            {
+                in++;
+                ASSERT(in != samplerName.end());
+            }
+        }
+        else
+        {
+            *out++ = *in;
+        }
+    }
+
+    samplerName.erase(out, samplerName.end());
+
+    return samplerName;
+}
+
 }  // namespace vk
 
 // VK_EXT_debug_utils
diff --git a/src/libANGLE/renderer/vulkan/vk_utils.h b/src/libANGLE/renderer/vulkan/vk_utils.h
index b813d20..ecd047d 100644
--- a/src/libANGLE/renderer/vulkan/vk_utils.h
+++ b/src/libANGLE/renderer/vulkan/vk_utils.h
@@ -528,6 +528,9 @@
     std::vector<T> mObjectFreeList;
 };
 
+bool SamplerNameContainsNonZeroArrayElement(const std::string &name);
+std::string GetMappedSamplerName(const std::string &originalName);
+
 }  // namespace vk
 
 // List of function pointers for used extensions.
diff --git a/src/tests/gl_tests/GLSLTest.cpp b/src/tests/gl_tests/GLSLTest.cpp
index 22a54b0..daee387 100644
--- a/src/tests/gl_tests/GLSLTest.cpp
+++ b/src/tests/gl_tests/GLSLTest.cpp
@@ -3086,8 +3086,10 @@
 // Test that arrays of arrays of samplers work as expected.
 TEST_P(GLSLTest_ES31, ArraysOfArraysSampler)
 {
-    // anglebug.com/3604 - Vulkan doesn't support 2D arrays of samplers
-    ANGLE_SKIP_TEST_IF(IsVulkan());
+    // anglebug.com/2703 - QC doesn't support arrays of samplers as parameters,
+    // so sampler array of array handling is disabled
+    ANGLE_SKIP_TEST_IF(IsAndroid() && IsVulkan());
+
     constexpr char kFS[] =
         "#version 310 es\n"
         "precision mediump float;\n"
@@ -3174,8 +3176,10 @@
 // Test that arrays of arrays of samplers inside structs work as expected.
 TEST_P(GLSLTest_ES31, StructArrayArraySampler)
 {
-    // anglebug.com/3604 - Vulkan doesn't support 2D arrays of samplers
-    ANGLE_SKIP_TEST_IF(IsVulkan());
+    // anglebug.com/2703 - QC doesn't support arrays of samplers as parameters,
+    // so sampler array of array handling is disabled
+    ANGLE_SKIP_TEST_IF(IsAndroid() && IsVulkan());
+
     constexpr char kFS[] =
         "#version 310 es\n"
         "precision mediump float;\n"
@@ -3226,8 +3230,10 @@
 // Test that an array of structs with arrays of arrays of samplers works.
 TEST_P(GLSLTest_ES31, ArrayStructArrayArraySampler)
 {
-    // anglebug.com/3604 - Vulkan doesn't support 2D arrays of samplers
-    ANGLE_SKIP_TEST_IF(IsVulkan());
+    // anglebug.com/2703 - QC doesn't support arrays of samplers as parameters,
+    // so sampler array of array handling is disabled
+    ANGLE_SKIP_TEST_IF(IsAndroid() && IsVulkan());
+
     GLint numTextures;
     glGetIntegerv(GL_MAX_TEXTURE_IMAGE_UNITS, &numTextures);
     ANGLE_SKIP_TEST_IF(numTextures < 2 * (2 * 2 + 2 * 2));
@@ -3296,8 +3302,10 @@
 // Test that a complex chain of structs and arrays of samplers works as expected.
 TEST_P(GLSLTest_ES31, ComplexStructArraySampler)
 {
-    // anglebug.com/3604 - Vulkan doesn't support 2D arrays of samplers
-    ANGLE_SKIP_TEST_IF(IsVulkan());
+    // anglebug.com/2703 - QC doesn't support arrays of samplers as parameters,
+    // so sampler array of array handling is disabled
+    ANGLE_SKIP_TEST_IF(IsAndroid() && IsVulkan());
+
     GLint numTextures;
     glGetIntegerv(GL_MAX_TEXTURE_IMAGE_UNITS, &numTextures);
     ANGLE_SKIP_TEST_IF(numTextures < 2 * 3 * (2 + 3));
@@ -3383,8 +3391,10 @@
 
 TEST_P(GLSLTest_ES31, ArraysOfArraysStructDifferentTypesSampler)
 {
-    // anglebug.com/3604 - Vulkan doesn't support 2D arrays of samplers
-    ANGLE_SKIP_TEST_IF(IsVulkan());
+    // anglebug.com/2703 - QC doesn't support arrays of samplers as parameters,
+    // so sampler array of array handling is disabled
+    ANGLE_SKIP_TEST_IF(IsAndroid() && IsVulkan());
+
     GLint numTextures;
     glGetIntegerv(GL_MAX_TEXTURE_IMAGE_UNITS, &numTextures);
     ANGLE_SKIP_TEST_IF(numTextures < 3 * (2 + 2));
@@ -3459,10 +3469,11 @@
 // Test that arrays of arrays of samplers as parameters works as expected.
 TEST_P(GLSLTest_ES31, ParameterArraysOfArraysSampler)
 {
-    // anglebug.com/3604 - Vulkan doesn't support 2D arrays of samplers
-    ANGLE_SKIP_TEST_IF(IsVulkan());
     // anglebug.com/3832 - no sampler array params on Android
     ANGLE_SKIP_TEST_IF(IsAndroid() && IsOpenGLES());
+    // anglebug.com/2703 - QC doesn't support arrays of samplers as parameters,
+    // so sampler array of array handling is disabled
+    ANGLE_SKIP_TEST_IF(IsAndroid() && IsVulkan());
     constexpr char kFS[] =
         "#version 310 es\n"
         "precision mediump float;\n"
@@ -3470,6 +3481,7 @@
         "uniform mediump isampler2D test[2][3];\n"
         "const vec2 ZERO = vec2(0.0, 0.0);\n"
         "\n"
+        "bool check(isampler2D data[2][3]);\n"
         "bool check(isampler2D data[2][3]) {\n"
         "#define DO_CHECK(i,j) \\\n"
         "    if (texture(data[i][j], ZERO) != ivec4(i+1, j+1, 0, 1)) { \\\n"
@@ -3519,10 +3531,11 @@
 // Test that structs with arrays of arrays of samplers as parameters works as expected.
 TEST_P(GLSLTest_ES31, ParameterStructArrayArraySampler)
 {
-    // anglebug.com/3604 - Vulkan doesn't support 2D arrays of samplers
-    ANGLE_SKIP_TEST_IF(IsVulkan());
     // anglebug.com/3832 - no sampler array params on Android
     ANGLE_SKIP_TEST_IF(IsAndroid() && IsOpenGLES());
+    // anglebug.com/2703 - QC doesn't support arrays of samplers as parameters,
+    // so sampler array of array handling is disabled
+    ANGLE_SKIP_TEST_IF(IsAndroid() && IsVulkan());
     constexpr char kFS[] =
         "#version 310 es\n"
         "precision mediump float;\n"
@@ -3581,10 +3594,11 @@
 // as parameters works as expected.
 TEST_P(GLSLTest_ES31, ParameterArrayArrayStructArrayArraySampler)
 {
-    // anglebug.com/3604 - Vulkan doesn't support 2D arrays of samplers
-    ANGLE_SKIP_TEST_IF(IsVulkan());
     // anglebug.com/3832 - no sampler array params on Android
     ANGLE_SKIP_TEST_IF(IsAndroid() && IsOpenGLES());
+    // anglebug.com/2703 - QC doesn't support arrays of samplers as parameters,
+    // so sampler array of array handling is disabled
+    ANGLE_SKIP_TEST_IF(IsAndroid() && IsVulkan());
     GLint numTextures;
     glGetIntegerv(GL_MAX_TEXTURE_IMAGE_UNITS, &numTextures);
     ANGLE_SKIP_TEST_IF(numTextures < 3 * 2 * 2 * 2);
@@ -3655,6 +3669,176 @@
     EXPECT_PIXEL_COLOR_EQ(0, 0, GLColor::green);
 }
 
+// Test that 3D arrays with sub-arrays passed as parameters works as expected.
+TEST_P(GLSLTest_ES31, ParameterArrayArrayArraySampler)
+{
+    // anglebug.com/2703 - QC doesn't support arrays of samplers as parameters,
+    // so sampler array of array handling is disabled
+    ANGLE_SKIP_TEST_IF(IsAndroid() && IsVulkan());
+
+    GLint numTextures;
+    glGetIntegerv(GL_MAX_TEXTURE_IMAGE_UNITS, &numTextures);
+    ANGLE_SKIP_TEST_IF(numTextures < 2 * 3 * 4 + 4);
+    // anglebug.com/3832 - no sampler array params on Android
+    ANGLE_SKIP_TEST_IF(IsAndroid() && IsOpenGLES());
+    // Seems like this is failing on Windows Intel?
+    ANGLE_SKIP_TEST_IF(IsWindows() && IsIntel() && IsOpenGL());
+    constexpr char kFS[] =
+        "#version 310 es\n"
+        "precision mediump float;\n"
+        "out vec4 my_FragColor;\n"
+        "uniform mediump isampler2D test[2][3][4];\n"
+        "uniform mediump isampler2D test2[4];\n"
+        "const vec2 ZERO = vec2(0.0, 0.0);\n"
+        "\n"
+        "bool check1D(isampler2D arr[4], int x, int y) {\n"
+        "    if (texture(arr[0], ZERO) != ivec4(x, y, 0, 0)+1) return false;\n"
+        "    if (texture(arr[1], ZERO) != ivec4(x, y, 1, 0)+1) return false;\n"
+        "    if (texture(arr[2], ZERO) != ivec4(x, y, 2, 0)+1) return false;\n"
+        "    if (texture(arr[3], ZERO) != ivec4(x, y, 3, 0)+1) return false;\n"
+        "    return true;\n"
+        "}\n"
+        "bool check2D(isampler2D arr[3][4], int x) {\n"
+        "    if (!check1D(arr[0], x, 0)) return false;\n"
+        "    if (!check1D(arr[1], x, 1)) return false;\n"
+        "    if (!check1D(arr[2], x, 2)) return false;\n"
+        "    return true;\n"
+        "}\n"
+        "bool check3D(isampler2D arr[2][3][4]) {\n"
+        "    if (!check2D(arr[0], 0)) return false;\n"
+        "    if (!check2D(arr[1], 1)) return false;\n"
+        "    return true;\n"
+        "}\n"
+        "void main() {\n"
+        "    bool passed = check3D(test) && check1D(test2, 7, 8);\n"
+        "    my_FragColor = passed ? vec4(0.0, 1.0, 0.0, 1.0) : vec4(1.0, 0.0, 0.0, 1.0);\n"
+        "}\n";
+
+    ANGLE_GL_PROGRAM(program, essl31_shaders::vs::Simple(), kFS);
+    glUseProgram(program.get());
+    GLTexture textures1[2][3][4];
+    GLTexture textures2[4];
+    for (int i = 0; i < 2; i++)
+    {
+        for (int j = 0; j < 3; j++)
+        {
+            for (int k = 0; k < 4; k++)
+            {
+                // First generate the texture
+                int textureUnit = k + 4 * (j + 3 * i);
+                glActiveTexture(GL_TEXTURE0 + textureUnit);
+                glBindTexture(GL_TEXTURE_2D, textures1[i][j][k]);
+                GLint texData[3] = {i + 1, j + 1, k + 1};
+                glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB32I, 1, 1, 0, GL_RGB_INTEGER, GL_INT,
+                             &texData[0]);
+                glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
+                glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
+                // Then send it as a uniform
+                std::stringstream uniformName;
+                uniformName << "test[" << i << "][" << j << "][" << k << "]";
+                GLint uniformLocation =
+                    glGetUniformLocation(program.get(), uniformName.str().c_str());
+                // All array indices should be used.
+                EXPECT_NE(uniformLocation, -1);
+                glUniform1i(uniformLocation, textureUnit);
+            }
+        }
+    }
+    for (int k = 0; k < 4; k++)
+    {
+        // First generate the texture
+        int textureUnit = 2 * 3 * 4 + k;
+        glActiveTexture(GL_TEXTURE0 + textureUnit);
+        glBindTexture(GL_TEXTURE_2D, textures2[k]);
+        GLint texData[3] = {7 + 1, 8 + 1, k + 1};
+        glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB32I, 1, 1, 0, GL_RGB_INTEGER, GL_INT, &texData[0]);
+        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
+        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
+        // Then send it as a uniform
+        std::stringstream uniformName;
+        uniformName << "test2[" << k << "]";
+        GLint uniformLocation = glGetUniformLocation(program.get(), uniformName.str().c_str());
+        // All array indices should be used.
+        EXPECT_NE(uniformLocation, -1);
+        glUniform1i(uniformLocation, textureUnit);
+    }
+    drawQuad(program.get(), essl31_shaders::PositionAttrib(), 0.5f);
+    EXPECT_PIXEL_COLOR_EQ(0, 0, GLColor::green);
+}
+
+// Test that names do not collide when translating arrays of arrays of samplers.
+TEST_P(GLSLTest_ES31, ArraysOfArraysNameCollisionSampler)
+{
+    ANGLE_SKIP_TEST_IF(IsVulkan());  // anglebug.com/3604 - rewriter can create name collisions
+    GLint numTextures;
+    glGetIntegerv(GL_MAX_TEXTURE_IMAGE_UNITS, &numTextures);
+    ANGLE_SKIP_TEST_IF(numTextures < 2 * 2 + 3 * 3 + 4 * 4);
+    // anglebug.com/3832 - no sampler array params on Android
+    ANGLE_SKIP_TEST_IF(IsAndroid() && IsOpenGLES());
+    constexpr char kFS[] =
+        "#version 310 es\n"
+        "precision mediump sampler2D;\n"
+        "precision mediump float;\n"
+        "uniform sampler2D test_field1_field2[2][2];\n"
+        "struct S1 { sampler2D field2[3][3]; }; uniform S1 test_field1;\n"
+        "struct S2 { sampler2D field1_field2[4][4]; }; uniform S2 test;\n"
+        "vec4 func1(sampler2D param_field1_field2[2][2],\n"
+        "           int param_field1_field2_offset,\n"
+        "           S1 param_field1,\n"
+        "           S2 param) {\n"
+        "    return vec4(0.0, 1.0, 0.0, 0.0);\n"
+        "}\n"
+        "out vec4 my_FragColor;\n"
+        "void main() {\n"
+        "    my_FragColor = vec4(0.0, 0.0, 0.0, 1.0);\n"
+        "    my_FragColor += func1(test_field1_field2, 0, test_field1, test);\n"
+        "    vec2 uv = vec2(0.0);\n"
+        "    my_FragColor += texture(test_field1_field2[0][0], uv) +\n"
+        "                    texture(test_field1.field2[0][0], uv) +\n"
+        "                    texture(test.field1_field2[0][0], uv);\n"
+        "}\n";
+    ANGLE_GL_PROGRAM(program, essl31_shaders::vs::Simple(), kFS);
+    glActiveTexture(GL_TEXTURE0);
+    GLTexture tex;
+    glBindTexture(GL_TEXTURE_2D, tex);
+    GLint zero = 0;
+    glTexImage2D(GL_TEXTURE_2D, 0, GL_RED, 1, 1, 0, GL_RED, GL_UNSIGNED_BYTE, &zero);
+    drawQuad(program.get(), essl31_shaders::PositionAttrib(), 0.5f);
+    EXPECT_PIXEL_COLOR_EQ(0, 0, GLColor::green);
+}
+
+// Test that regular arrays are unmodified.
+TEST_P(GLSLTest_ES31, BasicTypeArrayAndArrayOfSampler)
+{
+    // anglebug.com/2703 - QC doesn't support arrays of samplers as parameters,
+    // so sampler array of array handling is disabled
+    ANGLE_SKIP_TEST_IF(IsAndroid() && IsVulkan());
+
+    constexpr char kFS[] =
+        "#version 310 es\n"
+        "precision mediump sampler2D;\n"
+        "precision mediump float;\n"
+        "uniform sampler2D sampler_array[2][2];\n"
+        "uniform int array[3][2];\n"
+        "vec4 func1(int param[2],\n"
+        "           int param2[3]) {\n"
+        "    return vec4(0.0, 1.0, 0.0, 0.0);\n"
+        "}\n"
+        "out vec4 my_FragColor;\n"
+        "void main() {\n"
+        "    my_FragColor = texture(sampler_array[0][0], vec2(0.0));\n"
+        "    my_FragColor += func1(array[1], int[](1, 2, 3));\n"
+        "}\n";
+    ANGLE_GL_PROGRAM(program, essl31_shaders::vs::Simple(), kFS);
+    glActiveTexture(GL_TEXTURE0);
+    GLTexture tex;
+    glBindTexture(GL_TEXTURE_2D, tex);
+    GLint zero = 0;
+    glTexImage2D(GL_TEXTURE_2D, 0, GL_RED, 1, 1, 0, GL_RED, GL_UNSIGNED_BYTE, &zero);
+    drawQuad(program.get(), essl31_shaders::PositionAttrib(), 0.5f);
+    EXPECT_PIXEL_COLOR_EQ(0, 0, GLColor::green);
+}
+
 // This test covers a bug (and associated workaround) with nested sampling operations in the HLSL
 // compiler DLL.
 TEST_P(GLSLTest_ES3, NestedSamplingOperation)