Implement IterateOverInstances JVMTI function

This JVMTI function can be useful since it allows one to easily
iterate over all instances of a class, including subclasses.

Bug: 32635074
Test: ./test.py --host

Change-Id: I31c3b54ca599964c64aa31fbf666253e9b8000c4
diff --git a/openjdkjvmti/OpenjdkJvmTi.cc b/openjdkjvmti/OpenjdkJvmTi.cc
index 59f61e2..3213bbe 100644
--- a/openjdkjvmti/OpenjdkJvmTi.cc
+++ b/openjdkjvmti/OpenjdkJvmTi.cc
@@ -506,13 +506,15 @@
 
   static jvmtiError IterateOverInstancesOfClass(
       jvmtiEnv* env,
-      jclass klass ATTRIBUTE_UNUSED,
-      jvmtiHeapObjectFilter object_filter ATTRIBUTE_UNUSED,
-      jvmtiHeapObjectCallback heap_object_callback ATTRIBUTE_UNUSED,
-      const void* user_data ATTRIBUTE_UNUSED) {
+      jclass klass,
+      jvmtiHeapObjectFilter object_filter,
+      jvmtiHeapObjectCallback heap_object_callback,
+      const void* user_data) {
     ENSURE_VALID_ENV(env);
     ENSURE_HAS_CAP(env, can_tag_objects);
-    return ERR(NOT_IMPLEMENTED);
+    HeapUtil heap_util(ArtJvmTiEnv::AsArtJvmTiEnv(env)->object_tag_table.get());
+    return heap_util.IterateOverInstancesOfClass(
+        env, klass, object_filter, heap_object_callback, user_data);
   }
 
   static jvmtiError GetLocalObject(jvmtiEnv* env,
diff --git a/openjdkjvmti/ti_heap.cc b/openjdkjvmti/ti_heap.cc
index d23370b..d1583e5 100644
--- a/openjdkjvmti/ti_heap.cc
+++ b/openjdkjvmti/ti_heap.cc
@@ -653,6 +653,70 @@
   art::Runtime::Current()->RemoveSystemWeakHolder(&gIndexCachingTable);
 }
 
+jvmtiError HeapUtil::IterateOverInstancesOfClass(jvmtiEnv* env,
+                                                 jclass klass,
+                                                 jvmtiHeapObjectFilter filter,
+                                                 jvmtiHeapObjectCallback cb,
+                                                 const void* user_data) {
+  if (cb == nullptr || klass == nullptr) {
+    return ERR(NULL_POINTER);
+  }
+
+  art::Thread* self = art::Thread::Current();
+  art::ScopedObjectAccess soa(self);      // Now we know we have the shared lock.
+  art::StackHandleScope<1> hs(self);
+
+  art::ObjPtr<art::mirror::Object> klass_ptr(soa.Decode<art::mirror::Class>(klass));
+  if (!klass_ptr->IsClass()) {
+    return ERR(INVALID_CLASS);
+  }
+  art::Handle<art::mirror::Class> filter_klass(hs.NewHandle(klass_ptr->AsClass()));
+  if (filter_klass->IsInterface()) {
+    // nothing is an 'instance' of an interface so just return without walking anything.
+    return OK;
+  }
+
+  ObjectTagTable* tag_table = ArtJvmTiEnv::AsArtJvmTiEnv(env)->object_tag_table.get();
+  bool stop_reports = false;
+  auto visitor = [&](art::mirror::Object* obj) REQUIRES_SHARED(art::Locks::mutator_lock_) {
+    // Early return, as we can't really stop visiting.
+    if (stop_reports) {
+      return;
+    }
+
+    art::ScopedAssertNoThreadSuspension no_suspension("IterateOverInstancesOfClass");
+
+    art::ObjPtr<art::mirror::Class> klass = obj->GetClass();
+
+    if (filter_klass != nullptr && !filter_klass->IsAssignableFrom(klass)) {
+      return;
+    }
+
+    jlong tag = 0;
+    tag_table->GetTag(obj, &tag);
+    if ((filter != JVMTI_HEAP_OBJECT_EITHER) &&
+        ((tag == 0 && filter == JVMTI_HEAP_OBJECT_TAGGED) ||
+         (tag != 0 && filter == JVMTI_HEAP_OBJECT_UNTAGGED))) {
+      return;
+    }
+
+    jlong class_tag = 0;
+    tag_table->GetTag(klass.Ptr(), &class_tag);
+
+    jlong saved_tag = tag;
+    jint ret = cb(class_tag, obj->SizeOf(), &tag, const_cast<void*>(user_data));
+
+    stop_reports = (ret == JVMTI_ITERATION_ABORT);
+
+    if (tag != saved_tag) {
+      tag_table->Set(obj, tag);
+    }
+  };
+  art::Runtime::Current()->GetHeap()->VisitObjects(visitor);
+
+  return OK;
+}
+
 template <typename T>
 static jvmtiError DoIterateThroughHeap(T fn,
                                        jvmtiEnv* env,
diff --git a/openjdkjvmti/ti_heap.h b/openjdkjvmti/ti_heap.h
index 62761b5..382d80f 100644
--- a/openjdkjvmti/ti_heap.h
+++ b/openjdkjvmti/ti_heap.h
@@ -30,6 +30,12 @@
 
   jvmtiError GetLoadedClasses(jvmtiEnv* env, jint* class_count_ptr, jclass** classes_ptr);
 
+  jvmtiError IterateOverInstancesOfClass(jvmtiEnv* env,
+                                         jclass klass,
+                                         jvmtiHeapObjectFilter filter,
+                                         jvmtiHeapObjectCallback cb,
+                                         const void* user_data);
+
   jvmtiError IterateThroughHeap(jvmtiEnv* env,
                                 jint heap_filter,
                                 jclass klass,
diff --git a/test/906-iterate-heap/iterate_heap.cc b/test/906-iterate-heap/iterate_heap.cc
index 57c0274..d2f69ef 100644
--- a/test/906-iterate-heap/iterate_heap.cc
+++ b/test/906-iterate-heap/iterate_heap.cc
@@ -418,5 +418,21 @@
   return (status & JVMTI_CLASS_STATUS_INITIALIZED) != 0;
 }
 
+extern "C" JNIEXPORT jint JNICALL Java_art_Test906_iterateOverInstancesCount(
+    JNIEnv* env, jclass, jclass target) {
+  jint cnt = 0;
+  auto count_func = [](jlong, jlong, jlong*, void* user_data) -> jvmtiIterationControl {
+    *reinterpret_cast<jint*>(user_data) += 1;
+    return JVMTI_ITERATION_CONTINUE;
+  };
+  JvmtiErrorToException(env,
+                        jvmti_env,
+                        jvmti_env->IterateOverInstancesOfClass(target,
+                                                               JVMTI_HEAP_OBJECT_EITHER,
+                                                               count_func,
+                                                               &cnt));
+  return cnt;
+}
+
 }  // namespace Test906IterateHeap
 }  // namespace art
