blob: 1e8dfb21d993f0c0702e552b90a4674b2d88bb09 [file] [log] [blame]
// Copyright 2019 The libgav1 Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/dsp/loop_restoration.h"
#include "src/utils/cpu.h"
#if LIBGAV1_ENABLE_NEON
#include <arm_neon.h>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include "src/dsp/arm/common_neon.h"
#include "src/dsp/constants.h"
#include "src/dsp/dsp.h"
#include "src/utils/common.h"
#include "src/utils/constants.h"
namespace libgav1 {
namespace dsp {
namespace low_bitdepth {
namespace {
template <int bytes>
inline uint8x8_t VshrU128(const uint8x8x2_t src) {
return vext_u8(src.val[0], src.val[1], bytes);
}
template <int bytes>
inline uint16x8_t VshrU128(const uint16x8x2_t src) {
return vextq_u16(src.val[0], src.val[1], bytes / 2);
}
// Wiener
// Must make a local copy of coefficients to help compiler know that they have
// no overlap with other buffers. Using 'const' keyword is not enough. Actually
// compiler doesn't make a copy, since there is enough registers in this case.
inline void PopulateWienerCoefficients(
const RestorationUnitInfo& restoration_info, const int direction,
int16_t filter[4]) {
// In order to keep the horizontal pass intermediate values within 16 bits we
// offset |filter[3]| by 128. The 128 offset will be added back in the loop.
for (int i = 0; i < 4; ++i) {
filter[i] = restoration_info.wiener_info.filter[direction][i];
}
if (direction == WienerInfo::kHorizontal) {
filter[3] -= 128;
}
}
inline int16x8_t WienerHorizontal2(const uint8x8_t s0, const uint8x8_t s1,
const int16_t filter, const int16x8_t sum) {
const int16x8_t ss = vreinterpretq_s16_u16(vaddl_u8(s0, s1));
return vmlaq_n_s16(sum, ss, filter);
}
inline int16x8x2_t WienerHorizontal2(const uint8x16_t s0, const uint8x16_t s1,
const int16_t filter,
const int16x8x2_t sum) {
int16x8x2_t d;
d.val[0] =
WienerHorizontal2(vget_low_u8(s0), vget_low_u8(s1), filter, sum.val[0]);
d.val[1] =
WienerHorizontal2(vget_high_u8(s0), vget_high_u8(s1), filter, sum.val[1]);
return d;
}
inline void WienerHorizontalSum(const uint8x8_t s[3], const int16_t filter[4],
int16x8_t sum, int16_t* const wiener_buffer) {
constexpr int offset =
1 << (8 + kWienerFilterBits - kInterRoundBitsHorizontal - 1);
constexpr int limit = (offset << 2) - 1;
const int16x8_t s_0_2 = vreinterpretq_s16_u16(vaddl_u8(s[0], s[2]));
const int16x8_t s_1 = ZeroExtend(s[1]);
sum = vmlaq_n_s16(sum, s_0_2, filter[2]);
sum = vmlaq_n_s16(sum, s_1, filter[3]);
// Calculate scaled down offset correction, and add to sum here to prevent
// signed 16 bit outranging.
sum = vrsraq_n_s16(vshlq_n_s16(s_1, 7 - kInterRoundBitsHorizontal), sum,
kInterRoundBitsHorizontal);
sum = vmaxq_s16(sum, vdupq_n_s16(-offset));
sum = vminq_s16(sum, vdupq_n_s16(limit - offset));
vst1q_s16(wiener_buffer, sum);
}
inline void WienerHorizontalSum(const uint8x16_t src[3],
const int16_t filter[4], int16x8x2_t sum,
int16_t* const wiener_buffer) {
uint8x8_t s[3];
s[0] = vget_low_u8(src[0]);
s[1] = vget_low_u8(src[1]);
s[2] = vget_low_u8(src[2]);
WienerHorizontalSum(s, filter, sum.val[0], wiener_buffer);
s[0] = vget_high_u8(src[0]);
s[1] = vget_high_u8(src[1]);
s[2] = vget_high_u8(src[2]);
WienerHorizontalSum(s, filter, sum.val[1], wiener_buffer + 8);
}
inline void WienerHorizontalTap7(const uint8_t* src, const ptrdiff_t src_stride,
const ptrdiff_t width, const int height,
const int16_t filter[4],
int16_t** const wiener_buffer) {
int y = height;
do {
const uint8_t* src_ptr = src;
uint8x16_t s[8];
s[0] = vld1q_u8(src_ptr);
ptrdiff_t x = width;
do {
src_ptr += 16;
s[7] = vld1q_u8(src_ptr);
s[1] = vextq_u8(s[0], s[7], 1);
s[2] = vextq_u8(s[0], s[7], 2);
s[3] = vextq_u8(s[0], s[7], 3);
s[4] = vextq_u8(s[0], s[7], 4);
s[5] = vextq_u8(s[0], s[7], 5);
s[6] = vextq_u8(s[0], s[7], 6);
int16x8x2_t sum;
sum.val[0] = sum.val[1] = vdupq_n_s16(0);
sum = WienerHorizontal2(s[0], s[6], filter[0], sum);
sum = WienerHorizontal2(s[1], s[5], filter[1], sum);
WienerHorizontalSum(s + 2, filter, sum, *wiener_buffer);
s[0] = s[7];
*wiener_buffer += 16;
x -= 16;
} while (x != 0);
src += src_stride;
} while (--y != 0);
}
inline void WienerHorizontalTap5(const uint8_t* src, const ptrdiff_t src_stride,
const ptrdiff_t width, const int height,
const int16_t filter[4],
int16_t** const wiener_buffer) {
int y = height;
do {
const uint8_t* src_ptr = src;
uint8x16_t s[6];
s[0] = vld1q_u8(src_ptr);
ptrdiff_t x = width;
do {
src_ptr += 16;
s[5] = vld1q_u8(src_ptr);
s[1] = vextq_u8(s[0], s[5], 1);
s[2] = vextq_u8(s[0], s[5], 2);
s[3] = vextq_u8(s[0], s[5], 3);
s[4] = vextq_u8(s[0], s[5], 4);
int16x8x2_t sum;
sum.val[0] = sum.val[1] = vdupq_n_s16(0);
sum = WienerHorizontal2(s[0], s[4], filter[1], sum);
WienerHorizontalSum(s + 1, filter, sum, *wiener_buffer);
s[0] = s[5];
*wiener_buffer += 16;
x -= 16;
} while (x != 0);
src += src_stride;
} while (--y != 0);
}
inline void WienerHorizontalTap3(const uint8_t* src, const ptrdiff_t src_stride,
const ptrdiff_t width, const int height,
const int16_t filter[4],
int16_t** const wiener_buffer) {
int y = height;
do {
const uint8_t* src_ptr = src;
uint8x16_t s[4];
s[0] = vld1q_u8(src_ptr);
ptrdiff_t x = width;
do {
src_ptr += 16;
s[3] = vld1q_u8(src_ptr);
s[1] = vextq_u8(s[0], s[3], 1);
s[2] = vextq_u8(s[0], s[3], 2);
int16x8x2_t sum;
sum.val[0] = sum.val[1] = vdupq_n_s16(0);
WienerHorizontalSum(s, filter, sum, *wiener_buffer);
s[0] = s[3];
*wiener_buffer += 16;
x -= 16;
} while (x != 0);
src += src_stride;
} while (--y != 0);
}
inline void WienerHorizontalTap1(const uint8_t* src, const ptrdiff_t src_stride,
const ptrdiff_t width, const int height,
int16_t** const wiener_buffer) {
int y = height;
do {
const uint8_t* src_ptr = src;
ptrdiff_t x = width;
do {
const uint8x16_t s = vld1q_u8(src_ptr);
const uint8x8_t s0 = vget_low_u8(s);
const uint8x8_t s1 = vget_high_u8(s);
const int16x8_t d0 = vreinterpretq_s16_u16(vshll_n_u8(s0, 4));
const int16x8_t d1 = vreinterpretq_s16_u16(vshll_n_u8(s1, 4));
vst1q_s16(*wiener_buffer + 0, d0);
vst1q_s16(*wiener_buffer + 8, d1);
src_ptr += 16;
*wiener_buffer += 16;
x -= 16;
} while (x != 0);
src += src_stride;
} while (--y != 0);
}
inline int32x4x2_t WienerVertical2(const int16x8_t a0, const int16x8_t a1,
const int16_t filter,
const int32x4x2_t sum) {
const int16x8_t a = vaddq_s16(a0, a1);
int32x4x2_t d;
d.val[0] = vmlal_n_s16(sum.val[0], vget_low_s16(a), filter);
d.val[1] = vmlal_n_s16(sum.val[1], vget_high_s16(a), filter);
return d;
}
inline uint8x8_t WienerVertical(const int16x8_t a[3], const int16_t filter[4],
const int32x4x2_t sum) {
int32x4x2_t d = WienerVertical2(a[0], a[2], filter[2], sum);
d.val[0] = vmlal_n_s16(d.val[0], vget_low_s16(a[1]), filter[3]);
d.val[1] = vmlal_n_s16(d.val[1], vget_high_s16(a[1]), filter[3]);
const uint16x4_t sum_lo_16 = vqrshrun_n_s32(d.val[0], 11);
const uint16x4_t sum_hi_16 = vqrshrun_n_s32(d.val[1], 11);
return vqmovn_u16(vcombine_u16(sum_lo_16, sum_hi_16));
}
inline uint8x8_t WienerVerticalTap7Kernel(const int16_t* const wiener_buffer,
const ptrdiff_t wiener_stride,
const int16_t filter[4],
int16x8_t a[7]) {
int32x4x2_t sum;
a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
a[5] = vld1q_s16(wiener_buffer + 5 * wiener_stride);
a[6] = vld1q_s16(wiener_buffer + 6 * wiener_stride);
sum.val[0] = sum.val[1] = vdupq_n_s32(0);
sum = WienerVertical2(a[0], a[6], filter[0], sum);
sum = WienerVertical2(a[1], a[5], filter[1], sum);
a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
a[4] = vld1q_s16(wiener_buffer + 4 * wiener_stride);
return WienerVertical(a + 2, filter, sum);
}
inline uint8x8x2_t WienerVerticalTap7Kernel2(const int16_t* const wiener_buffer,
const ptrdiff_t wiener_stride,
const int16_t filter[4]) {
int16x8_t a[8];
int32x4x2_t sum;
uint8x8x2_t d;
d.val[0] = WienerVerticalTap7Kernel(wiener_buffer, wiener_stride, filter, a);
a[7] = vld1q_s16(wiener_buffer + 7 * wiener_stride);
sum.val[0] = sum.val[1] = vdupq_n_s32(0);
sum = WienerVertical2(a[1], a[7], filter[0], sum);
sum = WienerVertical2(a[2], a[6], filter[1], sum);
d.val[1] = WienerVertical(a + 3, filter, sum);
return d;
}
inline void WienerVerticalTap7(const int16_t* wiener_buffer,
const ptrdiff_t width, const int height,
const int16_t filter[4], uint8_t* dst,
const ptrdiff_t dst_stride) {
for (int y = height >> 1; y != 0; --y) {
uint8_t* dst_ptr = dst;
ptrdiff_t x = width;
do {
uint8x8x2_t d[2];
d[0] = WienerVerticalTap7Kernel2(wiener_buffer + 0, width, filter);
d[1] = WienerVerticalTap7Kernel2(wiener_buffer + 8, width, filter);
vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0]));
vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1]));
wiener_buffer += 16;
dst_ptr += 16;
x -= 16;
} while (x != 0);
wiener_buffer += width;
dst += 2 * dst_stride;
}
if ((height & 1) != 0) {
ptrdiff_t x = width;
do {
int16x8_t a[7];
const uint8x8_t d0 =
WienerVerticalTap7Kernel(wiener_buffer + 0, width, filter, a);
const uint8x8_t d1 =
WienerVerticalTap7Kernel(wiener_buffer + 8, width, filter, a);
vst1q_u8(dst, vcombine_u8(d0, d1));
wiener_buffer += 16;
dst += 16;
x -= 16;
} while (x != 0);
}
}
inline uint8x8_t WienerVerticalTap5Kernel(const int16_t* const wiener_buffer,
const ptrdiff_t wiener_stride,
const int16_t filter[4],
int16x8_t a[5]) {
a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
a[4] = vld1q_s16(wiener_buffer + 4 * wiener_stride);
int32x4x2_t sum;
sum.val[0] = sum.val[1] = vdupq_n_s32(0);
sum = WienerVertical2(a[0], a[4], filter[1], sum);
return WienerVertical(a + 1, filter, sum);
}
inline uint8x8x2_t WienerVerticalTap5Kernel2(const int16_t* const wiener_buffer,
const ptrdiff_t wiener_stride,
const int16_t filter[4]) {
int16x8_t a[6];
int32x4x2_t sum;
uint8x8x2_t d;
d.val[0] = WienerVerticalTap5Kernel(wiener_buffer, wiener_stride, filter, a);
a[5] = vld1q_s16(wiener_buffer + 5 * wiener_stride);
sum.val[0] = sum.val[1] = vdupq_n_s32(0);
sum = WienerVertical2(a[1], a[5], filter[1], sum);
d.val[1] = WienerVertical(a + 2, filter, sum);
return d;
}
inline void WienerVerticalTap5(const int16_t* wiener_buffer,
const ptrdiff_t width, const int height,
const int16_t filter[4], uint8_t* dst,
const ptrdiff_t dst_stride) {
for (int y = height >> 1; y != 0; --y) {
uint8_t* dst_ptr = dst;
ptrdiff_t x = width;
do {
uint8x8x2_t d[2];
d[0] = WienerVerticalTap5Kernel2(wiener_buffer + 0, width, filter);
d[1] = WienerVerticalTap5Kernel2(wiener_buffer + 8, width, filter);
vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0]));
vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1]));
wiener_buffer += 16;
dst_ptr += 16;
x -= 16;
} while (x != 0);
wiener_buffer += width;
dst += 2 * dst_stride;
}
if ((height & 1) != 0) {
ptrdiff_t x = width;
do {
int16x8_t a[5];
const uint8x8_t d0 =
WienerVerticalTap5Kernel(wiener_buffer + 0, width, filter, a);
const uint8x8_t d1 =
WienerVerticalTap5Kernel(wiener_buffer + 8, width, filter, a);
vst1q_u8(dst, vcombine_u8(d0, d1));
wiener_buffer += 16;
dst += 16;
x -= 16;
} while (x != 0);
}
}
inline uint8x8_t WienerVerticalTap3Kernel(const int16_t* const wiener_buffer,
const ptrdiff_t wiener_stride,
const int16_t filter[4],
int16x8_t a[3]) {
a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
int32x4x2_t sum;
sum.val[0] = sum.val[1] = vdupq_n_s32(0);
return WienerVertical(a, filter, sum);
}
inline uint8x8x2_t WienerVerticalTap3Kernel2(const int16_t* const wiener_buffer,
const ptrdiff_t wiener_stride,
const int16_t filter[4]) {
int16x8_t a[4];
int32x4x2_t sum;
uint8x8x2_t d;
d.val[0] = WienerVerticalTap3Kernel(wiener_buffer, wiener_stride, filter, a);
a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
sum.val[0] = sum.val[1] = vdupq_n_s32(0);
d.val[1] = WienerVertical(a + 1, filter, sum);
return d;
}
inline void WienerVerticalTap3(const int16_t* wiener_buffer,
const ptrdiff_t width, const int height,
const int16_t filter[4], uint8_t* dst,
const ptrdiff_t dst_stride) {
for (int y = height >> 1; y != 0; --y) {
uint8_t* dst_ptr = dst;
ptrdiff_t x = width;
do {
uint8x8x2_t d[2];
d[0] = WienerVerticalTap3Kernel2(wiener_buffer + 0, width, filter);
d[1] = WienerVerticalTap3Kernel2(wiener_buffer + 8, width, filter);
vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0]));
vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1]));
wiener_buffer += 16;
dst_ptr += 16;
x -= 16;
} while (x != 0);
wiener_buffer += width;
dst += 2 * dst_stride;
}
if ((height & 1) != 0) {
ptrdiff_t x = width;
do {
int16x8_t a[3];
const uint8x8_t d0 =
WienerVerticalTap3Kernel(wiener_buffer + 0, width, filter, a);
const uint8x8_t d1 =
WienerVerticalTap3Kernel(wiener_buffer + 8, width, filter, a);
vst1q_u8(dst, vcombine_u8(d0, d1));
wiener_buffer += 16;
dst += 16;
x -= 16;
} while (x != 0);
}
}
inline void WienerVerticalTap1Kernel(const int16_t* const wiener_buffer,
uint8_t* const dst) {
const int16x8_t a0 = vld1q_s16(wiener_buffer + 0);
const int16x8_t a1 = vld1q_s16(wiener_buffer + 8);
const uint8x8_t d0 = vqrshrun_n_s16(a0, 4);
const uint8x8_t d1 = vqrshrun_n_s16(a1, 4);
vst1q_u8(dst, vcombine_u8(d0, d1));
}
inline void WienerVerticalTap1(const int16_t* wiener_buffer,
const ptrdiff_t width, const int height,
uint8_t* dst, const ptrdiff_t dst_stride) {
for (int y = height >> 1; y != 0; --y) {
uint8_t* dst_ptr = dst;
ptrdiff_t x = width;
do {
WienerVerticalTap1Kernel(wiener_buffer, dst_ptr);
WienerVerticalTap1Kernel(wiener_buffer + width, dst_ptr + dst_stride);
wiener_buffer += 16;
dst_ptr += 16;
x -= 16;
} while (x != 0);
wiener_buffer += width;
dst += 2 * dst_stride;
}
if ((height & 1) != 0) {
ptrdiff_t x = width;
do {
WienerVerticalTap1Kernel(wiener_buffer, dst);
wiener_buffer += 16;
dst += 16;
x -= 16;
} while (x != 0);
}
}
// For width 16 and up, store the horizontal results, and then do the vertical
// filter row by row. This is faster than doing it column by column when
// considering cache issues.
void WienerFilter_NEON(const void* const source, void* const dest,
const RestorationUnitInfo& restoration_info,
const ptrdiff_t source_stride,
const ptrdiff_t dest_stride, const int width,
const int height, RestorationBuffer* const buffer) {
constexpr int kCenterTap = kWienerFilterTaps / 2;
const int16_t* const number_leading_zero_coefficients =
restoration_info.wiener_info.number_leading_zero_coefficients;
const int number_rows_to_skip = std::max(
static_cast<int>(number_leading_zero_coefficients[WienerInfo::kVertical]),
1);
const ptrdiff_t wiener_stride = Align(width, 16);
int16_t* const wiener_buffer_vertical = buffer->wiener_buffer;
// The values are saturated to 13 bits before storing.
int16_t* wiener_buffer_horizontal =
wiener_buffer_vertical + number_rows_to_skip * wiener_stride;
int16_t filter_horizontal[(kWienerFilterTaps + 1) / 2];
int16_t filter_vertical[(kWienerFilterTaps + 1) / 2];
PopulateWienerCoefficients(restoration_info, WienerInfo::kHorizontal,
filter_horizontal);
PopulateWienerCoefficients(restoration_info, WienerInfo::kVertical,
filter_vertical);
// horizontal filtering.
// Over-reads up to 15 - |kRestorationHorizontalBorder| values.
const int height_horizontal =
height + kWienerFilterTaps - 1 - 2 * number_rows_to_skip;
const auto* const src = static_cast<const uint8_t*>(source) -
(kCenterTap - number_rows_to_skip) * source_stride;
if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) {
WienerHorizontalTap7(src - 3, source_stride, wiener_stride,
height_horizontal, filter_horizontal,
&wiener_buffer_horizontal);
} else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) {
WienerHorizontalTap5(src - 2, source_stride, wiener_stride,
height_horizontal, filter_horizontal,
&wiener_buffer_horizontal);
} else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) {
// The maximum over-reads happen here.
WienerHorizontalTap3(src - 1, source_stride, wiener_stride,
height_horizontal, filter_horizontal,
&wiener_buffer_horizontal);
} else {
assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3);
WienerHorizontalTap1(src, source_stride, wiener_stride, height_horizontal,
&wiener_buffer_horizontal);
}
// vertical filtering.
// Over-writes up to 15 values.
auto* dst = static_cast<uint8_t*>(dest);
if (number_leading_zero_coefficients[WienerInfo::kVertical] == 0) {
// Because the top row of |source| is a duplicate of the second row, and the
// bottom row of |source| is a duplicate of its above row, we can duplicate
// the top and bottom row of |wiener_buffer| accordingly.
memcpy(wiener_buffer_horizontal, wiener_buffer_horizontal - wiener_stride,
sizeof(*wiener_buffer_horizontal) * wiener_stride);
memcpy(buffer->wiener_buffer, buffer->wiener_buffer + wiener_stride,
sizeof(*buffer->wiener_buffer) * wiener_stride);
WienerVerticalTap7(wiener_buffer_vertical, wiener_stride, height,
filter_vertical, dst, dest_stride);
} else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 1) {
WienerVerticalTap5(wiener_buffer_vertical + wiener_stride, wiener_stride,
height, filter_vertical, dst, dest_stride);
} else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 2) {
WienerVerticalTap3(wiener_buffer_vertical + 2 * wiener_stride,
wiener_stride, height, filter_vertical, dst,
dest_stride);
} else {
assert(number_leading_zero_coefficients[WienerInfo::kVertical] == 3);
WienerVerticalTap1(wiener_buffer_vertical + 3 * wiener_stride,
wiener_stride, height, dst, dest_stride);
}
}
//------------------------------------------------------------------------------
// SGR
template <int n>
inline uint16x4_t CalculateMa(const uint16x4_t sum, const uint32x4_t sum_sq,
const uint32_t scale) {
// a = |sum_sq|
// d = |sum|
// p = (a * n < d * d) ? 0 : a * n - d * d;
const uint32x4_t dxd = vmull_u16(sum, sum);
const uint32x4_t axn = vmulq_n_u32(sum_sq, n);
// Ensure |p| does not underflow by using saturating subtraction.
const uint32x4_t p = vqsubq_u32(axn, dxd);
// z = RightShiftWithRounding(p * scale, kSgrProjScaleBits);
const uint32x4_t pxs = vmulq_n_u32(p, scale);
// vrshrn_n_u32() (narrowing shift) can only shift by 16 and kSgrProjScaleBits
// is 20.
const uint32x4_t shifted = vrshrq_n_u32(pxs, kSgrProjScaleBits);
return vmovn_u32(shifted);
}
inline void Prepare3_8(const uint8x8x2_t src, uint8x8_t dst[3]) {
dst[0] = VshrU128<0>(src);
dst[1] = VshrU128<1>(src);
dst[2] = VshrU128<2>(src);
}
inline void Prepare3_16(const uint16x8x2_t src, uint16x4_t low[3],
uint16x4_t high[3]) {
uint16x8_t s[3];
s[0] = VshrU128<0>(src);
s[1] = VshrU128<2>(src);
s[2] = VshrU128<4>(src);
low[0] = vget_low_u16(s[0]);
low[1] = vget_low_u16(s[1]);
low[2] = vget_low_u16(s[2]);
high[0] = vget_high_u16(s[0]);
high[1] = vget_high_u16(s[1]);
high[2] = vget_high_u16(s[2]);
}
inline void Prepare5_8(const uint8x8x2_t src, uint8x8_t dst[5]) {
dst[0] = VshrU128<0>(src);
dst[1] = VshrU128<1>(src);
dst[2] = VshrU128<2>(src);
dst[3] = VshrU128<3>(src);
dst[4] = VshrU128<4>(src);
}
inline void Prepare5_16(const uint16x8x2_t src, uint16x4_t low[5],
uint16x4_t high[5]) {
Prepare3_16(src, low, high);
const uint16x8_t s3 = VshrU128<6>(src);
const uint16x8_t s4 = VshrU128<8>(src);
low[3] = vget_low_u16(s3);
low[4] = vget_low_u16(s4);
high[3] = vget_high_u16(s3);
high[4] = vget_high_u16(s4);
}
inline uint16x8_t Sum3_16(const uint16x8_t src0, const uint16x8_t src1,
const uint16x8_t src2) {
const uint16x8_t sum = vaddq_u16(src0, src1);
return vaddq_u16(sum, src2);
}
inline uint16x8_t Sum3_16(const uint16x8_t src[3]) {
return Sum3_16(src[0], src[1], src[2]);
}
inline uint32x4_t Sum3_32(const uint32x4_t src0, const uint32x4_t src1,
const uint32x4_t src2) {
const uint32x4_t sum = vaddq_u32(src0, src1);
return vaddq_u32(sum, src2);
}
inline uint32x4x2_t Sum3_32(const uint32x4x2_t src[3]) {
uint32x4x2_t d;
d.val[0] = Sum3_32(src[0].val[0], src[1].val[0], src[2].val[0]);
d.val[1] = Sum3_32(src[0].val[1], src[1].val[1], src[2].val[1]);
return d;
}
inline uint16x8_t Sum3W_16(const uint8x8_t src[3]) {
const uint16x8_t sum = vaddl_u8(src[0], src[1]);
return vaddw_u8(sum, src[2]);
}
inline uint32x4_t Sum3W_32(const uint16x4_t src[3]) {
const uint32x4_t sum = vaddl_u16(src[0], src[1]);
return vaddw_u16(sum, src[2]);
}
inline uint16x8_t Sum5_16(const uint16x8_t src[5]) {
const uint16x8_t sum01 = vaddq_u16(src[0], src[1]);
const uint16x8_t sum23 = vaddq_u16(src[2], src[3]);
const uint16x8_t sum = vaddq_u16(sum01, sum23);
return vaddq_u16(sum, src[4]);
}
inline uint32x4_t Sum5_32(const uint32x4_t src0, const uint32x4_t src1,
const uint32x4_t src2, const uint32x4_t src3,
const uint32x4_t src4) {
const uint32x4_t sum01 = vaddq_u32(src0, src1);
const uint32x4_t sum23 = vaddq_u32(src2, src3);
const uint32x4_t sum = vaddq_u32(sum01, sum23);
return vaddq_u32(sum, src4);
}
inline uint32x4x2_t Sum5_32(const uint32x4x2_t src[5]) {
uint32x4x2_t d;
d.val[0] = Sum5_32(src[0].val[0], src[1].val[0], src[2].val[0], src[3].val[0],
src[4].val[0]);
d.val[1] = Sum5_32(src[0].val[1], src[1].val[1], src[2].val[1], src[3].val[1],
src[4].val[1]);
return d;
}
inline uint32x4_t Sum5W_32(const uint16x4_t src[5]) {
const uint32x4_t sum01 = vaddl_u16(src[0], src[1]);
const uint32x4_t sum23 = vaddl_u16(src[2], src[3]);
const uint32x4_t sum0123 = vaddq_u32(sum01, sum23);
return vaddw_u16(sum0123, src[4]);
}
inline uint16x8_t Sum3Horizontal(const uint8x8x2_t src) {
uint8x8_t s[3];
Prepare3_8(src, s);
return Sum3W_16(s);
}
inline uint32x4x2_t Sum3WHorizontal(const uint16x8x2_t src) {
uint16x4_t low[3], high[3];
uint32x4x2_t sum;
Prepare3_16(src, low, high);
sum.val[0] = Sum3W_32(low);
sum.val[1] = Sum3W_32(high);
return sum;
}
inline uint16x8_t Sum5Horizontal(const uint8x8x2_t src) {
uint8x8_t s[5];
Prepare5_8(src, s);
const uint16x8_t sum01 = vaddl_u8(s[0], s[1]);
const uint16x8_t sum23 = vaddl_u8(s[2], s[3]);
const uint16x8_t sum0123 = vaddq_u16(sum01, sum23);
return vaddw_u8(sum0123, s[4]);
}
inline uint32x4x2_t Sum5WHorizontal(const uint16x8x2_t src) {
uint16x4_t low[5], high[5];
Prepare5_16(src, low, high);
uint32x4x2_t sum;
sum.val[0] = Sum5W_32(low);
sum.val[1] = Sum5W_32(high);
return sum;
}
void SumHorizontal(const uint16x4_t src[5], uint32x4_t* const row_sq3,
uint32x4_t* const row_sq5) {
const uint32x4_t sum04 = vaddl_u16(src[0], src[4]);
const uint32x4_t sum12 = vaddl_u16(src[1], src[2]);
*row_sq3 = vaddw_u16(sum12, src[3]);
*row_sq5 = vaddq_u32(sum04, *row_sq3);
}
void SumHorizontal(const uint8x8x2_t src, const uint16x8x2_t sq,
uint16x8_t* const row3, uint16x8_t* const row5,
uint32x4x2_t* const row_sq3, uint32x4x2_t* const row_sq5) {
uint8x8_t s[5];
Prepare5_8(src, s);
const uint16x8_t sum04 = vaddl_u8(s[0], s[4]);
const uint16x8_t sum12 = vaddl_u8(s[1], s[2]);
*row3 = vaddw_u8(sum12, s[3]);
*row5 = vaddq_u16(sum04, *row3);
uint16x4_t low[5], high[5];
Prepare5_16(sq, low, high);
SumHorizontal(low, &row_sq3->val[0], &row_sq5->val[0]);
SumHorizontal(high, &row_sq3->val[1], &row_sq5->val[1]);
}
inline uint16x8_t Sum343(const uint8x8x2_t src) {
uint8x8_t s[3];
Prepare3_8(src, s);
const uint16x8_t sum = Sum3W_16(s);
const uint16x8_t sum3 = Sum3_16(sum, sum, sum);
return vaddw_u8(sum3, s[1]);
}
inline uint32x4_t Sum343W(const uint16x4_t src[3]) {
const uint32x4_t sum = Sum3W_32(src);
const uint32x4_t sum3 = Sum3_32(sum, sum, sum);
return vaddw_u16(sum3, src[1]);
}
inline uint32x4x2_t Sum343W(const uint16x8x2_t src) {
uint16x4_t low[3], high[3];
uint32x4x2_t d;
Prepare3_16(src, low, high);
d.val[0] = Sum343W(low);
d.val[1] = Sum343W(high);
return d;
}
inline uint16x8_t Sum565(const uint8x8x2_t src) {
uint8x8_t s[3];
Prepare3_8(src, s);
const uint16x8_t sum = Sum3W_16(s);
const uint16x8_t sum4 = vshlq_n_u16(sum, 2);
const uint16x8_t sum5 = vaddq_u16(sum4, sum);
return vaddw_u8(sum5, s[1]);
}
inline uint32x4_t Sum565W(const uint16x4_t src[3]) {
const uint32x4_t sum = Sum3W_32(src);
const uint32x4_t sum4 = vshlq_n_u32(sum, 2);
const uint32x4_t sum5 = vaddq_u32(sum4, sum);
return vaddw_u16(sum5, src[1]);
}
inline uint32x4x2_t Sum565W(const uint16x8x2_t src) {
uint16x4_t low[3], high[3];
uint32x4x2_t d;
Prepare3_16(src, low, high);
d.val[0] = Sum565W(low);
d.val[1] = Sum565W(high);
return d;
}
inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3,
const ptrdiff_t x, uint16x8_t* const sum_ma343,
uint16x8_t* const sum_ma444,
uint32x4x2_t* const sum_b343,
uint32x4x2_t* const sum_b444, uint16_t* const ma343,
uint16_t* const ma444, uint32_t* const b343,
uint32_t* const b444) {
uint8x8_t s[3];
Prepare3_8(ma3, s);
const uint16x8_t sum_ma111 = Sum3W_16(s);
*sum_ma444 = vshlq_n_u16(sum_ma111, 2);
const uint16x8_t sum333 = vsubq_u16(*sum_ma444, sum_ma111);
*sum_ma343 = vaddw_u8(sum333, s[1]);
uint16x4_t low[3], high[3];
uint32x4x2_t sum_b111;
Prepare3_16(b3, low, high);
sum_b111.val[0] = Sum3W_32(low);
sum_b111.val[1] = Sum3W_32(high);
sum_b444->val[0] = vshlq_n_u32(sum_b111.val[0], 2);
sum_b444->val[1] = vshlq_n_u32(sum_b111.val[1], 2);
sum_b343->val[0] = vsubq_u32(sum_b444->val[0], sum_b111.val[0]);
sum_b343->val[1] = vsubq_u32(sum_b444->val[1], sum_b111.val[1]);
sum_b343->val[0] = vaddw_u16(sum_b343->val[0], low[1]);
sum_b343->val[1] = vaddw_u16(sum_b343->val[1], high[1]);
vst1q_u16(ma343 + x, *sum_ma343);
vst1q_u16(ma444 + x, *sum_ma444);
vst1q_u32(b343 + x + 0, (*sum_b343).val[0]);
vst1q_u32(b343 + x + 4, (*sum_b343).val[1]);
vst1q_u32(b444 + x + 0, (*sum_b444).val[0]);
vst1q_u32(b444 + x + 4, (*sum_b444).val[1]);
}
inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3,
const ptrdiff_t x, uint16x8_t* const sum_ma343,
uint32x4x2_t* const sum_b343, uint16_t* const ma343,
uint16_t* const ma444, uint32_t* const b343,
uint32_t* const b444) {
uint16x8_t sum_ma444;
uint32x4x2_t sum_b444;
Store343_444(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, &sum_b444, ma343,
ma444, b343, b444);
}
inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3,
const ptrdiff_t x, uint16_t* const ma343,
uint16_t* const ma444, uint32_t* const b343,
uint32_t* const b444) {
uint16x8_t sum_ma343;
uint32x4x2_t sum_b343;
Store343_444(ma3, b3, x, &sum_ma343, &sum_b343, ma343, ma444, b343, b444);
}
template <int shift>
inline int16x4_t FilterOutput(const uint16x4_t src, const uint16x4_t ma,
const uint32x4_t b) {
// ma: 255 * 32 = 8160 (13 bits)
// b: 65088 * 32 = 2082816 (21 bits)
// v: b - ma * 255 (22 bits)
const int32x4_t v = vreinterpretq_s32_u32(vmlsl_u16(b, ma, src));
// kSgrProjSgrBits = 8
// kSgrProjRestoreBits = 4
// shift = 4 or 5
// v >> 8 or 9 (13 bits)
return vrshrn_n_s32(v, kSgrProjSgrBits + shift - kSgrProjRestoreBits);
}
template <int shift>
inline int16x8_t CalculateFilteredOutput(const uint8x8_t src,
const uint16x8_t ma,
const uint32x4x2_t b) {
const uint16x8_t src_u16 = vmovl_u8(src);
const int16x4_t dst_lo =
FilterOutput<shift>(vget_low_u16(src_u16), vget_low_u16(ma), b.val[0]);
const int16x4_t dst_hi =
FilterOutput<shift>(vget_high_u16(src_u16), vget_high_u16(ma), b.val[1]);
return vcombine_s16(dst_lo, dst_hi); // 13 bits
}
inline int16x8_t CalculateFilteredOutputPass1(const uint8x8_t s,
uint16x8_t ma[2],
uint32x4x2_t b[2]) {
const uint16x8_t ma_sum = vaddq_u16(ma[0], ma[1]);
uint32x4x2_t b_sum;
b_sum.val[0] = vaddq_u32(b[0].val[0], b[1].val[0]);
b_sum.val[1] = vaddq_u32(b[0].val[1], b[1].val[1]);
return CalculateFilteredOutput<5>(s, ma_sum, b_sum);
}
inline int16x8_t CalculateFilteredOutputPass2(const uint8x8_t s,
uint16x8_t ma[3],
uint32x4x2_t b[3]) {
const uint16x8_t ma_sum = Sum3_16(ma);
const uint32x4x2_t b_sum = Sum3_32(b);
return CalculateFilteredOutput<5>(s, ma_sum, b_sum);
}
inline void SelfGuidedFinal(const uint8x8_t src, const int32x4_t v[2],
uint8_t* const dst) {
const int16x4_t v_lo =
vrshrn_n_s32(v[0], kSgrProjRestoreBits + kSgrProjPrecisionBits);
const int16x4_t v_hi =
vrshrn_n_s32(v[1], kSgrProjRestoreBits + kSgrProjPrecisionBits);
const int16x8_t vv = vcombine_s16(v_lo, v_hi);
const int16x8_t s = ZeroExtend(src);
const int16x8_t d = vaddq_s16(s, vv);
vst1_u8(dst, vqmovun_s16(d));
}
inline void SelfGuidedDoubleMultiplier(const uint8x8_t src,
const int16x8_t filter[2], const int w0,
const int w2, uint8_t* const dst) {
int32x4_t v[2];
v[0] = vmull_n_s16(vget_low_s16(filter[0]), w0);
v[1] = vmull_n_s16(vget_high_s16(filter[0]), w0);
v[0] = vmlal_n_s16(v[0], vget_low_s16(filter[1]), w2);
v[1] = vmlal_n_s16(v[1], vget_high_s16(filter[1]), w2);
SelfGuidedFinal(src, v, dst);
}
inline void SelfGuidedSingleMultiplier(const uint8x8_t src,
const int16x8_t filter, const int w0,
uint8_t* const dst) {
// weight: -96 to 96 (Sgrproj_Xqd_Min/Max)
int32x4_t v[2];
v[0] = vmull_n_s16(vget_low_s16(filter), w0);
v[1] = vmull_n_s16(vget_high_s16(filter), w0);
SelfGuidedFinal(src, v, dst);
}
inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride,
const int height, const ptrdiff_t width, uint16_t* sum3,
uint16_t* sum5, uint32_t* square_sum3,
uint32_t* square_sum5) {
int y = height;
do {
uint8x8x2_t s;
uint16x8x2_t sq;
s.val[0] = vld1_u8(src);
sq.val[0] = vmull_u8(s.val[0], s.val[0]);
ptrdiff_t x = 0;
do {
uint16x8_t row3, row5;
uint32x4x2_t row_sq3, row_sq5;
s.val[1] = vld1_u8(src + x + 8);
sq.val[1] = vmull_u8(s.val[1], s.val[1]);
SumHorizontal(s, sq, &row3, &row5, &row_sq3, &row_sq5);
vst1q_u16(sum3, row3);
vst1q_u16(sum5, row5);
vst1q_u32(square_sum3 + 0, row_sq3.val[0]);
vst1q_u32(square_sum3 + 4, row_sq3.val[1]);
vst1q_u32(square_sum5 + 0, row_sq5.val[0]);
vst1q_u32(square_sum5 + 4, row_sq5.val[1]);
s.val[0] = s.val[1];
sq.val[0] = sq.val[1];
sum3 += 8;
sum5 += 8;
square_sum3 += 8;
square_sum5 += 8;
x += 8;
} while (x < width);
src += src_stride;
} while (--y != 0);
}
template <int size>
inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride,
const int height, const ptrdiff_t width, uint16_t* sums,
uint32_t* square_sums) {
static_assert(size == 3 || size == 5, "");
int y = height;
do {
uint8x8x2_t s;
uint16x8x2_t sq;
s.val[0] = vld1_u8(src);
sq.val[0] = vmull_u8(s.val[0], s.val[0]);
ptrdiff_t x = 0;
do {
uint16x8_t row;
uint32x4x2_t row_sq;
s.val[1] = vld1_u8(src + x + 8);
sq.val[1] = vmull_u8(s.val[1], s.val[1]);
if (size == 3) {
row = Sum3Horizontal(s);
row_sq = Sum3WHorizontal(sq);
} else {
row = Sum5Horizontal(s);
row_sq = Sum5WHorizontal(sq);
}
vst1q_u16(sums, row);
vst1q_u32(square_sums + 0, row_sq.val[0]);
vst1q_u32(square_sums + 4, row_sq.val[1]);
s.val[0] = s.val[1];
sq.val[0] = sq.val[1];
sums += 8;
square_sums += 8;
x += 8;
} while (x < width);
src += src_stride;
} while (--y != 0);
}
template <int n>
inline void CalculateIntermediate(const uint16x8_t sum,
const uint32x4x2_t sum_sq,
const uint32_t scale, uint8x8_t* const ma,
uint16x8_t* const b) {
constexpr uint32_t one_over_n =
((1 << kSgrProjReciprocalBits) + (n >> 1)) / n;
const uint16x4_t z0 = CalculateMa<n>(vget_low_u16(sum), sum_sq.val[0], scale);
const uint16x4_t z1 =
CalculateMa<n>(vget_high_u16(sum), sum_sq.val[1], scale);
const uint16x8_t z01 = vcombine_u16(z0, z1);
// Using vqmovn_u16() needs an extra sign extension instruction.
const uint16x8_t z = vminq_u16(z01, vdupq_n_u16(255));
// Using vgetq_lane_s16() can save the sign extension instruction.
const uint8_t lookup[8] = {
kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 0)],
kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 1)],
kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 2)],
kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 3)],
kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 4)],
kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 5)],
kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 6)],
kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 7)]};
*ma = vld1_u8(lookup);
// b = ma * b * one_over_n
// |ma| = [0, 255]
// |sum| is a box sum with radius 1 or 2.
// For the first pass radius is 2. Maximum value is 5x5x255 = 6375.
// For the second pass radius is 1. Maximum value is 3x3x255 = 2295.
// |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n
// When radius is 2 |n| is 25. |one_over_n| is 164.
// When radius is 1 |n| is 9. |one_over_n| is 455.
// |kSgrProjReciprocalBits| is 12.
// Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits).
// Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits).
const uint16x8_t maq = vmovl_u8(*ma);
const uint32x4_t m0 = vmull_u16(vget_low_u16(maq), vget_low_u16(sum));
const uint32x4_t m1 = vmull_u16(vget_high_u16(maq), vget_high_u16(sum));
const uint32x4_t m2 = vmulq_n_u32(m0, one_over_n);
const uint32x4_t m3 = vmulq_n_u32(m1, one_over_n);
const uint16x4_t b_lo = vrshrn_n_u32(m2, kSgrProjReciprocalBits);
const uint16x4_t b_hi = vrshrn_n_u32(m3, kSgrProjReciprocalBits);
*b = vcombine_u16(b_lo, b_hi);
}
inline void CalculateIntermediate5(const uint16x8_t s5[5],
const uint32x4x2_t sq5[5],
const uint32_t scale, uint8x8_t* const ma,
uint16x8_t* const b) {
const uint16x8_t sum = Sum5_16(s5);
const uint32x4x2_t sum_sq = Sum5_32(sq5);
CalculateIntermediate<25>(sum, sum_sq, scale, ma, b);
}
inline void CalculateIntermediate3(const uint16x8_t s3[3],
const uint32x4x2_t sq3[3],
const uint32_t scale, uint8x8_t* const ma,
uint16x8_t* const b) {
const uint16x8_t sum = Sum3_16(s3);
const uint32x4x2_t sum_sq = Sum3_32(sq3);
CalculateIntermediate<9>(sum, sum_sq, scale, ma, b);
}
LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5(
const uint8_t* const src, const ptrdiff_t src_stride, const ptrdiff_t x,
const uint32_t scale, uint8x8x2_t s[2], uint16x8x2_t sq[2],
uint16_t* const sum5[5], uint32_t* const square_sum5[5],
uint8x8_t* const ma, uint16x8_t* const b) {
uint16x8_t s5[5];
uint32x4x2_t sq5[5];
s[0].val[1] = vld1_u8(src + x + 8);
s[1].val[1] = vld1_u8(src + src_stride + x + 8);
sq[0].val[1] = vmull_u8(s[0].val[1], s[0].val[1]);
sq[1].val[1] = vmull_u8(s[1].val[1], s[1].val[1]);
s5[3] = Sum5Horizontal(s[0]);
s5[4] = Sum5Horizontal(s[1]);
sq5[3] = Sum5WHorizontal(sq[0]);
sq5[4] = Sum5WHorizontal(sq[1]);
vst1q_u16(sum5[3] + x, s5[3]);
vst1q_u16(sum5[4] + x, s5[4]);
vst1q_u32(square_sum5[3] + x + 0, sq5[3].val[0]);
vst1q_u32(square_sum5[3] + x + 4, sq5[3].val[1]);
vst1q_u32(square_sum5[4] + x + 0, sq5[4].val[0]);
vst1q_u32(square_sum5[4] + x + 4, sq5[4].val[1]);
s5[0] = vld1q_u16(sum5[0] + x);
s5[1] = vld1q_u16(sum5[1] + x);
s5[2] = vld1q_u16(sum5[2] + x);
sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
CalculateIntermediate5(s5, sq5, scale, ma, b);
}
LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRow(
const uint8_t* const src, const ptrdiff_t x, const uint32_t scale,
uint8x8x2_t* const s, uint16x8x2_t* const sq, uint16_t* const sum5[5],
uint32_t* const square_sum5[5], uint8x8_t* const ma, uint16x8_t* const b) {
uint16x8_t s5[5];
uint32x4x2_t sq5[5];
s->val[1] = vld1_u8(src + x + 8);
sq->val[1] = vmull_u8(s->val[1], s->val[1]);
s5[3] = s5[4] = Sum5Horizontal(*s);
sq5[3] = sq5[4] = Sum5WHorizontal(*sq);
s5[0] = vld1q_u16(sum5[0] + x);
s5[1] = vld1q_u16(sum5[1] + x);
s5[2] = vld1q_u16(sum5[2] + x);
sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
CalculateIntermediate5(s5, sq5, scale, ma, b);
}
LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3(
const uint8_t* const src, const ptrdiff_t x, const uint32_t scale,
uint8x8x2_t* const s, uint16x8x2_t* const sq, uint16_t* const sum3[3],
uint32_t* const square_sum3[3], uint8x8_t* const ma, uint16x8_t* const b) {
uint16x8_t s3[3];
uint32x4x2_t sq3[3];
s->val[1] = vld1_u8(src + x + 8);
sq->val[1] = vmull_u8(s->val[1], s->val[1]);
s3[2] = Sum3Horizontal(*s);
sq3[2] = Sum3WHorizontal(*sq);
vst1q_u16(sum3[2] + x, s3[2]);
vst1q_u32(square_sum3[2] + x + 0, sq3[2].val[0]);
vst1q_u32(square_sum3[2] + x + 4, sq3[2].val[1]);
s3[0] = vld1q_u16(sum3[0] + x);
s3[1] = vld1q_u16(sum3[1] + x);
sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0);
sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4);
sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0);
sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4);
CalculateIntermediate3(s3, sq3, scale, ma, b);
}
LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess(
const uint8_t* const src, const ptrdiff_t src_stride, const ptrdiff_t x,
const uint16_t scales[2], uint8x8x2_t s[2], uint16x8x2_t sq[2],
uint16_t* const sum3[4], uint16_t* const sum5[5],
uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
uint8x8_t* const ma3_0, uint8x8_t* const ma3_1, uint16x8_t* const b3_0,
uint16x8_t* const b3_1, uint8x8_t* const ma5, uint16x8_t* const b5) {
uint16x8_t s3[4], s5[5];
uint32x4x2_t sq3[4], sq5[5];
s[0].val[1] = vld1_u8(src + x + 8);
s[1].val[1] = vld1_u8(src + src_stride + x + 8);
sq[0].val[1] = vmull_u8(s[0].val[1], s[0].val[1]);
sq[1].val[1] = vmull_u8(s[1].val[1], s[1].val[1]);
SumHorizontal(s[0], sq[0], &s3[2], &s5[3], &sq3[2], &sq5[3]);
SumHorizontal(s[1], sq[1], &s3[3], &s5[4], &sq3[3], &sq5[4]);
vst1q_u16(sum3[2] + x, s3[2]);
vst1q_u16(sum3[3] + x, s3[3]);
vst1q_u32(square_sum3[2] + x + 0, sq3[2].val[0]);
vst1q_u32(square_sum3[2] + x + 4, sq3[2].val[1]);
vst1q_u32(square_sum3[3] + x + 0, sq3[3].val[0]);
vst1q_u32(square_sum3[3] + x + 4, sq3[3].val[1]);
vst1q_u16(sum5[3] + x, s5[3]);
vst1q_u16(sum5[4] + x, s5[4]);
vst1q_u32(square_sum5[3] + x + 0, sq5[3].val[0]);
vst1q_u32(square_sum5[3] + x + 4, sq5[3].val[1]);
vst1q_u32(square_sum5[4] + x + 0, sq5[4].val[0]);
vst1q_u32(square_sum5[4] + x + 4, sq5[4].val[1]);
s3[0] = vld1q_u16(sum3[0] + x);
s3[1] = vld1q_u16(sum3[1] + x);
sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0);
sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4);
sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0);
sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4);
s5[0] = vld1q_u16(sum5[0] + x);
s5[1] = vld1q_u16(sum5[1] + x);
s5[2] = vld1q_u16(sum5[2] + x);
sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
CalculateIntermediate3(s3, sq3, scales[1], ma3_0, b3_0);
CalculateIntermediate3(s3 + 1, sq3 + 1, scales[1], ma3_1, b3_1);
CalculateIntermediate5(s5, sq5, scales[0], ma5, b5);
}
LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow(
const uint8_t* const src, const ptrdiff_t x, const uint16_t scales[2],
const uint16_t* const sum3[4], const uint16_t* const sum5[5],
const uint32_t* const square_sum3[4], const uint32_t* const square_sum5[5],
uint8x8x2_t* const s, uint16x8x2_t* const sq, uint8x8_t* const ma3,
uint8x8_t* const ma5, uint16x8_t* const b3, uint16x8_t* const b5) {
uint16x8_t s3[3], s5[5];
uint32x4x2_t sq3[3], sq5[5];
s->val[1] = vld1_u8(src + x + 8);
sq->val[1] = vmull_u8(s->val[1], s->val[1]);
SumHorizontal(*s, *sq, &s3[2], &s5[3], &sq3[2], &sq5[3]);
s5[0] = vld1q_u16(sum5[0] + x);
s5[1] = vld1q_u16(sum5[1] + x);
s5[2] = vld1q_u16(sum5[2] + x);
s5[4] = s5[3];
sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
sq5[4] = sq5[3];
CalculateIntermediate5(s5, sq5, scales[0], ma5, b5);
s3[0] = vld1q_u16(sum3[0] + x);
s3[1] = vld1q_u16(sum3[1] + x);
sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0);
sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4);
sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0);
sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4);
CalculateIntermediate3(s3, sq3, scales[1], ma3, b3);
}
inline void BoxSumFilterPreProcess5(const uint8_t* const src,
const ptrdiff_t src_stride, const int width,
const uint32_t scale,
uint16_t* const sum5[5],
uint32_t* const square_sum5[5],
uint16_t* ma565, uint32_t* b565) {
uint8x8x2_t s[2], mas;
uint16x8x2_t sq[2], bs;
s[0].val[0] = vld1_u8(src);
sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]);
s[1].val[0] = vld1_u8(src + src_stride);
sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]);
BoxFilterPreProcess5(src, src_stride, 0, scale, s, sq, sum5, square_sum5,
&mas.val[0], &bs.val[0]);
int x = 0;
do {
s[0].val[0] = s[0].val[1];
s[1].val[0] = s[1].val[1];
sq[0].val[0] = sq[0].val[1];
sq[1].val[0] = sq[1].val[1];
BoxFilterPreProcess5(src, src_stride, x + 8, scale, s, sq, sum5,
square_sum5, &mas.val[1], &bs.val[1]);
const uint16x8_t ma = Sum565(mas);
const uint32x4x2_t b = Sum565W(bs);
vst1q_u16(ma565, ma);
vst1q_u32(b565 + 0, b.val[0]);
vst1q_u32(b565 + 4, b.val[1]);
mas.val[0] = mas.val[1];
bs.val[0] = bs.val[1];
ma565 += 8;
b565 += 8;
x += 8;
} while (x < width);
}
template <bool calculate444>
LIBGAV1_ALWAYS_INLINE void BoxSumFilterPreProcess3(
const uint8_t* const src, const int width, const uint32_t scale,
uint16_t* const sum3[3], uint32_t* const square_sum3[3], uint16_t* ma343,
uint16_t* ma444, uint32_t* b343, uint32_t* b444) {
uint8x8x2_t s, mas;
uint16x8x2_t sq, bs;
s.val[0] = vld1_u8(src);
sq.val[0] = vmull_u8(s.val[0], s.val[0]);
BoxFilterPreProcess3(src, 0, scale, &s, &sq, sum3, square_sum3, &mas.val[0],
&bs.val[0]);
int x = 0;
do {
s.val[0] = s.val[1];
sq.val[0] = sq.val[1];
BoxFilterPreProcess3(src, x + 8, scale, &s, &sq, sum3, square_sum3,
&mas.val[1], &bs.val[1]);
if (calculate444) {
Store343_444(mas, bs, 0, ma343, ma444, b343, b444);
ma444 += 8;
b444 += 8;
} else {
const uint16x8_t ma = Sum343(mas);
const uint32x4x2_t b = Sum343W(bs);
vst1q_u16(ma343, ma);
vst1q_u32(b343 + 0, b.val[0]);
vst1q_u32(b343 + 4, b.val[1]);
}
mas.val[0] = mas.val[1];
bs.val[0] = bs.val[1];
ma343 += 8;
b343 += 8;
x += 8;
} while (x < width);
}
inline void BoxSumFilterPreProcess(
const uint8_t* const src, const ptrdiff_t src_stride, const int width,
const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5],
uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
uint16_t* const ma343[4], uint16_t* const ma444[2], uint16_t* ma565,
uint32_t* const b343[4], uint32_t* const b444[2], uint32_t* b565) {
uint8x8x2_t s[2];
uint8x8x2_t ma3[2], ma5;
uint16x8x2_t sq[2], b3[2], b5;
s[0].val[0] = vld1_u8(src + 0);
s[1].val[0] = vld1_u8(src + src_stride + 0);
sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]);
sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]);
BoxFilterPreProcess(src, src_stride, 0, scales, s, sq, sum3, sum5,
square_sum3, square_sum5, &ma3[0].val[0], &ma3[1].val[0],
&b3[0].val[0], &b3[1].val[0], &ma5.val[0], &b5.val[0]);
int x = 0;
do {
s[0].val[0] = s[0].val[1];
s[1].val[0] = s[1].val[1];
sq[0].val[0] = sq[0].val[1];
sq[1].val[0] = sq[1].val[1];
BoxFilterPreProcess(src, src_stride, x + 8, scales, s, sq, sum3, sum5,
square_sum3, square_sum5, &ma3[0].val[1],
&ma3[1].val[1], &b3[0].val[1], &b3[1].val[1],
&ma5.val[1], &b5.val[1]);
uint16x8_t ma = Sum343(ma3[0]);
uint32x4x2_t b = Sum343W(b3[0]);
vst1q_u16(ma343[0] + x, ma);
vst1q_u32(b343[0] + x, b.val[0]);
vst1q_u32(b343[0] + x + 4, b.val[1]);
Store343_444(ma3[1], b3[1], x, ma343[1], ma444[0], b343[1], b444[0]);
ma = Sum565(ma5);
b = Sum565W(b5);
vst1q_u16(ma565, ma);
vst1q_u32(b565 + 0, b.val[0]);
vst1q_u32(b565 + 4, b.val[1]);
ma3[0].val[0] = ma3[0].val[1];
ma3[1].val[0] = ma3[1].val[1];
b3[0].val[0] = b3[0].val[1];
b3[1].val[0] = b3[1].val[1];
ma5.val[0] = ma5.val[1];
b5.val[0] = b5.val[1];
ma565 += 8;
b565 += 8;
x += 8;
} while (x < width);
}
inline void BoxFilterPass1(const uint8_t* const src0, const uint8_t* const src,
const ptrdiff_t src_stride, uint16_t* const sum5[5],
uint32_t* const square_sum5[5], const int width,
const uint32_t scale, const int16_t w0,
uint16_t* const ma565[2], uint32_t* const b565[2],
uint8_t* const dst, const ptrdiff_t dst_stride) {
uint8x8x2_t s[2], mas;
uint16x8x2_t sq[2], bs;
s[0].val[0] = vld1_u8(src);
s[1].val[0] = vld1_u8(src + src_stride);
sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]);
sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]);
BoxFilterPreProcess5(src, src_stride, 0, scale, s, sq, sum5, square_sum5,
&mas.val[0], &bs.val[0]);
int x = 0;
do {
s[0].val[0] = s[0].val[1];
s[1].val[0] = s[1].val[1];
sq[0].val[0] = sq[0].val[1];
sq[1].val[0] = sq[1].val[1];
BoxFilterPreProcess5(src, src_stride, x + 8, scale, s, sq, sum5,
square_sum5, &mas.val[1], &bs.val[1]);
uint16x8_t ma[2];
uint32x4x2_t b[2];
ma[1] = Sum565(mas);
b[1] = Sum565W(bs);
vst1q_u16(ma565[1] + x, ma[1]);
vst1q_u32(b565[1] + x + 0, b[1].val[0]);
vst1q_u32(b565[1] + x + 4, b[1].val[1]);
const uint8x8_t s0 = vld1_u8(src0 + x);
const uint8x8_t s1 = vld1_u8(src0 + src_stride + x);
int16x8_t p0, p1;
ma[0] = vld1q_u16(ma565[0] + x);
b[0].val[0] = vld1q_u32(b565[0] + x + 0);
b[0].val[1] = vld1q_u32(b565[0] + x + 4);
p0 = CalculateFilteredOutputPass1(s0, ma, b);
p1 = CalculateFilteredOutput<4>(s1, ma[1], b[1]);
SelfGuidedSingleMultiplier(s0, p0, w0, dst + x);
SelfGuidedSingleMultiplier(s1, p1, w0, dst + dst_stride + x);
mas.val[0] = mas.val[1];
bs.val[0] = bs.val[1];
x += 8;
} while (x < width);
}
inline void BoxFilterPass1LastRow(const uint8_t* const src0,
const uint8_t* const src, const int width,
const uint32_t scale, const int16_t w0,
uint16_t* const sum5[5],
uint32_t* const square_sum5[5],
uint16_t* ma565, uint32_t* b565,
uint8_t* const dst) {
uint8x8x2_t s, mas;
uint16x8x2_t sq, bs;
s.val[0] = vld1_u8(src);
sq.val[0] = vmull_u8(s.val[0], s.val[0]);
BoxFilterPreProcess5LastRow(src, 0, scale, &s, &sq, sum5, square_sum5,
&mas.val[0], &bs.val[0]);
int x = 0;
do {
s.val[0] = s.val[1];
sq.val[0] = sq.val[1];
BoxFilterPreProcess5LastRow(src, x + 8, scale, &s, &sq, sum5, square_sum5,
&mas.val[1], &bs.val[1]);
uint16x8_t ma[2];
uint32x4x2_t b[2];
ma[1] = Sum565(mas);
b[1] = Sum565W(bs);
mas.val[0] = mas.val[1];
bs.val[0] = bs.val[1];
ma[0] = vld1q_u16(ma565);
b[0].val[0] = vld1q_u32(b565 + 0);
b[0].val[1] = vld1q_u32(b565 + 4);
const uint8x8_t s = vld1_u8(src0 + x);
const int16x8_t p = CalculateFilteredOutputPass1(s, ma, b);
SelfGuidedSingleMultiplier(s, p, w0, dst + x);
ma565 += 8;
b565 += 8;
x += 8;
} while (x < width);
}
inline void BoxFilterPass2(const uint8_t* const src0, const uint8_t* const src,
const int width, const uint32_t scale,
const int16_t w0, uint16_t* const sum3[3],
uint32_t* const square_sum3[3],
uint16_t* const ma343[3], uint16_t* const ma444[2],
uint32_t* const b343[3], uint32_t* const b444[2],
uint8_t* const dst) {
uint8x8x2_t s, mas;
uint16x8x2_t sq, bs;
s.val[0] = vld1_u8(src);
sq.val[0] = vmull_u8(s.val[0], s.val[0]);
BoxFilterPreProcess3(src, 0, scale, &s, &sq, sum3, square_sum3, &mas.val[0],
&bs.val[0]);
int x = 0;
do {
s.val[0] = s.val[1];
sq.val[0] = sq.val[1];
BoxFilterPreProcess3(src, x + 8, scale, &s, &sq, sum3, square_sum3,
&mas.val[1], &bs.val[1]);
uint16x8_t ma[3];
uint32x4x2_t b[3];
Store343_444(mas, bs, x, &ma[2], &b[2], ma343[2], ma444[1], b343[2],
b444[1]);
const uint8x8_t s0 = vld1_u8(src0 + x);
ma[0] = vld1q_u16(ma343[0] + x);
ma[1] = vld1q_u16(ma444[0] + x);
b[0].val[0] = vld1q_u32(b343[0] + x + 0);
b[0].val[1] = vld1q_u32(b343[0] + x + 4);
b[1].val[0] = vld1q_u32(b444[0] + x + 0);
b[1].val[1] = vld1q_u32(b444[0] + x + 4);
const int16x8_t p = CalculateFilteredOutputPass2(s0, ma, b);
SelfGuidedSingleMultiplier(s0, p, w0, dst + x);
mas.val[0] = mas.val[1];
bs.val[0] = bs.val[1];
x += 8;
} while (x < width);
}
inline void BoxFilter(const uint8_t* const src0, const uint8_t* const src,
const ptrdiff_t src_stride, const int width,
const uint16_t scales[2], const int16_t w0,
const int16_t w2, uint16_t* const sum3[4],
uint16_t* const sum5[5], uint32_t* const square_sum3[4],
uint32_t* const square_sum5[5], uint16_t* const ma343[4],
uint16_t* const ma444[3], uint16_t* const ma565[2],
uint32_t* const b343[4], uint32_t* const b444[3],
uint32_t* const b565[2], uint8_t* const dst,
const ptrdiff_t dst_stride) {
uint8x8x2_t s[2], ma3[2], ma5;
uint16x8x2_t sq[2], b3[2], b5;
s[0].val[0] = vld1_u8(src);
s[1].val[0] = vld1_u8(src + src_stride);
sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]);
sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]);
BoxFilterPreProcess(src, src_stride, 0, scales, s, sq, sum3, sum5,
square_sum3, square_sum5, &ma3[0].val[0], &ma3[1].val[0],
&b3[0].val[0], &b3[1].val[0], &ma5.val[0], &b5.val[0]);
int x = 0;
do {
s[0].val[0] = s[0].val[1];
s[1].val[0] = s[1].val[1];
sq[0].val[0] = sq[0].val[1];
sq[1].val[0] = sq[1].val[1];
BoxFilterPreProcess(src, src_stride, x + 8, scales, s, sq, sum3, sum5,
square_sum3, square_sum5, &ma3[0].val[1],
&ma3[1].val[1], &b3[0].val[1], &b3[1].val[1],
&ma5.val[1], &b5.val[1]);
uint16x8_t ma[3][3];
uint32x4x2_t b[3][3];
Store343_444(ma3[0], b3[0], x, &ma[1][2], &ma[2][1], &b[1][2], &b[2][1],
ma343[2], ma444[1], b343[2], b444[1]);
Store343_444(ma3[1], b3[1], x, &ma[2][2], &b[2][2], ma343[3], ma444[2],
b343[3], b444[2]);
ma[0][1] = Sum565(ma5);
b[0][1] = Sum565W(b5);
vst1q_u16(ma565[1] + x, ma[0][1]);
vst1q_u32(b565[1] + x, b[0][1].val[0]);
vst1q_u32(b565[1] + x + 4, b[0][1].val[1]);
s[0].val[0] = s[0].val[1];
s[1].val[0] = s[1].val[1];
sq[0].val[0] = sq[0].val[1];
sq[1].val[0] = sq[1].val[1];
ma3[0].val[0] = ma3[0].val[1];
ma3[1].val[0] = ma3[1].val[1];
b3[0].val[0] = b3[0].val[1];
b3[1].val[0] = b3[1].val[1];
ma5.val[0] = ma5.val[1];
b5.val[0] = b5.val[1];
int16x8_t p[2][2];
const uint8x8_t s0 = vld1_u8(src0 + x);
const uint8x8_t s1 = vld1_u8(src0 + src_stride + x);
ma[0][0] = vld1q_u16(ma565[0] + x);
b[0][0].val[0] = vld1q_u32(b565[0] + x);
b[0][0].val[1] = vld1q_u32(b565[0] + x + 4);
p[0][0] = CalculateFilteredOutputPass1(s0, ma[0], b[0]);
p[1][0] = CalculateFilteredOutput<4>(s1, ma[0][1], b[0][1]);
ma[1][0] = vld1q_u16(ma343[0] + x);
ma[1][1] = vld1q_u16(ma444[0] + x);
b[1][0].val[0] = vld1q_u32(b343[0] + x);
b[1][0].val[1] = vld1q_u32(b343[0] + x + 4);
b[1][1].val[0] = vld1q_u32(b444[0] + x);
b[1][1].val[1] = vld1q_u32(b444[0] + x + 4);
p[0][1] = CalculateFilteredOutputPass2(s0, ma[1], b[1]);
ma[2][0] = vld1q_u16(ma343[1] + x);
b[2][0].val[0] = vld1q_u32(b343[1] + x);
b[2][0].val[1] = vld1q_u32(b343[1] + x + 4);
p[1][1] = CalculateFilteredOutputPass2(s1, ma[2], b[2]);
SelfGuidedDoubleMultiplier(s0, p[0], w0, w2, dst + x);
SelfGuidedDoubleMultiplier(s1, p[1], w0, w2, dst + dst_stride + x);
x += 8;
} while (x < width);
}
inline void BoxFilterLastRow(
const uint8_t* const src0, const uint8_t* const src, const int width,
const uint16_t scales[2], const int16_t w0, const int16_t w2,
uint16_t* const sum3[4], uint16_t* const sum5[5],
uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
uint16_t* const ma343[4], uint16_t* const ma444[3],
uint16_t* const ma565[2], uint32_t* const b343[4], uint32_t* const b444[3],
uint32_t* const b565[2], uint8_t* const dst) {
uint8x8x2_t s, ma3, ma5;
uint16x8x2_t sq, b3, b5;
uint16x8_t ma[3];
uint32x4x2_t b[3];
s.val[0] = vld1_u8(src);
sq.val[0] = vmull_u8(s.val[0], s.val[0]);
BoxFilterPreProcessLastRow(src, 0, scales, sum3, sum5, square_sum3,
square_sum5, &s, &sq, &ma3.val[0], &ma5.val[0],
&b3.val[0], &b5.val[0]);
int x = 0;
do {
s.val[0] = s.val[1];
sq.val[0] = sq.val[1];
BoxFilterPreProcessLastRow(src, x + 8, scales, sum3, sum5, square_sum3,
square_sum5, &s, &sq, &ma3.val[1], &ma5.val[1],
&b3.val[1], &b5.val[1]);
ma[1] = Sum565(ma5);
b[1] = Sum565W(b5);
ma5.val[0] = ma5.val[1];
b5.val[0] = b5.val[1];
ma[2] = Sum343(ma3);
b[2] = Sum343W(b3);
ma3.val[0] = ma3.val[1];
b3.val[0] = b3.val[1];
const uint8x8_t s0 = vld1_u8(src0 + x);
int16x8_t p[2];
ma[0] = vld1q_u16(ma565[0] + x);
b[0].val[0] = vld1q_u32(b565[0] + x + 0);
b[0].val[1] = vld1q_u32(b565[0] + x + 4);
p[0] = CalculateFilteredOutputPass1(s0, ma, b);
ma[0] = vld1q_u16(ma343[0] + x);
ma[1] = vld1q_u16(ma444[0] + x);
b[0].val[0] = vld1q_u32(b343[0] + x + 0);
b[0].val[1] = vld1q_u32(b343[0] + x + 4);
b[1].val[0] = vld1q_u32(b444[0] + x + 0);
b[1].val[1] = vld1q_u32(b444[0] + x + 4);
p[1] = CalculateFilteredOutputPass2(s0, ma, b);
SelfGuidedDoubleMultiplier(s0, p, w0, w2, dst + x);
x += 8;
} while (x < width);
}
template <typename T>
void Circulate3PointersBy1(T* p[3]) {
T* const p0 = p[0];
p[0] = p[1];
p[1] = p[2];
p[2] = p0;
}
template <typename T>
void Circulate4PointersBy2(T* p[4]) {
std::swap(p[0], p[2]);
std::swap(p[1], p[3]);
}
template <typename T>
void Circulate5PointersBy2(T* p[5]) {
T* const p0 = p[0];
T* const p1 = p[1];
p[0] = p[2];
p[1] = p[3];
p[2] = p[4];
p[3] = p0;
p[4] = p1;
}
inline void BoxFilterProcess(const RestorationUnitInfo& restoration_info,
const uint8_t* src, const ptrdiff_t src_stride,
const int width, const int height,
SgrBuffer* const sgr_buffer, uint8_t* dst,
const ptrdiff_t dst_stride) {
const auto temp_stride = Align<ptrdiff_t>(width, 8);
const ptrdiff_t sum_stride = temp_stride + 8;
const int sgr_proj_index = restoration_info.sgr_proj_info.index;
const uint16_t* const scales = kSgrScaleParameter[sgr_proj_index]; // < 2^12.
const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
const int16_t w2 = (1 << kSgrProjPrecisionBits) - w0 - w1;
uint16_t *sum3[4], *sum5[5], *ma343[4], *ma444[3], *ma565[2];
uint32_t *square_sum3[4], *square_sum5[5], *b343[4], *b444[3], *b565[2];
sum3[0] = sgr_buffer->sum3;
square_sum3[0] = sgr_buffer->square_sum3;
ma343[0] = sgr_buffer->ma343;
b343[0] = sgr_buffer->b343;
for (int i = 1; i <= 3; ++i) {
sum3[i] = sum3[i - 1] + sum_stride;
square_sum3[i] = square_sum3[i - 1] + sum_stride;
ma343[i] = ma343[i - 1] + temp_stride;
b343[i] = b343[i - 1] + temp_stride;
}
sum5[0] = sgr_buffer->sum5;
square_sum5[0] = sgr_buffer->square_sum5;
for (int i = 1; i <= 4; ++i) {
sum5[i] = sum5[i - 1] + sum_stride;
square_sum5[i] = square_sum5[i - 1] + sum_stride;
}
ma444[0] = sgr_buffer->ma444;
b444[0] = sgr_buffer->b444;
for (int i = 1; i <= 2; ++i) {
ma444[i] = ma444[i - 1] + temp_stride;
b444[i] = b444[i - 1] + temp_stride;
}
ma565[0] = sgr_buffer->ma565;
ma565[1] = ma565[0] + temp_stride;
b565[0] = sgr_buffer->b565;
b565[1] = b565[0] + temp_stride;
assert(scales[0] != 0);
assert(scales[1] != 0);
BoxSum(src - 2 * src_stride - 3, src_stride, 2, sum_stride, sum3[0], sum5[1],
square_sum3[0], square_sum5[1]);
sum5[0] = sum5[1];
square_sum5[0] = square_sum5[1];
BoxSumFilterPreProcess(src - 3, src_stride, width, scales, sum3, sum5,
square_sum3, square_sum5, ma343, ma444, ma565[0], b343,
b444, b565[0]);
sum5[0] = sgr_buffer->sum5;
square_sum5[0] = sgr_buffer->square_sum5;
for (int y = height >> 1; y != 0; --y) {
Circulate4PointersBy2<uint16_t>(sum3);
Circulate4PointersBy2<uint32_t>(square_sum3);
Circulate5PointersBy2<uint16_t>(sum5);
Circulate5PointersBy2<uint32_t>(square_sum5);
BoxFilter(src, src + 2 * src_stride - 3, src_stride, width, scales, w0, w2,
sum3, sum5, square_sum3, square_sum5, ma343, ma444, ma565, b343,
b444, b565, dst, dst_stride);
src += 2 * src_stride;
dst += 2 * dst_stride;
Circulate4PointersBy2<uint16_t>(ma343);
Circulate4PointersBy2<uint32_t>(b343);
std::swap(ma444[0], ma444[2]);
std::swap(b444[0], b444[2]);
std::swap(ma565[0], ma565[1]);
std::swap(b565[0], b565[1]);
}
if ((height & 1) != 0) {
Circulate4PointersBy2<uint16_t>(sum3);
Circulate4PointersBy2<uint32_t>(square_sum3);
Circulate5PointersBy2<uint16_t>(sum5);
Circulate5PointersBy2<uint32_t>(square_sum5);
BoxFilterLastRow(src, src + 2 * src_stride - 3, width, scales, w0, w2, sum3,
sum5, square_sum3, square_sum5, ma343, ma444, ma565, b343,
b444, b565, dst);
}
}
inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info,
const uint8_t* src,
const ptrdiff_t src_stride, const int width,
const int height, SgrBuffer* const sgr_buffer,
uint8_t* dst, const ptrdiff_t dst_stride) {
const auto temp_stride = Align<ptrdiff_t>(width, 8);
const ptrdiff_t sum_stride = temp_stride + 8;
const int sgr_proj_index = restoration_info.sgr_proj_info.index;
const uint32_t scale = kSgrScaleParameter[sgr_proj_index][0]; // < 2^12.
const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
uint16_t *sum5[5], *ma565[2];
uint32_t *square_sum5[5], *b565[2];
sum5[0] = sgr_buffer->sum5;
square_sum5[0] = sgr_buffer->square_sum5;
for (int i = 1; i <= 4; ++i) {
sum5[i] = sum5[i - 1] + sum_stride;
square_sum5[i] = square_sum5[i - 1] + sum_stride;
}
ma565[0] = sgr_buffer->ma565;
ma565[1] = ma565[0] + temp_stride;
b565[0] = sgr_buffer->b565;
b565[1] = b565[0] + temp_stride;
assert(scale != 0);
BoxSum<5>(src - 2 * src_stride - 3, src_stride, 2, sum_stride, sum5[1],
square_sum5[1]);
sum5[0] = sum5[1];
square_sum5[0] = square_sum5[1];
BoxSumFilterPreProcess5(src - 3, src_stride, width, scale, sum5, square_sum5,
ma565[0], b565[0]);
sum5[0] = sgr_buffer->sum5;
square_sum5[0] = sgr_buffer->square_sum5;
for (int y = height >> 1; y != 0; --y) {
Circulate5PointersBy2<uint16_t>(sum5);
Circulate5PointersBy2<uint32_t>(square_sum5);
BoxFilterPass1(src, src + 2 * src_stride - 3, src_stride, sum5, square_sum5,
width, scale, w0, ma565, b565, dst, dst_stride);
src += 2 * src_stride;
dst += 2 * dst_stride;
std::swap(ma565[0], ma565[1]);
std::swap(b565[0], b565[1]);
}
if ((height & 1) != 0) {
Circulate5PointersBy2<uint16_t>(sum5);
Circulate5PointersBy2<uint32_t>(square_sum5);
BoxFilterPass1LastRow(src, src + 2 * src_stride - 3, width, scale, w0, sum5,
square_sum5, ma565[0], b565[0], dst);
}
}
inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info,
const uint8_t* src,
const ptrdiff_t src_stride, const int width,
const int height, SgrBuffer* const sgr_buffer,
uint8_t* dst, const ptrdiff_t dst_stride) {
assert(restoration_info.sgr_proj_info.multiplier[0] == 0);
const auto temp_stride = Align<ptrdiff_t>(width, 8);
const ptrdiff_t sum_stride = temp_stride + 8;
const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
const int16_t w0 = (1 << kSgrProjPrecisionBits) - w1;
const int sgr_proj_index = restoration_info.sgr_proj_info.index;
const uint32_t scale = kSgrScaleParameter[sgr_proj_index][1]; // < 2^12.
uint16_t *sum3[3], *ma343[3], *ma444[2];
uint32_t *square_sum3[3], *b343[3], *b444[2];
sum3[0] = sgr_buffer->sum3;
square_sum3[0] = sgr_buffer->square_sum3;
ma343[0] = sgr_buffer->ma343;
b343[0] = sgr_buffer->b343;
for (int i = 1; i <= 2; ++i) {
sum3[i] = sum3[i - 1] + sum_stride;
square_sum3[i] = square_sum3[i - 1] + sum_stride;
ma343[i] = ma343[i - 1] + temp_stride;
b343[i] = b343[i - 1] + temp_stride;
}
ma444[0] = sgr_buffer->ma444;
ma444[1] = ma444[0] + temp_stride;
b444[0] = sgr_buffer->b444;
b444[1] = b444[0] + temp_stride;
assert(scale != 0);
BoxSum<3>(src - 2 * src_stride - 2, src_stride, 2, sum_stride, sum3[0],
square_sum3[0]);
BoxSumFilterPreProcess3<false>(src - 2, width, scale, sum3, square_sum3,
ma343[0], nullptr, b343[0], nullptr);
Circulate3PointersBy1<uint16_t>(sum3);
Circulate3PointersBy1<uint32_t>(square_sum3);
BoxSumFilterPreProcess3<true>(src + src_stride - 2, width, scale, sum3,
square_sum3, ma343[1], ma444[0], b343[1],
b444[0]);
int y = height;
do {
Circulate3PointersBy1<uint16_t>(sum3);
Circulate3PointersBy1<uint32_t>(square_sum3);
BoxFilterPass2(src, src + 2 * src_stride - 2, width, scale, w0, sum3,
square_sum3, ma343, ma444, b343, b444, dst);
src += src_stride;
dst += dst_stride;
Circulate3PointersBy1<uint16_t>(ma343);
Circulate3PointersBy1<uint32_t>(b343);
std::swap(ma444[0], ma444[1]);
std::swap(b444[0], b444[1]);
} while (--y != 0);
}
// If |width| is non-multiple of 8, up to 7 more pixels are written to |dest| in
// the end of each row. It is safe to overwrite the output as it will not be
// part of the visible frame.
void SelfGuidedFilter_NEON(const void* const source, void* const dest,
const RestorationUnitInfo& restoration_info,
const ptrdiff_t source_stride,
const ptrdiff_t dest_stride, const int width,
const int height,
RestorationBuffer* const restoration_buffer) {
const int index = restoration_info.sgr_proj_info.index;
const int radius_pass_0 = kSgrProjParams[index][0]; // 2 or 0
const int radius_pass_1 = kSgrProjParams[index][2]; // 1 or 0
const auto* const src = static_cast<const uint8_t*>(source);
auto* const dst = static_cast<uint8_t*>(dest);
SgrBuffer* const sgr_buffer = &restoration_buffer->sgr_buffer;
if (radius_pass_1 == 0) {
// |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the
// following assertion.
assert(radius_pass_0 != 0);
BoxFilterProcessPass1(restoration_info, src, source_stride, width, height,
sgr_buffer, dst, dest_stride);
} else if (radius_pass_0 == 0) {
BoxFilterProcessPass2(restoration_info, src, source_stride, width, height,
sgr_buffer, dst, dest_stride);
} else {
BoxFilterProcess(restoration_info, src, source_stride, width, height,
sgr_buffer, dst, dest_stride);
}
}
void Init8bpp() {
Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
assert(dsp != nullptr);
dsp->loop_restorations[0] = WienerFilter_NEON;
dsp->loop_restorations[1] = SelfGuidedFilter_NEON;
}
} // namespace
} // namespace low_bitdepth
void LoopRestorationInit_NEON() { low_bitdepth::Init8bpp(); }
} // namespace dsp
} // namespace libgav1
#else // !LIBGAV1_ENABLE_NEON
namespace libgav1 {
namespace dsp {
void LoopRestorationInit_NEON() {}
} // namespace dsp
} // namespace libgav1
#endif // LIBGAV1_ENABLE_NEON