blob: 445fc7ad443c27794748e555718dcd2c184a1449 [file] [log] [blame]
#define GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
#include "test.h"
#include "../internal/fixedpoint.h"
using namespace gemmlowp;
template <int tIntegerBits>
void test_convert(FixedPoint<int32_t, tIntegerBits> x) {
typedef FixedPoint<int32_t, tIntegerBits> F;
F y = ToFixedPoint<int32_t, tIntegerBits>(ToDouble(x));
Check(y == x);
}
template <int tIntegerBits_a, int tIntegerBits_b>
void test_Rescale(FixedPoint<int32_t, tIntegerBits_a> a) {
FixedPoint<int32_t, tIntegerBits_b> actual = Rescale<tIntegerBits_b>(a);
FixedPoint<int32_t, tIntegerBits_b> expected =
ToFixedPoint<int32_t, tIntegerBits_b>(ToDouble(a));
Check(actual == expected);
}
template <int tIntegerBits_a, int tIntegerBits_b>
void test_Rescale(const std::vector<int32_t>& testvals_int32) {
for (auto a : testvals_int32) {
FixedPoint<int32_t, tIntegerBits_a> aq;
aq.raw() = a;
test_Rescale<tIntegerBits_a, tIntegerBits_b>(aq);
}
}
template <int tIntegerBits_a, int tIntegerBits_b>
void test_mul(FixedPoint<int32_t, tIntegerBits_a> a,
FixedPoint<int32_t, tIntegerBits_b> b) {
static const int IntegerBits_ab = tIntegerBits_a + tIntegerBits_b;
FixedPoint<int32_t, IntegerBits_ab> ab;
ab = a * b;
double a_double = ToDouble(a);
double b_double = ToDouble(b);
double ab_double = a_double * b_double;
FixedPoint<int32_t, IntegerBits_ab> expected =
ToFixedPoint<int32_t, IntegerBits_ab>(ab_double);
int64_t diff = int64_t(ab.raw()) - int64_t(expected.raw());
Check(std::abs(diff) <= 1);
}
template <int tIntegerBits_a, int tIntegerBits_b>
void test_mul(const std::vector<int32_t>& testvals_int32) {
for (auto a : testvals_int32) {
for (auto b : testvals_int32) {
FixedPoint<int32_t, tIntegerBits_a> aq;
FixedPoint<int32_t, tIntegerBits_b> bq;
aq.raw() = a;
bq.raw() = b;
test_mul(aq, bq);
}
}
}
template <int tExponent, int tIntegerBits_a>
void test_ExactMulByPot(FixedPoint<int32_t, tIntegerBits_a> a) {
double x = ToDouble(a) * std::pow(2.0, tExponent);
double y = ToDouble(ExactMulByPot<tExponent>(a));
Check(x == y);
}
template <int tExponent, int tIntegerBits_a>
void test_ExactMulByPot(const std::vector<int32_t>& testvals_int32) {
for (auto a : testvals_int32) {
FixedPoint<int32_t, tIntegerBits_a> aq;
aq.raw() = a;
test_ExactMulByPot<tExponent, tIntegerBits_a>(aq);
}
}
void test_exp_on_interval_between_negative_one_quarter_and_0_excl(
FixedPoint<int32_t, 0> a) {
double a_double = ToDouble(a);
double expected = std::exp(a_double);
double actual =
ToDouble(exp_on_interval_between_negative_one_quarter_and_0_excl(a));
double error = expected - actual;
Check(std::abs(error) < 3e-7);
}
void test_exp_on_interval_between_negative_one_quarter_and_0_excl(
const std::vector<int32_t>& testvals_int32) {
for (auto a : testvals_int32) {
typedef FixedPoint<int32_t, 0> F;
F aq = SaturatingRoundingMultiplyByPOT<-3>(F::FromRaw(a)) -
F::ConstantPOT<-3>();
test_exp_on_interval_between_negative_one_quarter_and_0_excl(aq);
}
}
template <int tIntegerBits>
void test_exp_on_negative_values(FixedPoint<int32_t, tIntegerBits> a) {
double a_double = ToDouble(a);
double expected = std::exp(a_double);
double actual = ToDouble(exp_on_negative_values(a));
double error = expected - actual;
Check(std::abs(error) < 3e-7);
}
template <int tIntegerBits>
void test_exp_on_negative_values(const std::vector<int32_t>& testvals_int32) {
for (auto a : testvals_int32) {
if (a < 0) {
FixedPoint<int32_t, tIntegerBits> aq;
aq.raw() = a;
test_exp_on_negative_values(aq);
}
}
}
void test_one_minus_x_over_one_plus_x_for_x_in_0_1(FixedPoint<int32_t, 0> a) {
double a_double = ToDouble(a);
double expected = (1 - a_double) / (1 + a_double);
FixedPoint<int32_t, 0> retval = one_minus_x_over_one_plus_x_for_x_in_0_1(a);
double actual = ToDouble(retval);
double error = expected - actual;
Check(std::abs(error) < 6e-9);
}
void test_one_minus_x_over_one_plus_x_for_x_in_0_1(
const std::vector<int32_t>& testvals_int32) {
for (auto a : testvals_int32) {
if (a > 0) {
FixedPoint<int32_t, 0> aq;
aq.raw() = a;
test_one_minus_x_over_one_plus_x_for_x_in_0_1(aq);
}
}
}
template <int tIntegerBits>
void test_tanh(FixedPoint<int32_t, tIntegerBits> a) {
double a_double = ToDouble(a);
double expected = std::tanh(a_double);
double actual = ToDouble(tanh(a));
double error = expected - actual;
Check(std::abs(error) < 1.5e-7);
}
template <int tIntegerBits>
void test_tanh(const std::vector<int32_t>& testvals_int32) {
for (auto a : testvals_int32) {
FixedPoint<int32_t, tIntegerBits> aq;
aq.raw() = a;
test_tanh(aq);
}
}
#ifdef GEMMLOWP_NEON
void test_int32x4(const std::vector<int32_t>& testvals_int32) {
size_t n = testvals_int32.size();
size_t n4 = n - (n % 4);
std::vector<int32_t> results_int32(n4);
std::vector<int32_t> results_int32x4(n4);
for (size_t i = 0; i < n4; i++) {
results_int32[i] =
tanh(FixedPoint<int32_t, 4>::FromRaw(testvals_int32[i])).raw();
}
for (size_t i = 0; i < n4; i++) {
vst1q_s32(
&results_int32x4[i],
tanh(FixedPoint<int32x4_t, 4>::FromRaw(vld1q_s32(&testvals_int32[i])))
.raw());
}
for (size_t i = 0; i < n4; i++) {
Check(results_int32[i] == results_int32x4[i]);
}
}
#endif // GEMMLOWP_NEON
int main() {
std::vector<int32_t> testvals_int32;
for (int i = 0; i < 31; i++) {
testvals_int32.push_back((1 << i) - 2);
testvals_int32.push_back((1 << i) - 1);
testvals_int32.push_back((1 << i));
testvals_int32.push_back((1 << i) + 1);
testvals_int32.push_back((1 << i) + 2);
testvals_int32.push_back(-(1 << i) - 2);
testvals_int32.push_back(-(1 << i) - 1);
testvals_int32.push_back(-(1 << i));
testvals_int32.push_back(-(1 << i) + 1);
testvals_int32.push_back(-(1 << i) + 2);
}
testvals_int32.push_back(std::numeric_limits<int32_t>::min());
testvals_int32.push_back(std::numeric_limits<int32_t>::min() + 1);
testvals_int32.push_back(std::numeric_limits<int32_t>::min() + 2);
testvals_int32.push_back(std::numeric_limits<int32_t>::max() - 2);
testvals_int32.push_back(std::numeric_limits<int32_t>::max() - 1);
testvals_int32.push_back(std::numeric_limits<int32_t>::max());
uint32_t random = 1;
for (int i = 0; i < 1000; i++) {
random = random * 1664525 + 1013904223;
testvals_int32.push_back(static_cast<int32_t>(random));
}
std::sort(testvals_int32.begin(), testvals_int32.end());
for (auto a : testvals_int32) {
FixedPoint<int32_t, 4> x;
x.raw() = a;
test_convert(x);
}
test_mul<0, 0>(testvals_int32);
test_mul<0, 1>(testvals_int32);
test_mul<2, 0>(testvals_int32);
test_mul<1, 1>(testvals_int32);
test_mul<4, 4>(testvals_int32);
test_mul<3, 5>(testvals_int32);
test_mul<7, 2>(testvals_int32);
test_mul<14, 15>(testvals_int32);
test_Rescale<0, 0>(testvals_int32);
test_Rescale<0, 1>(testvals_int32);
test_Rescale<2, 0>(testvals_int32);
test_Rescale<4, 4>(testvals_int32);
test_Rescale<4, 5>(testvals_int32);
test_Rescale<6, 3>(testvals_int32);
test_Rescale<13, 9>(testvals_int32);
test_ExactMulByPot<0, 0>(testvals_int32);
test_ExactMulByPot<0, 4>(testvals_int32);
test_ExactMulByPot<1, 4>(testvals_int32);
test_ExactMulByPot<3, 2>(testvals_int32);
test_ExactMulByPot<-4, 5>(testvals_int32);
test_ExactMulByPot<-2, 6>(testvals_int32);
test_exp_on_interval_between_negative_one_quarter_and_0_excl(testvals_int32);
test_exp_on_negative_values<1>(testvals_int32);
test_exp_on_negative_values<2>(testvals_int32);
test_exp_on_negative_values<3>(testvals_int32);
test_exp_on_negative_values<4>(testvals_int32);
test_exp_on_negative_values<5>(testvals_int32);
test_exp_on_negative_values<6>(testvals_int32);
test_one_minus_x_over_one_plus_x_for_x_in_0_1(testvals_int32);
test_tanh<1>(testvals_int32);
test_tanh<2>(testvals_int32);
test_tanh<3>(testvals_int32);
test_tanh<4>(testvals_int32);
test_tanh<5>(testvals_int32);
test_tanh<6>(testvals_int32);
#ifdef GEMMLOWP_NEON
test_int32x4(testvals_int32);
#endif // GEMMLOWP_NEON
std::cerr << "All tests passed." << std::endl;
}