Add search_txk_type

Change-Id: I50493fa9daf2de8859608d57f8d2010842c9eb07
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index 3910faa..bdd04ad 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -325,6 +325,9 @@
 #endif
   MV_REFERENCE_FRAME ref_frame[2];
   TX_TYPE tx_type;
+#if CONFIG_LV_MAP
+  TX_TYPE txk_type[MAX_SB_SQUARE / (TX_SIZE_W_MIN * TX_SIZE_H_MIN)];
+#endif
 
 #if CONFIG_FILTER_INTRA
   FILTER_INTRA_MODE_INFO filter_intra_mode_info;
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index 2c1e22f..d8835d6 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -14,6 +14,7 @@
 #include "av1/common/pred_common.h"
 #include "av1/encoder/cost.h"
 #include "av1/encoder/encodetxb.h"
+#include "av1/encoder/rdopt.h"
 #include "av1/encoder/subexp.h"
 #include "av1/encoder/tokenize.h"
 
@@ -705,3 +706,47 @@
   for (tx_size = TX_4X4; tx_size <= max_tx_size; ++tx_size)
     write_txb_probs(w, cpi, tx_size);
 }
+static INLINE int allow_txk_type(MACROBLOCKD *xd, TX_SIZE tx_size, int plane) {
+  if (plane != 0 || tx_size == TX_32X32 ||
+      xd->lossless[xd->mi[0]->mbmi.segment_id])
+    return 0;
+
+  return 1;
+}
+
+int64_t av1_search_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
+                            int block, int blk_row, int blk_col,
+                            BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
+                            int coeff_ctx, RD_STATS *rd_stats) {
+  const AV1_COMMON *cm = &cpi->common;
+  MACROBLOCKD *xd = &x->e_mbd;
+  MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
+  TX_TYPE txk_start = DCT_DCT;
+  TX_TYPE txk_end = TX_TYPES - 1;
+  TX_TYPE best_tx_type = txk_start;
+  int64_t best_rd = INT64_MAX;
+  int best_eob = tx_size_2d[tx_size];
+  if (!allow_txk_type(xd, tx_size, plane)) txk_end = DCT_DCT;
+  TX_TYPE tx_type;
+  for (tx_type = txk_start; tx_type <= txk_end; ++tx_type) {
+    RD_STATS this_rd_stats;
+    av1_invalid_rd_stats(&this_rd_stats);
+    if (plane == 0) mbmi->txk_type[block] = tx_type;
+    av1_xform_quant(cm, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
+                    coeff_ctx, AV1_XFORM_QUANT_FP);
+    if (x->plane[plane].eobs[block] && !xd->lossless[mbmi->segment_id])
+      av1_optimize_b(cm, x, plane, block, tx_size, coeff_ctx);
+    av1_dist_block(cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size,
+                   &this_rd_stats.dist, &this_rd_stats.sse, 0);
+    int rd = RDCOST(x->rdmult, x->rddiv, 0, this_rd_stats.dist);
+    if (rd < best_rd) {
+      best_rd = rd;
+      *rd_stats = this_rd_stats;
+      best_tx_type = tx_type;
+      best_eob = x->plane[plane].eobs[block];
+    }
+  }
+  if (plane == 0) mbmi->txk_type[block] = best_tx_type;
+  x->plane[plane].eobs[block] = best_eob;
+  return best_rd;
+}
diff --git a/av1/encoder/encodetxb.h b/av1/encoder/encodetxb.h
index dd51f04..f6346d0 100644
--- a/av1/encoder/encodetxb.h
+++ b/av1/encoder/encodetxb.h
@@ -38,6 +38,11 @@
                             RUN_TYPE dry_run, BLOCK_SIZE bsize, int *rate,
                             const int mi_row, const int mi_col);
 void av1_write_txb_probs(AV1_COMP *cpi, aom_writer *w);
+
+int64_t av1_search_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
+                            int block, int blk_row, int blk_col,
+                            BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
+                            int coeff_ctx, RD_STATS *rd_stats);
 #ifdef __cplusplus
 }
 #endif
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index c9a51f3..5cc3f7d 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1360,15 +1360,10 @@
   return aom_sum_squares_2d_i16(diff, diff_stride, visible_cols, visible_rows);
 }
 
