[CFL] Cache DC_PRED during CfL-RDO

By default, the DC_PRED is not cached (this includes
decoding). During cfl_rd_pick_alpha(), DC_PRED caching
is enabled, the DC_PRED is cached after the first time it
is computed (for each plane) and then it is reused when
testing all the other scaling parameters.

Change-Id: Ie8ba0bb0427c4d9be8de5b44e6330e8a78b9c7d9
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index 229d82a..6e0ae12 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -537,14 +537,22 @@
   (CFL_SUB8X8_VAL_MI_SIZE * CFL_SUB8X8_VAL_MI_SIZE)
 #endif  // CONFIG_DEBUG
 #define CFL_MAX_BLOCK_SIZE (BLOCK_32X32)
-#define CFL_PRED_BUF_LINE (32)
-#define CFL_PRED_BUF_SQUARE (CFL_PRED_BUF_LINE * CFL_PRED_BUF_LINE)
+#define CFL_BUF_LINE (32)
+#define CFL_BUF_SQUARE (CFL_BUF_LINE * CFL_BUF_LINE)
 typedef struct cfl_ctx {
   // The CfL prediction buffer is used in two steps:
   //   1. Stores Q3 reconstructed luma pixels
   //      (only Q2 is required, but Q3 is used to avoid shifts)
   //   2. Stores Q3 AC contributions (step1 - tx block avg)
-  int16_t pred_buf_q3[CFL_PRED_BUF_SQUARE];
+  int16_t pred_buf_q3[CFL_BUF_SQUARE];
+
+  // Cache the DC_PRED when performing RDO, so it does not have to be recomputed
+  // for every scaling parameter
+  int dc_pred_is_cached[CFL_PRED_PLANES];
+  // The DC_PRED cache is disable when decoding
+  int use_dc_pred_cache;
+  // Only cache the first row of the DC_PRED
+  int16_t dc_pred_cache[CFL_PRED_PLANES][CFL_BUF_LINE];
 
   // Height and width currently used in the CfL prediction buffer.
   int buf_height, buf_width;
diff --git a/av1/common/cfl.c b/av1/common/cfl.c
index 43fa62c..4b2ffde 100644
--- a/av1/common/cfl.c
+++ b/av1/common/cfl.c
@@ -14,8 +14,8 @@
 #include "av1/common/onyxc_int.h"
 
 void cfl_init(CFL_CTX *cfl, AV1_COMMON *cm) {
-  assert(block_size_wide[CFL_MAX_BLOCK_SIZE] == CFL_PRED_BUF_LINE);
-  assert(block_size_high[CFL_MAX_BLOCK_SIZE] == CFL_PRED_BUF_LINE);
+  assert(block_size_wide[CFL_MAX_BLOCK_SIZE] == CFL_BUF_LINE);
+  assert(block_size_high[CFL_MAX_BLOCK_SIZE] == CFL_BUF_LINE);
   if ((cm->subsampling_x != 0 && cm->subsampling_x != 1) ||
       (cm->subsampling_y != 0 && cm->subsampling_y != 1)) {
     aom_internal_error(&cm->error, AOM_CODEC_UNSUP_BITSTREAM,
@@ -29,6 +29,11 @@
   cfl->subsampling_y = cm->subsampling_y;
   cfl->are_parameters_computed = 0;
   cfl->store_y = 0;
+  // The DC_PRED cache is disabled by default and is only enabled in
+  // cfl_rd_pick_alpha
+  cfl->use_dc_pred_cache = 0;
+  cfl->dc_pred_is_cached[CFL_PRED_U] = 0;
+  cfl->dc_pred_is_cached[CFL_PRED_V] = 0;
 #if CONFIG_DEBUG
   cfl_clear_sub8x8_val(cfl);
   cfl->store_counter = 0;
@@ -36,6 +41,56 @@
 #endif  // CONFIG_DEBUG
 }
 
+void cfl_store_dc_pred(MACROBLOCKD *const xd, const uint8_t *input,
+                       CFL_PRED_TYPE pred_plane, int width) {
+  assert(pred_plane < CFL_PRED_PLANES);
+  assert(width <= CFL_BUF_LINE);
+#if CONFIG_HIGHBITDEPTH
+  if (get_bitdepth_data_path_index(xd)) {
+    uint16_t *const input_16 = CONVERT_TO_SHORTPTR(input);
+    memcpy(xd->cfl.dc_pred_cache[pred_plane], input_16, width << 1);
+    return;
+  }
+#endif  // CONFIG_HIGHBITDEPTH
+  memcpy(xd->cfl.dc_pred_cache[pred_plane], input, width);
+}
+
+static void cfl_load_dc_pred_lbd(const int16_t *dc_pred_cache, uint8_t *dst,
+                                 int dst_stride, int width, int height) {
+  for (int j = 0; j < height; j++) {
+    memcpy(dst, dc_pred_cache, width);
+    dst += dst_stride;
+  }
+}
+
+static void cfl_load_dc_pred_hbd(const int16_t *dc_pred_cache, uint16_t *dst,
+                                 int dst_stride, int width, int height) {
+  const size_t num_bytes = width << 1;
+  for (int j = 0; j < height; j++) {
+    memcpy(dst, dc_pred_cache, num_bytes);
+    dst += dst_stride;
+  }
+}
+
+void cfl_load_dc_pred(MACROBLOCKD *const xd, uint8_t *dst, int dst_stride,
+                      TX_SIZE tx_size, CFL_PRED_TYPE pred_plane) {
+  const int width = tx_size_wide[tx_size];
+  const int height = tx_size_high[tx_size];
+  assert(pred_plane < CFL_PRED_PLANES);
+  assert(width <= CFL_BUF_LINE);
+  assert(height <= CFL_BUF_LINE);
+#if CONFIG_HIGHBITDEPTH
+  if (get_bitdepth_data_path_index(xd)) {
+    uint16_t *dst_16 = CONVERT_TO_SHORTPTR(dst);
+    cfl_load_dc_pred_hbd(xd->cfl.dc_pred_cache[pred_plane], dst_16, dst_stride,
+                         width, height);
+    return;
+  }
+#endif  // CONFIG_HIGHBITDEPTH
+  cfl_load_dc_pred_lbd(xd->cfl.dc_pred_cache[pred_plane], dst, dst_stride,
+                       width, height);
+}
+
 // Due to frame boundary issues, it is possible that the total area covered by
 // chroma exceeds that of luma. When this happens, we fill the missing pixels by
 // repeating the last columns and/or rows.
@@ -48,25 +103,24 @@
     int16_t *pred_buf_q3 = cfl->pred_buf_q3 + (width - diff_width);
     for (int j = 0; j < min_height; j++) {
       const int16_t last_pixel = pred_buf_q3[-1];
-      assert(pred_buf_q3 + diff_width <=
-             cfl->pred_buf_q3 + CFL_PRED_BUF_SQUARE);
+      assert(pred_buf_q3 + diff_width <= cfl->pred_buf_q3 + CFL_BUF_SQUARE);
       for (int i = 0; i < diff_width; i++) {
         pred_buf_q3[i] = last_pixel;
       }
-      pred_buf_q3 += CFL_PRED_BUF_LINE;
+      pred_buf_q3 += CFL_BUF_LINE;
     }
     cfl->buf_width = width;
   }
   if (diff_height > 0) {
     int16_t *pred_buf_q3 =
-        cfl->pred_buf_q3 + ((height - diff_height) * CFL_PRED_BUF_LINE);
+        cfl->pred_buf_q3 + ((height - diff_height) * CFL_BUF_LINE);
     for (int j = 0; j < diff_height; j++) {
-      const int16_t *last_row_q3 = pred_buf_q3 - CFL_PRED_BUF_LINE;
-      assert(pred_buf_q3 + width <= cfl->pred_buf_q3 + CFL_PRED_BUF_SQUARE);
+      const int16_t *last_row_q3 = pred_buf_q3 - CFL_BUF_LINE;
+      assert(pred_buf_q3 + width <= cfl->pred_buf_q3 + CFL_BUF_SQUARE);
       for (int i = 0; i < width; i++) {
         pred_buf_q3[i] = last_row_q3[i];
       }
-      pred_buf_q3 += CFL_PRED_BUF_LINE;
+      pred_buf_q3 += CFL_BUF_LINE;
     }
     cfl->buf_height = height;
   }
@@ -84,11 +138,11 @@
   cfl_pad(cfl, tx_width, tx_height);
 
   for (int j = 0; j < tx_height; j++) {
-    assert(pred_buf_q3 + tx_width <= cfl->pred_buf_q3 + CFL_PRED_BUF_SQUARE);
+    assert(pred_buf_q3 + tx_width <= cfl->pred_buf_q3 + CFL_BUF_SQUARE);
     for (int i = 0; i < tx_width; i++) {
       sum_q3 += pred_buf_q3[i];
     }
-    pred_buf_q3 += CFL_PRED_BUF_LINE;
+    pred_buf_q3 += CFL_BUF_LINE;
   }
   const int avg_q3 = (sum_q3 + (1 << (num_pel_log2 - 1))) >> num_pel_log2;
   // Loss is never more than 1/2 (in Q3)
@@ -96,11 +150,11 @@
          1);
   pred_buf_q3 = cfl->pred_buf_q3;
   for (int j = 0; j < tx_height; j++) {
-    assert(pred_buf_q3 + tx_width <= cfl->pred_buf_q3 + CFL_PRED_BUF_SQUARE);
+    assert(pred_buf_q3 + tx_width <= cfl->pred_buf_q3 + CFL_BUF_SQUARE);
     for (int i = 0; i < tx_width; i++) {
       pred_buf_q3[i] -= avg_q3;
     }
-    pred_buf_q3 += CFL_PRED_BUF_LINE;
+    pred_buf_q3 += CFL_BUF_LINE;
   }
 }
 
