Thread safety: completeing the implementation of shared/exclusive locks required attributes

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@139804 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Analysis/ThreadSafety.cpp b/lib/Analysis/ThreadSafety.cpp
index 504a24c..e8bae52 100644
--- a/lib/Analysis/ThreadSafety.cpp
+++ b/lib/Analysis/ThreadSafety.cpp
@@ -119,6 +119,10 @@
   iterator end() {
     return Blocks.rend();
   }
+
+  bool empty() {
+    return begin() == end();
+  }
 };
 
 /// \brief A MutexID object uniquely identifies a particular mutex, and
@@ -640,7 +644,7 @@
 }
 
 /// \brief Returns the location of the first Stmt in a Block.
-static SourceLocation getFirstStmtLocation(CFGBlock *Block) {
+static SourceLocation getFirstStmtLocation(const CFGBlock *Block) {
   SourceLocation Loc;
   for (CFGBlock::const_iterator BI = Block->begin(), BE = Block->end();
        BI != BE; ++BI) {
@@ -649,13 +653,26 @@
       if (Loc.isValid()) return Loc;
     }
   }
-  if (Stmt *S = Block->getTerminator().getStmt()) {
+  if (const Stmt *S = Block->getTerminator().getStmt()) {
     Loc = S->getLocStart();
     if (Loc.isValid()) return Loc;
   }
   return Loc;
 }
 
+static Lockset addLock(ThreadSafetyHandler &Handler,
+                       Lockset::Factory &LocksetFactory,
+                       Lockset &LSet, Expr *LockExp, LockKind LK,
+                       SourceLocation Loc) {
+  MutexID Mutex(LockExp, 0);
+  if (!Mutex.isValid()) {
+    Handler.handleInvalidLockExp(LockExp->getExprLoc());
+    return LSet;
+  }
+  LockData NewLock(Loc, LK);
+  return LocksetFactory.add(LSet, Mutex, NewLock);
+}
+
 namespace clang {
 namespace thread_safety {
 /// \brief Check a function's CFG for thread-safety violations.
@@ -684,6 +701,32 @@
   TopologicallySortedCFG SortedGraph(CFGraph);
   CFGBlockSet VisitedBlocks(CFGraph);
 
+  if (!SortedGraph.empty() && D->hasAttrs()) {
+    const CFGBlock *FirstBlock = *SortedGraph.begin();
+    Lockset &InitialLockset = EntryLocksets[FirstBlock->getBlockID()];
+    const AttrVec &ArgAttrs = D->getAttrs();
+    for(unsigned i = 0; i < ArgAttrs.size(); ++i) {
+      Attr *Attr = ArgAttrs[i];
+      if (SharedLocksRequiredAttr *SLRAttr
+            = dyn_cast<SharedLocksRequiredAttr>(Attr)) {
+        for (SharedLocksRequiredAttr::args_iterator
+            SLRIter = SLRAttr->args_begin(),
+            SLREnd = SLRAttr->args_end(); SLRIter != SLREnd; ++SLRIter)
+          InitialLockset = addLock(Handler, LocksetFactory, InitialLockset,
+                                   *SLRIter, LK_Shared,
+                                   getFirstStmtLocation(FirstBlock));
+      } else if (ExclusiveLocksRequiredAttr *ELRAttr
+                   = dyn_cast<ExclusiveLocksRequiredAttr>(Attr)) {
+        for (ExclusiveLocksRequiredAttr::args_iterator
+            ELRIter = ELRAttr->args_begin(),
+            ELREnd = ELRAttr->args_end(); ELRIter != ELREnd; ++ELRIter)
+          InitialLockset = addLock(Handler, LocksetFactory, InitialLockset,
+                                   *ELRIter, LK_Exclusive,
+                                   getFirstStmtLocation(FirstBlock));
+      }
+    }
+  }
+
   for (TopologicallySortedCFG::iterator I = SortedGraph.begin(),
        E = SortedGraph.end(); I!= E; ++I) {
     const CFGBlock *CurrBlock = *I;
diff --git a/test/SemaCXX/warn-thread-safety-analysis.cpp b/test/SemaCXX/warn-thread-safety-analysis.cpp
index cca2f3a..5107d5b 100644
--- a/test/SemaCXX/warn-thread-safety-analysis.cpp
+++ b/test/SemaCXX/warn-thread-safety-analysis.cpp
@@ -662,6 +662,16 @@
   Bar.aa_elr_fun_s();
 }
 
+void es_fun_9() __attribute__((shared_locks_required(aa_mu)));
+void es_fun_9() {
+  Bar.aa_elr_fun_s();
+}
+
+void es_fun_10() __attribute__((exclusive_locks_required(aa_mu)));
+void es_fun_10() {
+  Bar.aa_elr_fun_s();
+}
+
 void es_bad_0() {
   Bar.aa_elr_fun(); // \
     // expected-warning {{calling function 'aa_elr_fun' requires exclusive lock on 'aa_mu'}}