Validate input of JNI::NewStringUTF().

Do a minimal validation needed to construct valid strings.
If the validation fails, replace invalid sequences with '?'.

(cherry picked from commit d0b940349294a363e6d578adf58db8222c425669)

Test: Additional tests in JniInternalTest.NewStringUTF
Test: m test-art-host-gtest
Bug: 172655291
Merged-In: I683c142fe5972599297d604f775bb8cfe6154bb7
Merged-In: Ica2c5c33f981cbd2f07e7990b3e321cd3b7473b6
Change-Id: Ieda5a89b738bbb624872512a400855866e120acd
diff --git a/runtime/jni_internal.cc b/runtime/jni_internal.cc
index cd66a60..60c7d3b 100644
--- a/runtime/jni_internal.cc
+++ b/runtime/jni_internal.cc
@@ -27,6 +27,7 @@
 #include "art_method-inl.h"
 #include "base/allocator.h"
 #include "base/atomic.h"
+#include "base/casts.h"
 #include "base/enums.h"
 #include "base/logging.h"  // For VLOG.
 #include "base/mutex.h"
@@ -34,7 +35,7 @@
 #include "base/stl_util.h"
 #include "class_linker-inl.h"
 #include "dex/dex_file-inl.h"
-#include "dex/utf.h"
+#include "dex/utf-inl.h"
 #include "fault_handler.h"
 #include "hidden_api.h"
 #include "gc/accounting/card_table-inl.h"
@@ -60,7 +61,10 @@
 #include "thread.h"
 #include "well_known_classes.h"
 
