blob: 09b6da76fc08fbe63e6e37c4b55941b2c6a76277 [file] [log] [blame]
/******************************************************************************
*
* Copyright 2022 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
******************************************************************************/
#include "bits.h"
#include "common.h"
/* ----------------------------------------------------------------------------
* Common
* -------------------------------------------------------------------------- */
static inline int ac_get(struct lc3_bits_buffer *);
static inline void accu_load(struct lc3_bits_accu *, struct lc3_bits_buffer *);
/**
* Arithmetic coder return range bits
* ac Arithmetic coder
* return 1 + log2(ac->range)
*/
static int ac_get_range_bits(const struct lc3_bits_ac *ac)
{
int nbits = 0;
for (unsigned r = ac->range; r; r >>= 1, nbits++);
return nbits;
}
/**
* Arithmetic coder return pending bits
* ac Arithmetic coder
* return Pending bits
*/
static int ac_get_pending_bits(const struct lc3_bits_ac *ac)
{
return 26 - ac_get_range_bits(ac) +
((ac->cache >= 0) + ac->carry_count) * 8;
}
/**
* Return number of bits left in the bitstream
* bits Bitstream context
* return >= 0: Number of bits left < 0: Overflow
*/
static int get_bits_left(const struct lc3_bits *bits)
{
const struct lc3_bits_buffer *buffer = &bits->buffer;
const struct lc3_bits_accu *accu = &bits->accu;
const struct lc3_bits_ac *ac = &bits->ac;
uintptr_t end = (uintptr_t)buffer->p_bw +
(bits->mode == LC3_BITS_MODE_READ ? LC3_ACCU_BITS/8 : 0);
uintptr_t start = (uintptr_t)buffer->p_fw -
(bits->mode == LC3_BITS_MODE_READ ? LC3_AC_BITS/8 : 0);
int n = end > start ? (int)(end - start) : -(int)(start - end);
return 8 * n - (accu->n + accu->nover + ac_get_pending_bits(ac));
}
/**
* Setup bitstream writing
*/
void lc3_setup_bits(struct lc3_bits *bits,
enum lc3_bits_mode mode, void *buffer, int len)
{
*bits = (struct lc3_bits){
.mode = mode,
.accu = {
.n = mode == LC3_BITS_MODE_READ ? LC3_ACCU_BITS : 0,
},
.ac = {
.range = 0xffffff,
.cache = -1
},
.buffer = {
.start = (uint8_t *)buffer, .end = (uint8_t *)buffer + len,
.p_fw = (uint8_t *)buffer, .p_bw = (uint8_t *)buffer + len,
}
};
if (mode == LC3_BITS_MODE_READ) {
struct lc3_bits_ac *ac = &bits->ac;
struct lc3_bits_accu *accu = &bits->accu;
struct lc3_bits_buffer *buffer = &bits->buffer;
ac->low = ac_get(buffer) << 16;
ac->low |= ac_get(buffer) << 8;
ac->low |= ac_get(buffer);
accu_load(accu, buffer);
}
}
/**
* Return number of bits left in the bitstream
*/
int lc3_get_bits_left(const struct lc3_bits *bits)
{
return LC3_MAX(get_bits_left(bits), 0);
}
/**
* Return number of bits left in the bitstream
*/
int lc3_check_bits(const struct lc3_bits *bits)
{
const struct lc3_bits_ac *ac = &bits->ac;
return -(get_bits_left(bits) < 0 || ac->error);
}
/* ----------------------------------------------------------------------------
* Writing
* -------------------------------------------------------------------------- */
/**
* Flush the bits accumulator
* accu Bitstream accumulator
* buffer Bitstream buffer
*/
static inline void accu_flush(
struct lc3_bits_accu *accu, struct lc3_bits_buffer *buffer)
{
int nbytes = LC3_MIN(accu->n >> 3,
LC3_MAX(buffer->p_bw - buffer->p_fw, 0));
accu->n -= 8 * nbytes;
for ( ; nbytes; accu->v >>= 8, nbytes--)
*(--buffer->p_bw) = accu->v & 0xff;
if (accu->n >= 8)
accu->n = 0;
}
/**
* Arithmetic coder put byte
* buffer Bitstream buffer
* byte Byte to output
*/
static inline void ac_put(struct lc3_bits_buffer *buffer, int byte)
{
if (buffer->p_fw < buffer->end)
*(buffer->p_fw++) = byte;
}
/**
* Arithmetic coder range shift
* ac Arithmetic coder
* buffer Bitstream buffer
*/
LC3_HOT static inline void ac_shift(
struct lc3_bits_ac *ac, struct lc3_bits_buffer *buffer)
{
if (ac->low < 0xff0000 || ac->carry)
{
if (ac->cache >= 0)
ac_put(buffer, ac->cache + ac->carry);
for ( ; ac->carry_count > 0; ac->carry_count--)
ac_put(buffer, ac->carry ? 0x00 : 0xff);
ac->cache = ac->low >> 16;
ac->carry = 0;
}
else
ac->carry_count++;
ac->low = (ac->low << 8) & 0xffffff;
}
/**
* Arithmetic coder termination
* ac Arithmetic coder
* buffer Bitstream buffer
* end_val/nbits End value and count of bits to terminate (1 to 8)
*/
static void ac_terminate(struct lc3_bits_ac *ac,
struct lc3_bits_buffer *buffer)
{
int nbits = 25 - ac_get_range_bits(ac);
unsigned mask = 0xffffff >> nbits;
unsigned val = ac->low + mask;
unsigned high = ac->low + ac->range;
bool over_val = val >> 24;
bool over_high = high >> 24;
val = (val & 0xffffff) & ~mask;
high = (high & 0xffffff);
if (over_val == over_high) {
if (val + mask >= high) {
nbits++;
mask >>= 1;
val = ((ac->low + mask) & 0xffffff) & ~mask;
}
ac->carry |= val < ac->low;
}
ac->low = val;
for (; nbits > 8; nbits -= 8)
ac_shift(ac, buffer);
ac_shift(ac, buffer);
int end_val = ac->cache >> (8 - nbits);
if (ac->carry_count) {
ac_put(buffer, ac->cache);
for ( ; ac->carry_count > 1; ac->carry_count--)
ac_put(buffer, 0xff);
end_val = nbits < 8 ? 0 : 0xff;
}
if (buffer->p_fw < buffer->end) {
*buffer->p_fw &= 0xff >> nbits;
*buffer->p_fw |= end_val << (8 - nbits);
}
}
/**
* Flush and terminate bitstream
*/
void lc3_flush_bits(struct lc3_bits *bits)
{
struct lc3_bits_ac *ac = &bits->ac;
struct lc3_bits_accu *accu = &bits->accu;
struct lc3_bits_buffer *buffer = &bits->buffer;
int nleft = buffer->p_bw - buffer->p_fw;
for (int n = 8 * nleft - accu->n; n > 0; n -= 32)
lc3_put_bits(bits, 0, LC3_MIN(n, 32));
accu_flush(accu, buffer);
ac_terminate(ac, buffer);
}
/**
* Write from 1 to 32 bits,
* exceeding the capacity of the accumulator
*/
LC3_HOT void lc3_put_bits_generic(struct lc3_bits *bits, unsigned v, int n)
{
struct lc3_bits_accu *accu = &bits->accu;
/* --- Fulfill accumulator and flush -- */
int n1 = LC3_MIN(LC3_ACCU_BITS - accu->n, n);
if (n1) {
accu->v |= v << accu->n;
accu->n = LC3_ACCU_BITS;
}
accu_flush(accu, &bits->buffer);
/* --- Accumulate remaining bits -- */
accu->v = v >> n1;
accu->n = n - n1;
}
/**
* Arithmetic coder renormalization
*/
LC3_HOT void lc3_ac_write_renorm(struct lc3_bits *bits)
{
struct lc3_bits_ac *ac = &bits->ac;
for ( ; ac->range < 0x10000; ac->range <<= 8)
ac_shift(ac, &bits->buffer);
}
/* ----------------------------------------------------------------------------
* Reading
* -------------------------------------------------------------------------- */
/**
* Arithmetic coder get byte
* buffer Bitstream buffer
* return Byte read, 0 on overflow
*/
static inline int ac_get(struct lc3_bits_buffer *buffer)
{
return buffer->p_fw < buffer->end ? *(buffer->p_fw++) : 0;
}
/**
* Load the accumulator
* accu Bitstream accumulator
* buffer Bitstream buffer
*/
static inline void accu_load(struct lc3_bits_accu *accu,
struct lc3_bits_buffer *buffer)
{
int nbytes = LC3_MIN(accu->n >> 3, buffer->p_bw - buffer->start);
accu->n -= 8 * nbytes;
for ( ; nbytes; nbytes--) {
accu->v >>= 8;
accu->v |= *(--buffer->p_bw) << (LC3_ACCU_BITS - 8);
}
if (accu->n >= 8) {
accu->nover = LC3_MIN(accu->nover + accu->n, LC3_ACCU_BITS);
accu->v >>= accu->n;
accu->n = 0;
}
}
/**
* Read from 1 to 32 bits,
* exceeding the capacity of the accumulator
*/
LC3_HOT unsigned lc3_get_bits_generic(struct lc3_bits *bits, int n)
{
struct lc3_bits_accu *accu = &bits->accu;
struct lc3_bits_buffer *buffer = &bits->buffer;
/* --- Fulfill accumulator and read -- */
accu_load(accu, buffer);
int n1 = LC3_MIN(LC3_ACCU_BITS - accu->n, n);
unsigned v = (accu->v >> accu->n) & ((1u << n1) - 1);
accu->n += n1;
/* --- Second round --- */
int n2 = n - n1;
if (n2) {
accu_load(accu, buffer);
v |= ((accu->v >> accu->n) & ((1u << n2) - 1)) << n1;
accu->n += n2;
}
return v;
}
/**
* Arithmetic coder renormalization
*/
LC3_HOT void lc3_ac_read_renorm(struct lc3_bits *bits)
{
struct lc3_bits_ac *ac = &bits->ac;
for ( ; ac->range < 0x10000; ac->range <<= 8)
ac->low = ((ac->low << 8) | ac_get(&bits->buffer)) & 0xffffff;
}