@@ -117,14 +171,14 @@
 static void cfl_build_prediction_lbd(const int16_t *pred_buf_q3, uint8_t *dst,
                                      int dst_stride, int width, int height,
                                      int alpha_q3) {
-  assert((height - 1) * CFL_PRED_BUF_LINE + width <= CFL_PRED_BUF_SQUARE);
+  assert((height - 1) * CFL_BUF_LINE + width <= CFL_BUF_SQUARE);
   for (int j = 0; j < height; j++) {
     for (int i = 0; i < width; i++) {
       dst[i] =
           clip_pixel(get_scaled_luma_q0(alpha_q3, pred_buf_q3[i]) + dst[i]);
     }
     dst += dst_stride;
-    pred_buf_q3 += CFL_PRED_BUF_LINE;
+    pred_buf_q3 += CFL_BUF_LINE;
   }
 }
 
@@ -132,14 +186,14 @@
 static void cfl_build_prediction_hbd(const int16_t *pred_buf_q3, uint16_t *dst,
                                      int dst_stride, int width, int height,
                                      int alpha_q3, int bit_depth) {
-  assert((height - 1) * CFL_PRED_BUF_LINE + width <= CFL_PRED_BUF_SQUARE);
+  assert((height - 1) * CFL_BUF_LINE + width <= CFL_BUF_SQUARE);
   for (int j = 0; j < height; j++) {
     for (int i = 0; i < width; i++) {
       dst[i] = clip_pixel_highbd(
           get_scaled_luma_q0(alpha_q3, pred_buf_q3[i]) + dst[i], bit_depth);
     }
     dst += dst_stride;
-    pred_buf_q3 += CFL_PRED_BUF_LINE;
+    pred_buf_q3 += CFL_BUF_LINE;
   }
 }
 #endif  // CONFIG_HIGHBITDEPTH
