Add txfm functions corresponding to MRC_DCT

MRC_DCT uses a mask based on the prediction signal to modify the
residual before applying DCT_DCT. This adds all necessary functions
to perform this transform and makes the prediction signal available
to the 32x32 txfm functions so the mask can be created. I am still
experimenting with different types of mask generation functions and
so this patch contains a placeholder. This patch has no impact on
performance.

Change-Id: Ie3772f528e82103187a85c91cf00bb291dba328a
diff --git a/aom_dsp/inv_txfm.c b/aom_dsp/inv_txfm.c
index bf1345c..398eb0a 100644
--- a/aom_dsp/inv_txfm.c
+++ b/aom_dsp/inv_txfm.c
@@ -1222,6 +1222,109 @@
   output[31] = WRAPLOW(step1[0] - step1[31]);
 }
 
+#if CONFIG_MRC_TX
+void aom_imrc32x32_1024_add_c(const tran_low_t *input, uint8_t *dest,
+                              int stride, int *mask) {
+  tran_low_t out[32 * 32];
+  tran_low_t *outptr = out;
+  int i, j;
+  tran_low_t temp_in[32], temp_out[32];
+
+  // Rows
+  for (i = 0; i < 32; ++i) {
+    int16_t zero_coeff[16];
+    for (j = 0; j < 16; ++j) zero_coeff[j] = input[2 * j] | input[2 * j + 1];
+    for (j = 0; j < 8; ++j)
+      zero_coeff[j] = zero_coeff[2 * j] | zero_coeff[2 * j + 1];
+    for (j = 0; j < 4; ++j)
+      zero_coeff[j] = zero_coeff[2 * j] | zero_coeff[2 * j + 1];
+    for (j = 0; j < 2; ++j)
+      zero_coeff[j] = zero_coeff[2 * j] | zero_coeff[2 * j + 1];
+
+    if (zero_coeff[0] | zero_coeff[1])
+      aom_idct32_c(input, outptr);
+    else
+      memset(outptr, 0, sizeof(tran_low_t) * 32);
+    input += 32;
+    outptr += 32;
+  }
+
+  // Columns
+  for (i = 0; i < 32; ++i) {
+    for (j = 0; j < 32; ++j) temp_in[j] = out[j * 32 + i];
+    aom_idct32_c(temp_in, temp_out);
+    for (j = 0; j < 32; ++j) {
+      // Only add the coefficient if the mask value is 1
+      int mask_val = mask[j * 32 + i];
+      dest[j * stride + i] =
+          mask_val ? clip_pixel_add(dest[j * stride + i],
+                                    ROUND_POWER_OF_TWO(temp_out[j], 6))
+                   : dest[j * stride + i];
+    }
+  }
+}
+
+void aom_imrc32x32_135_add_c(const tran_low_t *input, uint8_t *dest, int stride,
+                             int *mask) {
+  tran_low_t out[32 * 32] = { 0 };
+  tran_low_t *outptr = out;
+  int i, j;
+  tran_low_t temp_in[32], temp_out[32];
+
+  // Rows
+  // only upper-left 16x16 has non-zero coeff
+  for (i = 0; i < 16; ++i) {
+    aom_idct32_c(input, outptr);
+    input += 32;
+    outptr += 32;
+  }
+
+  // Columns
+  for (i = 0; i < 32; ++i) {
+    for (j = 0; j < 32; ++j) temp_in[j] = out[j * 32 + i];
+    aom_idct32_c(temp_in, temp_out);
+    for (j = 0; j < 32; ++j) {
+      // Only add the coefficient if the mask value is 1
+      int mask_val = mask[j * 32 + i];
+      dest[j * stride + i] =
+          mask_val ? clip_pixel_add(dest[j * stride + i],
+                                    ROUND_POWER_OF_TWO(temp_out[j], 6))
+                   : dest[j * stride + i];
+    }
+  }
+}
+
+void aom_imrc32x32_34_add_c(const tran_low_t *input, uint8_t *dest, int stride,
+                            int *mask) {
+  tran_low_t out[32 * 32] = { 0 };
+  tran_low_t *outptr = out;
+  int i, j;
+  tran_low_t temp_in[32], temp_out[32];
+
+  // Rows
+  // only upper-left 8x8 has non-zero coeff
+  for (i = 0; i < 8; ++i) {
+    aom_idct32_c(input, outptr);
+    input += 32;
+    outptr += 32;
+  }
+
+  // Columns
+  for (i = 0; i < 32; ++i) {
+    for (j = 0; j < 32; ++j) temp_in[j] = out[j * 32 + i];
+    aom_idct32_c(temp_in, temp_out);
+    for (j = 0; j < 32; ++j) {
+      // Only add the coefficient if the mask value is 1
+      int mask_val = mask[j * 32 + i];
+      dest[j * stride + i] =
+          mask_val ? clip_pixel_add(dest[j * stride + i],
+                                    ROUND_POWER_OF_TWO(temp_out[j], 6))
+                   : dest[j * stride + i];
+    }
+  }
+}
+#endif  // CONFIG_MRC_TX
+
 void aom_idct32x32_1024_add_c(const tran_low_t *input, uint8_t *dest,
                               int stride) {
   tran_low_t out[32 * 32];
diff --git a/aom_dsp/inv_txfm.h b/aom_dsp/inv_txfm.h
index 5bf6ed6..a9c485e 100644
--- a/aom_dsp/inv_txfm.h
+++ b/aom_dsp/inv_txfm.h
@@ -52,6 +52,18 @@
 #define WRAPLOW(x) ((int32_t)check_range(x, 8))
 #define HIGHBD_WRAPLOW(x, bd) ((int32_t)check_range((x), bd))
 
+#if CONFIG_MRC_TX
+// These each perform dct but add coefficients based on a mask
+void aom_imrc32x32_1024_add_c(const tran_low_t *input, uint8_t *dest,
+                              int stride, int *mask);
+
+void aom_imrc32x32_135_add_c(const tran_low_t *input, uint8_t *dest, int stride,
+                             int *mask);
+
+void aom_imrc32x32_34_add_c(const tran_low_t *input, uint8_t *dest, int stride,
+                            int *mask);
+#endif  // CONFIG_MRC_TX
+
 void aom_idct4_c(const tran_low_t *input, tran_low_t *output);
 void aom_idct8_c(const tran_low_t *input, tran_low_t *output);
 void aom_idct16_c(const tran_low_t *input, tran_low_t *output);
diff --git a/aom_dsp/txfm_common.h b/aom_dsp/txfm_common.h
index 945c963..01732ae 100644
--- a/aom_dsp/txfm_common.h
+++ b/aom_dsp/txfm_common.h
@@ -27,11 +27,13 @@
   int tx_size;
   int lossless;
   int bd;
+#if CONFIG_MRC_TX || CONFIG_LGT
+  int stride;
+  uint8_t *dst;
+#endif  // CONFIG_MRC_TX || CONFIG_LGT
 #if CONFIG_LGT
   int is_inter;
-  int stride;
   int mode;
-  uint8_t *dst;
 #endif
 // for inverse transforms only
 #if CONFIG_ADAPT_SCAN
diff --git a/av1/common/av1_txfm.h b/av1/common/av1_txfm.h
index 1304e4c..269ef57 100644
--- a/av1/common/av1_txfm.h
+++ b/av1/common/av1_txfm.h
@@ -209,6 +209,16 @@
   }
 }
 
