Do not wait for the entire reference frame to be decoded

In frame parallel mode, make use of the progress tracking of the
reference frame to wait only until the necessary rows are decoded.

PiperOrigin-RevId: 299206010
Change-Id: I51a60db4d1605ed33196cdaad6223128519abe99
diff --git a/src/decoder_impl.cc b/src/decoder_impl.cc
index f80701e..1ee9851 100644
--- a/src/decoder_impl.cc
+++ b/src/decoder_impl.cc
@@ -886,8 +886,7 @@
   if (IsFrameParallel()) {
     return DecodeTilesFrameParallel(
         sequence_header, frame_header, tiles, saved_symbol_decoder_context,
-        prev_segment_ids, state, frame_scratch_buffer, &post_filter,
-        current_frame);
+        prev_segment_ids, frame_scratch_buffer, &post_filter, current_frame);
   }
   StatusCode status;
   if (settings_.threads == 1) {
@@ -1015,7 +1014,7 @@
     const ObuFrameHeader& frame_header,
     const Vector<std::unique_ptr<Tile>>& tiles,
     const SymbolDecoderContext& saved_symbol_decoder_context,
-    const SegmentationMap* const prev_segment_ids, const DecoderState& state,
+    const SegmentationMap* const prev_segment_ids,
     FrameScratchBuffer* const frame_scratch_buffer,
     PostFilter* const post_filter, RefCountedBuffer* const current_frame) {
   // Parse the frame.
@@ -1043,17 +1042,8 @@
   SetCurrentFrameSegmentationMap(frame_header, prev_segment_ids, current_frame);
   // Mark frame as parsed.
   current_frame->SetFrameState(kFrameStateParsed);
-
-  // We can decode the current frame if all the reference frames have been
-  // decoded.
-  for (int i = 0; i < kNumReferenceFrameTypes; ++i) {
-    if (!state.reference_valid[i] || state.reference_frame[i] == nullptr) {
-      continue;
-    }
-    // Wait for this reference frame to be decoded.
-    state.reference_frame[i]->WaitUntilDecoded();
-  }
-  // Decode in superblock row order.
+  // Decode in superblock row order (inter prediction in the Tile class will
+  // block until the required superblocks in the reference frame are decoded).
   int row4x4;
   for (row4x4 = 0; row4x4 < frame_header.rows4x4; row4x4 += block_width4x4) {
     for (const auto& tile_ptr : tiles) {
diff --git a/src/decoder_impl.h b/src/decoder_impl.h
index 2fbb4b0..7ef97f3 100644
--- a/src/decoder_impl.h
+++ b/src/decoder_impl.h
@@ -216,7 +216,7 @@
       const ObuFrameHeader& frame_header,
       const Vector<std::unique_ptr<Tile>>& tiles,
       const SymbolDecoderContext& saved_symbol_decoder_context,
-      const SegmentationMap* prev_segment_ids, const DecoderState& state,
+      const SegmentationMap* prev_segment_ids,
       FrameScratchBuffer* frame_scratch_buffer, PostFilter* post_filter,
       RefCountedBuffer* current_frame);
   // Sets the current frame's segmentation map for two cases. The third case
diff --git a/src/tile.h b/src/tile.h
index d03b78a..ba00af5 100644
--- a/src/tile.h
+++ b/src/tile.h
@@ -697,6 +697,7 @@
   bool delta_lf_all_zero_;
   bool build_bit_mask_when_parsing_;
   bool initialized_;
+  const bool frame_parallel_;
 };
 
 struct Tile::Block {
diff --git a/src/tile/prediction.cc b/src/tile/prediction.cc
index 75ce824..9fd3d70 100644
--- a/src/tile/prediction.cc
+++ b/src/tile/prediction.cc
@@ -1063,6 +1063,26 @@
       reference_buffer->top_border(plane),
       reference_buffer->bottom_border(plane), &ref_block_start_x,
       &ref_block_start_y, &ref_block_end_x);
+
+  // In frame parallel mode, ensure that the reference block has been decoded
+  // and available for referencing.
+  if (reference_frame_index != -1 && frame_parallel_) {
+    int reference_y_max;
+    if (is_scaled) {
+      // TODO(vigneshv): For now, we wait for the entire reference frame to be
+      // decoded if we are using scaled references. This will eventually be
+      // fixed.
+      reference_y_max = reference_height;
+    } else {
+      reference_y_max = std::max(
+          std::min(ref_block_start_y + height + kSubPixelTaps, ref_last_y), 0);
+      // For U and V planes with subsampling, we need to multiply
+      // reference_y_max by 2 since we only track the progress of Y planes.
+      reference_y_max <<= subsampling_y;
+    }
+    reference_frames_[reference_frame_index]->WaitUntil(reference_y_max);
+  }
+
   const uint8_t* block_start = nullptr;
   ptrdiff_t convolve_buffer_stride;
   if (!extend_block) {
@@ -1173,6 +1193,31 @@
   const int source_height =
       reference_frames_[reference_frame_index]->buffer()->height(plane);
   uint16_t* const prediction = block.scratch_buffer->prediction_buffer[index];
+
+  // In frame parallel mode, ensure that the reference block has been decoded
+  // and available for referencing.
+  if (frame_parallel_) {
+    int reference_y_max = kLargeNegativeValue;
+    // Find out the maximum y-coordinate for warping.
+    for (int start_y = block_start_y; start_y < block_start_y + height;
+         start_y += 8) {
+      for (int start_x = block_start_x; start_x < block_start_x + width;
+           start_x += 8) {
+        const int src_x = (start_x + 4) << subsampling_x_[plane];
+        const int src_y = (start_y + 4) << subsampling_y_[plane];
+        const int dst_y = src_x * warp_params->params[4] +
+                          src_y * warp_params->params[5] +
+                          warp_params->params[1];
+        const int y4 = dst_y >> subsampling_y_[plane];
+        const int iy4 = y4 >> kWarpedModelPrecisionBits;
+        reference_y_max = std::max(iy4 + 8, reference_y_max);
+      }
+    }
+    // For U and V planes with subsampling, we need to multiply reference_y_max
+    // by 2 since we only track the progress of Y planes.
+    reference_y_max <<= subsampling_y_[plane];
+    reference_frames_[reference_frame_index]->WaitUntil(reference_y_max);
+  }
   if (is_compound) {
     dsp_.warp_compound(source, source_stride, source_width, source_height,
                        warp_params->params, subsampling_x_[plane],
diff --git a/src/tile/tile.cc b/src/tile/tile.cc
index 223e2cc..400f1dc 100644
--- a/src/tile/tile.cc
+++ b/src/tile/tile.cc
@@ -435,7 +435,8 @@
       tile_scratch_buffer_pool_(tile_scratch_buffer_pool),
       pending_tiles_(pending_tiles),
       build_bit_mask_when_parsing_(false),
-      initialized_(false) {
+      initialized_(false),
+      frame_parallel_(frame_parallel) {
   row_ = number_ / frame_header.tile_info.tile_columns;
   column_ = number_ % frame_header.tile_info.tile_columns;
   row4x4_start_ = frame_header.tile_info.tile_row_start[row_];