Guard traversers used during parsing against stack overflow

Traversers used during parsing can be vulnerable to stack overflow
since the AST has not yet been validated for max depth. Make sure to
check for traversal depth in traversers used during parsing.

We set the maximum traversal depth in ValidateGlobalInitializer and
ValidateSwitchStatementList to 256, which matches the default value
for validating general AST complexity. The depth check is on
regardless of compiler options. In case the traversers go over the
maximum traversal depth, they fail validation.

BUG=angleproject:2453
TEST=angle_unittests

Change-Id: I89ba576e8ef69663ba35d7b9050a6da319f1757c
Reviewed-on: https://chromium-review.googlesource.com/995795
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
diff --git a/src/compiler/translator/IsASTDepthBelowLimit.cpp b/src/compiler/translator/IsASTDepthBelowLimit.cpp
index 756c194..73cb9d1 100644
--- a/src/compiler/translator/IsASTDepthBelowLimit.cpp
+++ b/src/compiler/translator/IsASTDepthBelowLimit.cpp
@@ -18,24 +18,10 @@
 class MaxDepthTraverser : public TIntermTraverser
 {
   public:
-    MaxDepthTraverser(int depthLimit) : TIntermTraverser(true, true, false), mDepthLimit(depthLimit)
+    MaxDepthTraverser(int depthLimit) : TIntermTraverser(true, false, false, nullptr)
     {
+        setMaxAllowedDepth(depthLimit);
     }
-
-    bool visitBinary(Visit, TIntermBinary *) override { return depthCheck(); }
-    bool visitUnary(Visit, TIntermUnary *) override { return depthCheck(); }
-    bool visitTernary(Visit, TIntermTernary *) override { return depthCheck(); }
-    bool visitSwizzle(Visit, TIntermSwizzle *) override { return depthCheck(); }
-    bool visitIfElse(Visit, TIntermIfElse *) override { return depthCheck(); }
-    bool visitAggregate(Visit, TIntermAggregate *) override { return depthCheck(); }
-    bool visitBlock(Visit, TIntermBlock *) override { return depthCheck(); }
-    bool visitLoop(Visit, TIntermLoop *) override { return depthCheck(); }
-    bool visitBranch(Visit, TIntermBranch *) override { return depthCheck(); }
-
-  protected:
-    bool depthCheck() const { return mMaxDepth < mDepthLimit; }
-
-    int mDepthLimit;
 };
 
 }  // anonymous namespace
diff --git a/src/compiler/translator/OutputGLSLBase.cpp b/src/compiler/translator/OutputGLSLBase.cpp
index e780f40..553d6c1 100644
--- a/src/compiler/translator/OutputGLSLBase.cpp
+++ b/src/compiler/translator/OutputGLSLBase.cpp
@@ -846,7 +846,7 @@
 {
     TInfoSinkBase &out = objSink();
     // Scope the blocks except when at the global scope.
-    if (mDepth > 0)
+    if (getCurrentTraversalDepth() > 0)
     {
         out << "{\n";
     }
@@ -863,7 +863,7 @@
     }
 
     // Scope the blocks except when at the global scope.
-    if (mDepth > 0)
+    if (getCurrentTraversalDepth() > 0)
     {
         out << "}\n";
     }
diff --git a/src/compiler/translator/OutputTree.cpp b/src/compiler/translator/OutputTree.cpp
index 8d2c127..adcd242 100644
--- a/src/compiler/translator/OutputTree.cpp
+++ b/src/compiler/translator/OutputTree.cpp
@@ -31,7 +31,10 @@
 class TOutputTraverser : public TIntermTraverser
 {
   public:
-    TOutputTraverser(TInfoSinkBase &out) : TIntermTraverser(true, false, false), mOut(out) {}
+    TOutputTraverser(TInfoSinkBase &out)
+        : TIntermTraverser(true, false, false), mOut(out), mIndentDepth(0)
+    {
+    }
 
   protected:
     void visitSymbol(TIntermSymbol *) override;
@@ -52,7 +55,10 @@
     bool visitLoop(Visit visit, TIntermLoop *) override;
     bool visitBranch(Visit visit, TIntermBranch *) override;
 
+    int getCurrentIndentDepth() const { return mIndentDepth + getCurrentTraversalDepth(); }
+
     TInfoSinkBase &mOut;
+    int mIndentDepth;
 };
 
 //