+#if CONFIG_MRC_TX
+static INLINE void get_mrc_mask(const uint8_t *pred, int pred_stride, int *mask,
+                                int mask_stride, int width, int height) {
+  for (int i = 0; i < height; ++i) {
+    for (int j = 0; j < width; ++j)
+      mask[i * mask_stride + j] = pred[i * pred_stride + j] > 100 ? 1 : 0;
+  }
+}
+#endif  // CONFIG_MRC_TX
+
 #ifdef __cplusplus
 extern "C" {
 #endif
diff --git a/av1/common/idct.c b/av1/common/idct.c
index 63dfdb0..09f01d3 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -1463,6 +1463,35 @@
     aom_idct16x16_256_add(input, dest, stride);
 }
 
+#if CONFIG_MRC_TX
+static void imrc32x32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
+                            const TxfmParam *txfm_param) {
+#if CONFIG_ADAPT_SCAN
+  const int16_t half = txfm_param->eob_threshold[0];
+  const int16_t quarter = txfm_param->eob_threshold[1];
+#else
+  const int16_t half = 135;
+  const int16_t quarter = 34;
+#endif
+
+  const int eob = txfm_param->eob;
+  if (eob == 1) {
+    aom_idct32x32_1_add_c(input, dest, stride);
+  } else {
+    tran_low_t mask[32 * 32];
+    get_mrc_mask(txfm_param->dst, txfm_param->stride, mask, 32, 32, 32);
+    if (eob <= quarter)
+      // non-zero coeff only in upper-left 8x8
+      aom_imrc32x32_34_add_c(input, dest, stride, mask);
+    else if (eob <= half)
+      // non-zero coeff only in upper-left 16x16
+      aom_imrc32x32_135_add_c(input, dest, stride, mask);
+    else
+      aom_imrc32x32_1024_add_c(input, dest, stride, mask);
+  }
+}
+#endif  // CONFIG_MRC_TX
+
 static void idct32x32_add(const tran_low_t *input, uint8_t *dest, int stride,
                           const TxfmParam *txfm_param) {
 #if CONFIG_ADAPT_SCAN
@@ -1486,24 +1515,6 @@
     aom_idct32x32_1024_add(input, dest, stride);
 }
 
