blob: 6fc7428f16b0fa6a68c8bf6f786f52a909b2c139 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/kernels/mlir_generated/base_binary_ops_test.h"
#include "tensorflow/core/kernels/mlir_generated/base_ops_test.h"
namespace tensorflow {
namespace {
// Test fixture `BinaryOpsTest` that sets the TF device is expected by the TEST
// macros below.
class BinaryOpsTest : public BinaryOpsTestBase {
protected:
void SetUp() override {
std::unique_ptr<tensorflow::Device> device_gpu(
tensorflow::DeviceFactory::NewDevice("GPU", {},
"/job:a/replica:0/task:0"));
SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu));
}
};
/// Test `tf.Add`.
template <typename T>
T baseline_add(T lhs, T rhs) {
return lhs + rhs;
}
GENERATE_DEFAULT_TESTS(Add, /*test_name=*/Half, Eigen::half, Eigen::half,
baseline_add)
GENERATE_DEFAULT_TESTS(Add, /*test_name=*/Float, float, float, baseline_add)
GENERATE_DEFAULT_TESTS(Add, /*test_name=*/Double, double, double, baseline_add)
GENERATE_DEFAULT_TESTS(Add, /*test_name=*/Int64, int64, int64, baseline_add)
/// Test `tf.AddV2`.
GENERATE_DEFAULT_TESTS(AddV2, /*test_name=*/Half, Eigen::half, Eigen::half,
baseline_add)
GENERATE_DEFAULT_TESTS(AddV2, /*test_name=*/Float, float, float, baseline_add)
GENERATE_DEFAULT_TESTS(AddV2, /*test_name=*/Double, double, double,
baseline_add)
GENERATE_DEFAULT_TESTS(AddV2, /*test_name=*/Int64, int64, int64, baseline_add)
/// Test `tf.Atan2`.
// Prevent the undefined case (0, 0) with non-zero rhs values.
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
Atan2,
/*test_name=*/FloatRhsNonZero, float, float, test::DefaultInput<float>(),
test::DefaultInputNonZero<float>(), std::atan2);
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
Atan2,
/*test_name=*/DoubleRhsNonZero, double, double,
test::DefaultInput<double>(), test::DefaultInputNonZero<double>(),
std::atan2);
// Prevent the undefined case (0, 0) with non-zero lhs values.
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
Atan2,
/*test_name=*/FloatLhsNonZero, float, float,
test::DefaultInputNonZero<float>(), test::DefaultInput<float>(),
std::atan2);
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
Atan2,
/*test_name=*/DoubleLhsNonZero, double, double,
test::DefaultInputNonZero<double>(), test::DefaultInput<double>(),
std::atan2);
// Test some particularly interesting cases.
TEST_F(BinaryOpsTest, Atan2FloatSpecialCases) {
TestEqualShapes<float, float, float, float>(
"Atan2", /*shape=*/{20},
test::InputAsVector<float>({1, 1, 1, 0, -1, -1, -1, 0}),
test::InputAsVector<float>({1, 0, -1, -1, -1, 0, 1, 1}), std::atan2,
test::OpsTestConfig().ExpectStrictlyEqual());
}
TEST_F(BinaryOpsTest, Atan2DoubleSpecialCases) {
TestEqualShapes<double, double, double, double>(
"Atan2", /*shape=*/{20},
test::InputAsVector<double>({1, 1, 1, 0, -1, -1, -1, 0}),
test::InputAsVector<double>({1, 0, -1, -1, -1, 0, 1, 1}), std::atan2,
test::OpsTestConfig().ExpectStrictlyEqual());
}
/// Test `tf.BitwiseAnd`.
template <typename T>
T baseline_bitwise_and(T lhs, T rhs) {
return lhs & rhs;
}
GENERATE_DEFAULT_TESTS(BitwiseAnd,
/*test_name=*/Int8, int8, int8, baseline_bitwise_and)
GENERATE_DEFAULT_TESTS(BitwiseAnd,
/*test_name=*/Int16, int16, int16, baseline_bitwise_and)
GENERATE_DEFAULT_TESTS(BitwiseAnd,
/*test_name=*/Int32, int32, int32, baseline_bitwise_and)
GENERATE_DEFAULT_TESTS(BitwiseAnd,
/*test_name=*/Int64, int64, int64, baseline_bitwise_and)
/// Test `tf.BitwiseOr`.
template <typename T>
T baseline_bitwise_or(T lhs, T rhs) {
return lhs | rhs;
}
GENERATE_DEFAULT_TESTS(BitwiseOr,
/*test_name=*/Int8, int8, int8, baseline_bitwise_or)
GENERATE_DEFAULT_TESTS(BitwiseOr,
/*test_name=*/Int16, int16, int16, baseline_bitwise_or)
GENERATE_DEFAULT_TESTS(BitwiseOr,
/*test_name=*/Int32, int32, int32, baseline_bitwise_or)
GENERATE_DEFAULT_TESTS(BitwiseOr,
/*test_name=*/Int64, int64, int64, baseline_bitwise_or)
/// Test `tf.BitwiseXor`.
template <typename T>
T baseline_bitwise_xor(T lhs, T rhs) {
return lhs ^ rhs;
}
GENERATE_DEFAULT_TESTS(BitwiseXor,
/*test_name=*/Int8, int8, int8, baseline_bitwise_xor)
GENERATE_DEFAULT_TESTS(BitwiseXor,
/*test_name=*/Int16, int16, int16, baseline_bitwise_xor)
GENERATE_DEFAULT_TESTS(BitwiseXor,
/*test_name=*/Int32, int32, int32, baseline_bitwise_xor)
GENERATE_DEFAULT_TESTS(BitwiseXor,
/*test_name=*/Int64, int64, int64, baseline_bitwise_xor)
/// Test `tf.Complex`.
template <typename T>
std::complex<T> baseline_complex(T lhs, T rhs) {
return std::complex<T>(lhs, rhs);
}
GENERATE_DEFAULT_TESTS_2(Complex,
/*test_name=*/C64, float, float, std::complex<float>,
std::complex<float>, test::DefaultInput<float>(),
test::DefaultInput<float>(), baseline_complex,
test::OpsTestConfig().ExpectStrictlyEqual().AddTout())
GENERATE_DEFAULT_TESTS_2(Complex,
/*test_name=*/C128, double, double,
std::complex<double>, std::complex<double>,
test::DefaultInput<double>(),
test::DefaultInput<double>(), baseline_complex,
test::OpsTestConfig().ExpectStrictlyEqual().AddTout())
/// Test `tf.Div`.
template <typename T>
T baseline_div(T lhs, T rhs) {
return lhs / rhs;
}
GENERATE_DEFAULT_TESTS(Div,
/*test_name=*/Half, Eigen::half, Eigen::half,
baseline_div);
GENERATE_DEFAULT_TESTS(Div,
/*test_name=*/Float, float, float, baseline_div);
GENERATE_DEFAULT_TESTS(Div,
/*test_name=*/Double, double, double, baseline_div);
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
Div,
/*test_name=*/Int16, int16, int16, test::DefaultInput<int16>(),
test::DefaultInputNonZero<int16>(), baseline_div);
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
Div,
/*test_name=*/Int64, int64, int64, test::DefaultInput<int64>(),
test::DefaultInputNonZero<int64>(), baseline_div);
/// Test `tf.Equal`.
template <typename T>
bool baseline_equal(T lhs, T rhs) {
return lhs == rhs;
}
GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Half, Eigen::half, bool,
baseline_equal)
GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Float, float, bool, baseline_equal)
GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Double, double, bool,
baseline_equal)
GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Bool, bool, bool, baseline_equal)
GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int8, int8, bool, baseline_equal)
GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int16, int16, bool, baseline_equal)
GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int64, int64, bool, baseline_equal)
/// Test `tf.FloorDiv`.
template <typename T>
T baseline_floor_div(T lhs, T rhs) {
return std::floor(lhs / rhs);
}
template <>
Eigen::half baseline_floor_div(Eigen::half lhs, Eigen::half rhs) {
return static_cast<Eigen::half>(std::floor(static_cast<float>(lhs / rhs)));
}
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
FloorDiv,
/*test_name=*/Half, Eigen::half, Eigen::half,
test::DefaultInput<Eigen::half>(), test::DefaultInputNonZero<Eigen::half>(),
baseline_floor_div);
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
FloorDiv,
/*test_name=*/Float, float, float, test::DefaultInput<float>(),
test::DefaultInputNonZero<float>(), baseline_floor_div);
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
FloorDiv,
/*test_name=*/Double, double, double, test::DefaultInput<double>(),
test::DefaultInputNonZero<double>(), baseline_floor_div);
/// Test `tf.Greater`.
template <typename T>
bool baseline_greater(T lhs, T rhs) {
return lhs > rhs;
}
GENERATE_DEFAULT_TESTS(Greater, /*test_name=*/Half, Eigen::half, bool,
baseline_greater)
GENERATE_DEFAULT_TESTS(Greater, /*test_name=*/Float, float, bool,
baseline_greater)
GENERATE_DEFAULT_TESTS(Greater, /*test_name=*/Double, double, bool,
baseline_greater)
GENERATE_DEFAULT_TESTS(Greater, /*test_name=*/Int8, int8, bool,
baseline_greater)
GENERATE_DEFAULT_TESTS(Greater, /*test_name=*/Int16, int16, bool,
baseline_greater)
GENERATE_DEFAULT_TESTS(Greater, /*test_name=*/Int64, int64, bool,
baseline_greater)
/// Test `tf.GreaterEqual`.
template <typename T>
bool baseline_greater_equal(T lhs, T rhs) {
return lhs >= rhs;
}
GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Half, Eigen::half, bool,
baseline_greater_equal)
GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Float, float, bool,
baseline_greater_equal)
GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Double, double, bool,
baseline_greater_equal)
GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Int8, int8, bool,
baseline_greater_equal)
GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Int16, int16, bool,
baseline_greater_equal)
GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Int64, int64, bool,
baseline_greater_equal)
/// Test `tf.LeftShift`.
template <typename T>
T baseline_left_shift(T lhs, T rhs) {
return lhs << rhs;
}
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
LeftShift, /*test_name=*/Int8, int8, int8, test::DefaultInput<int8>(),
test::DefaultInputLessThanBitwidth<int8>(), baseline_left_shift)
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
LeftShift, /*test_name=*/Int16, int16, int16, test::DefaultInput<int16>(),
test::DefaultInputLessThanBitwidth<int16>(), baseline_left_shift)
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
LeftShift, /*test_name=*/Int32, int32, int32, test::DefaultInput<int32>(),
test::DefaultInputLessThanBitwidth<int32>(), baseline_left_shift)
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
LeftShift, /*test_name=*/Int64, int64, int64, test::DefaultInput<int64>(),
test::DefaultInputLessThanBitwidth<int64>(), baseline_left_shift)
/// Test `tf.Less`.
template <typename T>
bool baseline_less(T lhs, T rhs) {
return lhs < rhs;
}
GENERATE_DEFAULT_TESTS(Less, /*test_name=*/Half, Eigen::half, bool,
baseline_less)
GENERATE_DEFAULT_TESTS(Less, /*test_name=*/Float, float, bool, baseline_less)
GENERATE_DEFAULT_TESTS(Less, /*test_name=*/Double, double, bool, baseline_less)
GENERATE_DEFAULT_TESTS(Less, /*test_name=*/Int8, int8, bool, baseline_less)
GENERATE_DEFAULT_TESTS(Less, /*test_name=*/Int16, int16, bool, baseline_less)
GENERATE_DEFAULT_TESTS(Less, /*test_name=*/Int64, int64, bool, baseline_less)
/// Test `tf.LessEqual`.
template <typename T>
bool baseline_less_equal(T lhs, T rhs) {
return lhs <= rhs;
}
GENERATE_DEFAULT_TESTS(LessEqual, /*test_name=*/Half, Eigen::half, bool,
baseline_less_equal)
GENERATE_DEFAULT_TESTS(LessEqual, /*test_name=*/Float, float, bool,
baseline_less_equal)
GENERATE_DEFAULT_TESTS(LessEqual, /*test_name=*/Double, double, bool,
baseline_less_equal)
GENERATE_DEFAULT_TESTS(LessEqual, /*test_name=*/Int8, int8, bool,
baseline_less_equal)
GENERATE_DEFAULT_TESTS(LessEqual, /*test_name=*/Int16, int16, bool,
baseline_less_equal)
GENERATE_DEFAULT_TESTS(LessEqual, /*test_name=*/Int64, int64, bool,
baseline_less_equal)
/// Test `tf.LogicalAnd`.
bool baseline_logical_and(bool lhs, bool rhs) { return lhs && rhs; }
GENERATE_DEFAULT_TESTS_2(LogicalAnd, /*test_name=*/Bool, /*T=*/bool,
/*BaselineT=*/bool, /*OutT=*/bool,
/*BaselineOutT=*/bool, test::DefaultInput<bool>(),
test::DefaultInput<bool>(), baseline_logical_and,
test::OpsTestConfig().ExpectStrictlyEqual().NoT())
/// Test `tf.LogicalOr`.
bool baseline_logical_or(bool lhs, bool rhs) { return lhs || rhs; }
GENERATE_DEFAULT_TESTS_2(LogicalOr, /*test_name=*/Bool, /*T=*/bool,
/*BaselineT=*/bool, /*OutT=*/bool,
/*BaselineOutT=*/bool, test::DefaultInput<bool>(),
test::DefaultInput<bool>(), baseline_logical_or,
test::OpsTestConfig().ExpectStrictlyEqual().NoT())
/// Test `tf.Maximum`.
template <typename T>
T baseline_maximum(T lhs, T rhs) {
if (std::isnan(lhs) || std::isnan(rhs)) {
return lhs + rhs;
}
return std::max(lhs, rhs);
}
GENERATE_DEFAULT_TESTS(Maximum, /*test_name=*/Half, Eigen::half, Eigen::half,
baseline_maximum)
GENERATE_DEFAULT_TESTS(Maximum, /*test_name=*/Float, float, float,
baseline_maximum)
GENERATE_DEFAULT_TESTS(Maximum, /*test_name=*/Double, double, double,
baseline_maximum)
GENERATE_DEFAULT_TESTS(Maximum, /*test_name=*/Int64, int64, int64,
baseline_maximum)
/// Test `tf.Minmum`.
template <typename T>
T baseline_minimum(T lhs, T rhs) {
if (std::isnan(lhs) || std::isnan(rhs)) {
return lhs + rhs;
}
return std::min(lhs, rhs);
}
GENERATE_DEFAULT_TESTS(Minimum, /*test_name=*/Half, Eigen::half, Eigen::half,
baseline_minimum)
GENERATE_DEFAULT_TESTS(Minimum, /*test_name=*/Float, float, float,
baseline_minimum)
GENERATE_DEFAULT_TESTS(Minimum, /*test_name=*/Double, double, double,
baseline_minimum)
GENERATE_DEFAULT_TESTS(Minimum, /*test_name=*/Int64, int64, int64,
baseline_minimum)
/// Test `tf.Mul`.
template <typename T>
T baseline_mul(T lhs, T rhs) {
return lhs * rhs;
}
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Half, Eigen::half, Eigen::half,
baseline_mul)
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Float, float, float, baseline_mul)
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Double, double, double, baseline_mul)
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Int8, int8, int8, baseline_mul)
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Int16, int16, int16, baseline_mul)
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Int64, int64, int64, baseline_mul)
/// Test `tf.NotEqual`.
template <typename T>
bool baseline_not_equal(T lhs, T rhs) {
return lhs != rhs;
}
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Half, Eigen::half, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Float, float, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Double, double, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Bool, bool, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int8, int8, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int16, int16, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int64, int64, bool,
baseline_not_equal)
/// Test `tf.Pow`.
template <typename T>
T baseline_pow(T lhs, T rhs) {
return std::pow(lhs, rhs);
}
template <typename T, std::enable_if_t<
llvm::is_one_of<T, Eigen::half, float, double>::value,
bool> = true>
absl::InlinedVector<T, 10> PowInput() {
return test::InputAsVector<T, double>({0.0, 0.1, 0.2, 0.3, 1.0, 2.0, 3.0});
}
template <typename T,
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
bool> = true>
absl::InlinedVector<T, 10> PowInput() {
return test::InputAsVector<T, double>({-2, -1, -1, 1, 1, 3});
}
template <>
Eigen::half baseline_pow(Eigen::half lhs, Eigen::half rhs) {
return static_cast<Eigen::half>(
std::pow(static_cast<float>(lhs), static_cast<float>(rhs)));
}
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(Pow,
/*test_name=*/Half,
Eigen::half, Eigen::half,
PowInput<Eigen::half>(),
PowInput<Eigen::half>(),
baseline_pow)
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(Pow,
/*test_name=*/Float, float,
float, PowInput<float>(),
PowInput<float>(),
baseline_pow)
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(Pow,
/*test_name=*/Double, double,
double, PowInput<double>(),
PowInput<double>(),
baseline_pow)
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(Pow,
/*test_name=*/Int64, int64,
int64, PowInput<int64>(),
PowInput<int64>(),
baseline_pow)
/// Test `tf.RealDiv`.
GENERATE_DEFAULT_TESTS(RealDiv,
/*test_name=*/Half, Eigen::half, Eigen::half,
baseline_div);
GENERATE_DEFAULT_TESTS(RealDiv,
/*test_name=*/Float, float, float, baseline_div);
GENERATE_DEFAULT_TESTS(RealDiv,
/*test_name=*/Double, double, double, baseline_div);
/// Test `tf.RightShift`.
template <typename T>
T baseline_right_shift(T lhs, T rhs) {
return lhs >> rhs;
}
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
RightShift,
/*test_name=*/Int8, int8, int8, test::DefaultInput<int8>(),
test::DefaultInputLessThanBitwidth<int8>(), baseline_right_shift)
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
RightShift,
/*test_name=*/Int16, int16, int16, test::DefaultInput<int16>(),
test::DefaultInputLessThanBitwidth<int16>(), baseline_right_shift)
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
RightShift,
/*test_name=*/Int32, int32, int32, test::DefaultInput<int32>(),
test::DefaultInputLessThanBitwidth<int32>(), baseline_right_shift)
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
RightShift,
/*test_name=*/Int64, int64, int64, test::DefaultInput<int64>(),
test::DefaultInputLessThanBitwidth<int64>(), baseline_right_shift)
/// Test `tf.SquaredDifference`.
template <typename T>
T baseline_squared_difference(T lhs, T rhs) {
return (lhs - rhs) * (lhs - rhs);
}
GENERATE_DEFAULT_TESTS(SquaredDifference, /*test_name=*/Half, Eigen::half,
Eigen::half, baseline_squared_difference)
GENERATE_DEFAULT_TESTS(SquaredDifference, /*test_name=*/Float, float, float,
baseline_squared_difference)
GENERATE_DEFAULT_TESTS(SquaredDifference, /*test_name=*/Double, double, double,
baseline_squared_difference)
GENERATE_DEFAULT_TESTS(SquaredDifference, /*test_name=*/Int64, int64, int64,
baseline_squared_difference)
/// Test `tf.Sub`.
template <typename T>
T baseline_sub(T lhs, T rhs) {
return lhs - rhs;
}
GENERATE_DEFAULT_TESTS(Sub,
/*test_name=*/Half, Eigen::half, Eigen::half,
baseline_sub)
GENERATE_DEFAULT_TESTS(Sub,
/*test_name=*/Float, float, float, baseline_sub)
GENERATE_DEFAULT_TESTS(Sub,
/*test_name=*/Double, double, double, baseline_sub)
GENERATE_DEFAULT_TESTS(Sub,
/*test_name=*/Int64, int64, int64, baseline_sub)
/// Test `tf.TruncateDiv`.
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
TruncateDiv,
/*test_name=*/Int16, int16, int16, test::DefaultInput<int16>(),
test::DefaultInputNonZero<int16>(), baseline_div);
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
TruncateDiv,
/*test_name=*/Int64, int64, int64, test::DefaultInput<int64>(),
test::DefaultInputNonZero<int64>(), baseline_div);
/// Test `tf.Zeta`.
// This test data was generated using the scipy implementation of zeta.
template <typename T>
static absl::InlinedVector<T, 10> GetZetaTestDataX() {
return test::InputAsVector<T, double>(
{1., 169.23969873, 105.93557562, 114.43259882, 179.62388639,
172.80836494, 127.82036549, 163.07586688, 157.31865127, 121.55091407,
132.49244284, 14.74785056, 61.69721805, 49.37079477, 32.73957728,
8.63833678, 5.77183618, 7.43098888, 9.68867483, 6.90594844,
1.10974422, 9.15604525, 5.39278873, 4.82471684, 3.61560063,
5.95540334});
}
template <typename T>
static absl::InlinedVector<T, 10> GetZetaTestDataQ() {
return test::InputAsVector<T, double>(
{0.23672766, 0.92926068, 0.33551547, 0.53241745, 0.39939397, 0.73085145,
0.91634121, 0.92935301, 0.90518735, 0.93155356, 0.31607971, 3.76257433,
3.41533379, 3.4542971, 8.07960302, 7.49355634, 0.26524244, 0.11061626,
0.26367137, 0.17993167, 0.17947252, 0.27949224, 0.20880047, 0.12189132,
0.18806052, 0.19976058});
}
template <typename T>
static absl::InlinedVector<T, 10> GetZetaTestExpected() {
return test::InputAsVector<T, double>(
{std::numeric_limits<double>::infinity(),
2.46825299e+05,
1.75353388e+50,
2.11671833e+31,
3.96105582e+71,
3.39991735e+23,
7.07718091e+04,
1.54510527e+05,
6.39506276e+06,
5.53116025e+03,
1.87572363e+66,
3.36459087e-09,
1.22647410e-33,
2.63484970e-27,
2.00525974e-30,
4.37777089e-08,
2.12174334e+03,
1.27459042e+07,
4.06567559e+05,
1.39376449e+05,
1.61538935e+01,
1.17236802e+05,
4.66207773e+03,
2.56999783e+04,
4.21203884e+02,
1.46472701e+04});
}
template <typename T>
T baseline_zeta(T x, T q) {
if (x == 1.) {
return std::numeric_limits<T>::infinity();
}
auto x_data = GetZetaTestDataX<T>();
auto pos = std::find(x_data.begin(), x_data.end(), x);
assert(pos != x_data.end());
auto index = std::distance(x_data.begin(), pos);
auto q_data = GetZetaTestDataQ<T>();
assert(q_data[index] == q);
auto expected = GetZetaTestExpected<T>();
return expected[index];
}
TEST_F(BinaryOpsTest, ZetaEqShapesFloat) {
TestEqualShapes<float, float, float, float>(
"Zeta", test::DefaultInputShape(), GetZetaTestDataX<float>(),
GetZetaTestDataQ<float>(), baseline_zeta,
test::OpsTestConfig().ATol(1e-11).RTol(1e-2));
}
TEST_F(BinaryOpsTest, ZetaEqShapesDouble) {
TestEqualShapes<double, double, double, double>(
"Zeta", test::DefaultInputShape(), GetZetaTestDataX<double>(),
GetZetaTestDataQ<double>(), baseline_zeta,
test::OpsTestConfig().ATol(1e-30).RTol(1e-4));
}
} // namespace
} // namespace tensorflow