Fast search sgr_params in self-guided filter

This change is applicable for speed>2. This CL
Introduce a binary search like algorithm in
self-guided filter based on the group.

           Quality drop(AWCY)
Preset   Cb      Cr    Overall  Encode time reduction
   2   -0.03%  -0.05%   0.00%    0.9%
   3   -0.05%   0.08%   0.01%    1.4%
   4   -0.02%  -0.15%  -0.01%    0.9%
Encode time is averaged across multiple test cases

STATS_CHANGED

Change-Id: I7560b10a5cc284e7beb2dfb89107231498db5f58
diff --git a/av1/encoder/pickrst.c b/av1/encoder/pickrst.c
index c6ea82e..acdaf56 100644
--- a/av1/encoder/pickrst.c
+++ b/av1/encoder/pickrst.c
@@ -46,6 +46,19 @@
 // Working precision for Wiener filter coefficients
 #define WIENER_TAP_SCALE_FACTOR ((int64_t)1 << 16)
 
+#define SGRPROJ_EP_GRP1_START_IDX 0
+#define SGRPROJ_EP_GRP1_END_IDX 9
+#define SGRPROJ_EP_GRP1_SEARCH_COUNT 4
+#define SGRPROJ_EP_GRP2_3_SEARCH_COUNT 2
+static const int sgproj_ep_grp1_seed[SGRPROJ_EP_GRP1_SEARCH_COUNT] = { 0, 3, 6,
+                                                                       9 };
+static const int sgproj_ep_grp2_3[SGRPROJ_EP_GRP2_3_SEARCH_COUNT][14] = {
+  { 10, 10, 11, 11, 12, 12, 13, 13, 13, 13, -1, -1, -1, -1 },
+  { 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15 }
+};
+
+const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 2, 2 };
+
 typedef int64_t (*sse_extractor_type)(const YV12_BUFFER_CONFIG *a,
                                       const YV12_BUFFER_CONFIG *b);
 typedef int64_t (*sse_part_extractor_type)(const YV12_BUFFER_CONFIG *a,
@@ -547,13 +560,45 @@
   }
 }
 
+static void compute_sgrproj_err(const uint8_t *dat8, const int width,
+                                const int height, const int dat_stride,
+                                const uint8_t *src8, const int src_stride,
+                                const int use_highbitdepth, const int bit_depth,
+                                const int pu_width, const int pu_height,
+                                const int ep, int32_t *flt0, int32_t *flt1,
+                                const int flt_stride, int *exqd, int64_t *err) {
+  int exq[2];
+  apply_sgr(ep, dat8, width, height, dat_stride, use_highbitdepth, bit_depth,
+            pu_width, pu_height, flt0, flt1, flt_stride);
+  aom_clear_system_state();
+  const sgr_params_type *const params = &av1_sgr_params[ep];
+  get_proj_subspace(src8, width, height, src_stride, dat8, dat_stride,
+                    use_highbitdepth, flt0, flt_stride, flt1, flt_stride, exq,
+                    params);
+  aom_clear_system_state();
+  encode_xq(exq, exqd, params);
+  *err = finer_search_pixel_proj_error(
+      src8, width, height, src_stride, dat8, dat_stride, use_highbitdepth, flt0,
+      flt_stride, flt1, flt_stride, 2, exqd, params);
+}
+
+static void get_best_error(int64_t *besterr, const int64_t err, const int *exqd,
+                           int *bestxqd, int *bestep, const int ep) {
+  if (*besterr == -1 || err < *besterr) {
+    *bestep = ep;
+    *besterr = err;
+    bestxqd[0] = exqd[0];
+    bestxqd[1] = exqd[1];
+  }
+}
+
 static SgrprojInfo search_selfguided_restoration(
     const uint8_t *dat8, int width, int height, int dat_stride,
     const uint8_t *src8, int src_stride, int use_highbitdepth, int bit_depth,
-    int pu_width, int pu_height, int32_t *rstbuf) {
+    int pu_width, int pu_height, int32_t *rstbuf, int disable_sgr_ep_pruning) {
   int32_t *flt0 = rstbuf;
   int32_t *flt1 = flt0 + RESTORATION_UNITPELS_MAX;
-  int ep, bestep = 0;
+  int ep, idx, bestep = 0;
   int64_t besterr = -1;
   int exqd[2], bestxqd[2] = { 0, 0 };
   int flt_stride = ((width + 7) & ~7) + 8;
@@ -561,26 +606,43 @@
          pu_width == RESTORATION_PROC_UNIT_SIZE);
   assert(pu_height == (RESTORATION_PROC_UNIT_SIZE >> 1) ||
          pu_height == RESTORATION_PROC_UNIT_SIZE);