-#if CONFIG_MRC_TX
-static void get_masked_residual32_inv(const tran_low_t *input, uint8_t *dest,
-                                      tran_low_t *output) {
-  // placeholder for bitmask creation, in the future it
-  // will likely be made based on dest
-  (void)dest;
-  memcpy(output, input, 32 * 32 * sizeof(*input));
-}
-
-static void imrc32x32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
-                            const TxfmParam *param) {
-  // placeholder mrc tx function
-  tran_low_t masked_input[32 * 32];
-  get_masked_residual32_inv(input, dest, masked_input);
-  idct32x32_add(input, dest, stride, param);
-}
-#endif  // CONFIG_MRC_TX
-
 #if CONFIG_TX64X64
 static void idct64x64_add(const tran_low_t *input, uint8_t *dest, int stride,
                           const TxfmParam *txfm_param) {
@@ -2200,10 +2211,12 @@
 #endif  // CONFIG_PVQ
   TxfmParam txfm_param;
   init_txfm_param(xd, tx_size, tx_type, eob, &txfm_param);
-#if CONFIG_LGT
+#if CONFIG_LGT || CONFIG_MRC_TX
   txfm_param.dst = dst;
-  txfm_param.mode = mode;
   txfm_param.stride = stride;
+#endif  // CONFIG_LGT || CONFIG_MRC_TX
+#if CONFIG_LGT
+  txfm_param.mode = mode;
 #endif
 
   const int is_hbd = get_bitdepth_data_path_index(xd);
diff --git a/av1/encoder/dct.c b/av1/encoder/dct.c
index 2ffc656..850b84c 100644
--- a/av1/encoder/dct.c
+++ b/av1/encoder/dct.c
@@ -1064,17 +1064,29 @@
 }
 
 #if CONFIG_MRC_TX
-static void get_masked_residual32_fwd(const tran_low_t *input,
-                                      tran_low_t *output) {
-  // placeholder for future bitmask creation
-  memcpy(output, input, 32 * 32 * sizeof(*input));
-}
-
-static void fmrc32(const tran_low_t *input, tran_low_t *output) {
-  // placeholder for mrc_dct, this just performs regular dct
-  tran_low_t masked_input[32 * 32];
-  get_masked_residual32_fwd(input, masked_input);
-  fdct32(masked_input, output);
+static void get_masked_residual32(const int16_t **input, int *input_stride,
+                                  const uint8_t *pred, int pred_stride,
+                                  int16_t *masked_input) {
+  int mrc_mask[32 * 32];
+  get_mrc_mask(pred, pred_stride, mrc_mask, 32, 32, 32);
+  int32_t sum = 0;
+  int16_t avg;
+  // Get the masked average of the prediction
+  for (int i = 0; i < 32; ++i) {
+    for (int j = 0; j < 32; ++j) {
+      sum += mrc_mask[i * 32 + j] * (*input)[i * (*input_stride) + j];
+    }
+  }
+  avg = ROUND_POWER_OF_TWO_SIGNED(sum, 10);
+  // Replace all of the unmasked pixels in the prediction with the average
+  // of the masked pixels
+  for (int i = 0; i < 32; ++i) {
+    for (int j = 0; j < 32; ++j)
+      masked_input[i * 32 + j] =
+          (mrc_mask[i * 32 + j]) ? (*input)[i * (*input_stride) + j] : avg;
+  }
+  *input = masked_input;
+  *input_stride = 32;
 }
 #endif  // CONFIG_MRC_TX
 
@@ -2387,7 +2399,7 @@
     { fidtx32, fhalfright32 },       // H_FLIPADST
 #endif
 #if CONFIG_MRC_TX
-    { fmrc32, fmrc32 },  // MRC_TX
+    { fdct32, fdct32 },  // MRC_TX
 #endif                   // CONFIG_MRC_TX
   };
   const transform_2d ht = FHT[tx_type];
