Add AUtils::isInRange, and use it to detect malformed MPEG4 nal sizes

Bug: 19641538
Change-Id: I5aae3f100846c125decc61eec7cd6563e3f33777
diff --git a/include/media/stagefright/foundation/AUtils.h b/include/media/stagefright/foundation/AUtils.h
index d7ecf50..47444c1 100644
--- a/include/media/stagefright/foundation/AUtils.h
+++ b/include/media/stagefright/foundation/AUtils.h
@@ -61,6 +61,28 @@
     return a > b ? a : b;
 }
 
+template<class T>
+void ENSURE_UNSIGNED_TYPE() {
+    T TYPE_MUST_BE_UNSIGNED[(T)-1 < 0 ? -1 : 0] __unused;
+}
+
+// needle is in range [hayStart, hayStart + haySize)
+template<class T, class U>
+inline static bool isInRange(const T &hayStart, const U &haySize, const T &needle) {
+    ENSURE_UNSIGNED_TYPE<U>();
+    return (T)(hayStart + haySize) >= hayStart && needle >= hayStart && (U)(needle - hayStart) < haySize;
+}
+
+// [needleStart, needleStart + needleSize) is in range [hayStart, hayStart + haySize)
+template<class T, class U>
+inline static bool isInRange(
+        const T &hayStart, const U &haySize, const T &needleStart, const U &needleSize) {
+    ENSURE_UNSIGNED_TYPE<U>();
+    return isInRange(hayStart, haySize, needleStart)
+            && (T)(needleStart + needleSize) >= needleStart
+            && (U)(needleStart + needleSize - hayStart) <= haySize;
+}
+
 /* T must be integer type, period must be positive */
 template<class T>
 inline static T periodicError(const T &val, const T &period) {
diff --git a/media/libstagefright/MPEG4Extractor.cpp b/media/libstagefright/MPEG4Extractor.cpp
index be3bc4a..9689ce4 100644
--- a/media/libstagefright/MPEG4Extractor.cpp
+++ b/media/libstagefright/MPEG4Extractor.cpp
@@ -33,6 +33,7 @@
 #include <media/stagefright/foundation/ABuffer.h>
 #include <media/stagefright/foundation/ADebug.h>
 #include <media/stagefright/foundation/AMessage.h>
+#include <media/stagefright/foundation/AUtils.h>
 #include <media/stagefright/MediaBuffer.h>
 #include <media/stagefright/MediaBufferGroup.h>
 #include <media/stagefright/MediaDefs.h>
@@ -221,8 +222,7 @@
 ssize_t MPEG4DataSource::readAt(off64_t offset, void *data, size_t size) {
     Mutex::Autolock autoLock(mLock);
 
-    if (offset >= mCachedOffset
-            && offset + size <= mCachedOffset + mCachedSize) {
+    if (isInRange(mCachedOffset, mCachedSize, offset, size)) {
         memcpy(data, &mCache[offset - mCachedOffset], size);
         return size;
     }
@@ -3882,12 +3882,12 @@
             size_t dstOffset = 0;
 
             while (srcOffset < size) {
-                bool isMalFormed = (srcOffset + mNALLengthSize > size);
+                bool isMalFormed = !isInRange((size_t)0u, size, srcOffset, mNALLengthSize);
                 size_t nalLength = 0;
                 if (!isMalFormed) {
                     nalLength = parseNALSize(&mSrcBuffer[srcOffset]);
                     srcOffset += mNALLengthSize;
-                    isMalFormed = srcOffset + nalLength > size;
+                    isMalFormed = !isInRange((size_t)0u, size, srcOffset, nalLength);
                 }
 
                 if (isMalFormed) {
@@ -4159,12 +4159,12 @@
             size_t dstOffset = 0;
 
             while (srcOffset < size) {
-                bool isMalFormed = (srcOffset + mNALLengthSize > size);
+                bool isMalFormed = !isInRange((size_t)0u, size, srcOffset, mNALLengthSize);
                 size_t nalLength = 0;
                 if (!isMalFormed) {
                     nalLength = parseNALSize(&mSrcBuffer[srcOffset]);
                     srcOffset += mNALLengthSize;
-                    isMalFormed = srcOffset + nalLength > size;
+                    isMalFormed = !isInRange((size_t)0u, size, srcOffset, nalLength);
                 }
 
                 if (isMalFormed) {
diff --git a/media/libstagefright/tests/Utils_test.cpp b/media/libstagefright/tests/Utils_test.cpp
index 5c323c1..c1e663c 100644
--- a/media/libstagefright/tests/Utils_test.cpp
+++ b/media/libstagefright/tests/Utils_test.cpp
@@ -192,6 +192,87 @@
     ASSERT_EQ(max(-4.3, 8.6), 8.6);
     ASSERT_EQ(max(8.6, -4.3), 8.6);
 
+    ASSERT_FALSE(isInRange(-43, 86u, -44));
+    ASSERT_TRUE(isInRange(-43, 87u, -43));
+    ASSERT_TRUE(isInRange(-43, 88u, -1));
+    ASSERT_TRUE(isInRange(-43, 89u, 0));
+    ASSERT_TRUE(isInRange(-43, 90u, 46));
+    ASSERT_FALSE(isInRange(-43, 91u, 48));
+    ASSERT_FALSE(isInRange(-43, 92u, 50));
+
+    ASSERT_FALSE(isInRange(43, 86u, 42));
+    ASSERT_TRUE(isInRange(43, 87u, 43));
+    ASSERT_TRUE(isInRange(43, 88u, 44));
+    ASSERT_TRUE(isInRange(43, 89u, 131));
+    ASSERT_FALSE(isInRange(43, 90u, 133));
+    ASSERT_FALSE(isInRange(43, 91u, 135));
+
+    ASSERT_FALSE(isInRange(43u, 86u, 42u));
+    ASSERT_TRUE(isInRange(43u, 85u, 43u));
+    ASSERT_TRUE(isInRange(43u, 84u, 44u));
+    ASSERT_TRUE(isInRange(43u, 83u, 125u));
+    ASSERT_FALSE(isInRange(43u, 82u, 125u));
+    ASSERT_FALSE(isInRange(43u, 81u, 125u));
+
+    ASSERT_FALSE(isInRange(-43, ~0u, 43));
+    ASSERT_FALSE(isInRange(-43, ~0u, 44));
+    ASSERT_FALSE(isInRange(-43, ~0u, ~0));
+    ASSERT_FALSE(isInRange(-43, ~0u, 41));
+    ASSERT_FALSE(isInRange(-43, ~0u, 40));
+
+    ASSERT_FALSE(isInRange(43u, ~0u, 43u));
+    ASSERT_FALSE(isInRange(43u, ~0u, 41u));
+    ASSERT_FALSE(isInRange(43u, ~0u, 40u));
+    ASSERT_FALSE(isInRange(43u, ~0u, ~0u));
+
+    ASSERT_FALSE(isInRange(-43, 86u, -44, 0u));
+    ASSERT_FALSE(isInRange(-43, 86u, -44, 1u));
+    ASSERT_FALSE(isInRange(-43, 86u, -44, 2u));
+    ASSERT_FALSE(isInRange(-43, 86u, -44, ~0u));
+    ASSERT_TRUE(isInRange(-43, 87u, -43, 0u));
+    ASSERT_TRUE(isInRange(-43, 87u, -43, 1u));
+    ASSERT_TRUE(isInRange(-43, 87u, -43, 86u));
+    ASSERT_TRUE(isInRange(-43, 87u, -43, 87u));
+    ASSERT_FALSE(isInRange(-43, 87u, -43, 88u));
+    ASSERT_FALSE(isInRange(-43, 87u, -43, ~0u));
+    ASSERT_TRUE(isInRange(-43, 88u, -1, 0u));
+    ASSERT_TRUE(isInRange(-43, 88u, -1, 45u));
+    ASSERT_TRUE(isInRange(-43, 88u, -1, 46u));
+    ASSERT_FALSE(isInRange(-43, 88u, -1, 47u));
+    ASSERT_FALSE(isInRange(-43, 88u, -1, ~3u));
+    ASSERT_TRUE(isInRange(-43, 90u, 46, 0u));
+    ASSERT_TRUE(isInRange(-43, 90u, 46, 1u));
+    ASSERT_FALSE(isInRange(-43, 90u, 46, 2u));
+    ASSERT_FALSE(isInRange(-43, 91u, 48, 0u));
+    ASSERT_FALSE(isInRange(-43, 91u, 48, 2u));
+    ASSERT_FALSE(isInRange(-43, 91u, 48, ~6u));
+    ASSERT_FALSE(isInRange(-43, 92u, 50, 0u));
+    ASSERT_FALSE(isInRange(-43, 92u, 50, 1u));
+
+    ASSERT_FALSE(isInRange(43u, 86u, 42u, 0u));
+    ASSERT_FALSE(isInRange(43u, 86u, 42u, 1u));
+    ASSERT_FALSE(isInRange(43u, 86u, 42u, 2u));
+    ASSERT_FALSE(isInRange(43u, 86u, 42u, ~0u));
+    ASSERT_TRUE(isInRange(43u, 87u, 43u, 0u));
+    ASSERT_TRUE(isInRange(43u, 87u, 43u, 1u));
+    ASSERT_TRUE(isInRange(43u, 87u, 43u, 86u));
+    ASSERT_TRUE(isInRange(43u, 87u, 43u, 87u));
+    ASSERT_FALSE(isInRange(43u, 87u, 43u, 88u));
+    ASSERT_FALSE(isInRange(43u, 87u, 43u, ~0u));
+    ASSERT_TRUE(isInRange(43u, 88u, 60u, 0u));
+    ASSERT_TRUE(isInRange(43u, 88u, 60u, 70u));
+    ASSERT_TRUE(isInRange(43u, 88u, 60u, 71u));
+    ASSERT_FALSE(isInRange(43u, 88u, 60u, 72u));
+    ASSERT_FALSE(isInRange(43u, 88u, 60u, ~3u));
+    ASSERT_TRUE(isInRange(43u, 90u, 132u, 0u));
+    ASSERT_TRUE(isInRange(43u, 90u, 132u, 1u));
+    ASSERT_FALSE(isInRange(43u, 90u, 132u, 2u));
+    ASSERT_FALSE(isInRange(43u, 91u, 134u, 0u));
+    ASSERT_FALSE(isInRange(43u, 91u, 134u, 2u));
+    ASSERT_FALSE(isInRange(43u, 91u, 134u, ~6u));
+    ASSERT_FALSE(isInRange(43u, 92u, 136u, 0u));
+    ASSERT_FALSE(isInRange(43u, 92u, 136u, 1u));
+
     ASSERT_EQ(periodicError(124, 100), 24);
     ASSERT_EQ(periodicError(288, 100), 12);
     ASSERT_EQ(periodicError(-345, 100), 45);