mp3dec: Fix out of bound read error

Add check for required number of bytes before stream read
while reading side info.
Modify bitstream read functions to only read required number of bytes

Bug: 154075955
Bug: 154076193
Test: POC in bug description
Test: atest android.mediav2.cts.CodecDecoderTest
Test: atest Mp3DecoderTest -- --enable-module-dynamic-download=true

Change-Id: I777f22d21cbf026056f1ac69de4bb763846b1a9d
(cherry picked from commit c57092d196049acacd85623365b99e45ccc61b86)
diff --git a/media/libstagefright/codecs/mp3dec/src/pvmp3_framedecoder.cpp b/media/libstagefright/codecs/mp3dec/src/pvmp3_framedecoder.cpp
index a5c7f5e..15d2feb 100644
--- a/media/libstagefright/codecs/mp3dec/src/pvmp3_framedecoder.cpp
+++ b/media/libstagefright/codecs/mp3dec/src/pvmp3_framedecoder.cpp
@@ -219,6 +219,11 @@
 
     if (info->error_protection)
     {
+        if (!bitsAvailable(&pVars->inputStream, 16))
+        {
+            return SIDE_INFO_ERROR;
+        }
+
         /*
          *  Get crc content
          */
diff --git a/media/libstagefright/codecs/mp3dec/src/pvmp3_get_side_info.cpp b/media/libstagefright/codecs/mp3dec/src/pvmp3_get_side_info.cpp
index d644207..1a3fca5 100644
--- a/media/libstagefright/codecs/mp3dec/src/pvmp3_get_side_info.cpp
+++ b/media/libstagefright/codecs/mp3dec/src/pvmp3_get_side_info.cpp
@@ -73,6 +73,7 @@
 
 #include "pvmp3_get_side_info.h"
 #include "pvmp3_crc.h"
+#include "pvmp3_getbits.h"
 
 
 /*----------------------------------------------------------------------------
@@ -125,12 +126,22 @@
     {
         if (stereo == 1)
         {
+            if (!bitsAvailable(inputStream, 14))
+            {
+                return SIDE_INFO_ERROR;
+            }
+
             tmp = getbits_crc(inputStream, 14, crc, info->error_protection);
             si->main_data_begin = (tmp << 18) >> 23;    /* 9 */
             si->private_bits    = (tmp << 27) >> 27;    /* 5 */
         }
         else
         {
+            if (!bitsAvailable(inputStream, 12))
+            {
+                return SIDE_INFO_ERROR;
+            }
+
             tmp = getbits_crc(inputStream, 12, crc, info->error_protection);
             si->main_data_begin = (tmp << 20) >> 23;    /* 9 */
             si->private_bits    = (tmp << 29) >> 29;    /* 3 */
@@ -139,6 +150,11 @@
 
         for (ch = 0; ch < stereo; ch++)
         {
+            if (!bitsAvailable(inputStream, 4))
+            {
+                return SIDE_INFO_ERROR;
+            }
+
             tmp = getbits_crc(inputStream, 4, crc, info->error_protection);
             si->ch[ch].scfsi[0] = (tmp << 28) >> 31;    /* 1 */
             si->ch[ch].scfsi[1] = (tmp << 29) >> 31;    /* 1 */
@@ -150,6 +166,11 @@
         {
             for (ch = 0; ch < stereo; ch++)
             {
+                if (!bitsAvailable(inputStream, 34))
+                {
+                    return SIDE_INFO_ERROR;
+                }
+
                 si->ch[ch].gran[gr].part2_3_length    = getbits_crc(inputStream, 12, crc, info->error_protection);
                 tmp = getbits_crc(inputStream, 22, crc, info->error_protection);
 
@@ -160,6 +181,11 @@
 
                 if (si->ch[ch].gran[gr].window_switching_flag)
                 {
+                    if (!bitsAvailable(inputStream, 22))
+                    {
+                        return SIDE_INFO_ERROR;
+                    }
+
                     tmp = getbits_crc(inputStream, 22, crc, info->error_protection);
 
                     si->ch[ch].gran[gr].block_type       = (tmp << 10) >> 30;   /* 2 */;
@@ -192,6 +218,11 @@
                 }
                 else
                 {
+                    if (!bitsAvailable(inputStream, 22))
+                    {
+                        return SIDE_INFO_ERROR;
+                    }
+
                     tmp = getbits_crc(inputStream, 22, crc, info->error_protection);
 
                     si->ch[ch].gran[gr].table_select[0] = (tmp << 10) >> 27;   /* 5 */;
@@ -204,6 +235,11 @@
                     si->ch[ch].gran[gr].block_type      = 0;
                 }
 
+                if (!bitsAvailable(inputStream, 3))
+                {
+                    return SIDE_INFO_ERROR;
+                }
+
                 tmp = getbits_crc(inputStream, 3, crc, info->error_protection);
                 si->ch[ch].gran[gr].preflag            = (tmp << 29) >> 31;    /* 1 */
                 si->ch[ch].gran[gr].scalefac_scale     = (tmp << 30) >> 31;    /* 1 */
@@ -213,11 +249,21 @@
     }
     else /* Layer 3 LSF */
     {
+        if (!bitsAvailable(inputStream, 8 + stereo))
+        {
+            return SIDE_INFO_ERROR;
+        }
+
         si->main_data_begin = getbits_crc(inputStream,      8, crc, info->error_protection);
         si->private_bits    = getbits_crc(inputStream, stereo, crc, info->error_protection);
 
         for (ch = 0; ch < stereo; ch++)
         {
+            if (!bitsAvailable(inputStream, 39))
+            {
+                return SIDE_INFO_ERROR;
+            }
+
             tmp = getbits_crc(inputStream, 21, crc, info->error_protection);
             si->ch[ch].gran[0].part2_3_length    = (tmp << 11) >> 20;  /* 12 */
             si->ch[ch].gran[0].big_values        = (tmp << 23) >> 23;  /*  9 */
@@ -230,6 +276,11 @@
             if (si->ch[ch].gran[0].window_switching_flag)
             {
 
+                if (!bitsAvailable(inputStream, 22))
+                {
+                    return SIDE_INFO_ERROR;
+                }
+
                 tmp = getbits_crc(inputStream, 22, crc, info->error_protection);
 
                 si->ch[ch].gran[0].block_type       = (tmp << 10) >> 30;   /* 2 */;
@@ -262,6 +313,11 @@
             }
             else
             {
+                if (!bitsAvailable(inputStream, 22))
+                {
+                    return SIDE_INFO_ERROR;
+                }
+
                 tmp = getbits_crc(inputStream, 22, crc, info->error_protection);
 
                 si->ch[ch].gran[0].table_select[0] = (tmp << 10) >> 27;   /* 5 */;
@@ -274,6 +330,11 @@
                 si->ch[ch].gran[0].block_type      = 0;
             }
 
+            if (!bitsAvailable(inputStream, 2))
+            {
+                return SIDE_INFO_ERROR;
+            }
+
             tmp = getbits_crc(inputStream, 2, crc, info->error_protection);
             si->ch[ch].gran[0].scalefac_scale     =  tmp >> 1;  /* 1 */
             si->ch[ch].gran[0].count1table_select =  tmp & 1;  /* 1 */
diff --git a/media/libstagefright/codecs/mp3dec/src/pvmp3_getbits.cpp b/media/libstagefright/codecs/mp3dec/src/pvmp3_getbits.cpp
index 8ff7953..4d252ef 100644
--- a/media/libstagefright/codecs/mp3dec/src/pvmp3_getbits.cpp
+++ b/media/libstagefright/codecs/mp3dec/src/pvmp3_getbits.cpp
@@ -113,10 +113,11 @@
 
     uint32    offset;
     uint32    bitIndex;
-    uint8     Elem;         /* Needs to be same type as pInput->pBuffer */
-    uint8     Elem1;
-    uint8     Elem2;
-    uint8     Elem3;
+    uint32    bytesToFetch;
+    uint8     Elem  = 0;         /* Needs to be same type as pInput->pBuffer */
+    uint8     Elem1 = 0;
+    uint8     Elem2 = 0;
+    uint8     Elem3 = 0;
     uint32   returnValue = 0;
 
     if (!neededBits)
@@ -126,10 +127,25 @@
 
     offset = (ptBitStream->usedBits) >> INBUF_ARRAY_INDEX_SHIFT;
 
-    Elem  = *(ptBitStream->pBuffer + module(offset  , BUFSIZE));
-    Elem1 = *(ptBitStream->pBuffer + module(offset + 1, BUFSIZE));
-    Elem2 = *(ptBitStream->pBuffer + module(offset + 2, BUFSIZE));
-    Elem3 = *(ptBitStream->pBuffer + module(offset + 3, BUFSIZE));
+    /* Remove extra high bits by shifting up */
+    bitIndex = module(ptBitStream->usedBits, INBUF_BIT_WIDTH);
+
+    bytesToFetch = (bitIndex + neededBits + 7 ) >> 3 ;
+
+    switch (bytesToFetch)
+    {
+    case 4:
+        Elem3 = *(ptBitStream->pBuffer + module(offset + 3, BUFSIZE));
+        [[fallthrough]];
+    case 3:
+        Elem2 = *(ptBitStream->pBuffer + module(offset + 2, BUFSIZE));
+        [[fallthrough]];
+    case 2:
+        Elem1 = *(ptBitStream->pBuffer + module(offset + 1, BUFSIZE));
+        [[fallthrough]];
+    case 1:
+        Elem = *(ptBitStream->pBuffer + module(offset, BUFSIZE));
+    }
 
 
     returnValue = (((uint32)(Elem)) << 24) |
@@ -137,9 +153,6 @@
                   (((uint32)(Elem2)) << 8) |
                   ((uint32)(Elem3));
 
-    /* Remove extra high bits by shifting up */
-    bitIndex = module(ptBitStream->usedBits, INBUF_BIT_WIDTH);
-
     /* This line is faster than to mask off the high bits. */
     returnValue <<= bitIndex;
 
@@ -161,22 +174,32 @@
 
     uint32    offset;
     uint32    bitIndex;
-    uint8    Elem;         /* Needs to be same type as pInput->pBuffer */
-    uint8    Elem1;
+    uint32    bytesToFetch;
+    uint8    Elem  = 0;         /* Needs to be same type as pInput->pBuffer */
+    uint8    Elem1 = 0;
     uint16   returnValue;
 
     offset = (ptBitStream->usedBits) >> INBUF_ARRAY_INDEX_SHIFT;
 
-    Elem  = *(ptBitStream->pBuffer + module(offset  , BUFSIZE));
-    Elem1 = *(ptBitStream->pBuffer + module(offset + 1, BUFSIZE));
+    /* Remove extra high bits by shifting up */
+    bitIndex = module(ptBitStream->usedBits, INBUF_BIT_WIDTH);
+
+    bytesToFetch = (bitIndex + neededBits + 7 ) >> 3 ;
+
+    if (bytesToFetch > 1)
+    {
+        Elem = *(ptBitStream->pBuffer + module(offset, BUFSIZE));
+        Elem1 = *(ptBitStream->pBuffer + module(offset + 1, BUFSIZE));
+    }
+    else if (bytesToFetch > 0)
+    {
+        Elem = *(ptBitStream->pBuffer + module(offset, BUFSIZE));
+    }
 
 
     returnValue = (((uint16)(Elem)) << 8) |
                   ((uint16)(Elem1));
 
-    /* Remove extra high bits by shifting up */
-    bitIndex = module(ptBitStream->usedBits, INBUF_BIT_WIDTH);
-
     ptBitStream->usedBits += neededBits;
     /* This line is faster than to mask off the high bits. */
     returnValue = (returnValue << (bitIndex));
@@ -197,25 +220,40 @@
 
     uint32    offset;
     uint32    bitIndex;
-    uint8     Elem;         /* Needs to be same type as pInput->pBuffer */
-    uint8     Elem1;
-    uint8     Elem2;
+    uint32    bytesToFetch;
+    uint8     Elem  = 0;         /* Needs to be same type as pInput->pBuffer */
+    uint8     Elem1 = 0;
+    uint8     Elem2 = 0;
     uint32   returnValue;
 
     offset = (ptBitStream->usedBits) >> INBUF_ARRAY_INDEX_SHIFT;
 
-    Elem  = *(ptBitStream->pBuffer + module(offset  , BUFSIZE));
-    Elem1 = *(ptBitStream->pBuffer + module(offset + 1, BUFSIZE));
-    Elem2 = *(ptBitStream->pBuffer + module(offset + 2, BUFSIZE));
+    /* Remove extra high bits by shifting up */
+    bitIndex = module(ptBitStream->usedBits, INBUF_BIT_WIDTH);
+
+    bytesToFetch = (bitIndex + neededBits + 7 ) >> 3 ;
+
+    if (bytesToFetch > 2)
+    {
+        Elem  = *(ptBitStream->pBuffer + module(offset, BUFSIZE));
+        Elem1 = *(ptBitStream->pBuffer + module(offset + 1, BUFSIZE));
+        Elem2 = *(ptBitStream->pBuffer + module(offset + 2, BUFSIZE));
+    }
+    else if (bytesToFetch > 1)
+    {
+        Elem  = *(ptBitStream->pBuffer + module(offset, BUFSIZE));
+        Elem1 = *(ptBitStream->pBuffer + module(offset + 1, BUFSIZE));
+    }
+    else if (bytesToFetch > 0)
+    {
+        Elem = *(ptBitStream->pBuffer + module(offset, BUFSIZE));
+    }
 
 
     returnValue = (((uint32)(Elem)) << 16) |
                   (((uint32)(Elem1)) << 8) |
                   ((uint32)(Elem2));
 
-    /* Remove extra high bits by shifting up */
-    bitIndex = module(ptBitStream->usedBits, INBUF_BIT_WIDTH);
-
     ptBitStream->usedBits += neededBits;
     /* This line is faster than to mask off the high bits. */
     returnValue = 0xFFFFFF & (returnValue << (bitIndex));
diff --git a/media/libstagefright/codecs/mp3dec/src/pvmp3_getbits.h b/media/libstagefright/codecs/mp3dec/src/pvmp3_getbits.h
index b058b00..b04fe6d 100644
--- a/media/libstagefright/codecs/mp3dec/src/pvmp3_getbits.h
+++ b/media/libstagefright/codecs/mp3dec/src/pvmp3_getbits.h
@@ -104,6 +104,11 @@
 ; Function Prototype declaration
 ----------------------------------------------------------------------------*/
 
+static inline bool bitsAvailable(tmp3Bits *inputStream, uint32 neededBits)
+{
+    return (inputStream->inputBufferCurrentLength << 3) >= (neededBits + inputStream->usedBits);
+}
+
 /*----------------------------------------------------------------------------
 ; END
 ----------------------------------------------------------------------------*/