-
-  for (ep = 0; ep < SGRPROJ_PARAMS; ep++) {
-    int exq[2];
-    apply_sgr(ep, dat8, width, height, dat_stride, use_highbitdepth, bit_depth,
-              pu_width, pu_height, flt0, flt1, flt_stride);
-    aom_clear_system_state();
-    const sgr_params_type *const params = &av1_sgr_params[ep];
-    get_proj_subspace(src8, width, height, src_stride, dat8, dat_stride,
-                      use_highbitdepth, flt0, flt_stride, flt1, flt_stride, exq,
-                      params);
-    aom_clear_system_state();
-    encode_xq(exq, exqd, params);
-    int64_t err = finer_search_pixel_proj_error(
-        src8, width, height, src_stride, dat8, dat_stride, use_highbitdepth,
-        flt0, flt_stride, flt1, flt_stride, 2, exqd, params);
-    if (besterr == -1 || err < besterr) {
-      bestep = ep;
-      besterr = err;
-      bestxqd[0] = exqd[0];
-      bestxqd[1] = exqd[1];
+  if (disable_sgr_ep_pruning) {
+    for (ep = 0; ep < SGRPROJ_PARAMS; ep++) {
+      int64_t err;
+      compute_sgrproj_err(dat8, width, height, dat_stride, src8, src_stride,
+                          use_highbitdepth, bit_depth, pu_width, pu_height, ep,
+                          flt0, flt1, flt_stride, exqd, &err);
+      get_best_error(&besterr, err, exqd, bestxqd, &bestep, ep);
+    }
+  } else {
+    // evaluate first four seed ep in first group
+    for (idx = 0; idx < SGRPROJ_EP_GRP1_SEARCH_COUNT; idx++) {
+      ep = sgproj_ep_grp1_seed[idx];
+      int64_t err;
+      compute_sgrproj_err(dat8, width, height, dat_stride, src8, src_stride,
+                          use_highbitdepth, bit_depth, pu_width, pu_height, ep,
+                          flt0, flt1, flt_stride, exqd, &err);
+      get_best_error(&besterr, err, exqd, bestxqd, &bestep, ep);
+    }
+    // evaluate left and right ep of winner in seed ep
+    int bestep_ref = bestep;
+    for (ep = bestep_ref - 1; ep < bestep_ref + 2; ep += 2) {
+      if (ep < SGRPROJ_EP_GRP1_START_IDX || ep > SGRPROJ_EP_GRP1_END_IDX)
+        continue;
+      int64_t err;
+      compute_sgrproj_err(dat8, width, height, dat_stride, src8, src_stride,
+                          use_highbitdepth, bit_depth, pu_width, pu_height, ep,
+                          flt0, flt1, flt_stride, exqd, &err);
+      get_best_error(&besterr, err, exqd, bestxqd, &bestep, ep);
+    }
+    // evaluate last two group
+    for (idx = 0; idx < SGRPROJ_EP_GRP2_3_SEARCH_COUNT; idx++) {
+      ep = sgproj_ep_grp2_3[idx][bestep];
+      int64_t err;
+      compute_sgrproj_err(dat8, width, height, dat_stride, src8, src_stride,
+                          use_highbitdepth, bit_depth, pu_width, pu_height, ep,
+                          flt0, flt1, flt_stride, exqd, &err);
+      get_best_error(&besterr, err, exqd, bestxqd, &bestep, ep);
     }
   }
 
@@ -636,7 +698,7 @@
       dgd_start, limits->h_end - limits->h_start,
       limits->v_end - limits->v_start, rsc->dgd_stride, src_start,
       rsc->src_stride, highbd, bit_depth, procunit_width, procunit_height,
-      tmpbuf);
+      tmpbuf, rsc->sf->disable_sgr_ep_pruning);
 
   RestorationUnitInfo rui;
   rui.restoration_type = RESTORE_SGRPROJ;
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index a765afc..8d26d43 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -285,6 +285,7 @@
   if (speed >= 2) {
     sf->gm_erroradv_type = GM_ERRORADV_TR_2;
 
+    sf->disable_sgr_ep_pruning = 0;
     sf->selective_ref_frame = 3;
     sf->inter_tx_size_search_init_depth_rect = 1;
     sf->inter_tx_size_search_init_depth_sqr = 1;
@@ -721,6 +722,7 @@
   sf->use_first_partition_pass_interintra_stats = 0;
   sf->disable_wedge_search_var_thresh = 0;
   sf->disable_loop_restoration_chroma = 0;
+  sf->disable_sgr_ep_pruning = 1;
   sf->reduce_wiener_window_size = 0;
   sf->fast_wedge_sign_estimate = 0;
   sf->prune_wedge_pred_diff_based = 0;
diff --git a/av1/encoder/speed_features.h b/av1/encoder/speed_features.h
index b767b8c..8de96e7 100644
--- a/av1/encoder/speed_features.h
+++ b/av1/encoder/speed_features.h
@@ -635,6 +635,9 @@
   // aggressiveness
   int prune_motion_mode_level;
 
+  // prune sgr ep using binary search like mechanism
+  int disable_sgr_ep_pruning;
+
   // Gate warp evaluation for motions of type IDENTITY,
   // TRANSLATION and AFFINE(based on number of warp neighbors)
   int prune_warp_using_wmtype;