Make class loaders weak roots

Making the class loaders weak roots in the class linker prevents them
from keeping the classes as live. However we currently do mark them
as strong roots to make sure no accidental class unloading occurs
until the logic to free from linear alloc is complete.

Bug: 22720414

Change-Id: I57466236d9ce6fd064dda9a30ce8ab68094fb8b0
diff --git a/runtime/class_linker.cc b/runtime/class_linker.cc
index 5f2c944..73da2cb 100644
--- a/runtime/class_linker.cc
+++ b/runtime/class_linker.cc
@@ -1295,7 +1295,8 @@
 }
 
 void ClassLinker::VisitClassRoots(RootVisitor* visitor, VisitRootFlags flags) {
-  WriterMutexLock mu(Thread::Current(), *Locks::classlinker_classes_lock_);
+  Thread* const self = Thread::Current();
+  WriterMutexLock mu(self, *Locks::classlinker_classes_lock_);
   BufferedRootVisitor<kDefaultBufferedRootCount> buffered_visitor(
       visitor, RootInfo(kRootStickyClass));
   if ((flags & kVisitRootFlagAllRoots) != 0) {
@@ -1315,9 +1316,13 @@
     // Need to make sure to not copy ArtMethods without doing read barriers since the roots are
     // marked concurrently and we don't hold the classlinker_classes_lock_ when we do the copy.
     boot_class_table_.VisitRoots(buffered_visitor);
-    for (GcRoot<mirror::ClassLoader>& root : class_loaders_) {
-      // May be null for boot ClassLoader.
-      root.VisitRoot(visitor, RootInfo(kRootVMInternal));
+    // TODO: Avoid marking these to enable class unloading.
+    JavaVMExt* const vm = Runtime::Current()->GetJavaVM();
+    for (jweak weak_root : class_loaders_) {
+      mirror::Object* class_loader =
+          down_cast<mirror::ClassLoader*>(vm->DecodeWeakGlobal(self, weak_root));
+      // Don't need to update anything since the class loaders will be updated by SweepSystemWeaks.
+      visitor->VisitRootIfNonNull(&class_loader, RootInfo(kRootVMInternal));
     }
   } else if ((flags & kVisitRootFlagNewRoots) != 0) {
     for (auto& root : new_class_roots_) {
@@ -1353,14 +1358,31 @@
   }
 }
 
+class VisitClassLoaderClassesVisitor : public ClassLoaderVisitor {
+ public:
+  explicit VisitClassLoaderClassesVisitor(ClassVisitor* visitor)
+      : visitor_(visitor),
+        done_(false) {}
+
+  void Visit(mirror::ClassLoader* class_loader)
+      SHARED_REQUIRES(Locks::classlinker_classes_lock_, Locks::mutator_lock_) OVERRIDE {
+    ClassTable* const class_table = class_loader->GetClassTable();
+    if (!done_ && class_table != nullptr && !class_table->Visit(visitor_)) {
+      // If the visitor ClassTable returns false it means that we don't need to continue.
+      done_ = true;
+    }
+  }
+
+ private:
+  ClassVisitor* const visitor_;
+  // If done is true then we don't need to do any more visiting.
+  bool done_;
+};
+
 void ClassLinker::VisitClassesInternal(ClassVisitor* visitor) {
   if (boot_class_table_.Visit(visitor)) {
-    for (GcRoot<mirror::ClassLoader>& root : class_loaders_) {
-      ClassTable* const class_table = root.Read()->GetClassTable();
-      if (class_table != nullptr && !class_table->Visit(visitor)) {
-        return;
-      }
-    }
+    VisitClassLoaderClassesVisitor loader_visitor(visitor);
+    VisitClassLoaders(&loader_visitor);
   }
 }
 
@@ -1479,10 +1501,17 @@
   mirror::LongArray::ResetArrayClass();
   mirror::ShortArray::ResetArrayClass();
   STLDeleteElements(&oat_files_);
-  for (GcRoot<mirror::ClassLoader>& root : class_loaders_) {
-    ClassTable* const class_table = root.Read()->GetClassTable();
-    delete class_table;
+  Thread* const self = Thread::Current();
+  JavaVMExt* const vm = Runtime::Current()->GetJavaVM();
+  for (jweak weak_root : class_loaders_) {
+    auto* const class_loader = down_cast<mirror::ClassLoader*>(
+        vm->DecodeWeakGlobal(self, weak_root));
+    if (class_loader != nullptr) {
+      delete class_loader->GetClassTable();
+    }
+    vm->DeleteWeakGlobalRef(self, weak_root);
   }
+  class_loaders_.clear();
 }
 
 mirror::PointerArray* ClassLinker::AllocPointerArray(Thread* self, size_t length) {
@@ -2611,8 +2640,7 @@
                                                   bool allow_failure) {
   // Search assuming unique-ness of dex file.
   JavaVMExt* const vm = self->GetJniEnv()->vm;
-  for (jobject weak_root : dex_caches_) {
-    DCHECK_EQ(GetIndirectRefKind(weak_root), kWeakGlobal);
+  for (jweak weak_root : dex_caches_) {
     mirror::DexCache* dex_cache = down_cast<mirror::DexCache*>(
         vm->DecodeWeakGlobal(self, weak_root));
     if (dex_cache != nullptr && dex_cache->GetDexFile() == &dex_file) {
@@ -2985,15 +3013,25 @@
   dex_cache_image_class_lookup_required_ = false;
 }
 
-void ClassLinker::MoveClassTableToPreZygote() {
-  WriterMutexLock mu(Thread::Current(), *Locks::classlinker_classes_lock_);
-  boot_class_table_.FreezeSnapshot();
-  for (GcRoot<mirror::ClassLoader>& root : class_loaders_) {
-    ClassTable* const class_table = root.Read()->GetClassTable();
+class MoveClassTableToPreZygoteVisitor : public ClassLoaderVisitor {
+ public:
+  explicit MoveClassTableToPreZygoteVisitor() {}
+
+  void Visit(mirror::ClassLoader* class_loader)
+      REQUIRES(Locks::classlinker_classes_lock_)
+      SHARED_REQUIRES(Locks::mutator_lock_) OVERRIDE {
+    ClassTable* const class_table = class_loader->GetClassTable();
     if (class_table != nullptr) {
       class_table->FreezeSnapshot();
     }
   }
+};
+
+void ClassLinker::MoveClassTableToPreZygote() {
+  WriterMutexLock mu(Thread::Current(), *Locks::classlinker_classes_lock_);
+  boot_class_table_.FreezeSnapshot();
+  MoveClassTableToPreZygoteVisitor visitor;
+  VisitClassLoadersAndRemoveClearedLoaders(&visitor);
 }
 
 mirror::Class* ClassLinker::LookupClassFromImage(const char* descriptor) {
@@ -3019,25 +3057,43 @@
   return nullptr;
 }
 
+// Look up classes by hash and descriptor and put all matching ones in the result array.
+class LookupClassesVisitor : public ClassLoaderVisitor {
+ public:
+  LookupClassesVisitor(const char* descriptor, size_t hash, std::vector<mirror::Class*>* result)
+     : descriptor_(descriptor),
+       hash_(hash),
+       result_(result) {}
+
+  void Visit(mirror::ClassLoader* class_loader)
+      SHARED_REQUIRES(Locks::classlinker_classes_lock_, Locks::mutator_lock_) OVERRIDE {
+    ClassTable* const class_table = class_loader->GetClassTable();
+    mirror::Class* klass = class_table->Lookup(descriptor_, hash_);
+    if (klass != nullptr) {
+      result_->push_back(klass);
+    }
+  }
+
+ private:
+  const char* const descriptor_;
+  const size_t hash_;
+  std::vector<mirror::Class*>* const result_;
+};
+
 void ClassLinker::LookupClasses(const char* descriptor, std::vector<mirror::Class*>& result) {
   result.clear();
   if (dex_cache_image_class_lookup_required_) {
     MoveImageClassesToClassTable();
   }
-  WriterMutexLock mu(Thread::Current(), *Locks::classlinker_classes_lock_);
+  Thread* const self = Thread::Current();
+  ReaderMutexLock mu(self, *Locks::classlinker_classes_lock_);
   const size_t hash = ComputeModifiedUtf8Hash(descriptor);
   mirror::Class* klass = boot_class_table_.Lookup(descriptor, hash);
   if (klass != nullptr) {
     result.push_back(klass);
   }
-  for (GcRoot<mirror::ClassLoader>& root : class_loaders_) {
-    // There can only be one class with the same descriptor per class loader.
-    ClassTable* const class_table = root.Read()->GetClassTable();
-    klass = class_table->Lookup(descriptor, hash);
-    if (klass != nullptr) {
-      result.push_back(klass);
-    }
-  }
+  LookupClassesVisitor visitor(descriptor, hash, &result);
+  VisitClassLoaders(&visitor);
 }
 
 void ClassLinker::VerifyClass(Thread* self, Handle<mirror::Class> klass) {
@@ -4109,7 +4165,8 @@
   ClassTable* class_table = class_loader->GetClassTable();
   if (class_table == nullptr) {
     class_table = new ClassTable;
-    class_loaders_.push_back(class_loader);
+    Thread* const self = Thread::Current();
+    class_loaders_.push_back(self->GetJniEnv()->vm->AddWeakGlobalRef(self, class_loader));
     // Don't already have a class table, add it to the class loader.
     class_loader->SetClassTable(class_table);
   }
@@ -5875,26 +5932,33 @@
      << NumNonZygoteClasses() << "\n";
 }
 
-size_t ClassLinker::NumZygoteClasses() const {
-  size_t sum = boot_class_table_.NumZygoteClasses();
-  for (const GcRoot<mirror::ClassLoader>& root : class_loaders_) {
-    ClassTable* const class_table = root.Read()->GetClassTable();
+class CountClassesVisitor : public ClassLoaderVisitor {
+ public:
+  CountClassesVisitor() : num_zygote_classes(0), num_non_zygote_classes(0) {}
+
+  void Visit(mirror::ClassLoader* class_loader)
+      SHARED_REQUIRES(Locks::classlinker_classes_lock_, Locks::mutator_lock_) OVERRIDE {
+    ClassTable* const class_table = class_loader->GetClassTable();
     if (class_table != nullptr) {
-      sum += class_table->NumZygoteClasses();
+      num_zygote_classes += class_table->NumZygoteClasses();
+      num_non_zygote_classes += class_table->NumNonZygoteClasses();
     }
   }
-  return sum;
+
+  size_t num_zygote_classes;
+  size_t num_non_zygote_classes;
+};
+
+size_t ClassLinker::NumZygoteClasses() const {
+  CountClassesVisitor visitor;
+  VisitClassLoaders(&visitor);
+  return visitor.num_zygote_classes + boot_class_table_.NumZygoteClasses();
 }
 
 size_t ClassLinker::NumNonZygoteClasses() const {
-  size_t sum = boot_class_table_.NumNonZygoteClasses();
-  for (const GcRoot<mirror::ClassLoader>& root : class_loaders_) {
-    ClassTable* const class_table = root.Read()->GetClassTable();
-    if (class_table != nullptr) {
-      sum += class_table->NumNonZygoteClasses();
-    }
-  }
-  return sum;
+  CountClassesVisitor visitor;
+  VisitClassLoaders(&visitor);
+  return visitor.num_non_zygote_classes + boot_class_table_.NumNonZygoteClasses();
 }
 
 size_t ClassLinker::NumLoadedClasses() {
@@ -6107,4 +6171,35 @@
   find_array_class_cache_next_victim_ = 0;
 }
 
+void ClassLinker::VisitClassLoadersAndRemoveClearedLoaders(ClassLoaderVisitor* visitor) {
+  Thread* const self = Thread::Current();
+  Locks::classlinker_classes_lock_->AssertExclusiveHeld(self);
+  JavaVMExt* const vm = self->GetJniEnv()->vm;
+  for (auto it = class_loaders_.begin(); it != class_loaders_.end();) {
+    const jweak weak_root = *it;
+    mirror::ClassLoader* const class_loader = down_cast<mirror::ClassLoader*>(
+        vm->DecodeWeakGlobal(self, weak_root));
+    if (class_loader != nullptr) {
+      visitor->Visit(class_loader);
+      ++it;
+    } else {
+      // Remove the cleared weak reference from the array.
+      vm->DeleteWeakGlobalRef(self, weak_root);
+      it = class_loaders_.erase(it);
+    }
+  }
+}
+
+void ClassLinker::VisitClassLoaders(ClassLoaderVisitor* visitor) const {
+  Thread* const self = Thread::Current();
+  JavaVMExt* const vm = self->GetJniEnv()->vm;
+  for (jweak weak_root : class_loaders_) {
+    mirror::ClassLoader* const class_loader = down_cast<mirror::ClassLoader*>(
+        vm->DecodeWeakGlobal(self, weak_root));
+    if (class_loader != nullptr) {
+      visitor->Visit(class_loader);
+    }
+  }
+}
+
 }  // namespace art
diff --git a/runtime/class_linker.h b/runtime/class_linker.h
index 17aa48a..fee7066 100644
--- a/runtime/class_linker.h
+++ b/runtime/class_linker.h
@@ -59,6 +59,13 @@
 
 enum VisitRootFlags : uint8_t;
 
+class ClassLoaderVisitor {
+ public:
+  virtual ~ClassLoaderVisitor() {}
+  virtual void Visit(mirror::ClassLoader* class_loader)
+      SHARED_REQUIRES(Locks::classlinker_classes_lock_, Locks::mutator_lock_) = 0;
+};
+
 class ClassLinker {
  public:
   // Well known mirror::Class roots accessed via GetClassRoot.
@@ -540,8 +547,18 @@
   void DropFindArrayClassCache() SHARED_REQUIRES(Locks::mutator_lock_);
 
  private:
+  // The RemoveClearedLoaders version removes cleared weak global class loaders and frees their
+  // class tables. This version can only be called with reader access to the
+  // classlinker_classes_lock_ since it modifies the class_loaders_ list.
+  void VisitClassLoadersAndRemoveClearedLoaders(ClassLoaderVisitor* visitor)
+      REQUIRES(Locks::classlinker_classes_lock_)
+      SHARED_REQUIRES(Locks::mutator_lock_);
+  void VisitClassLoaders(ClassLoaderVisitor* visitor) const
+      SHARED_REQUIRES(Locks::classlinker_classes_lock_, Locks::mutator_lock_);
+
+
   void VisitClassesInternal(ClassVisitor* visitor)
-      REQUIRES(Locks::classlinker_classes_lock_) SHARED_REQUIRES(Locks::mutator_lock_);
+      SHARED_REQUIRES(Locks::classlinker_classes_lock_, Locks::mutator_lock_);
 
   // Returns the number of zygote and image classes.
   size_t NumZygoteClasses() const
@@ -726,7 +743,7 @@
   size_t GetDexCacheCount() SHARED_REQUIRES(Locks::mutator_lock_, dex_lock_) {
     return dex_caches_.size();
   }
-  const std::list<jobject>& GetDexCaches() SHARED_REQUIRES(Locks::mutator_lock_, dex_lock_) {
+  const std::list<jweak>& GetDexCaches() SHARED_REQUIRES(Locks::mutator_lock_, dex_lock_) {
     return dex_caches_;
   }
 
@@ -805,12 +822,12 @@
   mutable ReaderWriterMutex dex_lock_ DEFAULT_MUTEX_ACQUIRED_AFTER;
   // JNI weak globals to allow dex caches to get unloaded. We lazily delete weak globals when we
   // register new dex files.
-  std::list<jobject> dex_caches_ GUARDED_BY(dex_lock_);
+  std::list<jweak> dex_caches_ GUARDED_BY(dex_lock_);
   std::vector<const OatFile*> oat_files_ GUARDED_BY(dex_lock_);
 
-  // This contains the class laoders which have class tables. It is populated by
-  // InsertClassTableForClassLoader.
-  std::vector<GcRoot<mirror::ClassLoader>> class_loaders_
+  // This contains the class loaders which have class tables. It is populated by
+  // InsertClassTableForClassLoader. Weak roots to enable class unloading.
+  std::list<jweak> class_loaders_
       GUARDED_BY(Locks::classlinker_classes_lock_);
 
   // Boot class path table. Since the class loader for this is null.
diff --git a/runtime/class_table.h b/runtime/class_table.h
index 6b18d90..727392e 100644
--- a/runtime/class_table.h
+++ b/runtime/class_table.h
@@ -58,10 +58,10 @@
       REQUIRES(Locks::classlinker_classes_lock_) SHARED_REQUIRES(Locks::mutator_lock_);
 
   // Returns the number of classes in previous snapshots.
-  size_t NumZygoteClasses() const REQUIRES(Locks::classlinker_classes_lock_);
+  size_t NumZygoteClasses() const SHARED_REQUIRES(Locks::classlinker_classes_lock_);
 
   // Returns all off the classes in the lastest snapshot.
-  size_t NumNonZygoteClasses() const REQUIRES(Locks::classlinker_classes_lock_);
+  size_t NumNonZygoteClasses() const SHARED_REQUIRES(Locks::classlinker_classes_lock_);
 
   // Update a class in the table with the new class. Returns the existing class which was replaced.
   mirror::Class* UpdateClass(const char* descriptor, mirror::Class* new_klass, size_t hash)
@@ -79,7 +79,7 @@
 
   // Return false if the callback told us to exit.
   bool Visit(ClassVisitor* visitor)
-      REQUIRES(Locks::classlinker_classes_lock_) SHARED_REQUIRES(Locks::mutator_lock_);
+      SHARED_REQUIRES(Locks::classlinker_classes_lock_, Locks::mutator_lock_);
 
   mirror::Class* Lookup(const char* descriptor, size_t hash)
       SHARED_REQUIRES(Locks::classlinker_classes_lock_, Locks::mutator_lock_);
diff --git a/runtime/gc_root.h b/runtime/gc_root.h
index 83471e6..477e67b 100644
--- a/runtime/gc_root.h
+++ b/runtime/gc_root.h
@@ -90,16 +90,16 @@
   virtual ~RootVisitor() { }
 
   // Single root version, not overridable.
-  ALWAYS_INLINE void VisitRoot(mirror::Object** roots, const RootInfo& info)
+  ALWAYS_INLINE void VisitRoot(mirror::Object** root, const RootInfo& info)
       SHARED_REQUIRES(Locks::mutator_lock_) {
-    VisitRoots(&roots, 1, info);
+    VisitRoots(&root, 1, info);
   }
 
   // Single root version, not overridable.
-  ALWAYS_INLINE void VisitRootIfNonNull(mirror::Object** roots, const RootInfo& info)
+  ALWAYS_INLINE void VisitRootIfNonNull(mirror::Object** root, const RootInfo& info)
       SHARED_REQUIRES(Locks::mutator_lock_) {
-    if (*roots != nullptr) {
-      VisitRoot(roots, info);
+    if (*root != nullptr) {
+      VisitRoot(root, info);
     }
   }
 
diff --git a/runtime/java_vm_ext.cc b/runtime/java_vm_ext.cc
index d6c798a..92eef39 100644
--- a/runtime/java_vm_ext.cc
+++ b/runtime/java_vm_ext.cc
@@ -592,6 +592,7 @@
   // This only applies in the case where MayAccessWeakGlobals goes from false to true. In the other
   // case, it may be racy, this is benign since DecodeWeakGlobalLocked does the correct behavior
   // if MayAccessWeakGlobals is false.
+  DCHECK_EQ(GetIndirectRefKind(ref), kWeakGlobal);
   if (LIKELY(MayAccessWeakGlobalsUnlocked(self))) {
     return weak_globals_.SynchronizedGet(ref);
   }