-typedef enum OUTPUT_STATUS {
-  OUTPUT_HAS_PREDICTED_PIXELS,
-  OUTPUT_HAS_DECODED_PIXELS
-} OUTPUT_STATUS;
-
-static void dist_block(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
-                       BLOCK_SIZE plane_bsize, int block, int blk_row,
-                       int blk_col, TX_SIZE tx_size, int64_t *out_dist,
-                       int64_t *out_sse, OUTPUT_STATUS output_status) {
+void av1_dist_block(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
+                    BLOCK_SIZE plane_bsize, int block, int blk_row, int blk_col,
+                    TX_SIZE tx_size, int64_t *out_dist, int64_t *out_sse,
+                    OUTPUT_STATUS output_status) {
   MACROBLOCKD *const xd = &x->e_mbd;
   const struct macroblock_plane *const p = &x->plane[plane];
   const struct macroblockd_plane *const pd = &xd->plane[plane];
@@ -1581,13 +1576,13 @@
     struct macroblock_plane *const p = &x->plane[plane];
     av1_inverse_transform_block_facade(xd, plane, block, blk_row, blk_col,
                                        p->eobs[block]);
-    dist_block(args->cpi, x, plane, plane_bsize, block, blk_row, blk_col,
-               tx_size, &this_rd_stats.dist, &this_rd_stats.sse,
-               OUTPUT_HAS_DECODED_PIXELS);
+    av1_dist_block(args->cpi, x, plane, plane_bsize, block, blk_row, blk_col,
+                   tx_size, &this_rd_stats.dist, &this_rd_stats.sse,
+                   OUTPUT_HAS_DECODED_PIXELS);
   } else {
-    dist_block(args->cpi, x, plane, plane_bsize, block, blk_row, blk_col,
-               tx_size, &this_rd_stats.dist, &this_rd_stats.sse,
-               OUTPUT_HAS_PREDICTED_PIXELS);
+    av1_dist_block(args->cpi, x, plane, plane_bsize, block, blk_row, blk_col,
+                   tx_size, &this_rd_stats.dist, &this_rd_stats.sse,
+                   OUTPUT_HAS_PREDICTED_PIXELS);
   }
 
   rd = RDCOST(x->rdmult, x->rddiv, 0, this_rd_stats.dist);
@@ -3853,7 +3848,7 @@
 
   av1_optimize_b(cm, x, plane, block, tx_size, coeff_ctx);
 
-// TODO(any): Use dist_block to compute distortion
+// TODO(any): Use av1_dist_block to compute distortion
 #if CONFIG_AOM_HIGHBITDEPTH
   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
     rec_buffer = CONVERT_TO_BYTEPTR(rec_buffer16);
@@ -5225,8 +5220,9 @@
                       BLOCK_8X8, tx_size, coeff_ctx, AV1_XFORM_QUANT_FP);
       if (xd->lossless[xd->mi[0]->mbmi.segment_id] == 0)
         av1_optimize_b(cm, x, 0, block, tx_size, coeff_ctx);
-      dist_block(cpi, x, 0, BLOCK_8X8, block, idy + (i >> 1), idx + (i & 0x1),
-                 tx_size, &dist, &ssz, OUTPUT_HAS_PREDICTED_PIXELS);
+      av1_dist_block(cpi, x, 0, BLOCK_8X8, block, idy + (i >> 1),
+                     idx + (i & 0x1), tx_size, &dist, &ssz,
+                     OUTPUT_HAS_PREDICTED_PIXELS);
       thisdistortion += dist;
       thissse += ssz;
 #if !CONFIG_PVQ
diff --git a/av1/encoder/rdopt.h b/av1/encoder/rdopt.h
index 4017da2..8eb1760 100644
--- a/av1/encoder/rdopt.h
+++ b/av1/encoder/rdopt.h
@@ -127,6 +127,17 @@
   }
 #endif
 }
+
+typedef enum OUTPUT_STATUS {
+  OUTPUT_HAS_PREDICTED_PIXELS,
+  OUTPUT_HAS_DECODED_PIXELS
+} OUTPUT_STATUS;
+
+void av1_dist_block(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
+                    BLOCK_SIZE plane_bsize, int block, int blk_row, int blk_col,
+                    TX_SIZE tx_size, int64_t *out_dist, int64_t *out_sse,
+                    OUTPUT_STATUS output_status);
+
 #if !CONFIG_PVQ || CONFIG_VAR_TX
 int av1_cost_coeffs(const AV1_COMMON *const cm, MACROBLOCK *x, int plane,
                     int block, TX_SIZE tx_size, const SCAN_ORDER *scan_order,