+namespace art {
+
 namespace {
+
 // Frees the given va_list upon destruction.
 // This also guards the returns from inside of the CHECK_NON_NULL_ARGUMENTs.
 struct ScopedVAArgs {
@@ -72,9 +76,84 @@
  private:
   va_list* args;
 };
-}  // namespace
 
-namespace art {
+constexpr char kBadUtf8ReplacementChar[] = "?";
+
+// This is a modified version of CountModifiedUtf8Chars() from utf.cc
+// with extra checks and different output options.
+//
+// The `good` functor can process valid characters.
+// The `bad` functor is called when we find an invalid character and
+// returns true to abort processing, or false to continue processing.
+//
+// When aborted, VisitModifiedUtf8Chars() returns 0, otherwise the
+// number of UTF-16 chars.
+template <typename GoodFunc, typename BadFunc>
+size_t VisitModifiedUtf8Chars(const char* utf8, size_t byte_count, GoodFunc good, BadFunc bad) {
+  DCHECK_LE(byte_count, strlen(utf8));
+  size_t len = 0;
+  const char* end = utf8 + byte_count;
+  while (utf8 != end) {
+    int ic = *utf8;
+    if (LIKELY((ic & 0x80) == 0)) {
+      // One-byte encoding.
+      good(utf8, 1u);
+      utf8 += 1u;
+      len += 1u;
+      continue;
+    }
+    auto is_ascii = [utf8]() {
+      const char* ptr = utf8;  // Make a copy that can be modified by GetUtf16FromUtf8().
+      return mirror::String::IsASCII(dchecked_integral_cast<uint16_t>(GetUtf16FromUtf8(&ptr)));
+    };
+    // Note: Neither CountModifiedUtf8Chars() nor GetUtf16FromUtf8() checks whether
+    // the bit 0x40 is correctly set in the leading byte of a multi-byte sequence.
+    if ((ic & 0x20) == 0) {
+      // Two-byte encoding.
+      if (static_cast<size_t>(end - utf8) < 2u) {
+        return bad() ? 0u : len + 1u;  // Reached end of sequence
+      }
+      if (mirror::kUseStringCompression && is_ascii()) {
+        if (bad()) {
+          return 0u;
+        }
+      } else {
+        good(utf8, 2u);
+      }
+      utf8 += 2u;
+      len += 1u;
+      continue;
+    }
+    if ((ic & 0x10) == 0) {
+      // Three-byte encoding.
+      if (static_cast<size_t>(end - utf8) < 3u) {
+        return bad() ? 0u : len + 1u;  // Reached end of sequence
+      }
+      if (mirror::kUseStringCompression && is_ascii()) {
+        if (bad()) {
+          return 0u;
+        }
+      } else {
+        good(utf8, 3u);
+      }
+      utf8 += 3u;
+      len += 1u;
+      continue;
+    }
+
+    // Four-byte encoding: needs to be converted into a surrogate pair.
+    // The decoded chars are never ASCII.
+    if (static_cast<size_t>(end - utf8) < 4u) {
+      return bad() ? 0u : len + 1u;  // Reached end of sequence
+    }
+    good(utf8, 4u);
+    utf8 += 4u;
+    len += 2u;
+  }
+  return len;
+}
+
+}  // namespace
 
 // Consider turning this on when there is errors which could be related to JNI array copies such as
 // things not rendering correctly. E.g. b/16858794
@@ -1778,8 +1857,69 @@
     if (utf == nullptr) {
       return nullptr;
     }
+
+    // The input may come from an untrusted source, so we need to validate it.
+    // We do not perform full validation, only as much as necessary to avoid reading
+    // beyond the terminating null character or breaking string compression invariants.
+    // CheckJNI performs stronger validation.
+    size_t utf8_length = strlen(utf);
+    if (UNLIKELY(utf8_length > static_cast<uint32_t>(std::numeric_limits<int32_t>::max()))) {
+      // Converting the utf8_length to int32_t for String::AllocFromModifiedUtf8() would
+      // overflow. Throw OOME eagerly to avoid 2GiB allocation when trying to replace
+      // invalid sequences (even if such replacements could reduce the size below 2GiB).
+      std::string error =
+          android::base::StringPrintf("NewStringUTF input is 2 GiB or more: %zu", utf8_length);
+      ScopedObjectAccess soa(env);
+      soa.Self()->ThrowOutOfMemoryError(error.c_str());
+      return nullptr;
+    }
+    std::unique_ptr<char[]> replacement_utf;
+    size_t replacement_utf_pos = 0u;
+    size_t utf16_length = VisitModifiedUtf8Chars(
+        utf,
+        utf8_length,
+        /*good=*/ [](const char* ptr ATTRIBUTE_UNUSED, size_t length ATTRIBUTE_UNUSED) {},
+        /*bad=*/ []() { return true; });  // Abort processing and return 0 for bad characters.
+    if (UNLIKELY(utf8_length != 0u && utf16_length == 0u)) {
+      // VisitModifiedUtf8Chars() aborted for a bad character.
+      // Report the error to logcat but avoid too much spam.
+      static const uint64_t kMinDelay = UINT64_C(10000000000);  // 10s
+      static std::atomic<uint64_t> prev_bad_input_time(UINT64_C(0));
+      uint64_t prev_time = prev_bad_input_time.load(std::memory_order_relaxed);
+      uint64_t now = NanoTime();
+      if ((prev_time == 0u || now - prev_time >= kMinDelay) &&
+          prev_bad_input_time.compare_exchange_strong(prev_time, now, std::memory_order_relaxed)) {
+        LOG(ERROR) << "Invalid UTF-8 input to JNI::NewStringUTF()";
+      }
+      // Copy the input to the `replacement_utf` and replace bad characters.
+      replacement_utf.reset(new char[utf8_length + 1u]);
+      utf16_length = VisitModifiedUtf8Chars(
+          utf,
+          utf8_length,
+          /*good=*/ [&](const char* ptr, size_t length) {
+            DCHECK_GE(utf8_length - replacement_utf_pos, length);
+            memcpy(&replacement_utf[replacement_utf_pos], ptr, length);
+            replacement_utf_pos += length;
+          },
+          /*bad=*/ [&]() {
+            DCHECK_GE(utf8_length - replacement_utf_pos, sizeof(kBadUtf8ReplacementChar) - 1u);
+            memcpy(&replacement_utf[replacement_utf_pos],
+                   kBadUtf8ReplacementChar,
+                   sizeof(kBadUtf8ReplacementChar) - 1u);
+            replacement_utf_pos += sizeof(kBadUtf8ReplacementChar) - 1u;
+            return false;  // Continue processing.
+          });
+      DCHECK_LE(replacement_utf_pos, utf8_length);
+      replacement_utf[replacement_utf_pos] = 0;  // Terminating null.
+      utf = replacement_utf.get();
+      utf8_length = replacement_utf_pos;
+    }
+    DCHECK_LE(utf16_length, utf8_length);
+    DCHECK_LE(utf8_length, static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
+
     ScopedObjectAccess soa(env);
-    mirror::String* result = mirror::String::AllocFromModifiedUtf8(soa.Self(), utf);
+    mirror::String* result =
+        mirror::String::AllocFromModifiedUtf8(soa.Self(), utf16_length, utf, utf8_length);
     return soa.AddLocalReference<jstring>(result);
   }
 
diff --git a/runtime/jni_internal_test.cc b/runtime/jni_internal_test.cc
index 5d74181..100fe23 100644
--- a/runtime/jni_internal_test.cc
+++ b/runtime/jni_internal_test.cc
@@ -23,6 +23,7 @@
 #include "indirect_reference_table.h"
 #include "java_vm_ext.h"
 #include "jni_env_ext.h"
+#include "mem_map.h"
 #include "mirror/string-inl.h"
 #include "nativehelper/scoped_local_ref.h"
 #include "scoped_thread_state_change-inl.h"
@@ -1551,6 +1552,119 @@
   EXPECT_EQ(13, env_->GetStringUTFLength(s));
 }
 
