Make RecursiveASTVisitor fully data recursive on Stmts, with
minimal disruption on its clients.

Unlike the previous data-recursive scheme, Traverse*Stmt methods are
always getting called. The base methods of RecursiveASTVisitor will enqueue
the sub-statements instead of calling TraverseStmt on them.

Clients that override a Traverse*Stmt method and call TraverseStmt will
still function as function-recursive traversal; if a client wants to
enqueue a sub-statement in its override method it can do it like this:

[inside the override method]
StmtQueueAction StmtQueue(*this);
StmtQueue.queue(Stmt->getSubStmt());

Should address rdar://11179167.

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@156141 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/include/clang/AST/RecursiveASTVisitor.h b/include/clang/AST/RecursiveASTVisitor.h
index a9a98d7..bd234bb 100644
--- a/include/clang/AST/RecursiveASTVisitor.h
+++ b/include/clang/AST/RecursiveASTVisitor.h
@@ -114,6 +114,9 @@
 /// node are grouped together.  In other words, Visit*() methods for
 /// different nodes are never interleaved.
 ///
+/// Stmts are traversed internally using a data queue to avoid a stack overflow
+/// with hugely nested ASTs.
+///
 /// Clients of this visitor should subclass the visitor (providing
 /// themselves as the template argument, using the curiously recurring
 /// template pattern) and override any of the Traverse*, WalkUpFrom*,
@@ -148,13 +151,6 @@
   /// TypeLocs.
   bool shouldWalkTypesOfTypeLocs() const { return true; }
 
-  /// \brief Return whether \param S should be traversed using data recursion
-  /// to avoid a stack overflow with extreme cases.
-  bool shouldUseDataRecursionFor(Stmt *S) const {
-    return isa<BinaryOperator>(S) || isa<UnaryOperator>(S) ||
-           isa<CaseStmt>(S) || isa<CXXOperatorCallExpr>(S);
-  }
-
   /// \brief Recursively visit a statement or expression, by
   /// dispatching to Traverse*() based on the argument's dynamic type.
   ///
@@ -267,7 +263,8 @@
 #define OPERATOR(NAME)                                           \
   bool TraverseUnary##NAME(UnaryOperator *S) {                  \
     TRY_TO(WalkUpFromUnary##NAME(S));                           \
-    TRY_TO(TraverseStmt(S->getSubExpr()));                      \
+    StmtQueueAction StmtQueue(*this);                           \
+    StmtQueue.queue(S->getSubExpr());                           \
     return true;                                                \
   }                                                             \
   bool WalkUpFromUnary##NAME(UnaryOperator *S) {                \
@@ -286,8 +283,9 @@
 #define GENERAL_BINOP_FALLBACK(NAME, BINOP_TYPE)                \
   bool TraverseBin##NAME(BINOP_TYPE *S) {                       \
     TRY_TO(WalkUpFromBin##NAME(S));                             \
-    TRY_TO(TraverseStmt(S->getLHS()));                          \
-    TRY_TO(TraverseStmt(S->getRHS()));                          \
+    StmtQueueAction StmtQueue(*this);                           \
+    StmtQueue.queue(S->getLHS());                               \
+    StmtQueue.queue(S->getRHS());                               \
     return true;                                                \
   }                                                             \
   bool WalkUpFromBin##NAME(BINOP_TYPE *S) {                     \
@@ -405,110 +403,47 @@
   bool TraverseFunctionHelper(FunctionDecl *D);
   bool TraverseVarHelper(VarDecl *D);
 
