Add get_coeff_cost() and get_txb_cost()

Change-Id: I085f2bc706fde41afbee5ff48b56acc095f804c2
diff --git a/av1/common/txb_common.h b/av1/common/txb_common.h
index 3f0a515..16866f5 100644
--- a/av1/common/txb_common.h
+++ b/av1/common/txb_common.h
@@ -130,17 +130,18 @@
   return 8 + ctx;
 }
 
-static int sig_ref_offset[11][2] = {
+#define SIG_REF_OFFSET_NUM 11
+static int sig_ref_offset[SIG_REF_OFFSET_NUM][2] = {
   { -2, -1 }, { -2, 0 }, { -2, 1 }, { -1, -2 }, { -1, -1 }, { -1, 0 },
   { -1, 1 },  { 0, -2 }, { 0, -1 }, { 1, -2 },  { 1, -1 },
 };
 
 static INLINE int get_nz_map_ctx(const tran_low_t *tcoeffs,
                                  const uint8_t *txb_mask,
-                                 const int c,  // raster order
+                                 const int coeff_idx,  // raster order
                                  const int bwl) {
-  const int row = c >> bwl;
-  const int col = c - (row << bwl);
+  const int row = coeff_idx >> bwl;
+  const int col = coeff_idx - (row << bwl);
   int ctx = 0;
   int idx;
   int stride = 1 << bwl;
@@ -166,7 +167,7 @@
     return 5 + ctx;
   }
 
-  for (idx = 0; idx < 11; ++idx) {
+  for (idx = 0; idx < SIG_REF_OFFSET_NUM; ++idx) {
     int ref_row = row + sig_ref_offset[idx][0];
     int ref_col = col + sig_ref_offset[idx][1];
     int pos;
@@ -200,14 +201,93 @@
   return 14 + ctx;
 }
 
+static INLINE int get_nz_count(const tran_low_t *tcoeffs, int stride, int row,
+                               int col, const int16_t *iscan) {
+  int count = 0;
+  const int pos = row * stride + col;
+  for (int idx = 0; idx < SIG_REF_OFFSET_NUM; ++idx) {
+    const int ref_row = row + sig_ref_offset[idx][0];
+    const int ref_col = col + sig_ref_offset[idx][1];
+    if (ref_row < 0 || ref_col < 0 || ref_row >= stride || ref_col >= stride)
+      continue;
+    const int nb_pos = ref_row * stride + ref_col;
+    if (iscan[nb_pos] < iscan[pos]) count += (tcoeffs[nb_pos] != 0);
+  }
+  return count;
+}
+
+// TODO(angiebird): optimize this function by generate a table that maps from
+// count to ctx
+static INLINE int get_nz_map_ctx_from_count(int count,
+                                            const tran_low_t *tcoeffs,
+                                            int coeff_idx,  // raster order
+                                            int bwl, const int16_t *iscan) {
+  const int row = coeff_idx >> bwl;
+  const int col = coeff_idx - (row << bwl);
+  int ctx = 0;
+
+  if (row == 0 && col == 0) return 0;
+
+  if (row == 0 && col == 1) return 1 + (tcoeffs[0] != 0);
+
+  if (row == 1 && col == 0) return 3 + (tcoeffs[0] != 0);
+
+  if (row == 1 && col == 1) {
+    int pos;
+    ctx = (tcoeffs[0] != 0);
+
+    if (iscan[1] < iscan[coeff_idx]) ctx += (tcoeffs[1] != 0);
+    pos = 1 << bwl;
+    if (iscan[pos] < iscan[coeff_idx]) ctx += (tcoeffs[pos] != 0);
+
+    ctx = (ctx + 1) >> 1;
+
+    assert(5 + ctx <= 7);
+
+    return 5 + ctx;
+  }
+
+  if (row == 0) {
+    ctx = (count + 1) >> 1;
+
+    assert(ctx < 3);
+    return 8 + ctx;
+  }
+
+  if (col == 0) {
+    ctx = (count + 1) >> 1;
+
+    assert(ctx < 3);
+    return 11 + ctx;
+  }
+
+  ctx = count >> 1;
+
+  assert(14 + ctx < 20);
+
+  return 14 + ctx;
+}
+
+// TODO(angiebird): merge this function with get_nz_map_ctx() after proper
+// testing
+static INLINE int get_nz_map_ctx2(const tran_low_t *tcoeffs,
+                                  const int coeff_idx,  // raster order
+                                  const int bwl, const int16_t *iscan) {
+  int stride = 1 << bwl;
+  const int row = coeff_idx >> bwl;
+  const int col = coeff_idx - (row << bwl);
+  int count = get_nz_count(tcoeffs, stride, row, col, iscan);
+  return get_nz_map_ctx_from_count(count, tcoeffs, coeff_idx, bwl, iscan);
+}
+
 static INLINE int get_eob_ctx(const tran_low_t *tcoeffs,
-                              const int c,  // raster order
+                              const int coeff_idx,  // raster order
                               const int bwl) {
   (void)tcoeffs;
-  if (bwl == 2) return av1_coeff_band_4x4[c];
-  if (bwl == 3) return av1_coeff_band_8x8[c];
-  if (bwl == 4) return av1_coeff_band_16x16[c];
-  if (bwl == 5) return av1_coeff_band_32x32[c];
+  if (bwl == 2) return av1_coeff_band_4x4[coeff_idx];
+  if (bwl == 3) return av1_coeff_band_8x8[coeff_idx];
+  if (bwl == 4) return av1_coeff_band_16x16[coeff_idx];
+  if (bwl == 5) return av1_coeff_band_32x32[coeff_idx];
 
   assert(0);
   return 0;
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index a5d1dce..3c95604 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -251,6 +251,32 @@
   return;
 }
 
+static INLINE int get_br_cost(tran_low_t abs_qc, int ctx,
+                              const aom_prob *coeff_lps) {
+  const tran_low_t min_level = 1 + NUM_BASE_LEVELS;
+  const tran_low_t max_level = 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE;
+  if (abs_qc >= min_level) {
+    const int cost0 = av1_cost_bit(coeff_lps[ctx], 0);
+    const int cost1 = av1_cost_bit(coeff_lps[ctx], 1);
+    if (abs_qc >= max_level)
+      return COEFF_BASE_RANGE * cost0;
+    else
+      return (abs_qc - min_level) * cost0 + cost1;
+  } else {
+    return 0;
+  }
+}
+
+static INLINE int get_base_cost(tran_low_t abs_qc, int ctx,
+                                aom_prob (*coeff_base)[COEFF_BASE_CONTEXTS],
+                                int base_idx) {
+  const int level = base_idx + 1;
+  if (abs_qc < level)
+    return 0;
+  else
+    return av1_cost_bit(coeff_base[base_idx][ctx], abs_qc == level);
+}
+
 int av1_cost_coeffs_txb(const AV1_COMP *const cpi, MACROBLOCK *x, int plane,
                         int block, TXB_CTX *txb_ctx) {
   const AV1_COMMON *const cm = &cpi->common;
@@ -373,6 +399,96 @@
   return cost;
 }
 
+static INLINE int get_sign_bit_cost(tran_low_t qc, int coeff_idx,
+                                    const aom_prob *dc_sign_prob,
+                                    int dc_sign_ctx) {
+  const int sign = (qc < 0) ? 1 : 0;
+  // sign bit cost
+  if (coeff_idx == 0) {
+    return av1_cost_bit(dc_sign_prob[dc_sign_ctx], sign);
+  } else {
+    return av1_cost_bit(128, sign);
+  }
+}
+static INLINE int get_golomb_cost(int abs_qc) {
+  if (abs_qc >= 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE) {
+    // residual cost
+    int r = abs_qc - COEFF_BASE_RANGE - NUM_BASE_LEVELS;
+    int ri = r;
+    int length = 0;
+
+    while (ri) {
+      ri >>= 1;
+      ++length;
+    }
+
+    return av1_cost_literal(2 * length - 1);
+  } else {
+    return 0;
+  }
+}
+
+static int get_coeff_cost(tran_low_t qc, int scan_idx, TxbInfo *txb_info,
+                          TxbProbs *txb_probs) {
+  const TXB_CTX *txb_ctx = txb_info->txb_ctx;
+  const int is_nz = (qc != 0);
+  const tran_low_t abs_qc = abs(qc);
+  int cost = 0;
+  const int16_t *scan = txb_info->scan_order->scan;
+  const int16_t *iscan = txb_info->scan_order->iscan;
+
+  if (scan_idx < txb_info->seg_eob) {
+    int coeff_ctx =
+        get_nz_map_ctx2(txb_info->qcoeff, scan[scan_idx], txb_info->bwl, iscan);
+    cost += av1_cost_bit(txb_probs->nz_map[coeff_ctx], is_nz);
+  }
+
+  if (is_nz) {
+    cost += get_sign_bit_cost(qc, scan_idx, txb_probs->dc_sign_prob,
+                              txb_ctx->dc_sign_ctx);
+
+    int ctx_ls[NUM_BASE_LEVELS] = { 0 };
+    get_base_ctx_set(txb_info->qcoeff, scan[scan_idx], txb_info->bwl, ctx_ls);
+
+    int i;
+    for (i = 0; i < NUM_BASE_LEVELS; ++i) {
+      cost += get_base_cost(abs_qc, ctx_ls[i], txb_probs->coeff_base, i);
+    }
+
+    if (abs_qc > NUM_BASE_LEVELS) {
+      int ctx = get_level_ctx(txb_info->qcoeff, scan[scan_idx], txb_info->bwl);
+      cost += get_br_cost(abs_qc, ctx, txb_probs->coeff_lps);
+      cost += get_golomb_cost(abs_qc);
+    }
+
+    if (scan_idx < txb_info->seg_eob) {
+      int eob_ctx =
+          get_eob_ctx(txb_info->qcoeff, scan[scan_idx], txb_info->bwl);
+      cost += av1_cost_bit(txb_probs->eob_flag[eob_ctx],
+                           scan_idx == (txb_info->eob - 1));
+    }
+  }
+  return cost;
+}
+
+// TODO(angiebird): make this static once it's called
+int get_txb_cost(TxbInfo *txb_info, TxbProbs *txb_probs) {
+  int cost = 0;
+  int txb_skip_ctx = txb_info->txb_ctx->txb_skip_ctx;
+  const int16_t *scan = txb_info->scan_order->scan;
+  if (txb_info->eob == 0) {
+    cost = av1_cost_bit(txb_probs->txb_skip[txb_skip_ctx], 1);
+    return cost;
+  }
+  cost = av1_cost_bit(txb_probs->txb_skip[txb_skip_ctx], 0);
+  for (int c = 0; c < txb_info->eob; ++c) {
+    tran_low_t qc = txb_info->qcoeff[scan[c]];
+    int coeff_cost = get_coeff_cost(qc, c, txb_info, txb_probs);
+    cost += coeff_cost;
+  }
+  return cost;
+}
+
 int av1_get_txb_entropy_context(const tran_low_t *qcoeff,
                                 const SCAN_ORDER *scan_order, int eob) {
   const int16_t *scan = scan_order->scan;
diff --git a/av1/encoder/encodetxb.h b/av1/encoder/encodetxb.h
index e1aca2c..5fd1e50 100644
--- a/av1/encoder/encodetxb.h
+++ b/av1/encoder/encodetxb.h
@@ -22,6 +22,47 @@
 #ifdef __cplusplus
 extern "C" {
 #endif
+
+typedef struct TxbInfo {
+  tran_low_t *qcoeff;
+  tran_low_t *dqcoeff;
+  const tran_low_t *tcoeff;
+  const int16_t *dequant;
+  int shift;
+  TX_SIZE tx_size;
+  int bwl;
+  int stride;
+  int eob;
+  int seg_eob;
+  const SCAN_ORDER *scan_order;
+  TXB_CTX *txb_ctx;
+  int64_t rdmult;
+  int64_t rddiv;
+} TxbInfo;
+
+typedef struct TxbCache {
+  int nz_count_arr[MAX_TX_SQUARE];
+  int nz_ctx_arr[MAX_TX_SQUARE][2];
+  int base_count_arr[NUM_BASE_LEVELS][MAX_TX_SQUARE];
+  int base_mag_arr[MAX_TX_SQUARE]
+                  [2];  // [0]: max magnitude [1]: num of max magnitude
+  int base_ctx_arr[NUM_BASE_LEVELS][MAX_TX_SQUARE][2];  // [1]: not used
+
+  int br_count_arr[MAX_TX_SQUARE];
+  int br_mag_arr[MAX_TX_SQUARE]
+                [2];  // [0]: max magnitude [1]: num of max magnitude
+  int br_ctx_arr[MAX_TX_SQUARE][2];  // [1]: not used
+} TxbCache;
+
+typedef struct TxbProbs {
+  const aom_prob *dc_sign_prob;
+  const aom_prob *nz_map;
+  aom_prob (*coeff_base)[COEFF_BASE_CONTEXTS];
+  const aom_prob *coeff_lps;
+  const aom_prob *eob_flag;
+  const aom_prob *txb_skip;
+} TxbProbs;
+
 void av1_alloc_txb_buf(AV1_COMP *cpi);
 void av1_free_txb_buf(AV1_COMP *cpi);
 int av1_cost_coeffs_txb(const AV1_COMP *const cpi, MACROBLOCK *x, int plane,