+TEST_F(JniInternalTest, NewStringUTF_Validation) {
+  // For the following tests, allocate two pages, one R/W and the next inaccessible.
+  std::string error_msg;
+  std::unique_ptr<MemMap> head_map(MemMap::MapAnonymous("head",
+                                                        /*addr=*/ nullptr,
+                                                        2 * kPageSize,
+                                                        PROT_READ | PROT_WRITE,
+                                                        /*low_4gb=*/ false,
+                                                        /*reuse=*/ false,
+                                                        &error_msg));
+  ASSERT_TRUE(head_map != nullptr) << error_msg;
+  std::unique_ptr<MemMap> tail_map(
+      head_map->RemapAtEnd(head_map->Begin() + kPageSize, "tail", PROT_NONE, &error_msg));
+  ASSERT_TRUE(tail_map != nullptr) << error_msg;
+  char* utf_src = reinterpret_cast<char*>(head_map->Begin());
+
+  // Prepare for checking the `count` field.
+  jclass c = env_->FindClass("java/lang/String");
+  ASSERT_NE(c, nullptr);
+  jfieldID count_fid = env_->GetFieldID(c, "count", "I");
+  ASSERT_TRUE(count_fid != nullptr);
+
+  // Prepare for testing with the unchecked interface.
+  const JNINativeInterface* base_env = down_cast<JNIEnvExt*>(env_)->GetUncheckedFunctions();
+
+  // Start with a simple ASCII string consisting of 4095 characters 'x'.
+  memset(utf_src, 'x', kPageSize - 1u);
+  utf_src[kPageSize - 1u] = 0u;
+  jstring s = base_env->NewStringUTF(env_, utf_src);
+  ASSERT_EQ(mirror::String::GetFlaggedCount(kPageSize - 1u, /* compressible= */ true),
+            env_->GetIntField(s, count_fid));
+  const char* chars = env_->GetStringUTFChars(s, nullptr);
+  for (size_t pos = 0; pos != kPageSize - 1u; ++pos) {
+    ASSERT_EQ('x', chars[pos]) << pos;
+  }
+  env_->ReleaseStringUTFChars(s, chars);
+
+  // Replace the last character with invalid character that requires continuation.
+  for (char invalid : { '\xc0', '\xe0', '\xf0' }) {
+    utf_src[kPageSize - 2u] = invalid;
+    s = base_env->NewStringUTF(env_, utf_src);
+    ASSERT_EQ(mirror::String::GetFlaggedCount(kPageSize - 1u, /* compressible= */ true),
+              env_->GetIntField(s, count_fid));
+    chars = env_->GetStringUTFChars(s, nullptr);
+    for (size_t pos = 0; pos != kPageSize - 2u; ++pos) {
+      ASSERT_EQ('x', chars[pos]) << pos;
+    }
+    EXPECT_EQ('?', chars[kPageSize - 2u]);
+    env_->ReleaseStringUTFChars(s, chars);
+  }
+
+  // Replace the first two characters with a valid two-byte sequence yielding one character.
+  utf_src[0] = '\xc2';
+  utf_src[1] = '\x80';
+  s = base_env->NewStringUTF(env_, utf_src);
+  ASSERT_EQ(mirror::String::GetFlaggedCount(kPageSize - 2u, /* compressible= */ false),
+            env_->GetIntField(s, count_fid));
+  const jchar* jchars = env_->GetStringChars(s, nullptr);
+  EXPECT_EQ(jchars[0], 0x80u);
+  for (size_t pos = 1; pos != kPageSize - 3u; ++pos) {
+    ASSERT_EQ('x', jchars[pos]) << pos;
+  }
+  EXPECT_EQ('?', jchars[kPageSize - 3u]);
+  env_->ReleaseStringChars(s, jchars);
+
+  // Replace the leading two-byte sequence with a two-byte sequence that decodes as ASCII (0x40).
+  // The sequence shall be replaced if string compression is used.
+  utf_src[0] = '\xc1';
+  utf_src[1] = '\x80';
+  s = base_env->NewStringUTF(env_, utf_src);
+  // Note: All invalid characters are replaced by ASCII replacement character.
+  ASSERT_EQ(mirror::String::GetFlaggedCount(kPageSize - 2u, /* compressible= */ true),
+            env_->GetIntField(s, count_fid));
+  jchars = env_->GetStringChars(s, nullptr);
+  EXPECT_EQ(mirror::kUseStringCompression ? '?' : '\x40', jchars[0]);
+  for (size_t pos = 1; pos != kPageSize - 3u; ++pos) {
+    ASSERT_EQ('x', jchars[pos]) << pos;
+  }
+  EXPECT_EQ('?', jchars[kPageSize - 3u]);
+  env_->ReleaseStringChars(s, jchars);
+
+  // Replace the leading three bytes with a three-byte sequence that decodes as ASCII (0x40).
+  // The sequence shall be replaced if string compression is used.
+  utf_src[0] = '\xe0';
+  utf_src[1] = '\x81';
+  utf_src[2] = '\x80';
+  s = base_env->NewStringUTF(env_, utf_src);
+  // Note: All invalid characters are replaced by ASCII replacement character.
+  ASSERT_EQ(mirror::String::GetFlaggedCount(kPageSize - 3u, /* compressible= */ true),
+            env_->GetIntField(s, count_fid));
+  jchars = env_->GetStringChars(s, nullptr);
+  EXPECT_EQ(mirror::kUseStringCompression ? '?' : '\x40', jchars[0]);
+  for (size_t pos = 1; pos != kPageSize - 4u; ++pos) {
+    ASSERT_EQ('x', jchars[pos]) << pos;
+  }
+  EXPECT_EQ('?', jchars[kPageSize - 4u]);
+  env_->ReleaseStringChars(s, jchars);
+
+  // Replace the last two characters with a valid two-byte sequence that decodes as 0.
+  utf_src[kPageSize - 3u] = '\xc0';
+  utf_src[kPageSize - 2u] = '\x80';
+  s = base_env->NewStringUTF(env_, utf_src);
+  ASSERT_EQ(mirror::String::GetFlaggedCount(kPageSize - 4u, /* compressible= */ false),
+            env_->GetIntField(s, count_fid));
+  jchars = env_->GetStringChars(s, nullptr);
+  EXPECT_EQ(mirror::kUseStringCompression ? '?' : '\x40', jchars[0]);
+  for (size_t pos = 1; pos != kPageSize - 5u; ++pos) {
+    ASSERT_EQ('x', jchars[pos]) << pos;
+  }
+  EXPECT_EQ('\0', jchars[kPageSize - 5u]);
+  env_->ReleaseStringChars(s, jchars);
+}
+
 TEST_F(JniInternalTest, NewString) {
   jchar chars[] = { 'h', 'i' };
   jstring s;
diff --git a/runtime/mirror/string.h b/runtime/mirror/string.h
index 545fe93..5942dc3 100644
--- a/runtime/mirror/string.h
+++ b/runtime/mirror/string.h
@@ -230,13 +230,13 @@
   std::string PrettyStringDescriptor()
       REQUIRES_SHARED(Locks::mutator_lock_);
 
- private:
   static constexpr bool IsASCII(uint16_t c) {
     // Valid ASCII characters are in range 1..0x7f. Zero is not considered ASCII
     // because it would complicate the detection of ASCII strings in Modified-UTF8.
     return (c - 1u) < 0x7fu;
   }
 
+ private:
   static bool AllASCIIExcept(const uint16_t* chars, int32_t length, uint16_t non_ascii);
 
   void SetHashCode(int32_t new_hash_code) REQUIRES_SHARED(Locks::mutator_lock_) {