-  struct EnqueueJob {
-    Stmt *S;
-    Stmt::child_iterator StmtIt;
+  typedef SmallVector<Stmt *, 16> StmtsTy;
+  typedef SmallVector<StmtsTy *, 4> QueuesTy;
+  
+  QueuesTy Queues;
 
-    EnqueueJob(Stmt *S) : S(S), StmtIt() {}
+  class NewQueueRAII {
+    RecursiveASTVisitor &RAV;
+  public:
+    NewQueueRAII(StmtsTy &queue, RecursiveASTVisitor &RAV) : RAV(RAV) {
+      RAV.Queues.push_back(&queue);
+    }
+    ~NewQueueRAII() {
+      RAV.Queues.pop_back();
+    }
   };
-  bool dataTraverse(Stmt *S);
-  bool dataTraverseNode(Stmt *S, bool &EnqueueChildren);
+
+  StmtsTy &getCurrentQueue() {
+    assert(!Queues.empty() && "base TraverseStmt was never called?");
+    return *Queues.back();
+  }
+
+public:
+  class StmtQueueAction {
+    StmtsTy &CurrQueue;
+    SmallVector<Stmt *, 8> Stmts;
+  public:
+    explicit StmtQueueAction(RecursiveASTVisitor &RAV)
+      : CurrQueue(RAV.getCurrentQueue()) { }
+
+    void queue(Stmt *S) {
+      Stmts.push_back(S);
+    }
+
+    ~StmtQueueAction() {
+      for (SmallVector<Stmt *, 8>::reverse_iterator
+             RI = Stmts.rbegin(), RE = Stmts.rend(); RI != RE; ++RI)
+        CurrQueue.push_back(*RI);
+    }
+  };
 };
 