@@ -202,7 +256,7 @@
 static void cfl_luma_subsampling_420_lbd(const uint8_t *input, int input_stride,
                                          int16_t *output_q3, int width,
                                          int height) {
-  assert((height - 1) * CFL_PRED_BUF_LINE + width <= CFL_PRED_BUF_SQUARE);
+  assert((height - 1) * CFL_BUF_LINE + width <= CFL_BUF_SQUARE);
   for (int j = 0; j < height; j++) {
     for (int i = 0; i < width; i++) {
       int top = i << 1;
@@ -211,47 +265,47 @@
                      << 1;
     }
     input += input_stride << 1;
-    output_q3 += CFL_PRED_BUF_LINE;
+    output_q3 += CFL_BUF_LINE;
   }
 }
 
 static void cfl_luma_subsampling_422_lbd(const uint8_t *input, int input_stride,
                                          int16_t *output_q3, int width,
                                          int height) {
-  assert((height - 1) * CFL_PRED_BUF_LINE + width <= CFL_PRED_BUF_SQUARE);
+  assert((height - 1) * CFL_BUF_LINE + width <= CFL_BUF_SQUARE);
   for (int j = 0; j < height; j++) {
     for (int i = 0; i < width; i++) {
       int left = i << 1;
       output_q3[i] = (input[left] + input[left + 1]) << 2;
     }
     input += input_stride;
-    output_q3 += CFL_PRED_BUF_LINE;
+    output_q3 += CFL_BUF_LINE;
   }
 }
 
 static void cfl_luma_subsampling_440_lbd(const uint8_t *input, int input_stride,
                                          int16_t *output_q3, int width,
                                          int height) {
-  assert((height - 1) * CFL_PRED_BUF_LINE + width <= CFL_PRED_BUF_SQUARE);
+  assert((height - 1) * CFL_BUF_LINE + width <= CFL_BUF_SQUARE);
   for (int j = 0; j < height; j++) {
     for (int i = 0; i < width; i++) {
       output_q3[i] = (input[i] + input[i + input_stride]) << 2;
     }
     input += input_stride << 1;
-    output_q3 += CFL_PRED_BUF_LINE;
+    output_q3 += CFL_BUF_LINE;
   }
 }
 
 static void cfl_luma_subsampling_444_lbd(const uint8_t *input, int input_stride,
                                          int16_t *output_q3, int width,
                                          int height) {
-  assert((height - 1) * CFL_PRED_BUF_LINE + width <= CFL_PRED_BUF_SQUARE);
+  assert((height - 1) * CFL_BUF_LINE + width <= CFL_BUF_SQUARE);
   for (int j = 0; j < height; j++) {
     for (int i = 0; i < width; i++) {
       output_q3[i] = input[i] << 3;
     }
     input += input_stride;
-    output_q3 += CFL_PRED_BUF_LINE;
+    output_q3 += CFL_BUF_LINE;
   }
 }
 
