ART: Correctly handle temporary classes in class-load events (1/3)

When a temporary class is given out in a ClassLoad event, all stored
references need to be fixed up before publishing a ClassPrepare event.

This CL handles objects stored as global references.

Bug: 31684920
Test: m test-art-host-run-test-912-classes
Change-Id: I2f79c7943e13c0db9ad7cb9cd60450ff6373be4f
diff --git a/runtime/openjdkjvmti/ti_class.cc b/runtime/openjdkjvmti/ti_class.cc
index c14fd84..a1efb97 100644
--- a/runtime/openjdkjvmti/ti_class.cc
+++ b/runtime/openjdkjvmti/ti_class.cc
@@ -42,6 +42,8 @@
 #include "class_linker.h"
 #include "common_throws.h"
 #include "events-inl.h"
+#include "gc/heap.h"
+#include "gc_root.h"
 #include "handle.h"
 #include "jni_env_ext-inl.h"
 #include "jni_internal.h"
@@ -261,15 +263,22 @@
             thread_jni.get(),
             jklass.get());
       }
-      AddTempClass(thread, jklass.get());
+      if (klass->IsTemp()) {
+        AddTempClass(thread, jklass.get());
+      }
     }
   }
 
-  void ClassPrepare(art::Handle<art::mirror::Class> temp_klass ATTRIBUTE_UNUSED,
+  void ClassPrepare(art::Handle<art::mirror::Class> temp_klass,
                     art::Handle<art::mirror::Class> klass)
       REQUIRES_SHARED(art::Locks::mutator_lock_) {
     if (event_handler->IsEventEnabledAnywhere(ArtJvmtiEvent::kClassPrepare)) {
       art::Thread* thread = art::Thread::Current();
+      if (temp_klass.Get() != klass.Get()) {
+        DCHECK(temp_klass->IsTemp());
+        DCHECK(temp_klass->IsRetired());
+        HandleTempClass(thread, temp_klass, klass);
+      }
       ScopedLocalRef<jclass> jklass(thread->GetJniEnv(),
                                     thread->GetJniEnv()->AddLocalReference<jclass>(klass.Get()));
       ScopedLocalRef<jthread> thread_jni(
@@ -285,10 +294,12 @@
 
   void AddTempClass(art::Thread* self, jclass klass) {
     std::unique_lock<std::mutex> mu(temp_classes_lock);
-    temp_classes.push_back(reinterpret_cast<jclass>(self->GetJniEnv()->NewGlobalRef(klass)));
+    jclass global_klass = reinterpret_cast<jclass>(self->GetJniEnv()->NewGlobalRef(klass));
+    temp_classes.push_back(global_klass);
   }
 
-  void HandleTempClass(art::Handle<art::mirror::Class> temp_klass,
+  void HandleTempClass(art::Thread* self,
+                       art::Handle<art::mirror::Class> temp_klass,
                        art::Handle<art::mirror::Class> klass)
       REQUIRES_SHARED(art::Locks::mutator_lock_) {
     std::unique_lock<std::mutex> mu(temp_classes_lock);
@@ -296,19 +307,99 @@
       return;
     }
 
-    art::Thread* self = art::Thread::Current();
     for (auto it = temp_classes.begin(); it != temp_classes.end(); ++it) {
       if (temp_klass.Get() == art::ObjPtr<art::mirror::Class>::DownCast(self->DecodeJObject(*it))) {
+        self->GetJniEnv()->DeleteGlobalRef(*it);
         temp_classes.erase(it);
-        FixupTempClass(temp_klass, klass);
+        FixupTempClass(self, temp_klass, klass);
+        break;
       }
     }
   }
 
-  void FixupTempClass(art::Handle<art::mirror::Class> temp_klass ATTRIBUTE_UNUSED,
-                      art::Handle<art::mirror::Class> klass ATTRIBUTE_UNUSED)
+  void FixupTempClass(art::Thread* self,
+                      art::Handle<art::mirror::Class> temp_klass,
+                      art::Handle<art::mirror::Class> klass)
      REQUIRES_SHARED(art::Locks::mutator_lock_) {
-    // TODO: Implement.
+    // Suspend everything.
+    art::gc::Heap* heap = art::Runtime::Current()->GetHeap();
+    if (heap->IsGcConcurrentAndMoving()) {
+      // Need to take a heap dump while GC isn't running. See the
+      // comment in Heap::VisitObjects().
+      heap->IncrementDisableMovingGC(self);
+    }
+    {
+      art::ScopedThreadSuspension sts(self, art::kWaitingForVisitObjects);
+      art::ScopedSuspendAll ssa("FixupTempClass");
+
+      art::mirror::Class* input = temp_klass.Get();
+      art::mirror::Class* output = klass.Get();
+
+      FixupGlobalReferenceTables(input, output);
+    }
+    if (heap->IsGcConcurrentAndMoving()) {
+      heap->DecrementDisableMovingGC(self);
+    }
+  }
+
+  void FixupGlobalReferenceTables(art::mirror::Class* input,
+                                  art::mirror::Class* output)
+      REQUIRES(art::Locks::mutator_lock_) {
+    art::JavaVMExt* java_vm = art::Runtime::Current()->GetJavaVM();
+
+    // Fix up the global table with a root visitor.
+    class GlobalUpdate : public art::RootVisitor {
+     public:
+      GlobalUpdate(art::mirror::Class* root_input, art::mirror::Class* root_output)
+          : input_(root_input), output_(root_output) {}
+
+      void VisitRoots(art::mirror::Object*** roots,
+                      size_t count,
+                      const art::RootInfo& info ATTRIBUTE_UNUSED)
+          OVERRIDE {
+        for (size_t i = 0; i != count; ++i) {
+          if (*roots[i] == input_) {
+            *roots[i] = output_;
+          }
+        }
+      }
+
+      void VisitRoots(art::mirror::CompressedReference<art::mirror::Object>** roots,
+                      size_t count,
+                      const art::RootInfo& info ATTRIBUTE_UNUSED)
+          OVERRIDE REQUIRES_SHARED(art::Locks::mutator_lock_) {
+        for (size_t i = 0; i != count; ++i) {
+          if (roots[i]->AsMirrorPtr() == input_) {
+            roots[i]->Assign(output_);
+          }
+        }
+      }
+
+     private:
+      const art::mirror::Class* input_;
+      art::mirror::Class* output_;
+    };
+    GlobalUpdate global_update(input, output);
+    java_vm->VisitRoots(&global_update);
+
+    class WeakGlobalUpdate : public art::IsMarkedVisitor {
+     public:
+      WeakGlobalUpdate(art::mirror::Class* root_input, art::mirror::Class* root_output)
+          : input_(root_input), output_(root_output) {}
+
+      art::mirror::Object* IsMarked(art::mirror::Object* obj) OVERRIDE {
+        if (obj == input_) {
+          return output_;
+        }
+        return obj;
+      }
+
+     private:
+      const art::mirror::Class* input_;
+      art::mirror::Class* output_;
+    };
+    WeakGlobalUpdate weak_global_update(input, output);
+    java_vm->SweepJniWeakGlobals(&weak_global_update);
   }
 
   // A set of all the temp classes we have handed out. We have to fix up references to these.
diff --git a/test/912-classes/classes.cc b/test/912-classes/classes.cc
index e659ea3..6c12522 100644
--- a/test/912-classes/classes.cc
+++ b/test/912-classes/classes.cc
@@ -430,5 +430,70 @@
   return found ? JNI_TRUE : JNI_FALSE;
 }
 
+class ClassLoadPrepareEquality {
+ public:
+  static constexpr const char* kClassName = "LMain$ClassE;";
+
+  static void JNICALL ClassLoadCallback(jvmtiEnv* jenv,
+                                        JNIEnv* jni_env,
+                                        jthread thread ATTRIBUTE_UNUSED,
+                                        jclass klass) {
+    std::string name = GetClassName(jenv, jni_env, klass);
+    if (name == kClassName) {
+      found_ = true;
+      stored_class_ = jni_env->NewGlobalRef(klass);
+      weakly_stored_class_ = jni_env->NewWeakGlobalRef(klass);
+    }
+  }
+
+  static void JNICALL ClassPrepareCallback(jvmtiEnv* jenv,
+                                           JNIEnv* jni_env,
+                                           jthread thread ATTRIBUTE_UNUSED,
+                                           jclass klass) {
+    std::string name = GetClassName(jenv, jni_env, klass);
+    if (name == kClassName) {
+      CHECK(stored_class_ != nullptr);
+      CHECK(jni_env->IsSameObject(stored_class_, klass));
+      CHECK(jni_env->IsSameObject(weakly_stored_class_, klass));
+      compared_ = true;
+    }
+  }
+
+  static void CheckFound() {
+    CHECK(found_);
+    CHECK(compared_);
+  }
+
+  static void Free(JNIEnv* env) {
+    if (stored_class_ != nullptr) {
+      env->DeleteGlobalRef(stored_class_);
+      DCHECK(weakly_stored_class_ != nullptr);
+      env->DeleteWeakGlobalRef(weakly_stored_class_);
+    }
+  }
+
+ private:
+  static jobject stored_class_;
+  static jweak weakly_stored_class_;
+  static bool found_;
+  static bool compared_;
+};
+jobject ClassLoadPrepareEquality::stored_class_ = nullptr;
+jweak ClassLoadPrepareEquality::weakly_stored_class_ = nullptr;
+bool ClassLoadPrepareEquality::found_ = false;
+bool ClassLoadPrepareEquality::compared_ = false;
+
+extern "C" JNIEXPORT void JNICALL Java_Main_enableClassLoadPrepareEqualityEvents(
+    JNIEnv* env, jclass Main_klass ATTRIBUTE_UNUSED, jboolean b) {
+  EnableEvents(env,
+               b,
+               ClassLoadPrepareEquality::ClassLoadCallback,
+               ClassLoadPrepareEquality::ClassPrepareCallback);
+  if (b == JNI_FALSE) {
+    ClassLoadPrepareEquality::Free(env);
+    ClassLoadPrepareEquality::CheckFound();
+  }
+}
+
 }  // namespace Test912Classes
 }  // namespace art
diff --git a/test/912-classes/src/Main.java b/test/912-classes/src/Main.java
index e3aceb9..c1de679 100644
--- a/test/912-classes/src/Main.java
+++ b/test/912-classes/src/Main.java
@@ -290,6 +290,8 @@
     if (hasJit() && !isLoadedClass("Main$ClassD")) {
       testClassEventsJit();
     }
+
+    testClassLoadPrepareEquality();
   }
 
   private static void testClassEventsJit() throws Exception {
@@ -312,6 +314,14 @@
     }
   }
 
+  private static void testClassLoadPrepareEquality() throws Exception {
+    enableClassLoadPrepareEqualityEvents(true);
+
+    Class.forName("Main$ClassE");
+
+    enableClassLoadPrepareEqualityEvents(false);
+  }
+
   private static void printClassLoaderClasses(ClassLoader cl) {
     for (;;) {
       if (cl == null || !cl.getClass().getName().startsWith("dalvik.system")) {
@@ -383,6 +393,8 @@
   private static native void enableClassLoadSeenEvents(boolean b);
   private static native boolean hadLoadEvent();
 
+  private static native void enableClassLoadPrepareEqualityEvents(boolean b);
+
   private static class TestForNonInit {
     public static double dummy = Math.random();  // So it can't be compile-time initialized.
   }
@@ -409,6 +421,13 @@
     static int x = 1;
   }
 
+  public static class ClassE {
+    public void foo() {
+    }
+    public void bar() {
+    }
+  }
+
   private static final String DEX1 = System.getenv("DEX_LOCATION") + "/912-classes.jar";
   private static final String DEX2 = System.getenv("DEX_LOCATION") + "/912-classes-ex.jar";