Refactor OBU header+size parsing code

av1_decode_frame_from_obus() currently has a block of code to handle
OBU headers and to correctly derive the payload size in both the
annexb and non-annexb cases.

This patch moves that code into its own function, so that it can be
shared with other places which do OBU parsing (in particular,
decoder_peek_si_internal()).

Along the way, tighten up the error detection, so that we're more
likely to detect cases where the input buffer is too short.

BUG=aomedia:1794

Change-Id: If16fe25b3a044743d15f44b9c91c2f3ce7c49df2
diff --git a/av1/decoder/obu.c b/av1/decoder/obu.c
index ca72a56..c8cafee 100644
--- a/av1/decoder/obu.c
+++ b/av1/decoder/obu.c
@@ -79,6 +79,9 @@
                                        int is_annexb, ObuHeader *header) {
   if (!rb || !header) return AOM_CODEC_INVALID_PARAM;
 
+  const ptrdiff_t bit_buffer_byte_length = rb->bit_buffer_end - rb->bit_buffer;
+  if (bit_buffer_byte_length < 1) return AOM_CODEC_CORRUPT_FRAME;
+
   header->size = 1;
 
   // first bit is obu_forbidden_bit (0) according to R19
@@ -98,8 +101,9 @@
 
   aom_rb_read_bit(rb);  // reserved
 
-  const ptrdiff_t bit_buffer_byte_length = rb->bit_buffer_end - rb->bit_buffer;
-  if (header->has_extension && bit_buffer_byte_length > 1) {
+  if (header->has_extension) {
+    if (bit_buffer_byte_length == 1) return AOM_CODEC_CORRUPT_FRAME;
+
     header->size += 1;
     header->temporal_layer_id = aom_rb_read_literal(rb, 3);
     header->enhancement_layer_id = aom_rb_read_literal(rb, 2);
@@ -410,6 +414,46 @@
   return AOM_CODEC_OK;
 }
 
+aom_codec_err_t aom_read_obu_header_and_size(const uint8_t *data,
+                                             size_t bytes_available,
+                                             int is_annexb,
+                                             ObuHeader *obu_header,
+                                             size_t *const payload_size,
+                                             size_t *const bytes_read) {
+  size_t length_field_size = 0, obu_size = 0;
+  aom_codec_err_t status;
+
+  if (is_annexb) {
+    // Size field comes before the OBU header, and includes the OBU header
+    status =
+        read_obu_size(data, bytes_available, &obu_size, &length_field_size);
+
+    if (status != AOM_CODEC_OK) return status;
+  }
+
+  struct aom_read_bit_buffer rb = { data + length_field_size,
+                                    data + bytes_available, 0, NULL, NULL };
+
+  status = read_obu_header(&rb, is_annexb, obu_header);
+  if (status != AOM_CODEC_OK) return status;
+
+  if (is_annexb) {
+    // Derive the payload size from the data we've already read
+    if (obu_size < obu_header->size) return AOM_CODEC_CORRUPT_FRAME;
+
+    *payload_size = obu_size - obu_header->size;
+  } else {
+    // Size field comes after the OBU header, and is just the payload size
+    status = read_obu_size(data + obu_header->size,
+                           bytes_available - obu_header->size, payload_size,
+                           &length_field_size);
+    if (status != AOM_CODEC_OK) return status;
+  }
+
+  *bytes_read = length_field_size + obu_header->size;
+  return AOM_CODEC_OK;
+}
+
 void av1_decode_frame_from_obus(struct AV1Decoder *pbi, const uint8_t *data,
                                 const uint8_t *data_end,
                                 const uint8_t **p_data_end) {
@@ -434,61 +478,26 @@
     size_t payload_size = 0;
     size_t decoded_payload_size = 0;
     size_t obu_payload_offset = 0;
+    size_t bytes_read = 0;
     const size_t bytes_available = data_end - data;
 
-    if (bytes_available < 1) {
-      cm->error.error_code = AOM_CODEC_CORRUPT_FRAME;
-      return;
-    }
-
-    size_t length_field_size = 0;
-    size_t obu_size = 0;
-    if (cm->is_annexb) {
-      if (read_obu_size(data, bytes_available, &obu_size, &length_field_size) !=
-          AOM_CODEC_OK) {
-        cm->error.error_code = AOM_CODEC_CORRUPT_FRAME;
-        return;
-      }
-    }
-    if (data_end < data + length_field_size) {
-      cm->error.error_code = AOM_CODEC_CORRUPT_FRAME;
-      return;
-    }
-    av1_init_read_bit_buffer(pbi, &rb, data + length_field_size, data_end);
-
-    const aom_codec_err_t status =
-        read_obu_header(&rb, cm->is_annexb, &obu_header);
+    aom_codec_err_t status =
+        aom_read_obu_header_and_size(data, bytes_available, cm->is_annexb,
+                                     &obu_header, &payload_size, &bytes_read);
     if (status != AOM_CODEC_OK) {
       cm->error.error_code = status;
       return;
     }
 
-    if (!cm->is_annexb) {
-      if (data_end < data + obu_header.size) {
-        cm->error.error_code = AOM_CODEC_CORRUPT_FRAME;
-        return;
-      }
-      if (read_obu_size(data + obu_header.size,
-                        bytes_available - obu_header.size, &payload_size,
-                        &length_field_size) != AOM_CODEC_OK) {
-        cm->error.error_code = AOM_CODEC_CORRUPT_FRAME;
-        return;
-      }
-      av1_init_read_bit_buffer(
-          pbi, &rb, data + length_field_size + obu_header.size, data_end);
-    } else {
-      payload_size = obu_size - obu_header.size;
-    }
-
-    data += length_field_size + obu_header.size;
-    if (data_end < data) {
-      cm->error.error_code = AOM_CODEC_CORRUPT_FRAME;
-      return;
-    }
+    // Note: aom_read_obu_header_and_size() takes care of checking that this
+    // doesn't cause 'data' to advance past 'data_end'.
+    data += bytes_read;
 
     cm->temporal_layer_id = obu_header.temporal_layer_id;
     cm->enhancement_layer_id = obu_header.enhancement_layer_id;
 
+    av1_init_read_bit_buffer(pbi, &rb, data, data_end);
+
     switch (obu_header.type) {
       case OBU_TEMPORAL_DELIMITER:
         decoded_payload_size = read_temporal_delimiter_obu();
diff --git a/av1/decoder/obu.h b/av1/decoder/obu.h
index 2295c69..42f4e54 100644
--- a/av1/decoder/obu.h
+++ b/av1/decoder/obu.h
@@ -38,6 +38,13 @@
                                     size_t *consumed, ObuHeader *header,
                                     int is_annexb);
 
+aom_codec_err_t aom_read_obu_header_and_size(const uint8_t *data,
+                                             size_t bytes_available,
+                                             int is_annexb,
+                                             ObuHeader *obu_header,
+                                             size_t *const payload_size,
+                                             size_t *const bytes_read);
+
 void av1_decode_frame_from_obus(struct AV1Decoder *pbi, const uint8_t *data,
                                 const uint8_t *data_end,
                                 const uint8_t **p_data_end);