@@ -79,7 +85,7 @@
 
 void TOutputTraverser::visitSymbol(TIntermSymbol *node)
 {
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
 
     if (node->variable().symbolType() == SymbolType::Empty)
     {
@@ -96,7 +102,7 @@
 
 bool TOutputTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node)
 {
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
     mOut << "vector swizzle (";
     node->writeOffsetsAsXYZW(&mOut);
     mOut << ")";
@@ -108,7 +114,7 @@
 
 bool TOutputTraverser::visitBinary(Visit visit, TIntermBinary *node)
 {
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
 
     switch (node->getOp())
     {
@@ -270,7 +276,7 @@
         TIntermConstantUnion *intermConstantUnion = node->getRight()->getAsConstantUnion();
         ASSERT(intermConstantUnion);
 
-        OutputTreeText(mOut, intermConstantUnion, mDepth + 1);
+        OutputTreeText(mOut, intermConstantUnion, getCurrentIndentDepth() + 1);
 
         // The following code finds the field name from the constant union
         const TConstantUnion *constantUnion   = intermConstantUnion->getConstantValue();
@@ -294,7 +300,7 @@
 
 bool TOutputTraverser::visitUnary(Visit visit, TIntermUnary *node)
 {
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
 
     switch (node->getOp())
     {
@@ -348,22 +354,21 @@
 
 bool TOutputTraverser::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
 {
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
     mOut << "Function Definition:\n";
-    mOut << "\n";
     return true;
 }
 
 bool TOutputTraverser::visitInvariantDeclaration(Visit visit, TIntermInvariantDeclaration *node)
 {
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
     mOut << "Invariant Declaration:\n";
     return true;
 }
 
 void TOutputTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
 {
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
     OutputFunction(mOut, "Function Prototype", node->getFunction());
     mOut << " (" << node->getCompleteString() << ")";
     mOut << "\n";
@@ -371,7 +376,7 @@
     for (size_t i = 0; i < paramCount; ++i)
     {
         const TVariable *param = node->getFunction()->getParam(i);
-        OutputTreeText(mOut, node, mDepth + 1);
+        OutputTreeText(mOut, node, getCurrentIndentDepth() + 1);
         mOut << "parameter: " << param->name() << " (" << param->getType().getCompleteString()
              << ")";
     }
@@ -379,7 +384,7 @@
 
 bool TOutputTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
 {
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
 
     if (node->getOp() == EOpNull)
     {
@@ -451,7 +456,7 @@
 
 bool TOutputTraverser::visitBlock(Visit visit, TIntermBlock *node)
 {
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
     mOut << "Code block\n";
 
     return true;
@@ -459,7 +464,7 @@
 
 bool TOutputTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
 {
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
     mOut << "Declaration\n";
 
     return true;
@@ -467,18 +472,18 @@
 
 bool TOutputTraverser::visitTernary(Visit visit, TIntermTernary *node)
 {
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
 
     mOut << "Ternary selection";
     mOut << " (" << node->getCompleteString() << ")\n";
 
-    ++mDepth;
+    ++mIndentDepth;
 
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
     mOut << "Condition\n";
     node->getCondition()->traverse(this);
 
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
     if (node->getTrueExpression())
     {
         mOut << "true case\n";
@@ -486,29 +491,29 @@
     }
     if (node->getFalseExpression())
     {
-        OutputTreeText(mOut, node, mDepth);
+        OutputTreeText(mOut, node, getCurrentIndentDepth());
         mOut << "false case\n";
         node->getFalseExpression()->traverse(this);
     }
 
-    --mDepth;
+    --mIndentDepth;
 
     return false;
 }
 
 bool TOutputTraverser::visitIfElse(Visit visit, TIntermIfElse *node)
 {
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
 
     mOut << "If test\n";
 
-    ++mDepth;
+    ++mIndentDepth;
 
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
     mOut << "Condition\n";
     node->getCondition()->traverse(this);
 
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
     if (node->getTrueBlock())
     {
         mOut << "true case\n";
@@ -521,19 +526,19 @@
 
     if (node->getFalseBlock())
     {
-        OutputTreeText(mOut, node, mDepth);
+        OutputTreeText(mOut, node, getCurrentIndentDepth());
         mOut << "false case\n";
         node->getFalseBlock()->traverse(this);
     }
 
-    --mDepth;
+    --mIndentDepth;
 
     return false;
 }
 
 bool TOutputTraverser::visitSwitch(Visit visit, TIntermSwitch *node)
 {
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
 
     mOut << "Switch\n";
 
@@ -542,7 +547,7 @@
 
 bool TOutputTraverser::visitCase(Visit visit, TIntermCase *node)
 {
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
 
     if (node->getCondition() == nullptr)
     {
@@ -562,7 +567,7 @@
 
     for (size_t i = 0; i < size; i++)
     {
-        OutputTreeText(mOut, node, mDepth);
+        OutputTreeText(mOut, node, getCurrentIndentDepth());
         switch (node->getConstantValue()[i].getType())
         {
             case EbtBool:
@@ -603,16 +608,16 @@
 
 bool TOutputTraverser::visitLoop(Visit visit, TIntermLoop *node)
 {
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
 
     mOut << "Loop with condition ";
     if (node->getType() == ELoopDoWhile)
         mOut << "not ";
     mOut << "tested first\n";
 
-    ++mDepth;
+    ++mIndentDepth;
 
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
     if (node->getCondition())
     {
         mOut << "Loop Condition\n";
@@ -623,7 +628,7 @@
         mOut << "No loop condition\n";
     }
 
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
     if (node->getBody())
     {
         mOut << "Loop Body\n";
@@ -636,19 +641,19 @@
 
     if (node->getExpression())
     {
-        OutputTreeText(mOut, node, mDepth);
+        OutputTreeText(mOut, node, getCurrentIndentDepth());
         mOut << "Loop Terminal Expression\n";
         node->getExpression()->traverse(this);
     }
 
-    --mDepth;
+    --mIndentDepth;
 
     return false;
 }
 
 bool TOutputTraverser::visitBranch(Visit visit, TIntermBranch *node)
 {
-    OutputTreeText(mOut, node, mDepth);
+    OutputTreeText(mOut, node, getCurrentIndentDepth());
 
     switch (node->getFlowOp())
     {
@@ -672,9 +677,9 @@
     if (node->getExpression())
     {
         mOut << " with expression\n";
-        ++mDepth;
+        ++mIndentDepth;
         node->getExpression()->traverse(this);
-        --mDepth;
+        --mIndentDepth;
     }
     else
     {
diff --git a/src/compiler/translator/ValidateGlobalInitializer.cpp b/src/compiler/translator/ValidateGlobalInitializer.cpp
index 6d79bbf..4ffb832 100644
--- a/src/compiler/translator/ValidateGlobalInitializer.cpp
+++ b/src/compiler/translator/ValidateGlobalInitializer.cpp
@@ -14,6 +14,8 @@
 namespace
 {
 
+const int kMaxAllowedTraversalDepth = 256;
+
 class ValidateGlobalInitializerTraverser : public TIntermTraverser
 {
   public:
@@ -25,7 +27,7 @@
     bool visitBinary(Visit visit, TIntermBinary *node) override;
     bool visitUnary(Visit visit, TIntermUnary *node) override;
 
-    bool isValid() const { return mIsValid; }
+    bool isValid() const { return mIsValid && mMaxDepth < mMaxAllowedDepth; }
     bool issueWarning() const { return mIssueWarning; }
 
   private:
@@ -117,11 +119,12 @@
 }
 
 ValidateGlobalInitializerTraverser::ValidateGlobalInitializerTraverser(int shaderVersion)
-    : TIntermTraverser(true, false, false),
+    : TIntermTraverser(true, false, false, nullptr),
       mShaderVersion(shaderVersion),
       mIsValid(true),
       mIssueWarning(false)
 {
+    setMaxAllowedDepth(kMaxAllowedTraversalDepth);
 }
 
 }  // namespace
diff --git a/src/compiler/translator/ValidateSwitch.cpp b/src/compiler/translator/ValidateSwitch.cpp
index 8b6fa72..0dbecc9 100644
--- a/src/compiler/translator/ValidateSwitch.cpp
+++ b/src/compiler/translator/ValidateSwitch.cpp
@@ -15,6 +15,8 @@
 namespace
 {
 
+const int kMaxAllowedTraversalDepth = 256;
+
 class ValidateSwitch : public TIntermTraverser
 {
   public:
@@ -69,7 +71,7 @@
 }
 
 ValidateSwitch::ValidateSwitch(TBasicType switchType, TDiagnostics *diagnostics)
-    : TIntermTraverser(true, false, true),
+    : TIntermTraverser(true, false, true, nullptr),
       mSwitchType(switchType),
       mDiagnostics(diagnostics),
       mCaseTypeMismatch(false),
@@ -81,6 +83,7 @@
       mDefaultCount(0),
       mDuplicateCases(false)
 {
+    setMaxAllowedDepth(kMaxAllowedTraversalDepth);
 }
 
 void ValidateSwitch::visitSymbol(TIntermSymbol *)
@@ -290,8 +293,13 @@
             loc, "no statement between the last label and the end of the switch statement",
             "switch");
     }
+    if (getMaxDepth() >= kMaxAllowedTraversalDepth)
+    {
+        mDiagnostics->error(loc, "too complex expressions inside a switch statement", "switch");
+    }
     return !mStatementBeforeCase && !mLastStatementWasCase && !mCaseInsideControlFlow &&
-           !mCaseTypeMismatch && mDefaultCount <= 1 && !mDuplicateCases;
+           !mCaseTypeMismatch && mDefaultCount <= 1 && !mDuplicateCases &&
+           getMaxDepth() < kMaxAllowedTraversalDepth;
 }
 
 }  // anonymous namespace
diff --git a/src/compiler/translator/tree_util/IntermTraverse.cpp b/src/compiler/translator/tree_util/IntermTraverse.cpp
index 6753f6e..2cdc405 100644
--- a/src/compiler/translator/tree_util/IntermTraverse.cpp
+++ b/src/compiler/translator/tree_util/IntermTraverse.cpp
@@ -110,8 +110,8 @@
     : preVisit(preVisit),
       inVisit(inVisit),
       postVisit(postVisit),
-      mDepth(-1),
       mMaxDepth(0),
+      mMaxAllowedDepth(std::numeric_limits<int>::max()),
       mInGlobalScope(true),
       mSymbolTable(symbolTable)
 {
@@ -121,6 +121,11 @@
 {
 }
 
+void TIntermTraverser::setMaxAllowedDepth(int depth)
+{
+    mMaxAllowedDepth = depth;
+}
+
 const TIntermBlock *TIntermTraverser::getParentBlock() const
 {
     if (!mParentBlockStack.empty())
@@ -215,6 +220,8 @@
 void TIntermTraverser::traverseSwizzle(TIntermSwizzle *node)
 {
     ScopedNodeInTraversalPath addToPath(this, node);
+    if (!addToPath.isWithinDepthLimit())
+        return;
 
     bool visit = true;
 
@@ -236,6 +243,8 @@
 void TIntermTraverser::traverseBinary(TIntermBinary *node)
 {
     ScopedNodeInTraversalPath addToPath(this, node);
+    if (!addToPath.isWithinDepthLimit())
+        return;
 
     bool visit = true;
 
@@ -271,6 +280,8 @@
 void TLValueTrackingTraverser::traverseBinary(TIntermBinary *node)
 {
     ScopedNodeInTraversalPath addToPath(this, node);
+    if (!addToPath.isWithinDepthLimit())
+        return;
 
     bool visit = true;
 
@@ -335,6 +346,8 @@
 void TIntermTraverser::traverseUnary(TIntermUnary *node)
 {
     ScopedNodeInTraversalPath addToPath(this, node);
+    if (!addToPath.isWithinDepthLimit())
+        return;
 
     bool visit = true;
 
@@ -353,6 +366,8 @@
 void TLValueTrackingTraverser::traverseUnary(TIntermUnary *node)
 {
     ScopedNodeInTraversalPath addToPath(this, node);
+    if (!addToPath.isWithinDepthLimit())
+        return;
 
     bool visit = true;
 
@@ -387,6 +402,8 @@
 void TIntermTraverser::traverseFunctionDefinition(TIntermFunctionDefinition *node)
 {
     ScopedNodeInTraversalPath addToPath(this, node);
+    if (!addToPath.isWithinDepthLimit())
+        return;
 
     bool visit = true;
 
@@ -413,6 +430,9 @@
 void TIntermTraverser::traverseBlock(TIntermBlock *node)
 {
     ScopedNodeInTraversalPath addToPath(this, node);
+    if (!addToPath.isWithinDepthLimit())
+        return;
+
     pushParentBlock(node);
 
     bool visit = true;
@@ -446,6 +466,8 @@
 void TIntermTraverser::traverseInvariantDeclaration(TIntermInvariantDeclaration *node)
 {
     ScopedNodeInTraversalPath addToPath(this, node);
+    if (!addToPath.isWithinDepthLimit())
+        return;
 
     bool visit = true;
 
@@ -468,6 +490,8 @@
 void TIntermTraverser::traverseDeclaration(TIntermDeclaration *node)
 {
     ScopedNodeInTraversalPath addToPath(this, node);
+    if (!addToPath.isWithinDepthLimit())
+        return;
 
     bool visit = true;
 
@@ -496,6 +520,7 @@
 void TIntermTraverser::traverseFunctionPrototype(TIntermFunctionPrototype *node)
 {
     ScopedNodeInTraversalPath addToPath(this, node);
+
     visitFunctionPrototype(node);
 }
 
@@ -503,6 +528,8 @@
 void TIntermTraverser::traverseAggregate(TIntermAggregate *node)
 {
     ScopedNodeInTraversalPath addToPath(this, node);
+    if (!addToPath.isWithinDepthLimit())
+        return;
 
     bool visit = true;
 
@@ -633,6 +660,8 @@
 void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node)
 {
     ScopedNodeInTraversalPath addToPath(this, node);
+    if (!addToPath.isWithinDepthLimit())
+        return;
 
     bool visit = true;
 
@@ -680,6 +709,8 @@
 void TIntermTraverser::traverseTernary(TIntermTernary *node)
 {
     ScopedNodeInTraversalPath addToPath(this, node);
+    if (!addToPath.isWithinDepthLimit())
+        return;
 
     bool visit = true;
 
@@ -703,6 +734,8 @@
 void TIntermTraverser::traverseIfElse(TIntermIfElse *node)
 {
     ScopedNodeInTraversalPath addToPath(this, node);
+    if (!addToPath.isWithinDepthLimit())
+        return;
 
     bool visit = true;
 
@@ -728,6 +761,8 @@
 void TIntermTraverser::traverseSwitch(TIntermSwitch *node)
 {
     ScopedNodeInTraversalPath addToPath(this, node);
+    if (!addToPath.isWithinDepthLimit())
+        return;
 
     bool visit = true;
 
@@ -753,6 +788,8 @@
 void TIntermTraverser::traverseCase(TIntermCase *node)
 {
     ScopedNodeInTraversalPath addToPath(this, node);
+    if (!addToPath.isWithinDepthLimit())
+        return;
 
     bool visit = true;
 
@@ -774,6 +811,8 @@
 void TIntermTraverser::traverseLoop(TIntermLoop *node)
 {
     ScopedNodeInTraversalPath addToPath(this, node);
+    if (!addToPath.isWithinDepthLimit())
+        return;
 
     bool visit = true;
 
@@ -805,6 +844,8 @@
 void TIntermTraverser::traverseBranch(TIntermBranch *node)
 {
     ScopedNodeInTraversalPath addToPath(this, node);
+    if (!addToPath.isWithinDepthLimit())
+        return;
 
     bool visit = true;
 
diff --git a/src/compiler/translator/tree_util/IntermTraverse.h b/src/compiler/translator/tree_util/IntermTraverse.h
index 93f7e69..468f6f1 100644
--- a/src/compiler/translator/tree_util/IntermTraverse.h
+++ b/src/compiler/translator/tree_util/IntermTraverse.h
@@ -98,21 +98,24 @@
     void updateTree();
 
   protected:
+    void setMaxAllowedDepth(int depth);
+
     // Should only be called from traverse*() functions
-    void incrementDepth(TIntermNode *current)
+    bool incrementDepth(TIntermNode *current)
     {
-        mDepth++;
-        mMaxDepth = std::max(mMaxDepth, mDepth);
+        mMaxDepth = std::max(mMaxDepth, static_cast<int>(mPath.size()));
         mPath.push_back(current);
+        return mMaxDepth < mMaxAllowedDepth;
     }
 
     // Should only be called from traverse*() functions
     void decrementDepth()
     {
-        mDepth--;
         mPath.pop_back();
     }
 
+    int getCurrentTraversalDepth() const { return static_cast<int>(mPath.size()) - 1; }
+
     // RAII helper for incrementDepth/decrementDepth
     class ScopedNodeInTraversalPath
     {
@@ -120,12 +123,15 @@
         ScopedNodeInTraversalPath(TIntermTraverser *traverser, TIntermNode *current)
             : mTraverser(traverser)
         {
-            mTraverser->incrementDepth(current);
+            mWithinDepthLimit = mTraverser->incrementDepth(current);
         }
         ~ScopedNodeInTraversalPath() { mTraverser->decrementDepth(); }
 
+        bool isWithinDepthLimit() { return mWithinDepthLimit; }
+
       private:
         TIntermTraverser *mTraverser;
+        bool mWithinDepthLimit;
     };
 
     TIntermNode *getParentNode() { return mPath.size() <= 1 ? nullptr : mPath[mPath.size() - 2u]; }
@@ -196,8 +202,8 @@
     const bool inVisit;
     const bool postVisit;
 
-    int mDepth;
     int mMaxDepth;
+    int mMaxAllowedDepth;
 
     bool mInGlobalScope;
 
diff --git a/src/tests/compiler_tests/ExpressionLimit_test.cpp b/src/tests/compiler_tests/ExpressionLimit_test.cpp
index 8aefe23..54ddcd3 100644
--- a/src/tests/compiler_tests/ExpressionLimit_test.cpp
+++ b/src/tests/compiler_tests/ExpressionLimit_test.cpp
@@ -10,10 +10,8 @@
 #include "gtest/gtest.h"
 #include "GLSLANG/ShaderLang.h"
 
-#define SHADER(Src) #Src
-
 class ExpressionLimitTest : public testing::Test {
-protected:
+  protected:
     static const int kMaxExpressionComplexity = 16;
     static const int kMaxCallStackDepth       = 16;
     static const int kMaxFunctionParameters   = 16;
@@ -21,6 +19,8 @@
     static const char* kCallStackTooDeep;
     static const char* kHasRecursion;
     static const char *kTooManyParameters;
+    static const char *kTooComplexSwitch;
+    static const char *kGlobalVariableInit;
 
     virtual void SetUp()
     {
@@ -51,22 +51,22 @@
         res->MaxFunctionParameters   = kMaxFunctionParameters;
     }
 
-    void GenerateLongExpression(int length, std::stringstream* ss)
+    static void GenerateLongExpression(int length, std::stringstream *ss)
     {
         for (int ii = 0; ii < length; ++ii) {
           *ss << "+ vec4(" << ii << ")";
         }
     }
 
-    std::string GenerateShaderWithLongExpression(int length)
+    static std::string GenerateShaderWithLongExpression(int length)
     {
-        static const char* shaderStart = SHADER(
-            precision mediump float;
+        static const char *shaderStart =
+            R"(precision mediump float;
             uniform vec4 u_color;
             void main()
             {
                gl_FragColor = u_color
-        );
+        )";
 
         std::stringstream ss;
         ss << shaderStart;
@@ -76,10 +76,10 @@
         return ss.str();
     }
 
-    std::string GenerateShaderWithUnusedLongExpression(int length)
+    static std::string GenerateShaderWithUnusedLongExpression(int length)
     {
-        static const char* shaderStart = SHADER(
-            precision mediump float;
+        static const char *shaderStart =
+            R"(precision mediump float;
             uniform vec4 u_color;
             void main()
             {
@@ -87,7 +87,7 @@
             }
             vec4 someFunction() {
               return u_color
-        );
+        )";
 
         std::stringstream ss;
 
@@ -98,15 +98,15 @@
         return ss.str();
     }
 
-    void GenerateDeepFunctionStack(int length, std::stringstream* ss)
+    static void GenerateDeepFunctionStack(int length, std::stringstream *ss)
     {
-        static const char* shaderStart = SHADER(
-            precision mediump float;
+        static const char *shaderStart =
+            R"(precision mediump float;
             uniform vec4 u_color;
             vec4 function0()  {
               return u_color;
             }
-        );
+        )";
 
         *ss << shaderStart;
         for (int ii = 0; ii < length; ++ii) {
@@ -116,7 +116,7 @@
         }
     }
 
-    std::string GenerateShaderWithDeepFunctionStack(int length)
+    static std::string GenerateShaderWithDeepFunctionStack(int length)
     {
         std::stringstream ss;
 
@@ -129,7 +129,7 @@
         return ss.str();
     }
 
-    std::string GenerateShaderWithUnusedDeepFunctionStack(int length)
+    static std::string GenerateShaderWithUnusedDeepFunctionStack(int length)
     {
         std::stringstream ss;
 
@@ -143,7 +143,7 @@
         return ss.str();
     }
 
-    std::string GenerateShaderWithFunctionParameters(int parameters)
+    static std::string GenerateShaderWithFunctionParameters(int parameters)
     {
         std::stringstream ss;
 
@@ -171,6 +171,50 @@
         return ss.str();
     }
 
+    static std::string GenerateShaderWithNestingInsideSwitch(int nesting)
+    {
+        std::stringstream shaderString;
+        shaderString <<
+            R"(#version 300 es
+            uniform int u;
+
+            void main()
+            {
+                int x;
+                switch (u)
+                {
+                    case 0:
+                        x = x)";
+        for (int i = 0; i < nesting; ++i)
+        {
+            shaderString << " + x";
+        }
+        shaderString <<
+            R"(;
+                }  // switch (u)
+            })";
+        return shaderString.str();
+    }
+
+    static std::string GenerateShaderWithNestingInsideGlobalInitializer(int nesting)
+    {
+        std::stringstream shaderString;
+        shaderString <<
+            R"(uniform int u;
+            int x = u)";
+
+        for (int i = 0; i < nesting; ++i)
+        {
+            shaderString << " + u";
+        }
+        shaderString << R"(;
+            void main()
+            {
+                gl_FragColor = vec4(0.0);
+            })";
+        return shaderString.str();
+    }
+
     // Compiles a shader and if there's an error checks for a specific
     // substring in the error log. This way we know the error is specific
     // to the issue we are testing.
@@ -206,6 +250,10 @@
     "Recursive function call in the following call chain";
 const char* ExpressionLimitTest::kTooManyParameters =
     "Function has too many parameters";
+const char *ExpressionLimitTest::kTooComplexSwitch =
+    "too complex expressions inside a switch statement";
+const char *ExpressionLimitTest::kGlobalVariableInit =
+    "global variable initializers must be constant expressions";
 
 TEST_F(ExpressionLimitTest, ExpressionComplexity)
 {
@@ -312,8 +360,8 @@
     ShHandle vertexCompiler = sh::ConstructCompiler(GL_FRAGMENT_SHADER, spec, output, &resources);
     ShCompileOptions compileOptions = 0;
 
-    static const char* shaderWithRecursion0 = SHADER(
-        precision mediump float;
+    static const char *shaderWithRecursion0 =
+        R"(precision mediump float;
         uniform vec4 u_color;
         vec4 someFunc()  {
             return someFunc();
@@ -322,10 +370,10 @@
         void main() {
             gl_FragColor = u_color * someFunc();
         }
-    );
+    )";
 
-    static const char* shaderWithRecursion1 = SHADER(
-        precision mediump float;
+    static const char *shaderWithRecursion1 =
+        R"(precision mediump float;
         uniform vec4 u_color;
 
         vec4 someFunc();
@@ -341,10 +389,10 @@
         void main() {
             gl_FragColor = u_color * someFunc();
         }
-    );
+    )";
 
-    static const char* shaderWithRecursion2 = SHADER(
-        precision mediump float;
+    static const char *shaderWithRecursion2 =
+        R"(precision mediump float;
         uniform vec4 u_color;
         vec4 someFunc()  {
             if (u_color.x > 0.5) {
@@ -357,10 +405,10 @@
         void main() {
             gl_FragColor = someFunc();
         }
-    );
+    )";
 
-    static const char* shaderWithRecursion3 = SHADER(
-        precision mediump float;
+    static const char *shaderWithRecursion3 =
+        R"(precision mediump float;
         uniform vec4 u_color;
         vec4 someFunc()  {
             if (u_color.x > 0.5) {
@@ -373,10 +421,10 @@
         void main() {
             gl_FragColor = someFunc();
         }
-    );
+    )";
 
-    static const char* shaderWithRecursion4 = SHADER(
-        precision mediump float;
+    static const char *shaderWithRecursion4 =
+        R"(precision mediump float;
         uniform vec4 u_color;
         vec4 someFunc()  {
             return (u_color.x > 0.5) ? vec4(1) : someFunc();
@@ -385,10 +433,10 @@
         void main() {
             gl_FragColor = someFunc();
         }
-    );
+    )";
 
-    static const char* shaderWithRecursion5 = SHADER(
-        precision mediump float;
+    static const char *shaderWithRecursion5 =
+        R"(precision mediump float;
         uniform vec4 u_color;
         vec4 someFunc()  {
             return (u_color.x > 0.5) ? someFunc() : vec4(1);
@@ -397,10 +445,10 @@
         void main() {
             gl_FragColor = someFunc();
         }
-    );
+    )";
 
-    static const char* shaderWithRecursion6 = SHADER(
-        precision mediump float;
+    static const char *shaderWithRecursion6 =
+        R"(precision mediump float;
         uniform vec4 u_color;
         vec4 someFunc()  {
             return someFunc();
@@ -409,10 +457,10 @@
         void main() {
             gl_FragColor = u_color;
         }
-    );
+    )";
 
-    static const char* shaderWithNoRecursion = SHADER(
-        precision mediump float;
+    static const char *shaderWithNoRecursion =
+        R"(precision mediump float;
         uniform vec4 u_color;
 
         vec3 rgb(int r, int g, int b) {
@@ -424,10 +472,10 @@
             vec3 faceColor2 = rgb(183, 148, 133);
             gl_FragColor = u_color + vec4(hairColor0 + faceColor2, 0);
         }
-    );
+    )";
 
-    static const char* shaderWithRecursion7 = SHADER(
-        precision mediump float;
+    static const char *shaderWithRecursion7 =
+        R"(precision mediump float;
         uniform vec4 u_color;
 
         vec4 function2() {
@@ -443,10 +491,10 @@
         void main() {
             gl_FragColor = function1();
         }
-    );
+    )";
 
-    static const char* shaderWithRecursion8 = SHADER(
-        precision mediump float;
+    static const char *shaderWithRecursion8 =
+        R"(precision mediump float;
         uniform vec4 u_color;
 
         vec4 function1();
@@ -466,7 +514,7 @@
         void main() {
             gl_FragColor = function1();
         }
-    );
+    )";
 
     // Check simple recursions fails.
     EXPECT_TRUE(CheckShaderCompilation(
@@ -539,3 +587,51 @@
         compileOptions & ~SH_LIMIT_EXPRESSION_COMPLEXITY, nullptr));
     sh::Destruct(compiler);
 }
+
+TEST_F(ExpressionLimitTest, NestingInsideSwitch)
+{
+    ShShaderSpec spec     = SH_WEBGL2_SPEC;
+    ShShaderOutput output = SH_ESSL_OUTPUT;
+    ShHandle compiler     = sh::ConstructCompiler(GL_FRAGMENT_SHADER, spec, output, &resources);
+    ShCompileOptions compileOptions = SH_LIMIT_EXPRESSION_COMPLEXITY;
+
+    // Test nesting over the limit fails.
+    EXPECT_TRUE(CheckShaderCompilation(
+        compiler, GenerateShaderWithNestingInsideSwitch(kMaxExpressionComplexity + 1).c_str(),
+        compileOptions, kExpressionTooComplex));
+    // Test nesting over the limit without limit does not fail.
+    EXPECT_TRUE(CheckShaderCompilation(
+        compiler, GenerateShaderWithNestingInsideSwitch(kMaxExpressionComplexity + 1).c_str(),
+        compileOptions & ~SH_LIMIT_EXPRESSION_COMPLEXITY, nullptr));
+    // Test that nesting way over the limit doesn't cause stack overflow but is handled
+    // gracefully.
+    EXPECT_TRUE(CheckShaderCompilation(compiler,
+                                       GenerateShaderWithNestingInsideSwitch(5000).c_str(),
+                                       compileOptions, kTooComplexSwitch));
+    sh::Destruct(compiler);
+}
+
+TEST_F(ExpressionLimitTest, NestingInsideGlobalInitializer)
+{
+    ShShaderSpec spec     = SH_WEBGL_SPEC;
+    ShShaderOutput output = SH_ESSL_OUTPUT;
+    ShHandle compiler     = sh::ConstructCompiler(GL_FRAGMENT_SHADER, spec, output, &resources);
+    ShCompileOptions compileOptions = SH_LIMIT_EXPRESSION_COMPLEXITY;
+
+    // Test nesting over the limit fails.
+    EXPECT_TRUE(CheckShaderCompilation(
+        compiler,
+        GenerateShaderWithNestingInsideGlobalInitializer(kMaxExpressionComplexity + 1).c_str(),
+        compileOptions, kExpressionTooComplex));
+    // Test nesting over the limit without limit does not fail.
+    EXPECT_TRUE(CheckShaderCompilation(
+        compiler,
+        GenerateShaderWithNestingInsideGlobalInitializer(kMaxExpressionComplexity + 1).c_str(),
+        compileOptions & ~SH_LIMIT_EXPRESSION_COMPLEXITY, nullptr));
+    // Test that nesting way over the limit doesn't cause stack overflow but is handled
+    // gracefully.
+    EXPECT_TRUE(CheckShaderCompilation(
+        compiler, GenerateShaderWithNestingInsideGlobalInitializer(5000).c_str(), compileOptions,
+        kGlobalVariableInit));
+    sh::Destruct(compiler);
+}