ART: Correctly handle temporary classes in class-load events (3/3)

When a temporary class is given out in a ClassLoad event, all stored
references need to be fixed up before publishing a ClassPrepare event.

This CL handles objects stored in the heap.

Bug: 31684920
Test: m test-art-host-run-test-912-classes
Change-Id: Ia0456c81fd848618e637b93301edf4dbc8d848f2
diff --git a/runtime/openjdkjvmti/ti_class.cc b/runtime/openjdkjvmti/ti_class.cc
index cf3cfa4..fc4b6fe 100644
--- a/runtime/openjdkjvmti/ti_class.cc
+++ b/runtime/openjdkjvmti/ti_class.cc
@@ -50,6 +50,8 @@
 #include "mirror/array-inl.h"
 #include "mirror/class-inl.h"
 #include "mirror/class_ext.h"
+#include "mirror/object_reference.h"
+#include "mirror/object-inl.h"
 #include "runtime.h"
 #include "runtime_callbacks.h"
 #include "ScopedLocalRef.h"
@@ -347,6 +349,7 @@
 
       FixupGlobalReferenceTables(input, output);
       FixupLocalReferenceTables(self, input, output);
+      FixupHeap(input, output);
     }
     if (heap->IsGcConcurrentAndMoving()) {
       heap->DecrementDisableMovingGC(self);
@@ -385,8 +388,7 @@
     art::mirror::Class* output_;
   };
 
-  void FixupGlobalReferenceTables(art::mirror::Class* input,
-                                  art::mirror::Class* output)
+  void FixupGlobalReferenceTables(art::mirror::Class* input, art::mirror::Class* output)
       REQUIRES(art::Locks::mutator_lock_) {
     art::JavaVMExt* java_vm = art::Runtime::Current()->GetJavaVM();
 
@@ -441,6 +443,53 @@
     art::Runtime::Current()->GetThreadList()->ForEach(LocalUpdate::Callback, &local_upd);
   }
 
+  void FixupHeap(art::mirror::Class* input, art::mirror::Class* output)
+        REQUIRES(art::Locks::mutator_lock_) {
+    class HeapFixupVisitor {
+     public:
+      HeapFixupVisitor(const art::mirror::Class* root_input, art::mirror::Class* root_output)
+                : input_(root_input), output_(root_output) {}
+
+      void operator()(art::mirror::Object* src,
+                      art::MemberOffset field_offset,
+                      bool is_static ATTRIBUTE_UNUSED) const
+          REQUIRES_SHARED(art::Locks::mutator_lock_) {
+        art::mirror::HeapReference<art::mirror::Object>* trg =
+          src->GetFieldObjectReferenceAddr(field_offset);
+        if (trg->AsMirrorPtr() == input_) {
+          DCHECK_NE(field_offset.Uint32Value(), 0u);  // This shouldn't be the class field of
+                                                      // an object.
+          trg->Assign(output_);
+        }
+      }
+
+      void VisitRoot(art::mirror::CompressedReference<art::mirror::Object>* root ATTRIBUTE_UNUSED)
+      const {
+        LOG(FATAL) << "Unreachable";
+      }
+
+      void VisitRootIfNonNull(
+          art::mirror::CompressedReference<art::mirror::Object>* root ATTRIBUTE_UNUSED) const {
+        LOG(FATAL) << "Unreachable";
+      }
+
+      static void AllObjectsCallback(art::mirror::Object* obj, void* arg)
+          REQUIRES_SHARED(art::Locks::mutator_lock_) {
+        HeapFixupVisitor* hfv = reinterpret_cast<HeapFixupVisitor*>(arg);
+
+        // Visit references, not native roots.
+        obj->VisitReferences<false>(*hfv, art::VoidFunctor());
+      }
+
+     private:
+      const art::mirror::Class* input_;
+      art::mirror::Class* output_;
+    };
+    HeapFixupVisitor hfv(input, output);
+    art::Runtime::Current()->GetHeap()->VisitObjectsPaused(HeapFixupVisitor::AllObjectsCallback,
+                                                           &hfv);
+  }
+
   // A set of all the temp classes we have handed out. We have to fix up references to these.
   // For simplicity, we store the temp classes as JNI global references in a vector. Normally a
   // Prepare event will closely follow, so the vector should be small.