@@ -2400,6 +2412,14 @@
   maybe_flip_input(&input, &stride, 32, 32, flipped_input, tx_type);
 #endif
 
+#if CONFIG_MRC_TX
+  if (tx_type == MRC_DCT) {
+    int16_t masked_input[32 * 32];
+    get_masked_residual32(&input, &stride, txfm_param->dst, txfm_param->stride,
+                          masked_input);
+  }
+#endif  // CONFIG_MRC_TX
+
   // Columns
   for (i = 0; i < 32; ++i) {
     for (j = 0; j < 32; ++j) temp_in[j] = input[j * stride + i] * 4;
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index b532c13..5b91cec 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -538,7 +538,7 @@
 
   TxfmParam txfm_param;
 
-#if CONFIG_PVQ || CONFIG_DIST_8X8 || CONFIG_LGT
+#if CONFIG_PVQ || CONFIG_DIST_8X8 || CONFIG_LGT || CONFIG_MRC_TX
   uint8_t *dst;
   const int dst_stride = pd->dst.stride;
 #if CONFIG_PVQ || CONFIG_DIST_8X8
@@ -601,7 +601,7 @@
 #endif  // CONFIG_HIGHBITDEPTH
 #endif
 
-#if CONFIG_PVQ || CONFIG_DIST_8X8 || CONFIG_LGT
+#if CONFIG_PVQ || CONFIG_DIST_8X8 || CONFIG_LGT || CONFIG_MRC_TX
   dst = &pd->dst.buf[(blk_row * dst_stride + blk_col) << tx_size_wide_log2[0]];
 #if CONFIG_PVQ || CONFIG_DIST_8X8
   pred = &pd->pred[(blk_row * diff_stride + blk_col) << tx_size_wide_log2[0]];
@@ -623,17 +623,19 @@
   }
 #endif  // CONFIG_HIGHBITDEPTH
 #endif  // CONFIG_PVQ || CONFIG_DIST_8X8
-#endif  // CONFIG_PVQ || CONFIG_DIST_8X8 || CONFIG_LGT
+#endif  // CONFIG_PVQ || CONFIG_DIST_8X8 || CONFIG_LGT || CONFIG_MRC_TX
 
   (void)ctx;
 
   txfm_param.tx_type = tx_type;
   txfm_param.tx_size = tx_size;
   txfm_param.lossless = xd->lossless[mbmi->segment_id];
-#if CONFIG_LGT
-  txfm_param.is_inter = is_inter_block(mbmi);
+#if CONFIG_MRC_TX || CONFIG_LGT
   txfm_param.dst = dst;
   txfm_param.stride = dst_stride;
+#endif  // CONFIG_MRC_TX || CONFIG_LGT
+#if CONFIG_LGT
+  txfm_param.is_inter = is_inter_block(mbmi);
   txfm_param.mode = get_prediction_mode(xd->mi[0], plane, tx_size, block);
 #endif
 
diff --git a/configure b/configure
index 8d5b19c..c8e1f3e 100755
--- a/configure
+++ b/configure
@@ -538,6 +538,7 @@
     enabled pvq && disable_feature var_tx
     enabled pvq && disable_feature highbitdepth
     enabled pvq && disable_feature lgt
+    enabled pvq && disable_feature mrc_tx
     enabled palette_throughput && soft_enable palette
     enabled ext_delta_q && soft_enable delta_q
     enabled txk_sel && soft_enable lv_map