-template<typename Derived>
-bool RecursiveASTVisitor<Derived>::dataTraverse(Stmt *S) {
-
-  SmallVector<EnqueueJob, 16> Queue;
-  Queue.push_back(S);
-
-  while (!Queue.empty()) {
-    EnqueueJob &job = Queue.back();
-    Stmt *CurrS = job.S;
-    if (!CurrS) {
-      Queue.pop_back();
-      continue;
-    }
-
-    if (getDerived().shouldUseDataRecursionFor(CurrS)) {
-      if (job.StmtIt == Stmt::child_iterator()) {
-        bool EnqueueChildren = true;
-        if (!dataTraverseNode(CurrS, EnqueueChildren)) return false;
-        if (!EnqueueChildren) {
-          Queue.pop_back();
-          continue;
-        }
-        job.StmtIt = CurrS->child_begin();
-      } else {
-        ++job.StmtIt;
-      }
-
-      if (job.StmtIt != CurrS->child_end())
-        Queue.push_back(*job.StmtIt);
-      else
-        Queue.pop_back();
-      continue;
-    }
-
-    Queue.pop_back();
-    TRY_TO(TraverseStmt(CurrS));
-  }
-
-  return true;
-}
-
-template<typename Derived>
-bool RecursiveASTVisitor<Derived>::dataTraverseNode(Stmt *S,
-                                                    bool &EnqueueChildren) {
-
-  // Dispatch to the corresponding WalkUpFrom* function only if the derived
-  // class didn't override Traverse* (and thus the traversal is trivial).
-  // The cast here is necessary to work around a bug in old versions of g++.
-#define DISPATCH_WALK(NAME, CLASS, VAR) \
-  if (&RecursiveASTVisitor::Traverse##NAME == \
-      (bool (RecursiveASTVisitor::*)(CLASS*))&Derived::Traverse##NAME) \
-    return getDerived().WalkUpFrom##NAME(static_cast<CLASS*>(VAR)); \
-  EnqueueChildren = false; \
-  return getDerived().Traverse##NAME(static_cast<CLASS*>(VAR));
-
-  if (BinaryOperator *BinOp = dyn_cast<BinaryOperator>(S)) {
-    switch (BinOp->getOpcode()) {
-#define OPERATOR(NAME) \
-    case BO_##NAME: DISPATCH_WALK(Bin##NAME, BinaryOperator, S);
-
-    BINOP_LIST()
-#undef OPERATOR
-
-#define OPERATOR(NAME)                                          \
-    case BO_##NAME##Assign:                          \
-    DISPATCH_WALK(Bin##NAME##Assign, CompoundAssignOperator, S);
-
-    CAO_LIST()
-#undef OPERATOR
-    }
-  } else if (UnaryOperator *UnOp = dyn_cast<UnaryOperator>(S)) {
-    switch (UnOp->getOpcode()) {
-#define OPERATOR(NAME)                                                  \
-    case UO_##NAME: DISPATCH_WALK(Unary##NAME, UnaryOperator, S);
-
-    UNARYOP_LIST()
-#undef OPERATOR
-    }
-  }
-
-  // Top switch stmt: dispatch to TraverseFooStmt for each concrete FooStmt.
-  switch (S->getStmtClass()) {
-  case Stmt::NoStmtClass: break;
-#define ABSTRACT_STMT(STMT)
-#define STMT(CLASS, PARENT) \
-  case Stmt::CLASS##Class: DISPATCH_WALK(CLASS, CLASS, S);
-#include "clang/AST/StmtNodes.inc"
-  }
-
-#undef DISPATCH_WALK
-
-  return true;
-}
-
 #define DISPATCH(NAME, CLASS, VAR) \
   return getDerived().Traverse##NAME(static_cast<CLASS*>(VAR))
 
@@ -517,47 +452,57 @@
   if (!S)
     return true;
 
-  if (getDerived().shouldUseDataRecursionFor(S))
-    return dataTraverse(S);
+  StmtsTy Queue;
+  Queue.push_back(S);
+  NewQueueRAII NQ(Queue, *this);
 
-  // If we have a binary expr, dispatch to the subcode of the binop.  A smart
-  // optimizer (e.g. LLVM) will fold this comparison into the switch stmt
-  // below.
-  if (BinaryOperator *BinOp = dyn_cast<BinaryOperator>(S)) {
-    switch (BinOp->getOpcode()) {
+  while (!Queue.empty()) {
+    S = Queue.pop_back_val();
+    if (!S)
+      continue;
+
+#define DISPATCH_STMT(NAME, CLASS, VAR) \
+    TRY_TO(Traverse##NAME(static_cast<CLASS*>(VAR))); continue
+
+    // If we have a binary expr, dispatch to the subcode of the binop.  A smart
+    // optimizer (e.g. LLVM) will fold this comparison into the switch stmt
+    // below.
+    if (BinaryOperator *BinOp = dyn_cast<BinaryOperator>(S)) {
+      switch (BinOp->getOpcode()) {
 #define OPERATOR(NAME) \
-    case BO_##NAME: DISPATCH(Bin##NAME, BinaryOperator, S);
-
-    BINOP_LIST()
+      case BO_##NAME: DISPATCH_STMT(Bin##NAME, BinaryOperator, S);
+  
+      BINOP_LIST()
 #undef OPERATOR
 #undef BINOP_LIST
-
+  
 #define OPERATOR(NAME)                                          \
-    case BO_##NAME##Assign:                          \
-      DISPATCH(Bin##NAME##Assign, CompoundAssignOperator, S);
-
-    CAO_LIST()
+      case BO_##NAME##Assign:                          \
+        DISPATCH_STMT(Bin##NAME##Assign, CompoundAssignOperator, S);
+  
+      CAO_LIST()
 #undef OPERATOR
 #undef CAO_LIST
-    }
-  } else if (UnaryOperator *UnOp = dyn_cast<UnaryOperator>(S)) {
-    switch (UnOp->getOpcode()) {
+      }
+    } else if (UnaryOperator *UnOp = dyn_cast<UnaryOperator>(S)) {
+      switch (UnOp->getOpcode()) {
 #define OPERATOR(NAME)                                                  \
-    case UO_##NAME: DISPATCH(Unary##NAME, UnaryOperator, S);
-
-    UNARYOP_LIST()
+      case UO_##NAME: DISPATCH_STMT(Unary##NAME, UnaryOperator, S);
+  
+      UNARYOP_LIST()
 #undef OPERATOR
 #undef UNARYOP_LIST
+      }
     }
-  }
-
-  // Top switch stmt: dispatch to TraverseFooStmt for each concrete FooStmt.
-  switch (S->getStmtClass()) {
-  case Stmt::NoStmtClass: break;
+  
+    // Top switch stmt: dispatch to TraverseFooStmt for each concrete FooStmt.
+    switch (S->getStmtClass()) {
+    case Stmt::NoStmtClass: break;
 #define ABSTRACT_STMT(STMT)
 #define STMT(CLASS, PARENT) \
-  case Stmt::CLASS##Class: DISPATCH(CLASS, CLASS, S);
+    case Stmt::CLASS##Class: DISPATCH_STMT(CLASS, CLASS, S);
 #include "clang/AST/StmtNodes.inc"
+    }
   }
 
   return true;
@@ -1797,23 +1742,24 @@
 template<typename Derived>                                              \
 bool RecursiveASTVisitor<Derived>::Traverse##STMT (STMT *S) {           \
   TRY_TO(WalkUpFrom##STMT(S));                                          \
+  StmtQueueAction StmtQueue(*this);                                     \
   { CODE; }                                                             \
   for (Stmt::child_range range = S->children(); range; ++range) {       \
-    TRY_TO(TraverseStmt(*range));                                       \
+    StmtQueue.queue(*range);                                            \
   }                                                                     \
   return true;                                                          \
 }
 
 DEF_TRAVERSE_STMT(AsmStmt, {
-    TRY_TO(TraverseStmt(S->getAsmString()));
+    StmtQueue.queue(S->getAsmString());
     for (unsigned I = 0, E = S->getNumInputs(); I < E; ++I) {
-      TRY_TO(TraverseStmt(S->getInputConstraintLiteral(I)));
+      StmtQueue.queue(S->getInputConstraintLiteral(I));
     }
     for (unsigned I = 0, E = S->getNumOutputs(); I < E; ++I) {
-      TRY_TO(TraverseStmt(S->getOutputConstraintLiteral(I)));
+      StmtQueue.queue(S->getOutputConstraintLiteral(I));
     }
     for (unsigned I = 0, E = S->getNumClobbers(); I < E; ++I) {
-      TRY_TO(TraverseStmt(S->getClobber(I)));
+      StmtQueue.queue(S->getClobber(I));
     }
     // children() iterates over inputExpr and outputExpr.
   })
@@ -1942,9 +1888,10 @@
   if (InitListExpr *Syn = S->getSyntacticForm())
     S = Syn;
   TRY_TO(WalkUpFromInitListExpr(S));
+  StmtQueueAction StmtQueue(*this);
   // All we need are the default actions.  FIXME: use a helper function.
   for (Stmt::child_range range = S->children(); range; ++range) {
-    TRY_TO(TraverseStmt(*range));
+    StmtQueue.queue(*range);
   }
   return true;
 }
@@ -1956,11 +1903,12 @@
 bool RecursiveASTVisitor<Derived>::
 TraverseGenericSelectionExpr(GenericSelectionExpr *S) {
   TRY_TO(WalkUpFromGenericSelectionExpr(S));
-  TRY_TO(TraverseStmt(S->getControllingExpr()));
+  StmtQueueAction StmtQueue(*this);
+  StmtQueue.queue(S->getControllingExpr());
   for (unsigned i = 0; i != S->getNumAssocs(); ++i) {
     if (TypeSourceInfo *TS = S->getAssocTypeSourceInfo(i))
       TRY_TO(TraverseTypeLoc(TS->getTypeLoc()));
-    TRY_TO(TraverseStmt(S->getAssocExpr(i)));
+    StmtQueue.queue(S->getAssocExpr(i));
   }
   return true;
 }
@@ -1971,13 +1919,14 @@
 bool RecursiveASTVisitor<Derived>::
 TraversePseudoObjectExpr(PseudoObjectExpr *S) {
   TRY_TO(WalkUpFromPseudoObjectExpr(S));
-  TRY_TO(TraverseStmt(S->getSyntacticForm()));
+  StmtQueueAction StmtQueue(*this);
+  StmtQueue.queue(S->getSyntacticForm());
   for (PseudoObjectExpr::semantics_iterator
          i = S->semantics_begin(), e = S->semantics_end(); i != e; ++i) {
     Expr *sub = *i;
     if (OpaqueValueExpr *OVE = dyn_cast<OpaqueValueExpr>(sub))
       sub = OVE->getSourceExpr();
-    TRY_TO(TraverseStmt(sub));
+    StmtQueue.queue(sub);
   }
   return true;
 }
@@ -2041,7 +1990,7 @@
   })
 
 DEF_TRAVERSE_STMT(ExpressionTraitExpr, {
-    TRY_TO(TraverseStmt(S->getQueriedExpression()));
+    StmtQueue.queue(S->getQueriedExpression());
   })
 
 DEF_TRAVERSE_STMT(VAArgExpr, {
@@ -2081,7 +2030,8 @@
     }
   }
 
-  TRY_TO(TraverseStmt(S->getBody()));
+  StmtQueueAction StmtQueue(*this);
+  StmtQueue.queue(S->getBody());
   return true;
 }