@@ -269,7 +323,7 @@
 static void cfl_luma_subsampling_420_hbd(const uint16_t *input,
                                          int input_stride, int16_t *output_q3,
                                          int width, int height) {
-  assert((height - 1) * CFL_PRED_BUF_LINE + width <= CFL_PRED_BUF_SQUARE);
+  assert((height - 1) * CFL_BUF_LINE + width <= CFL_BUF_SQUARE);
   for (int j = 0; j < height; j++) {
     for (int i = 0; i < width; i++) {
       int top = i << 1;
@@ -278,47 +332,47 @@
                      << 1;
     }
     input += input_stride << 1;
-    output_q3 += CFL_PRED_BUF_LINE;
+    output_q3 += CFL_BUF_LINE;
   }
 }
 
 static void cfl_luma_subsampling_422_hbd(const uint16_t *input,
                                          int input_stride, int16_t *output_q3,
                                          int width, int height) {
-  assert((height - 1) * CFL_PRED_BUF_LINE + width <= CFL_PRED_BUF_SQUARE);
+  assert((height - 1) * CFL_BUF_LINE + width <= CFL_BUF_SQUARE);
   for (int j = 0; j < height; j++) {
     for (int i = 0; i < width; i++) {
       int left = i << 1;
       output_q3[i] = (input[left] + input[left + 1]) << 2;
     }
     input += input_stride;
-    output_q3 += CFL_PRED_BUF_LINE;
+    output_q3 += CFL_BUF_LINE;
   }
 }
 
 static void cfl_luma_subsampling_440_hbd(const uint16_t *input,
                                          int input_stride, int16_t *output_q3,
                                          int width, int height) {
-  assert((height - 1) * CFL_PRED_BUF_LINE + width <= CFL_PRED_BUF_SQUARE);
+  assert((height - 1) * CFL_BUF_LINE + width <= CFL_BUF_SQUARE);
   for (int j = 0; j < height; j++) {
     for (int i = 0; i < width; i++) {
       output_q3[i] = (input[i] + input[i + input_stride]) << 2;
     }
     input += input_stride << 1;
-    output_q3 += CFL_PRED_BUF_LINE;
+    output_q3 += CFL_BUF_LINE;
   }
 }
 
 static void cfl_luma_subsampling_444_hbd(const uint16_t *input,
                                          int input_stride, int16_t *output_q3,
                                          int width, int height) {
-  assert((height - 1) * CFL_PRED_BUF_LINE + width <= CFL_PRED_BUF_SQUARE);
+  assert((height - 1) * CFL_BUF_LINE + width <= CFL_BUF_SQUARE);
   for (int j = 0; j < height; j++) {
     for (int i = 0; i < width; i++) {
       output_q3[i] = input[i] << 3;
     }
     input += input_stride;
-    output_q3 += CFL_PRED_BUF_LINE;
+    output_q3 += CFL_BUF_LINE;
   }
 }
 
@@ -358,12 +412,12 @@
   }
 
   // Check that we will remain inside the pixel buffer.
-  assert(store_row + store_height <= CFL_PRED_BUF_LINE);
-  assert(store_col + store_width <= CFL_PRED_BUF_LINE);
+  assert(store_row + store_height <= CFL_BUF_LINE);
+  assert(store_col + store_width <= CFL_BUF_LINE);
 
   // Store the input into the CfL pixel buffer
   int16_t *pred_buf_q3 =