diff --git a/test/912-classes/classes.cc b/test/912-classes/classes.cc
index b727453..c92e49f 100644
--- a/test/912-classes/classes.cc
+++ b/test/912-classes/classes.cc
@@ -433,6 +433,9 @@
 class ClassLoadPrepareEquality {
  public:
   static constexpr const char* kClassName = "LMain$ClassE;";
+  static constexpr const char* kStorageClassName = "Main$ClassF";
+  static constexpr const char* kStorageFieldName = "STATIC";
+  static constexpr const char* kStorageFieldSig = "Ljava/lang/Object;";
 
   static void JNICALL ClassLoadCallback(jvmtiEnv* jenv,
                                         JNIEnv* jni_env,
@@ -446,6 +449,8 @@
       // The following is bad and relies on implementation details. But otherwise a test would be
       // a lot more complicated.
       local_stored_class_ = jni_env->NewLocalRef(klass);
+      // Store the value into a field in the heap.
+      SetOrCompare(jni_env, klass, true);
     }
   }
 
@@ -459,10 +464,26 @@
       CHECK(jni_env->IsSameObject(stored_class_, klass));
       CHECK(jni_env->IsSameObject(weakly_stored_class_, klass));
       CHECK(jni_env->IsSameObject(local_stored_class_, klass));
+      // Look up the value in a field in the heap.
+      SetOrCompare(jni_env, klass, false);
       compared_ = true;
     }
   }
 
+  static void SetOrCompare(JNIEnv* jni_env, jobject value, bool set) {
+    CHECK(storage_class_ != nullptr);
+    jfieldID field = jni_env->GetStaticFieldID(storage_class_, kStorageFieldName, kStorageFieldSig);
+    CHECK(field != nullptr);
+
+    if (set) {
+      jni_env->SetStaticObjectField(storage_class_, field, value);
+      CHECK(!jni_env->ExceptionCheck());
+    } else {
+      ScopedLocalRef<jobject> stored(jni_env, jni_env->GetStaticObjectField(storage_class_, field));
+      CHECK(jni_env->IsSameObject(value, stored.get()));
+    }
+  }
+
   static void CheckFound() {
     CHECK(found_);
     CHECK(compared_);
@@ -477,6 +498,8 @@
     }
   }
 
+  static jclass storage_class_;
+
  private:
   static jobject stored_class_;
   static jweak weakly_stored_class_;
@@ -484,12 +507,19 @@
   static bool found_;
   static bool compared_;
 };
+jclass ClassLoadPrepareEquality::storage_class_ = nullptr;
 jobject ClassLoadPrepareEquality::stored_class_ = nullptr;
 jweak ClassLoadPrepareEquality::weakly_stored_class_ = nullptr;
 jobject ClassLoadPrepareEquality::local_stored_class_ = nullptr;
 bool ClassLoadPrepareEquality::found_ = false;
 bool ClassLoadPrepareEquality::compared_ = false;
 
+extern "C" JNIEXPORT void JNICALL Java_Main_setEqualityEventStorageClass(
+    JNIEnv* env, jclass Main_klass ATTRIBUTE_UNUSED, jclass klass) {
+  ClassLoadPrepareEquality::storage_class_ =
+      reinterpret_cast<jclass>(env->NewGlobalRef(klass));
+}
+
 extern "C" JNIEXPORT void JNICALL Java_Main_enableClassLoadPrepareEqualityEvents(
     JNIEnv* env, jclass Main_klass ATTRIBUTE_UNUSED, jboolean b) {
   EnableEvents(env,
@@ -499,6 +529,8 @@
   if (b == JNI_FALSE) {
     ClassLoadPrepareEquality::Free(env);
     ClassLoadPrepareEquality::CheckFound();
+    env->DeleteGlobalRef(ClassLoadPrepareEquality::storage_class_);
+    ClassLoadPrepareEquality::storage_class_ = nullptr;
   }
 }
 
diff --git a/test/912-classes/src/Main.java b/test/912-classes/src/Main.java
index c1de679..52a5194 100644
--- a/test/912-classes/src/Main.java
+++ b/test/912-classes/src/Main.java
@@ -315,6 +315,8 @@
   }
 
   private static void testClassLoadPrepareEquality() throws Exception {
+    setEqualityEventStorageClass(ClassF.class);
+
     enableClassLoadPrepareEqualityEvents(true);
 
     Class.forName("Main$ClassE");
@@ -393,6 +395,7 @@
   private static native void enableClassLoadSeenEvents(boolean b);
   private static native boolean hadLoadEvent();
 
+  private static native void setEqualityEventStorageClass(Class<?> c);
   private static native void enableClassLoadPrepareEqualityEvents(boolean b);
 
   private static class TestForNonInit {
@@ -428,6 +431,10 @@
     }
   }
 
+  public static class ClassF {
+    public static Object STATIC = null;
+  }
+
   private static final String DEX1 = System.getenv("DEX_LOCATION") + "/912-classes.jar";
   private static final String DEX2 = System.getenv("DEX_LOCATION") + "/912-classes-ex.jar";