ART: Add string reporting

Add support for string_primitive_value_callback.

Bug: 31385354
Test: m test-art-host-run-test-906-iterate-heap
Test: m test-art-host-run-test-913-heaps
Change-Id: I69f68fd07869ba3a156a84fcb806821fce1d7c03
diff --git a/runtime/openjdkjvmti/ti_heap.cc b/runtime/openjdkjvmti/ti_heap.cc
index fe3e52b..7efeea7 100644
--- a/runtime/openjdkjvmti/ti_heap.cc
+++ b/runtime/openjdkjvmti/ti_heap.cc
@@ -38,14 +38,69 @@
 
 namespace openjdkjvmti {
 
+namespace {
+
+// Report the contents of a string, if a callback is set.
+jint ReportString(art::ObjPtr<art::mirror::Object> obj,
+                  jvmtiEnv* env,
+                  ObjectTagTable* tag_table,
+                  const jvmtiHeapCallbacks* cb,
+                  const void* user_data) REQUIRES_SHARED(art::Locks::mutator_lock_) {
+  if (UNLIKELY(cb->string_primitive_value_callback != nullptr) && obj->IsString()) {
+    art::ObjPtr<art::mirror::String> str = obj->AsString();
+    int32_t string_length = str->GetLength();
+    jvmtiError alloc_error;
+    JvmtiUniquePtr<uint16_t[]> data = AllocJvmtiUniquePtr<uint16_t[]>(env,
+                                                                      string_length,
+                                                                      &alloc_error);
+    if (data == nullptr) {
+      // TODO: Not really sure what to do here. Should we abort the iteration and go all the way
+      //       back? For now just warn.
+      LOG(WARNING) << "Unable to allocate buffer for string reporting! Silently dropping value.";
+      return 0;
+    }
+
+    if (str->IsCompressed()) {
+      uint8_t* compressed_data = str->GetValueCompressed();
+      for (int32_t i = 0; i != string_length; ++i) {
+        data[i] = compressed_data[i];
+      }
+    } else {
+      // Can copy directly.
+      memcpy(data.get(), str->GetValue(), string_length * sizeof(uint16_t));
+    }
+
+    const jlong class_tag = tag_table->GetTagOrZero(obj->GetClass());
+    jlong string_tag = tag_table->GetTagOrZero(obj.Ptr());
+    const jlong saved_string_tag = string_tag;
+
+    jint result = cb->string_primitive_value_callback(class_tag,
+                                                      obj->SizeOf(),
+                                                      &string_tag,
+                                                      data.get(),
+                                                      string_length,
+                                                      const_cast<void*>(user_data));
+    if (string_tag != saved_string_tag) {
+      tag_table->Set(obj.Ptr(), string_tag);
+    }
+
+    return result;
+  }
+  return 0;
+}
+
+}  // namespace
+
 struct IterateThroughHeapData {
   IterateThroughHeapData(HeapUtil* _heap_util,
+                         jvmtiEnv* _env,
                          jint heap_filter,
                          art::ObjPtr<art::mirror::Class> klass,
                          const jvmtiHeapCallbacks* _callbacks,
                          const void* _user_data)
       : heap_util(_heap_util),
         filter_klass(klass),
+        env(_env),
         callbacks(_callbacks),
         user_data(_user_data),
         filter_out_tagged((heap_filter & JVMTI_HEAP_FILTER_TAGGED) != 0),
@@ -78,6 +133,7 @@
 
   HeapUtil* heap_util;
   art::ObjPtr<art::mirror::Class> filter_klass;
+  jvmtiEnv* env;
   const jvmtiHeapCallbacks* callbacks;
   const void* user_data;
   const bool filter_out_tagged;
@@ -111,8 +167,6 @@
     return;
   }
 
-  // TODO: Handle array_primitive_value_callback.
-
   if (ithd->filter_klass != nullptr) {
     if (ithd->filter_klass != klass) {
       return;
@@ -139,11 +193,20 @@
 
   ithd->stop_reports = (ret & JVMTI_VISIT_ABORT) != 0;
 
-  // TODO Implement array primitive and string primitive callback.
+  if (!ithd->stop_reports) {
+    jint string_ret = ReportString(obj,
+                                   ithd->env,
+                                   ithd->heap_util->GetTags(),
+                                   ithd->callbacks,
+                                   ithd->user_data);
+    ithd->stop_reports = (string_ret & JVMTI_VISIT_ABORT) != 0;
+  }
+
+  // TODO Implement array primitive callback.
   // TODO Implement primitive field callback.
 }
 
-jvmtiError HeapUtil::IterateThroughHeap(jvmtiEnv* env ATTRIBUTE_UNUSED,
+jvmtiError HeapUtil::IterateThroughHeap(jvmtiEnv* env,
                                         jint heap_filter,
                                         jclass klass,
                                         const jvmtiHeapCallbacks* callbacks,
@@ -161,6 +224,7 @@
   art::ScopedObjectAccess soa(self);      // Now we know we have the shared lock.
 
   IterateThroughHeapData ithd(this,
+                              env,
                               heap_filter,
                               soa.Decode<art::mirror::Class>(klass),
                               callbacks,
@@ -174,10 +238,12 @@
 class FollowReferencesHelper FINAL {
  public:
   FollowReferencesHelper(HeapUtil* h,
+                         jvmtiEnv* jvmti_env,
                          art::ObjPtr<art::mirror::Object> initial_object,
                          const jvmtiHeapCallbacks* callbacks,
                          const void* user_data)
-      : tag_table_(h->GetTags()),
+      : env(jvmti_env),
+        tag_table_(h->GetTags()),
         initial_object_(initial_object),
         callbacks_(callbacks),
         user_data_(user_data),
@@ -467,6 +533,11 @@
     obj->VisitReferences<false>(visitor, art::VoidFunctor());
 
     stop_reports_ = visitor.stop_reports;
+
+    if (!stop_reports_) {
+      jint string_ret = ReportString(obj, env, tag_table_, callbacks_, user_data_);
+      stop_reports_ = (string_ret & JVMTI_VISIT_ABORT) != 0;
+    }
   }
 
   void VisitArray(art::mirror::Object* array)
@@ -655,6 +726,7 @@
     return result;
   }
 
+  jvmtiEnv* env;
   ObjectTagTable* tag_table_;
   art::ObjPtr<art::mirror::Object> initial_object_;
   const jvmtiHeapCallbacks* callbacks_;
@@ -671,7 +743,7 @@
   friend class CollectAndReportRootsVisitor;
 };
 
-jvmtiError HeapUtil::FollowReferences(jvmtiEnv* env ATTRIBUTE_UNUSED,
+jvmtiError HeapUtil::FollowReferences(jvmtiEnv* env,
                                       jint heap_filter ATTRIBUTE_UNUSED,
                                       jclass klass ATTRIBUTE_UNUSED,
                                       jobject initial_object,
@@ -700,6 +772,7 @@
     art::ScopedSuspendAll ssa("FollowReferences");
 
     FollowReferencesHelper frh(this,
+                               env,
                                self->DecodeJObject(initial_object),
                                callbacks,
                                user_data);
diff --git a/test/906-iterate-heap/expected.txt b/test/906-iterate-heap/expected.txt
index 72cd47d..d636286 100644
--- a/test/906-iterate-heap/expected.txt
+++ b/test/906-iterate-heap/expected.txt
@@ -1,2 +1,4 @@
-[{tag=1, class-tag=0, size=8, length=-1}, {tag=2, class-tag=100, size=8, length=-1}, {tag=3, class-tag=100, size=8, length=-1}, {tag=4, class-tag=0, size=32, length=5}, {tag=100, class-tag=0, size=<class>, length=-1}]
-[{tag=11, class-tag=0, size=8, length=-1}, {tag=12, class-tag=110, size=8, length=-1}, {tag=13, class-tag=110, size=8, length=-1}, {tag=14, class-tag=0, size=32, length=5}, {tag=110, class-tag=0, size=<class>, length=-1}]
+[{tag=1, class-tag=0, size=8, length=-1}, {tag=2, class-tag=100, size=8, length=-1}, {tag=3, class-tag=100, size=8, length=-1}, {tag=4, class-tag=0, size=32, length=5}, {tag=5, class-tag=0, size=40, length=-1}, {tag=100, class-tag=0, size=<class>, length=-1}]
+[{tag=11, class-tag=0, size=8, length=-1}, {tag=12, class-tag=110, size=8, length=-1}, {tag=13, class-tag=110, size=8, length=-1}, {tag=14, class-tag=0, size=32, length=5}, {tag=15, class-tag=0, size=40, length=-1}, {tag=110, class-tag=0, size=<class>, length=-1}]
+15@0 ( 40, 'Hello World')
+16
diff --git a/test/906-iterate-heap/iterate_heap.cc b/test/906-iterate-heap/iterate_heap.cc
index 1362d47..0a0c68a 100644
--- a/test/906-iterate-heap/iterate_heap.cc
+++ b/test/906-iterate-heap/iterate_heap.cc
@@ -14,17 +14,21 @@
  * limitations under the License.
  */
 
+#include "inttypes.h"
+
 #include <iostream>
 #include <pthread.h>
 #include <stdio.h>
 #include <vector>
 
+#include "android-base/stringprintf.h"
 #include "base/logging.h"
 #include "jni.h"
 #include "openjdkjvmti/jvmti.h"
 #include "ScopedPrimitiveArray.h"
 #include "ti-agent/common_helper.h"
 #include "ti-agent/common_load.h"
+#include "utf.h"
 
 namespace art {
 namespace Test906IterateHeap {
@@ -172,5 +176,61 @@
   Run(heap_filter, klass_filter, &config);
 }
 
+extern "C" JNIEXPORT jstring JNICALL Java_Main_iterateThroughHeapString(
+    JNIEnv* env, jclass klass ATTRIBUTE_UNUSED, jlong tag) {
+  struct FindStringCallbacks {
+    explicit FindStringCallbacks(jlong t) : tag_to_find(t) {}
+
+    static jint JNICALL  HeapIterationCallback(jlong class_tag ATTRIBUTE_UNUSED,
+                                               jlong size ATTRIBUTE_UNUSED,
+                                               jlong* tag_ptr ATTRIBUTE_UNUSED,
+                                               jint length ATTRIBUTE_UNUSED,
+                                               void* user_data ATTRIBUTE_UNUSED) {
+      return 0;
+    }
+
+    static jint JNICALL StringValueCallback(jlong class_tag,
+                                            jlong size,
+                                            jlong* tag_ptr,
+                                            const jchar* value,
+                                            jint value_length,
+                                            void* user_data) {
+      FindStringCallbacks* p = reinterpret_cast<FindStringCallbacks*>(user_data);
+      if (*tag_ptr == p->tag_to_find) {
+        size_t utf_byte_count = CountUtf8Bytes(value, value_length);
+        std::unique_ptr<char[]> mod_utf(new char[utf_byte_count + 1]);
+        memset(mod_utf.get(), 0, utf_byte_count + 1);
+        ConvertUtf16ToModifiedUtf8(mod_utf.get(), utf_byte_count, value, value_length);
+        if (!p->data.empty()) {
+          p->data += "\n";
+        }
+        p->data += android::base::StringPrintf("%" PRId64 "@%" PRId64 " (% " PRId64 ", '%s')",
+                                               *tag_ptr,
+                                               class_tag,
+                                               size,
+                                               mod_utf.get());
+        // Update the tag to test whether that works.
+        *tag_ptr = *tag_ptr + 1;
+      }
+      return 0;
+    }
+
+    std::string data;
+    const jlong tag_to_find;
+  };
+
+  jvmtiHeapCallbacks callbacks;
+  memset(&callbacks, 0, sizeof(jvmtiHeapCallbacks));
+  callbacks.heap_iteration_callback = FindStringCallbacks::HeapIterationCallback;
+  callbacks.string_primitive_value_callback = FindStringCallbacks::StringValueCallback;
+
+  FindStringCallbacks fsc(tag);
+  jvmtiError ret = jvmti_env->IterateThroughHeap(0, nullptr, &callbacks, &fsc);
+  if (JvmtiErrorToException(env, ret)) {
+    return nullptr;
+  }
+  return env->NewStringUTF(fsc.data.c_str());
+}
+
 }  // namespace Test906IterateHeap
 }  // namespace art
diff --git a/test/906-iterate-heap/src/Main.java b/test/906-iterate-heap/src/Main.java
index cab27be..755d23c 100644
--- a/test/906-iterate-heap/src/Main.java
+++ b/test/906-iterate-heap/src/Main.java
@@ -28,11 +28,13 @@
     B b2 = new B();
     C c = new C();
     A[] aArray = new A[5];
+    String s = "Hello World";
 
     setTag(a, 1);
     setTag(b, 2);
     setTag(b2, 3);
     setTag(aArray, 4);
+    setTag(s, 5);
     setTag(B.class, 100);
 
     int all = iterateThroughHeapCount(0, null, Integer.MAX_VALUE);
@@ -50,7 +52,7 @@
       throw new IllegalStateException("By class: " + all + " != " + taggedClass + " + " +
           untaggedClass);
     }
-    if (tagged != 5) {
+    if (tagged != 6) {
       throw new IllegalStateException(tagged + " tagged objects");
     }
     if (taggedClass != 2) {
@@ -74,6 +76,9 @@
     iterateThroughHeapAdd(HEAP_FILTER_OUT_UNTAGGED, null);
     n = iterateThroughHeapData(HEAP_FILTER_OUT_UNTAGGED, null, classTags, sizes, tags, lengths);
     System.out.println(sort(n, classTags, sizes, tags, lengths));
+
+    System.out.println(iterateThroughHeapString(getTag(s)));
+    System.out.println(getTag(s));
   }
 
   static class A {
@@ -141,4 +146,5 @@
       Class<?> klassFilter, long classTags[], long sizes[], long tags[], int lengths[]);
   private static native int iterateThroughHeapAdd(int heapFilter,
       Class<?> klassFilter);
+  private static native String iterateThroughHeapString(long tag);
 }
diff --git a/test/913-heaps/expected.txt b/test/913-heaps/expected.txt
index 340cd70..3125d2b 100644
--- a/test/913-heaps/expected.txt
+++ b/test/913-heaps/expected.txt
@@ -79,3 +79,5 @@
 5@1002 --(field@28)--> 1@1000 [size=16, length=-1]
 6@1000 --(class)--> 1000@0 [size=123, length=-1]
 ---
+[1@0 ( 40, 'HelloWorld')]
+2
diff --git a/test/913-heaps/heaps.cc b/test/913-heaps/heaps.cc
index 6759919..0c2361a 100644
--- a/test/913-heaps/heaps.cc
+++ b/test/913-heaps/heaps.cc
@@ -493,5 +493,67 @@
   return ret;
 }
 
+extern "C" JNIEXPORT jobjectArray JNICALL Java_Main_followReferencesString(
+    JNIEnv* env, jclass klass ATTRIBUTE_UNUSED, jobject initial_object) {
+  struct FindStringCallbacks {
+    static jint JNICALL FollowReferencesCallback(
+        jvmtiHeapReferenceKind reference_kind ATTRIBUTE_UNUSED,
+        const jvmtiHeapReferenceInfo* reference_info ATTRIBUTE_UNUSED,
+        jlong class_tag ATTRIBUTE_UNUSED,
+        jlong referrer_class_tag ATTRIBUTE_UNUSED,
+        jlong size ATTRIBUTE_UNUSED,
+        jlong* tag_ptr ATTRIBUTE_UNUSED,
+        jlong* referrer_tag_ptr ATTRIBUTE_UNUSED,
+        jint length ATTRIBUTE_UNUSED,
+        void* user_data ATTRIBUTE_UNUSED) {
+      return JVMTI_VISIT_OBJECTS;  // Continue visiting.
+    }
+
+    static jint JNICALL StringValueCallback(jlong class_tag,
+                                            jlong size,
+                                            jlong* tag_ptr,
+                                            const jchar* value,
+                                            jint value_length,
+                                            void* user_data) {
+      FindStringCallbacks* p = reinterpret_cast<FindStringCallbacks*>(user_data);
+      if (*tag_ptr != 0) {
+        size_t utf_byte_count = CountUtf8Bytes(value, value_length);
+        std::unique_ptr<char[]> mod_utf(new char[utf_byte_count + 1]);
+        memset(mod_utf.get(), 0, utf_byte_count + 1);
+        ConvertUtf16ToModifiedUtf8(mod_utf.get(), utf_byte_count, value, value_length);
+        p->data.push_back(android::base::StringPrintf("%" PRId64 "@%" PRId64 " (% " PRId64 ", '%s')",
+                                                      *tag_ptr,
+                                                      class_tag,
+                                                      size,
+                                                      mod_utf.get()));
+        // Update the tag to test whether that works.
+        *tag_ptr = *tag_ptr + 1;
+      }
+      return 0;
+    }
+
+    std::vector<std::string> data;
+  };
+
+  jvmtiHeapCallbacks callbacks;
+  memset(&callbacks, 0, sizeof(jvmtiHeapCallbacks));
+  callbacks.heap_reference_callback = FindStringCallbacks::FollowReferencesCallback;
+  callbacks.string_primitive_value_callback = FindStringCallbacks::StringValueCallback;
+
+  FindStringCallbacks fsc;
+  jvmtiError ret = jvmti_env->FollowReferences(0, nullptr, initial_object, &callbacks, &fsc);
+  if (JvmtiErrorToException(env, ret)) {
+    return nullptr;
+  }
+
+  jobjectArray retArray = CreateObjectArray(env,
+                                            static_cast<jint>(fsc.data.size()),
+                                            "java/lang/String",
+                                            [&](jint i) {
+                                              return env->NewStringUTF(fsc.data[i].c_str());
+                                            });
+  return retArray;
+}
+
 }  // namespace Test913Heaps
 }  // namespace art
diff --git a/test/913-heaps/src/Main.java b/test/913-heaps/src/Main.java
index 7f9c8fc..4402072 100644
--- a/test/913-heaps/src/Main.java
+++ b/test/913-heaps/src/Main.java
@@ -24,6 +24,8 @@
   public static void main(String[] args) throws Exception {
     doTest();
     new TestConfig().doFollowReferencesTest();
+
+    doStringTest();
   }
 
   public static void doTest() throws Exception {
@@ -34,6 +36,17 @@
     enableGcTracking(false);
   }
 
+  public static void doStringTest() throws Exception {
+    final String str = "HelloWorld";
+    Object o = new Object() {
+      String s = str;
+    };
+
+    setTag(str, 1);
+    System.out.println(Arrays.toString(followReferencesString(o)));
+    System.out.println(getTag(str));
+  }
+
   private static void run() {
     clearStats();
     forceGarbageCollection();
@@ -410,4 +423,5 @@
 
   public static native String[] followReferences(int heapFilter, Class<?> klassFilter,
       Object initialObject, int stopAfter, int followSet, Object jniRef);
+  public static native String[] followReferencesString(Object initialObject);
 }