-      cfl->pred_buf_q3 + (store_row * CFL_PRED_BUF_LINE + store_col);
+      cfl->pred_buf_q3 + (store_row * CFL_BUF_LINE + store_col);
 
 #if CONFIG_HIGHBITDEPTH
   if (use_hbd) {
diff --git a/av1/common/cfl.h b/av1/common/cfl.h
index 7dbc2e2..13efdf4 100644
--- a/av1/common/cfl.h
+++ b/av1/common/cfl.h
@@ -25,6 +25,11 @@
   return ROUND_POWER_OF_TWO_SIGNED(scaled_luma_q6, 6);
 }
 
+static INLINE CFL_PRED_TYPE get_cfl_pred_type(PLANE_TYPE plane) {
+  assert(plane > 0);
+  return plane - 1;
+}
+
 void cfl_predict_block(MACROBLOCKD *const xd, uint8_t *dst, int dst_stride,
                        TX_SIZE tx_size, int plane);
 
@@ -32,4 +37,10 @@
 
 void cfl_store_tx(MACROBLOCKD *const xd, int row, int col, TX_SIZE tx_size,
                   BLOCK_SIZE bsize);
+
+void cfl_store_dc_pred(MACROBLOCKD *const xd, const uint8_t *input,
+                       CFL_PRED_TYPE pred_plane, int width);
+
+void cfl_load_dc_pred(MACROBLOCKD *const xd, uint8_t *dst, int dst_stride,
+                      TX_SIZE tx_size, CFL_PRED_TYPE pred_plane);
 #endif  // AV1_COMMON_CFL_H_
diff --git a/av1/common/reconintra.c b/av1/common/reconintra.c
index 1236699..b5e75d3 100644
--- a/av1/common/reconintra.c
+++ b/av1/common/reconintra.c
@@ -2762,10 +2762,6 @@
   const PREDICTION_MODE mode =
       (plane == AOM_PLANE_Y) ? mbmi->mode : get_uv_mode(mbmi->uv_mode);
 
-  av1_predict_intra_block(cm, xd, pd->width, pd->height,
-                          txsize_to_bsize[tx_size], mode, dst, dst_stride, dst,
-                          dst_stride, blk_col, blk_row, plane);
-
 #if CONFIG_CFL
   if (plane != AOM_PLANE_Y && mbmi->uv_mode == UV_CFL_PRED) {
 #if CONFIG_DEBUG
@@ -2777,9 +2773,26 @@
     assert(block_size_wide[plane_bsize] == tx_size_wide[tx_size]);
     assert(block_size_high[plane_bsize] == tx_size_high[tx_size]);
 #endif
+    CFL_CTX *const cfl = &xd->cfl;
+    CFL_PRED_TYPE pred_plane = get_cfl_pred_type(plane);
+    if (cfl->dc_pred_is_cached[pred_plane] == 0) {
+      av1_predict_intra_block(cm, xd, pd->width, pd->height,
+                              txsize_to_bsize[tx_size], mode, dst, dst_stride,
+                              dst, dst_stride, blk_col, blk_row, plane);
+      if (cfl->use_dc_pred_cache) {
+        cfl_store_dc_pred(xd, dst, pred_plane, tx_size_wide[tx_size]);
+        cfl->dc_pred_is_cached[pred_plane] = 1;
+      }
+    } else {
+      cfl_load_dc_pred(xd, dst, dst_stride, tx_size, pred_plane);
+    }
     cfl_predict_block(xd, dst, dst_stride, tx_size, plane);
+    return;
   }
 #endif
+  av1_predict_intra_block(cm, xd, pd->width, pd->height,
+                          txsize_to_bsize[tx_size], mode, dst, dst_stride, dst,
+                          dst_stride, blk_col, blk_row, plane);
 }
 
 // Copy the given row of dst into the equivalent row of ref, saving
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 7bf435d..fb8abf7 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -5418,6 +5418,7 @@
   assert(block_size_high[plane_bsize] == tx_size_high[tx_size]);
 #endif
 
+  xd->cfl.use_dc_pred_cache = 1;
   const int64_t mode_rd =
       RDCOST(x->rdmult, x->intra_uv_mode_cost[mbmi->mode][UV_CFL_PRED], 0);
   int64_t best_rd_uv[CFL_JOINT_SIGNS][CFL_PRED_PLANES];
@@ -5508,6 +5509,9 @@
 
   mbmi->cfl_alpha_idx = ind;
   mbmi->cfl_alpha_signs = best_joint_sign;
+  xd->cfl.use_dc_pred_cache = 0;
+  xd->cfl.dc_pred_is_cached[0] = 0;
+  xd->cfl.dc_pred_is_cached[1] = 0;
   return best_rate_overhead;
 }
 #endif  // CONFIG_CFL