diff --git a/test/906-iterate-heap/src/art/Test906.java b/test/906-iterate-heap/src/art/Test906.java
index be9663a..190f36f 100644
--- a/test/906-iterate-heap/src/art/Test906.java
+++ b/test/906-iterate-heap/src/art/Test906.java
@@ -18,6 +18,7 @@
 
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.Random;
 
 public class Test906 {
   public static void run() throws Exception {
@@ -69,6 +70,40 @@
     throw lastThrow;
   }
 
+  private static Object[] GenTs(Class<?> k) throws Exception {
+    Object[] ret = new Object[new Random().nextInt(100) + 10];
+    for (int i = 0; i < ret.length; i++) {
+      ret[i] = k.newInstance();
+    }
+    return ret;
+  }
+
+  private static void checkEq(int a, int b) {
+    if (a != b) {
+      Error e = new Error("Failed: Expected equal " + a + " and " + b);
+      System.out.println(e);
+      e.printStackTrace(System.out);
+    }
+  }
+
+  public static class Foo {}
+  public static class Bar extends Foo {}
+  public static class Baz extends Bar {}
+  public static class Alpha extends Bar {}
+  public static class MISSING extends Baz {}
+  private static void testIterateOverInstances() throws Exception {
+    Object[] foos = GenTs(Foo.class);
+    Object[] bars = GenTs(Bar.class);
+    Object[] bazs = GenTs(Baz.class);
+    Object[] alphas = GenTs(Alpha.class);
+    checkEq(0, iterateOverInstancesCount(MISSING.class));
+    checkEq(alphas.length, iterateOverInstancesCount(Alpha.class));
+    checkEq(bazs.length, iterateOverInstancesCount(Baz.class));
+    checkEq(bazs.length + alphas.length + bars.length, iterateOverInstancesCount(Bar.class));
+    checkEq(bazs.length + alphas.length + bars.length + foos.length,
+        iterateOverInstancesCount(Foo.class));
+  }
+
   public static void doTest() throws Exception {
     A a = new A();
     B b = new B();
@@ -86,6 +121,8 @@
 
     testHeapCount();
 
+    testIterateOverInstances();
+
     long classTags[] = new long[100];
     long sizes[] = new long[100];
     long tags[] = new long[100];
@@ -308,6 +345,8 @@
     return Main.getTag(o);
   }
 
+  private static native int iterateOverInstancesCount(Class<?> klass);
+
   private static native boolean checkInitialized(Class<?> klass);
   private static native int iterateThroughHeapCount(int heapFilter,
       Class<?> klassFilter, int stopAfter);