Add get_coeff_cost() and get_txb_cost()
Change-Id: I085f2bc706fde41afbee5ff48b56acc095f804c2
diff --git a/av1/common/txb_common.h b/av1/common/txb_common.h
index 3f0a515..16866f5 100644
--- a/av1/common/txb_common.h
+++ b/av1/common/txb_common.h
@@ -130,17 +130,18 @@
return 8 + ctx;
}
-static int sig_ref_offset[11][2] = {
+#define SIG_REF_OFFSET_NUM 11
+static int sig_ref_offset[SIG_REF_OFFSET_NUM][2] = {
{ -2, -1 }, { -2, 0 }, { -2, 1 }, { -1, -2 }, { -1, -1 }, { -1, 0 },
{ -1, 1 }, { 0, -2 }, { 0, -1 }, { 1, -2 }, { 1, -1 },
};
static INLINE int get_nz_map_ctx(const tran_low_t *tcoeffs,
const uint8_t *txb_mask,
- const int c, // raster order
+ const int coeff_idx, // raster order
const int bwl) {
- const int row = c >> bwl;
- const int col = c - (row << bwl);
+ const int row = coeff_idx >> bwl;
+ const int col = coeff_idx - (row << bwl);
int ctx = 0;
int idx;
int stride = 1 << bwl;
@@ -166,7 +167,7 @@
return 5 + ctx;
}
- for (idx = 0; idx < 11; ++idx) {
+ for (idx = 0; idx < SIG_REF_OFFSET_NUM; ++idx) {
int ref_row = row + sig_ref_offset[idx][0];
int ref_col = col + sig_ref_offset[idx][1];
int pos;
@@ -200,14 +201,93 @@
return 14 + ctx;
}
+static INLINE int get_nz_count(const tran_low_t *tcoeffs, int stride, int row,
+ int col, const int16_t *iscan) {
+ int count = 0;
+ const int pos = row * stride + col;
+ for (int idx = 0; idx < SIG_REF_OFFSET_NUM; ++idx) {
+ const int ref_row = row + sig_ref_offset[idx][0];
+ const int ref_col = col + sig_ref_offset[idx][1];
+ if (ref_row < 0 || ref_col < 0 || ref_row >= stride || ref_col >= stride)
+ continue;
+ const int nb_pos = ref_row * stride + ref_col;
+ if (iscan[nb_pos] < iscan[pos]) count += (tcoeffs[nb_pos] != 0);
+ }
+ return count;
+}
+
+// TODO(angiebird): optimize this function by generate a table that maps from
+// count to ctx
+static INLINE int get_nz_map_ctx_from_count(int count,
+ const tran_low_t *tcoeffs,
+ int coeff_idx, // raster order
+ int bwl, const int16_t *iscan) {
+ const int row = coeff_idx >> bwl;
+ const int col = coeff_idx - (row << bwl);
+ int ctx = 0;
+
+ if (row == 0 && col == 0) return 0;
+
+ if (row == 0 && col == 1) return 1 + (tcoeffs[0] != 0);
+
+ if (row == 1 && col == 0) return 3 + (tcoeffs[0] != 0);
+
+ if (row == 1 && col == 1) {
+ int pos;
+ ctx = (tcoeffs[0] != 0);
+
+ if (iscan[1] < iscan[coeff_idx]) ctx += (tcoeffs[1] != 0);
+ pos = 1 << bwl;
+ if (iscan[pos] < iscan[coeff_idx]) ctx += (tcoeffs[pos] != 0);
+
+ ctx = (ctx + 1) >> 1;
+
+ assert(5 + ctx <= 7);
+
+ return 5 + ctx;
+ }
+
+ if (row == 0) {
+ ctx = (count + 1) >> 1;
+
+ assert(ctx < 3);
+ return 8 + ctx;
+ }
+
+ if (col == 0) {
+ ctx = (count + 1) >> 1;
+
+ assert(ctx < 3);
+ return 11 + ctx;
+ }
+
+ ctx = count >> 1;
+
+ assert(14 + ctx < 20);
+
+ return 14 + ctx;
+}
+
+// TODO(angiebird): merge this function with get_nz_map_ctx() after proper
+// testing
+static INLINE int get_nz_map_ctx2(const tran_low_t *tcoeffs,
+ const int coeff_idx, // raster order
+ const int bwl, const int16_t *iscan) {
+ int stride = 1 << bwl;
+ const int row = coeff_idx >> bwl;
+ const int col = coeff_idx - (row << bwl);
+ int count = get_nz_count(tcoeffs, stride, row, col, iscan);
+ return get_nz_map_ctx_from_count(count, tcoeffs, coeff_idx, bwl, iscan);
+}
+
static INLINE int get_eob_ctx(const tran_low_t *tcoeffs,
- const int c, // raster order
+ const int coeff_idx, // raster order
const int bwl) {
(void)tcoeffs;
- if (bwl == 2) return av1_coeff_band_4x4[c];
- if (bwl == 3) return av1_coeff_band_8x8[c];
- if (bwl == 4) return av1_coeff_band_16x16[c];
- if (bwl == 5) return av1_coeff_band_32x32[c];
+ if (bwl == 2) return av1_coeff_band_4x4[coeff_idx];
+ if (bwl == 3) return av1_coeff_band_8x8[coeff_idx];
+ if (bwl == 4) return av1_coeff_band_16x16[coeff_idx];
+ if (bwl == 5) return av1_coeff_band_32x32[coeff_idx];
assert(0);
return 0;
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index a5d1dce..3c95604 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -251,6 +251,32 @@
return;
}
+static INLINE int get_br_cost(tran_low_t abs_qc, int ctx,
+ const aom_prob *coeff_lps) {
+ const tran_low_t min_level = 1 + NUM_BASE_LEVELS;
+ const tran_low_t max_level = 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE;
+ if (abs_qc >= min_level) {
+ const int cost0 = av1_cost_bit(coeff_lps[ctx], 0);
+ const int cost1 = av1_cost_bit(coeff_lps[ctx], 1);
+ if (abs_qc >= max_level)
+ return COEFF_BASE_RANGE * cost0;
+ else
+ return (abs_qc - min_level) * cost0 + cost1;
+ } else {
+ return 0;
+ }
+}
+
+static INLINE int get_base_cost(tran_low_t abs_qc, int ctx,
+ aom_prob (*coeff_base)[COEFF_BASE_CONTEXTS],
+ int base_idx) {
+ const int level = base_idx + 1;
+ if (abs_qc < level)
+ return 0;
+ else
+ return av1_cost_bit(coeff_base[base_idx][ctx], abs_qc == level);
+}
+
int av1_cost_coeffs_txb(const AV1_COMP *const cpi, MACROBLOCK *x, int plane,
int block, TXB_CTX *txb_ctx) {
const AV1_COMMON *const cm = &cpi->common;
@@ -373,6 +399,96 @@
return cost;
}
+static INLINE int get_sign_bit_cost(tran_low_t qc, int coeff_idx,
+ const aom_prob *dc_sign_prob,
+ int dc_sign_ctx) {
+ const int sign = (qc < 0) ? 1 : 0;
+ // sign bit cost
+ if (coeff_idx == 0) {
+ return av1_cost_bit(dc_sign_prob[dc_sign_ctx], sign);
+ } else {
+ return av1_cost_bit(128, sign);
+ }
+}
+static INLINE int get_golomb_cost(int abs_qc) {
+ if (abs_qc >= 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE) {
+ // residual cost
+ int r = abs_qc - COEFF_BASE_RANGE - NUM_BASE_LEVELS;
+ int ri = r;
+ int length = 0;
+
+ while (ri) {
+ ri >>= 1;
+ ++length;
+ }
+
+ return av1_cost_literal(2 * length - 1);
+ } else {
+ return 0;
+ }
+}
+
+static int get_coeff_cost(tran_low_t qc, int scan_idx, TxbInfo *txb_info,
+ TxbProbs *txb_probs) {
+ const TXB_CTX *txb_ctx = txb_info->txb_ctx;
+ const int is_nz = (qc != 0);
+ const tran_low_t abs_qc = abs(qc);
+ int cost = 0;
+ const int16_t *scan = txb_info->scan_order->scan;
+ const int16_t *iscan = txb_info->scan_order->iscan;
+
+ if (scan_idx < txb_info->seg_eob) {
+ int coeff_ctx =
+ get_nz_map_ctx2(txb_info->qcoeff, scan[scan_idx], txb_info->bwl, iscan);
+ cost += av1_cost_bit(txb_probs->nz_map[coeff_ctx], is_nz);
+ }
+
+ if (is_nz) {
+ cost += get_sign_bit_cost(qc, scan_idx, txb_probs->dc_sign_prob,
+ txb_ctx->dc_sign_ctx);
+
+ int ctx_ls[NUM_BASE_LEVELS] = { 0 };
+ get_base_ctx_set(txb_info->qcoeff, scan[scan_idx], txb_info->bwl, ctx_ls);
+
+ int i;
+ for (i = 0; i < NUM_BASE_LEVELS; ++i) {
+ cost += get_base_cost(abs_qc, ctx_ls[i], txb_probs->coeff_base, i);
+ }
+
+ if (abs_qc > NUM_BASE_LEVELS) {
+ int ctx = get_level_ctx(txb_info->qcoeff, scan[scan_idx], txb_info->bwl);
+ cost += get_br_cost(abs_qc, ctx, txb_probs->coeff_lps);
+ cost += get_golomb_cost(abs_qc);
+ }
+
+ if (scan_idx < txb_info->seg_eob) {
+ int eob_ctx =
+ get_eob_ctx(txb_info->qcoeff, scan[scan_idx], txb_info->bwl);
+ cost += av1_cost_bit(txb_probs->eob_flag[eob_ctx],
+ scan_idx == (txb_info->eob - 1));
+ }
+ }
+ return cost;
+}
+
+// TODO(angiebird): make this static once it's called
+int get_txb_cost(TxbInfo *txb_info, TxbProbs *txb_probs) {
+ int cost = 0;
+ int txb_skip_ctx = txb_info->txb_ctx->txb_skip_ctx;
+ const int16_t *scan = txb_info->scan_order->scan;
+ if (txb_info->eob == 0) {
+ cost = av1_cost_bit(txb_probs->txb_skip[txb_skip_ctx], 1);
+ return cost;
+ }
+ cost = av1_cost_bit(txb_probs->txb_skip[txb_skip_ctx], 0);
+ for (int c = 0; c < txb_info->eob; ++c) {
+ tran_low_t qc = txb_info->qcoeff[scan[c]];
+ int coeff_cost = get_coeff_cost(qc, c, txb_info, txb_probs);
+ cost += coeff_cost;
+ }
+ return cost;
+}
+
int av1_get_txb_entropy_context(const tran_low_t *qcoeff,
const SCAN_ORDER *scan_order, int eob) {
const int16_t *scan = scan_order->scan;
diff --git a/av1/encoder/encodetxb.h b/av1/encoder/encodetxb.h
index e1aca2c..5fd1e50 100644
--- a/av1/encoder/encodetxb.h
+++ b/av1/encoder/encodetxb.h
@@ -22,6 +22,47 @@
#ifdef __cplusplus
extern "C" {
#endif
+
+typedef struct TxbInfo {
+ tran_low_t *qcoeff;
+ tran_low_t *dqcoeff;
+ const tran_low_t *tcoeff;
+ const int16_t *dequant;
+ int shift;
+ TX_SIZE tx_size;
+ int bwl;
+ int stride;
+ int eob;
+ int seg_eob;
+ const SCAN_ORDER *scan_order;
+ TXB_CTX *txb_ctx;
+ int64_t rdmult;
+ int64_t rddiv;
+} TxbInfo;
+
+typedef struct TxbCache {
+ int nz_count_arr[MAX_TX_SQUARE];
+ int nz_ctx_arr[MAX_TX_SQUARE][2];
+ int base_count_arr[NUM_BASE_LEVELS][MAX_TX_SQUARE];
+ int base_mag_arr[MAX_TX_SQUARE]
+ [2]; // [0]: max magnitude [1]: num of max magnitude
+ int base_ctx_arr[NUM_BASE_LEVELS][MAX_TX_SQUARE][2]; // [1]: not used
+
+ int br_count_arr[MAX_TX_SQUARE];
+ int br_mag_arr[MAX_TX_SQUARE]
+ [2]; // [0]: max magnitude [1]: num of max magnitude
+ int br_ctx_arr[MAX_TX_SQUARE][2]; // [1]: not used
+} TxbCache;
+
+typedef struct TxbProbs {
+ const aom_prob *dc_sign_prob;
+ const aom_prob *nz_map;
+ aom_prob (*coeff_base)[COEFF_BASE_CONTEXTS];
+ const aom_prob *coeff_lps;
+ const aom_prob *eob_flag;
+ const aom_prob *txb_skip;
+} TxbProbs;
+
void av1_alloc_txb_buf(AV1_COMP *cpi);
void av1_free_txb_buf(AV1_COMP *cpi);
int av1_cost_coeffs_txb(const AV1_COMP *const cpi, MACROBLOCK *x, int plane,