Make exhaustive_op_test_utils compile with C++14.
This is needed because in open source we are still using C++14.
Before, this library was using template specialization inside
a template class. To avoid this, move those template functions
to outside the class. Also fix the definition of the static
member of ComponentStringifyFormat. It would lead to linker errors
in open source before.
PiperOrigin-RevId: 293538921
Change-Id: I276a9758c8ff556a2636739800e89ea59351c8e6
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index b5e69f7..5a60d5a 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -740,9 +740,7 @@
name = "exhaustive_op_test_utils",
testonly = True,
srcs = ["exhaustive_op_test_utils.cc"],
- hdrs = [
- "exhaustive_op_test_utils.h",
- ],
+ hdrs = ["exhaustive_op_test_utils.h"],
tags = ["no_pip"],
deps = [
":client_library_test_base",
@@ -753,6 +751,7 @@
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:math",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/xla/tests/exhaustive_binary_test.cc b/tensorflow/compiler/xla/tests/exhaustive_binary_test.cc
index 5bb838a..9c67ee4 100644
--- a/tensorflow/compiler/xla/tests/exhaustive_binary_test.cc
+++ b/tensorflow/compiler/xla/tests/exhaustive_binary_test.cc
@@ -20,6 +20,7 @@
#endif
namespace xla {
+namespace exhaustive_op_test {
namespace {
template <PrimitiveType T>
@@ -415,4 +416,5 @@
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000))));
} // namespace
+} // namespace exhaustive_op_test
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.cc b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.cc
index 2ae1e2c..ba860d4 100644
--- a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.cc
+++ b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.cc
@@ -15,7 +15,14 @@
#include "tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h"
+#include <array>
+#include <string>
+#include <type_traits>
+
+#include "absl/strings/string_view.h"
+
namespace xla {
+namespace exhaustive_op_test {
// For f64, f32, f16, and bf16, we need 17, 9, 5, and 4 decimal places of
// precision to be guaranteed that we're printing the full number.
@@ -28,40 +35,291 @@
// See https://people.eecs.berkeley.edu/~wkahan/Math128/BinDecBin.pdf.)
namespace {
template <typename T>
-struct ComponentStringifyFormat {};
-
-template <>
-struct ComponentStringifyFormat<double> {
- static constexpr absl::string_view value = "%0.17g (0x%16x)";
+struct ComponentStringifyFormat {
+ static const absl::string_view value;
};
template <>
-struct ComponentStringifyFormat<float> {
- static constexpr absl::string_view value = "%0.9g (0x%08x)";
-};
+constexpr absl::string_view ComponentStringifyFormat<double>::value =
+ "%0.17g (0x%16x)";
template <>
-struct ComponentStringifyFormat<Eigen::half> {
- static constexpr absl::string_view value = "%0.5g (0x%04x)";
-};
+constexpr absl::string_view ComponentStringifyFormat<float>::value =
+ "%0.9g (0x%08x)";
template <>
-struct ComponentStringifyFormat<bfloat16> {
- static constexpr absl::string_view value = "%0.4g (0x%04x)";
-};
-} // namespace
+constexpr absl::string_view ComponentStringifyFormat<Eigen::half>::value =
+ "%0.5g (0x%04x)";
-/*static*/
-template <PrimitiveType T, size_t N>
-string ExhaustiveOpTestBase<T, N>::StringifyNum(
- typename ExhaustiveOpTestBase<T, N>::ComponentNativeT x) {
- typedef typename ExhaustiveOpTestBase<T, N>::ComponentNativeT ComponentType;
- typedef typename ExhaustiveOpTestBase<T, N>::ComponentIntegralNativeT
- IntegralType;
- return absl::StrFormat(ComponentStringifyFormat<ComponentType>::value,
+template <>
+constexpr absl::string_view ComponentStringifyFormat<bfloat16>::value =
+ "%0.4g (0x%04x)";
+
+template <typename Type, typename FuncPtr>
+ErrorSpec CallErrorSpec(FuncPtr* func, const std::array<Type, 1>& in) {
+ return func(in[0]);
+}
+
+template <typename Type, typename FuncPtr>
+ErrorSpec CallErrorSpec(FuncPtr* func, const std::array<Type, 2>& in) {
+ return func(in[0], in[1]);
+}
+
+template <typename Type, typename FuncPtr>
+Type CallOperation(FuncPtr* func, const std::array<Type, 1>& in) {
+ return func(in[0]);
+}
+
+template <typename Type, typename FuncPtr>
+Type CallOperation(FuncPtr* func, const std::array<Type, 2>& in) {
+ return func(in[0], in[1]);
+}
+
+// The number of values that can be substituted for subnormal inputs.
+constexpr int kNumSubnormalSubstitutionValues = 4;
+
+// Encodings used to determine where subnormal test values are cached.
+constexpr int kPositiveMin = 0;
+constexpr int kNegativeMin = 1;
+constexpr int kPositiveZero = 2;
+constexpr int kNegativeZero = 3;
+constexpr int kNonSubnormal = -1;
+constexpr int kInvalidCacheIndex = -1;
+
+template <typename T>
+struct is_complex_t : absl::disjunction<std::is_same<T, complex64>,
+ std::is_same<T, complex128>> {};
+
+// When we are testing a value such that all of its components are subnormal,
+// we also need to test inputs made up of the Cartesian product of values
+// replaced for each subnormal component. These additional test inputs are
+// common enough where it will be efficient to just cache the results of these
+// Cartesian products. In order to cache these values, we need a one to one
+// mapping between these Cartesian products and cache locations.
+//
+// Our mapping works by assigning each component an integer in
+// [0, kNumSubnormalSubstitutionValues) based on its test value. By lining
+// these integers up with the n'th component corresponding to the n'th digit,
+// then for each Cartesian product element we essentially create a unique base
+// kNumSubnormalSubstitutionValues number. This number represents our cache
+// index.
+//
+// In the event that there a component is not a subnormal, the value should
+// not be cached, so we return a kNonSubnormal value.
+
+template <
+ typename NativeRefT,
+ typename std::enable_if<!is_complex_t<NativeRefT>::value>::type* = nullptr>
+int GetCacheLocation(NativeRefT value) {
+ bool positive = !std::signbit(value);
+ if (std::abs(value) == std::numeric_limits<NativeRefT>::min()) {
+ return positive ? kPositiveMin : kNegativeMin;
+ } else if (value != 0) {
+ CHECK(std::fpclassify(value) != FP_SUBNORMAL);
+ return kNonSubnormal;
+ } else {
+ return positive ? kPositiveZero : kNegativeZero;
+ }
+}
+
+template <
+ typename NativeRefT,
+ typename std::enable_if<is_complex_t<NativeRefT>::value>::type* = nullptr>
+int GetCacheLocation(NativeRefT value) {
+ int real_loc =
+ GetCacheLocation<typename NativeRefT::value_type>(value.real());
+ int imag_loc =
+ GetCacheLocation<typename NativeRefT::value_type>(value.imag());
+ if (real_loc == kNonSubnormal || imag_loc == kNonSubnormal) {
+ return kNonSubnormal;
+ } else {
+ return real_loc * kNumSubnormalSubstitutionValues + imag_loc;
+ }
+}
+
+template <bool is_complex, typename NativeRefT, size_t N>
+int GetCacheLocation(const std::array<NativeRefT, N>& input) {
+ int location = 0;
+ int cache_size_per_element = (is_complex ? kNumSubnormalSubstitutionValues *
+ kNumSubnormalSubstitutionValues
+ : kNumSubnormalSubstitutionValues);
+ for (int i = 0; i < N; ++i) {
+ int comp_loc = GetCacheLocation<NativeRefT>(input[i]);
+ if (i == kNonSubnormal) {
+ return kNonSubnormal;
+ }
+ location *= cache_size_per_element;
+ location += comp_loc;
+ }
+ return location;
+}
+
+// The inverse function of GetCacheLocation.
+
+template <typename RetT,
+ typename std::enable_if<!is_complex_t<RetT>::value>::type* = nullptr>
+RetT FromCacheLocationComponent(int cache_loc) {
+ switch (cache_loc) {
+ case kPositiveMin:
+ return std::numeric_limits<RetT>::min();
+ case kNegativeMin:
+ return -std::numeric_limits<RetT>::min();
+ case kPositiveZero:
+ return static_cast<RetT>(0.0);
+ case kNegativeZero:
+ return static_cast<RetT>(-0.0);
+ default:
+ LOG(FATAL) << "Invalid cache_loc value of " << cache_loc;
+ }
+}
+
+template <typename RetT,
+ typename std::enable_if<is_complex_t<RetT>::value>::type* = nullptr>
+RetT FromCacheLocationComponent(int cache_loc) {
+ CHECK_LT(cache_loc,
+ kNumSubnormalSubstitutionValues * kNumSubnormalSubstitutionValues);
+ CHECK_GE(cache_loc, 0);
+
+ RetT value;
+ value.real(FromCacheLocationComponent<typename RetT::value_type>(
+ cache_loc / kNumSubnormalSubstitutionValues));
+ value.imag(FromCacheLocationComponent<typename RetT::value_type>(
+ cache_loc % kNumSubnormalSubstitutionValues));
+ return std::move(value);
+}
+
+template <bool is_complex, typename NativeRefT, size_t N>
+std::array<NativeRefT, N> FromCacheLocation(int cache_loc) {
+ std::array<NativeRefT, N> input;
+ int cache_size_per_element = (is_complex ? kNumSubnormalSubstitutionValues *
+ kNumSubnormalSubstitutionValues
+ : kNumSubnormalSubstitutionValues);
+ for (int i = N - 1; i >= 0; --i) {
+ input[i] = FromCacheLocationComponent<NativeRefT>(cache_loc %
+ cache_size_per_element);
+ cache_loc /= cache_size_per_element;
+ }
+
+ return input;
+}
+
+// Returns a string that describes the test value for the actual value.
+template <
+ typename NativeRefT,
+ typename std::enable_if<!is_complex_t<NativeRefT>::value>::type* = nullptr>
+std::string GetSubnormalDescription(NativeRefT test_val,
+ NativeRefT actual_val) {
+ std::string sp_min_normal = "sign-preserving min-normal-float";
+ std::string sp_zero = "sign-preserving zero";
+ std::string nsp_zero = "non-sign-preserving zero";
+
+ switch (GetCacheLocation<NativeRefT>(test_val)) {
+ case kNegativeMin:
+ case kPositiveMin:
+ return sp_min_normal;
+ case kNegativeZero:
+ case kPositiveZero:
+ return (std::signbit(test_val) == std::signbit(actual_val)) ? sp_zero
+ : nsp_zero;
+ default:
+ return "";
+ }
+}
+
+template <
+ typename NativeRefT,
+ typename std::enable_if<is_complex_t<NativeRefT>::value>::type* = nullptr>
+std::string GetSubnormalDescription(NativeRefT test_val,
+ NativeRefT actual_val) {
+ std::string real = GetSubnormalDescription<typename NativeRefT::value_type>(
+ test_val.real(), actual_val.real());
+ std::string imag = GetSubnormalDescription<typename NativeRefT::value_type>(
+ test_val.imag(), actual_val.imag());
+
+ if (real.empty()) {
+ if (imag.empty()) {
+ return "";
+ }
+ real = "real";
+ } else if (imag.empty()) {
+ imag = "imag";
+ }
+
+ return absl::StrCat("(", real, ", ", imag, ")");
+}
+
+template <bool is_complex, typename NativeRefT, size_t N>
+std::string GetSubnormalDescription(std::array<NativeRefT, N> test_vals,
+ std::array<NativeRefT, N> actual_vals) {
+ if (N == 1) {
+ return GetSubnormalDescription<NativeRefT>(test_vals[0], actual_vals[0]);
+ }
+
+ std::array<std::string, N> str_vals;
+ for (int i = 0; i < N; ++i) {
+ str_vals[i] =
+ GetSubnormalDescription<NativeRefT>(test_vals[i], actual_vals[i]);
+ if (str_vals[i].empty()) {
+ str_vals[i] = "original";
+ }
+ }
+
+ return absl::StrCat("(", absl::StrJoin(str_vals, ", "), ")");
+}
+
+template <
+ typename NativeT, typename IntegralType,
+ typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
+std::string StringifyNum(NativeT x) {
+ return absl::StrFormat(ComponentStringifyFormat<NativeT>::value,
static_cast<double>(x), BitCast<IntegralType>(x));
}
+template <
+ typename NativeT, typename IntegralType,
+ typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
+std::string StringifyNum(NativeT x) {
+ return absl::StrCat(
+ "(", StringifyNum<typename NativeT::value_type, IntegralType>(x.real()),
+ ", ", StringifyNum<typename NativeT::value_type, IntegralType>(x.imag()),
+ ")");
+}
+
+template <typename NativeT, typename IntegralType, size_t N>
+std::string StringifyNum(const std::array<NativeT, N>& inputs) {
+ if (N == 1) {
+ return StringifyNum<NativeT, IntegralType>(inputs[0]);
+ }
+
+ std::array<std::string, N> str_vals;
+ for (int i = 0; i < N; ++i) {
+ str_vals[i] = StringifyNum<NativeT, IntegralType>(inputs[i]);
+ }
+
+ return absl::StrCat("(", absl::StrJoin(str_vals, ", "), ")");
+}
+
+template <typename ErrorGenerator>
+void PrintMismatch(int64* mismatches, const ErrorGenerator& err_generator) {
+ // We send a few mismatches to gunit so they show up nicely in test logs.
+ // Then we send more to LOG(ERROR). The remainder we squelch unless we're
+ // at vlog level 2.
+ constexpr int64 kMaxMismatchesLoggedToGunit = 10;
+ constexpr int64 kMaxMismatchesLoggedToErr = 1000;
+
+ (*mismatches)++;
+ if (*mismatches < kMaxMismatchesLoggedToGunit) {
+ FAIL() << err_generator();
+ } else if (*mismatches < kMaxMismatchesLoggedToErr || VLOG_IS_ON(2)) {
+ LOG(ERROR) << err_generator();
+ } else if (*mismatches == kMaxMismatchesLoggedToErr) {
+ LOG(ERROR) << "Not printing any more mismatches; pass "
+ "--vmodule=exhaustive_op_test=2 to see "
+ "all of them.";
+ }
+}
+} // namespace
+
template <PrimitiveType T, size_t N>
void ExhaustiveOpTestBase<T, N>::ExpectNear(const InputLiterals& input_literals,
const Literal& result_literal,
@@ -69,10 +327,17 @@
ErrorSpecGen error_spec_gen) {
// Cache for when all components are subnormal testing values.
std::vector<NativeRefT> pure_subnormal_cache;
- pure_subnormal_cache.reserve(GetMaxCacheSize());
- for (int i = 0; i < GetMaxCacheSize(); ++i) {
- pure_subnormal_cache.push_back(
- CallOperation(evaluate_op, FromCacheLocation(i)));
+ // Since we take the cross product of all possible test values, and each
+ // component has kNumSubnormalSubstitutionValues possible test values, then
+ // the total number of different cache locations are
+ // kNumSubnormalSubstitutionValues raised to the num_components.
+ // num_components = N for the reals, and 2*N for the complex.
+ int64 max_cache_size =
+ pow(kNumSubnormalSubstitutionValues, N * (kIsComplex ? 2 : 1));
+ pure_subnormal_cache.reserve(max_cache_size);
+ for (int i = 0; i < max_cache_size; ++i) {
+ pure_subnormal_cache.push_back(CallOperation(
+ evaluate_op, FromCacheLocation<kIsComplex, NativeRefT, N>(i)));
}
NativeInputsList inputs_arr;
@@ -111,9 +376,11 @@
// error_spec), print an error.
if (subnormal_test_inputs.size() == 1) {
PrintMismatch(&mismatches, [&] {
- return absl::StrFormat("Mismatch on %s. Expected %s, but got %s.",
- StringifyNum(inputs), StringifyNum(expected),
- StringifyNum(actual));
+ return absl::StrFormat(
+ "Mismatch on %s. Expected %s, but got %s.",
+ StringifyNum<NativeT, ComponentIntegralNativeT, N>(inputs),
+ StringifyNum<NativeT, ComponentIntegralNativeT>(expected),
+ StringifyNum<NativeT, ComponentIntegralNativeT>(actual));
});
continue;
}
@@ -125,7 +392,9 @@
for (NativeRefInputs test_value : subnormal_test_inputs) {
NativeRefT result;
- int cache_loc = GetCacheLocation(test_value);
+ int cache_loc =
+ GetCacheLocation<kIsComplex, typename NativeRefInputs::value_type, N>(
+ test_value);
if (cache_loc == kInvalidCacheIndex) {
result = CallOperation(evaluate_op, test_value);
} else {
@@ -146,107 +415,33 @@
std::string mismatch = absl::StrFormat(
"Mismatch on subnormal value %s. Expected one of:\n"
" %10s (evaluated at full-precision value)\n",
- StringifyNum(inputs), StringifyNum(expected));
+ StringifyNum<NativeT, ComponentIntegralNativeT, N>(inputs),
+ StringifyNum<NativeT, ComponentIntegralNativeT>(expected));
CHECK_EQ(subnormal_test_inputs.size(), subnormal_test_results.size());
for (int i = 0; i < subnormal_test_inputs.size(); ++i) {
+ using IntegralNativeRefT =
+ typename ExhaustiveOpTestBase<RefT::value,
+ N>::ComponentIntegralNativeT;
absl::StrAppend(
&mismatch,
absl::StrFormat(" %10s (evaluated at %s)\n",
- StringifyNum(subnormal_test_results[i]),
- GetSubnormalDescription(subnormal_test_inputs[i],
- inputs_ref_ty)));
+ StringifyNum<NativeRefT, IntegralNativeRefT>(
+ subnormal_test_results[i]),
+ GetSubnormalDescription<kIsComplex, NativeRefT, N>(
+ subnormal_test_inputs[i], inputs_ref_ty)));
}
- absl::StrAppend(&mismatch,
- absl::StrFormat("but got %s", StringifyNum(actual)));
+ absl::StrAppend(
+ &mismatch,
+ absl::StrFormat(
+ "but got %s",
+ StringifyNum<NativeT, ComponentIntegralNativeT>(actual)));
PrintMismatch(&mismatches, [mismatch] { return mismatch; });
}
EXPECT_EQ(mismatches, 0);
}
-namespace {
-template <PrimitiveType T, size_t N>
-inline typename ExhaustiveOpTestBase<T, N>::ErrorSpec DefaultSpecGenerator(
- typename ExhaustiveOpTestBase<T, N>::NativeT) {
- LOG(FATAL) << "Unhandled Type";
-}
-
-template <PrimitiveType T, size_t N>
-inline typename ExhaustiveOpTestBase<T, N>::ErrorSpec DefaultSpecGenerator(
- typename ExhaustiveOpTestBase<T, N>::NativeT,
- typename ExhaustiveOpTestBase<T, N>::NativeT) {
- LOG(FATAL) << "Unhandled Type";
-}
-
-template <>
-inline ExhaustiveOpTestBase<C128, 1>::ErrorSpec DefaultSpecGenerator<C128, 1>(
- complex128) {
- return ExhaustiveOpTestBase<C128, 1>::ErrorSpec{0.0001, 0.0001};
-}
-
-template <>
-inline ExhaustiveOpTestBase<C64, 1>::ErrorSpec DefaultSpecGenerator<C64, 1>(
- complex64) {
- return ExhaustiveOpTestBase<C64, 1>::ErrorSpec{0.0001, 0.0001};
-}
-
-template <>
-inline ExhaustiveOpTestBase<F64, 1>::ErrorSpec DefaultSpecGenerator<F64, 1>(
- double) {
- return ExhaustiveOpTestBase<F64, 1>::ErrorSpec{0.0001, 0.0001};
-}
-
-template <>
-inline ExhaustiveOpTestBase<F32, 1>::ErrorSpec DefaultSpecGenerator<F32, 1>(
- float) {
- return ExhaustiveOpTestBase<F32, 1>::ErrorSpec{0.0001, 0.0001};
-}
-
-template <>
-inline ExhaustiveOpTestBase<F16, 1>::ErrorSpec DefaultSpecGenerator<F16, 1>(
- Eigen::half) {
- return ExhaustiveOpTestBase<F16, 1>::ErrorSpec{0.001, 0.001};
-}
-
-template <>
-inline ExhaustiveOpTestBase<BF16, 1>::ErrorSpec DefaultSpecGenerator<BF16, 1>(
- bfloat16) {
- return ExhaustiveOpTestBase<BF16, 1>::ErrorSpec{0.002, 0.02};
-}
-
-template <>
-inline ExhaustiveOpTestBase<F64, 2>::ErrorSpec DefaultSpecGenerator<F64, 2>(
- double, double) {
- return ExhaustiveOpTestBase<F64, 2>::ErrorSpec{0.001, 0.001};
-}
-
-template <>
-inline ExhaustiveOpTestBase<F32, 2>::ErrorSpec DefaultSpecGenerator<F32, 2>(
- float, float) {
- return ExhaustiveOpTestBase<F32, 2>::ErrorSpec{0.001, 0.001};
-}
-
-template <>
-inline ExhaustiveOpTestBase<F16, 2>::ErrorSpec DefaultSpecGenerator<F16, 2>(
- Eigen::half, Eigen::half) {
- return ExhaustiveOpTestBase<F16, 2>::ErrorSpec{0.001, 0.001};
-}
-
-template <>
-inline ExhaustiveOpTestBase<BF16, 2>::ErrorSpec DefaultSpecGenerator<BF16, 2>(
- bfloat16, bfloat16) {
- return ExhaustiveOpTestBase<BF16, 2>::ErrorSpec{0.002, 0.02};
-}
-} // namespace
-
-/*static*/
-template <PrimitiveType T, size_t N>
-typename ExhaustiveOpTestBase<T, N>::ErrorSpecGen
-ExhaustiveOpTestBase<T, N>::GetDefaultSpecGenerator() {
- return DefaultSpecGenerator<T, N>;
-}
-
template class ExhaustiveOpTestBase<C128, 1>;
template class ExhaustiveOpTestBase<C64, 1>;
template class ExhaustiveOpTestBase<F64, 1>;
@@ -259,4 +454,5 @@
template class ExhaustiveOpTestBase<F16, 2>;
template class ExhaustiveOpTestBase<BF16, 2>;
+} // namespace exhaustive_op_test
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h
index 67e6d6d6..009669b 100644
--- a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h
+++ b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h
@@ -16,6 +16,7 @@
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_
#define TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_
+#include <array>
#include <cmath>
#include <iterator>
@@ -28,25 +29,73 @@
#include "tensorflow/compiler/xla/tests/test_macros.h"
namespace xla {
+namespace exhaustive_op_test {
+
+struct ErrorSpec {
+ float abs_err;
+ float rel_err;
+
+ // If true, will consider -0 not near to +0 and vice versa. Note that
+ // +epsilon may still be considered close to -0, depending on the error
+ // spec; this only covers the case when both `expected` and `actual` are
+ // equal to 0.
+ bool strict_signed_zeros = false;
+
+ ErrorSpec(float a, float r) : abs_err(a), rel_err(r) {}
+};
+
+// Representations of the reference function passed in by the user.
+template <typename NativeRefT, size_t K>
+struct EvaluateOpWrapper {};
+template <typename NativeRefT>
+struct EvaluateOpWrapper<NativeRefT, 1> {
+ using type = NativeRefT (*)(NativeRefT);
+};
+template <typename NativeRefT>
+struct EvaluateOpWrapper<NativeRefT, 2> {
+ using type = NativeRefT (*)(NativeRefT, NativeRefT);
+};
+
+// Representations of the reference function passed in by the user.
+template <typename XlaInputs, size_t K>
+struct EnqueueOpWrapper {};
+template <typename XlaInputs>
+struct EnqueueOpWrapper<XlaInputs, 1> {
+ using type = std::function<XlaOp(XlaOp)>;
+ static XlaOp BuildFromInputs(XlaInputs inputs, type ty) {
+ return ty(inputs[0]);
+ }
+};
+template <typename XlaInputs>
+struct EnqueueOpWrapper<XlaInputs, 2> {
+ using type = std::function<XlaOp(XlaOp, XlaOp)>;
+ static XlaOp BuildFromInputs(XlaInputs inputs, type ty) {
+ return ty(inputs[0], inputs[1]);
+ }
+};
+
+// Representations of the ErrorSpecGen function passed in by the user.
+template <PrimitiveType T, size_t K>
+struct ErrorSpecGenWrapper {};
+template <PrimitiveType T>
+struct ErrorSpecGenWrapper<T, 1> {
+ using NativeT = typename primitive_util::PrimitiveTypeToNative<T>::type;
+ using type = ErrorSpec (*)(NativeT);
+};
+template <PrimitiveType T>
+struct ErrorSpecGenWrapper<T, 2> {
+ using NativeT = typename primitive_util::PrimitiveTypeToNative<T>::type;
+ using type = ErrorSpec (*)(NativeT, NativeT);
+};
+
+template <PrimitiveType T, size_t N>
+typename ErrorSpecGenWrapper<T, N>::type GetDefaultSpecGenerator();
// T: The primitive type being tested.
// N: The number of operands that the function being tested takes.
template <PrimitiveType T, size_t N>
class ExhaustiveOpTestBase : public ClientLibraryTestBase {
public:
- struct ErrorSpec {
- float abs_err;
- float rel_err;
-
- // If true, will consider -0 not near to +0 and vice versa. Note that
- // +epsilon may still be considered close to -0, depending on the error
- // spec; this only covers the case when both `expected` and `actual` are
- // equal to 0.
- bool strict_signed_zeros = false;
-
- ErrorSpec(float a, float r) : abs_err(a), rel_err(r) {}
- };
-
// Definitions depending on the primitive type T.
static constexpr bool kIsComplex = (T == C128 || T == C64);
@@ -112,52 +161,10 @@
// N data items representing a single input to an XLA function.
using XlaInputs = std::array<XlaOp, N>;
- // Representations of the reference function passed in by the user.
- template <size_t K>
- struct EvaluateOpWrapper {};
- template <>
- struct EvaluateOpWrapper<1> {
- using type = NativeRefT (*)(NativeRefT);
- };
- template <>
- struct EvaluateOpWrapper<2> {
- using type = NativeRefT (*)(NativeRefT, NativeRefT);
- };
-
- // Representations of the reference function passed in by the user.
- template <size_t K>
- struct EnqueueOpWrapper {};
- template <>
- struct EnqueueOpWrapper<1> {
- using type = std::function<XlaOp(XlaOp)>;
- static XlaOp BuildFromInputs(XlaInputs inputs, type ty) {
- return ty(inputs[0]);
- }
- };
- template <>
- struct EnqueueOpWrapper<2> {
- using type = std::function<XlaOp(XlaOp, XlaOp)>;
- static XlaOp BuildFromInputs(XlaInputs inputs, type ty) {
- return ty(inputs[0], inputs[1]);
- }
- };
-
- // Representations of the ErrorSpecGen function passed in by the user.
- template <size_t K>
- struct ErrorSpecGenWrapper {};
- template <>
- struct ErrorSpecGenWrapper<1> {
- using type = ErrorSpec (*)(NativeT);
- };
- template <>
- struct ErrorSpecGenWrapper<2> {
- using type = ErrorSpec (*)(NativeT, NativeT);
- };
-
public:
- using ErrorSpecGen = typename ErrorSpecGenWrapper<N>::type;
- using EvaluateOp = typename EvaluateOpWrapper<N>::type;
- using EnqueueOp = typename EnqueueOpWrapper<N>::type;
+ using ErrorSpecGen = typename ErrorSpecGenWrapper<T, N>::type;
+ using EvaluateOp = typename EvaluateOpWrapper<NativeRefT, N>::type;
+ using EnqueueOp = typename EnqueueOpWrapper<XlaInputs, N>::type;
explicit ExhaustiveOpTestBase()
: ty_(T), platform_(client_->platform()->Name()) {
@@ -169,7 +176,7 @@
}
void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op) {
- Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator());
+ Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator<T, N>());
}
// A helper for implementing the Run method for exhaustive op tests. It
@@ -190,7 +197,7 @@
xla_inputs[i] =
Parameter(&builder, i, input_literals[i].shape(), "input");
}
- EnqueueOpWrapper<N>::BuildFromInputs(xla_inputs, enqueue_op);
+ EnqueueOpWrapper<XlaInputs, N>::BuildFromInputs(xla_inputs, enqueue_op);
TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build());
TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
@@ -437,200 +444,6 @@
return test_values;
}
- // The number of values that can be substituted for subnormal inputs.
- static constexpr int kNumSubnormalSubstitutionValues = 4;
-
- // Encodings used to determine where subnormal test values are cached.
- static constexpr int kPositiveMin = 0;
- static constexpr int kNegativeMin = 1;
- static constexpr int kPositiveZero = 2;
- static constexpr int kNegativeZero = 3;
- static constexpr int kNonSubnormal = -1;
- static constexpr int kInvalidCacheIndex = -1;
-
- // Since we take the cross product of all possible test values, and each
- // component has kNumSubnormalSubstitutionValues possible test values, then
- // the total number of different cache locations are
- // kNumSubnormalSubstitutionValues raised to the num_components.
- // num_components = N for the reals, and 2*N for the complex.
- static constexpr int GetMaxCacheSize() {
- return pow(kNumSubnormalSubstitutionValues, N * (kIsComplex ? 2 : 1));
- }
-
- // When we are testing a value such that all of its components are subnormal,
- // we also need to test inputs made up of the Cartesian product of values
- // replaced for each subnormal component. These additional test inputs are
- // common enough where it will be efficient to just cache the results of these
- // Cartesian products. In order to cache these values, we need a one to one
- // mapping between these Cartesian products and cache locations.
- //
- // Our mapping works by assigning each component an integer in
- // [0, kNumSubnormalSubstitutionValues) based on its test value. By lining
- // these integers up with the n'th component corresponding to the n'th digit,
- // then for each Cartesian product element we essentially create a unique base
- // kNumSubnormalSubstitutionValues number. This number represents our cache
- // index.
- //
- // In the event that there a component is not a subnormal, the value should
- // not be cached, so we return a kNonSubnormal value.
-
- static int GetCacheLocation(ComponentNativeRefT value) {
- bool positive = !std::signbit(value);
- if (std::abs(value) == std::numeric_limits<ComponentNativeRefT>::min()) {
- if (positive) {
- return kPositiveMin;
- } else {
- return kNegativeMin;
- }
- } else if (value != 0) {
- CHECK(std::fpclassify(value) != FP_SUBNORMAL);
- return kNonSubnormal;
- } else if (positive) {
- return kPositiveZero;
- } else {
- return kNegativeZero;
- }
- }
-
- static int GetCacheLocation(std::complex<ComponentNativeRefT> value) {
- int real_loc = GetCacheLocation(value.real());
- int imag_loc = GetCacheLocation(value.imag());
- if (real_loc == kNonSubnormal || imag_loc == kNonSubnormal) {
- return kNonSubnormal;
- } else {
- return real_loc * kNumSubnormalSubstitutionValues + imag_loc;
- }
- }
-
- static int GetCacheLocation(const NativeRefInputs& input) {
- int location = 0;
- int cache_size_per_element =
- (kIsComplex
- ? kNumSubnormalSubstitutionValues * kNumSubnormalSubstitutionValues
- : kNumSubnormalSubstitutionValues);
- for (int i = 0; i < N; ++i) {
- int comp_loc = GetCacheLocation(input[i]);
- if (i == kNonSubnormal) {
- return kNonSubnormal;
- }
- location *= cache_size_per_element;
- location += comp_loc;
- }
- return location;
- }
-
- // The inverse function of GetCacheLocation.
-
- template <bool complex, typename RetT>
- static RetT FromCacheLocationComponent(int cache_loc) {
- LOG(FATAL) << "Not implemented.";
- }
-
- template <>
- static ComponentNativeRefT
- FromCacheLocationComponent<false, ComponentNativeRefT>(int cache_loc) {
- switch (cache_loc) {
- case kPositiveMin:
- return std::numeric_limits<ComponentNativeRefT>::min();
- case kNegativeMin:
- return -std::numeric_limits<ComponentNativeRefT>::min();
- case kPositiveZero:
- return static_cast<ComponentNativeRefT>(0.0);
- case kNegativeZero:
- return static_cast<ComponentNativeRefT>(-0.0);
- default:
- LOG(FATAL) << "Invalid cache_loc value of " << cache_loc;
- }
- }
-
- template <>
- static std::complex<ComponentNativeRefT>
- FromCacheLocationComponent<true, std::complex<ComponentNativeRefT>>(
- int cache_loc) {
- CHECK_LT(cache_loc,
- kNumSubnormalSubstitutionValues * kNumSubnormalSubstitutionValues);
- CHECK_GE(cache_loc, 0);
-
- std::complex<ComponentNativeRefT> value;
- value.real(FromCacheLocationComponent<false, ComponentNativeRefT>(
- cache_loc / kNumSubnormalSubstitutionValues));
- value.imag(FromCacheLocationComponent<false, ComponentNativeRefT>(
- cache_loc % kNumSubnormalSubstitutionValues));
- return std::move(value);
- }
-
- static NativeRefInputs FromCacheLocation(int cache_loc) {
- NativeRefInputs input;
- int cache_size_per_element =
- (kIsComplex
- ? kNumSubnormalSubstitutionValues * kNumSubnormalSubstitutionValues
- : kNumSubnormalSubstitutionValues);
- for (int i = N - 1; i >= 0; --i) {
- input[i] = FromCacheLocationComponent<kIsComplex, NativeRefT>(
- cache_loc % cache_size_per_element);
- cache_loc /= cache_size_per_element;
- }
-
- return input;
- }
-
- // Returns a string that describes the test value for the actual value.
- std::string GetSubnormalDescription(ComponentNativeRefT test_val,
- ComponentNativeRefT actual_val) {
- const string sp_min_normal = "sign-preserving min-normal-float";
- const string sp_zero = "sign-preserving zero";
- const string nsp_zero = "non-sign-preserving zero";
-
- switch (GetCacheLocation(test_val)) {
- case kNegativeMin:
- case kPositiveMin:
- return sp_min_normal;
- case kNegativeZero:
- case kPositiveZero:
- return (std::signbit(test_val) == std::signbit(actual_val)) ? sp_zero
- : nsp_zero;
- default:
- return "";
- }
- }
-
- std::string GetSubnormalDescription(
- std::complex<ComponentNativeRefT> test_val,
- std::complex<ComponentNativeRefT> actual_val) {
- std::string real =
- GetSubnormalDescription(test_val.real(), actual_val.real());
- std::string imag =
- GetSubnormalDescription(test_val.imag(), actual_val.imag());
-
- if (real.empty()) {
- if (imag.empty()) {
- return "";
- }
- real = "real";
- } else if (imag.empty()) {
- imag = "imag";
- }
-
- return absl::StrCat("(", real, ", ", imag, ")");
- }
-
- std::string GetSubnormalDescription(std::array<NativeRefT, N> test_vals,
- std::array<NativeRefT, N> actual_vals) {
- if (N == 1) {
- return GetSubnormalDescription(test_vals[0], actual_vals[0]);
- }
-
- std::array<std::string, N> str_vals;
- for (int i = 0; i < N; ++i) {
- str_vals[i] = GetSubnormalDescription(test_vals[i], actual_vals[i]);
- if (str_vals[i].empty()) {
- str_vals[i] = "original";
- }
- }
-
- return absl::StrCat("(", absl::StrJoin(str_vals, ", "), ")");
- }
-
InputLiterals CreateInputLiterals() {
InputLiterals literals;
for (int i = 0; i < N; ++i) {
@@ -662,26 +475,6 @@
return abs_err <= spec.abs_err || rel_err <= spec.rel_err;
}
- template <typename ErrorGenerator>
- void PrintMismatch(int64* mismatches, const ErrorGenerator& err_generator) {
- // We send a few mismatches to gunit so they show up nicely in test logs.
- // Then we send more to LOG(ERROR). The remainder we squelch unless we're
- // at vlog level 2.
- constexpr int64 kMaxMismatchesLoggedToGunit = 10;
- constexpr int64 kMaxMismatchesLoggedToErr = 1000;
-
- (*mismatches)++;
- if (*mismatches < kMaxMismatchesLoggedToGunit) {
- FAIL() << err_generator();
- } else if (*mismatches < kMaxMismatchesLoggedToErr || VLOG_IS_ON(2)) {
- LOG(ERROR) << err_generator();
- } else if (*mismatches == kMaxMismatchesLoggedToErr) {
- LOG(ERROR) << "Not printing any more mismatches; pass "
- "--vmodule=exhaustive_op_test=2 to see "
- "all of them.";
- }
- }
-
// Converts part or all bits in an uint64 to the value of the floating point
// data type being tested.
//
@@ -704,41 +497,6 @@
return ConvertValue(bits);
}
- static string StringifyNum(ComponentNativeT x);
-
- static string StringifyNum(std::complex<ComponentNativeT> x) {
- return absl::StrCat("(", StringifyNum(x.real()), ", ",
- StringifyNum(x.imag()), ")");
- }
-
- // We also stringify the NativeRefT, so we need to generate an additional
- // version of this function when NativeRefT != NativeT.
- template <
- typename T1 = NativeRefT,
- class = typename std::enable_if<!std::is_same<NativeT, T1>::value>::type>
- static string StringifyNum(NativeRefT x) {
- return ExhaustiveOpTestBase<RefT::value, N>::StringifyNum(x);
- }
-
- static string StringifyNum(const NativeInputs& inputs) {
- if (N == 1) {
- return StringifyNum(inputs[0]);
- }
-
- std::array<std::string, N> str_vals;
- for (int i = 0; i < N; ++i) {
- str_vals[i] = StringifyNum(inputs[i]);
- }
-
- return absl::StrCat("(", absl::StrJoin(str_vals, ", "), ")");
- }
-
- static void AppendStringifyNum(std::string* s, NativeT x) {
- absl::StrAppend(s, StringifyNum(x));
- }
-
- static ErrorSpecGen GetDefaultSpecGenerator();
-
protected:
// The primitive type being tested.
const PrimitiveType ty_;
@@ -759,30 +517,6 @@
//
// XLA:GPU preserves denormal signs, but other backends don't.
bool relaxed_denormal_signs_ = platform_ != "CUDA";
-
- private:
- using EvaluateOpInternal = NativeRefT (*)(NativeRefInputs);
- using ErrorSpecGenInternal = ErrorSpec (*)(NativeInputs);
-
- template <typename Type, typename FuncPtr>
- ErrorSpec CallErrorSpec(FuncPtr* func, const std::array<Type, 1>& in) {
- return func(in[0]);
- }
-
- template <typename Type, typename FuncPtr>
- ErrorSpec CallErrorSpec(FuncPtr* func, const std::array<Type, 2>& in) {
- return func(in[0], in[1]);
- }
-
- template <typename Type, typename FuncPtr>
- Type CallOperation(FuncPtr* func, const std::array<Type, 1>& in) {
- return func(in[0]);
- }
-
- template <typename Type, typename FuncPtr>
- Type CallOperation(FuncPtr* func, const std::array<Type, 2>& in) {
- return func(in[0], in[1]);
- }
};
// Represents a set of 64 bit chunks by representing the starting bit chunk,
@@ -1202,5 +936,74 @@
return result;
}
+template <PrimitiveType T, size_t N>
+inline ErrorSpec DefaultSpecGenerator(
+ typename ExhaustiveOpTestBase<T, N>::NativeT) {
+ LOG(FATAL) << "Unhandled Type";
+}
+
+template <PrimitiveType T, size_t N>
+inline ErrorSpec DefaultSpecGenerator(
+ typename ExhaustiveOpTestBase<T, N>::NativeT,
+ typename ExhaustiveOpTestBase<T, N>::NativeT) {
+ LOG(FATAL) << "Unhandled Type";
+}
+
+template <>
+inline ErrorSpec DefaultSpecGenerator<C128, 1>(complex128) {
+ return ErrorSpec{0.0001, 0.0001};
+}
+
+template <>
+inline ErrorSpec DefaultSpecGenerator<C64, 1>(complex64) {
+ return ErrorSpec{0.0001, 0.0001};
+}
+
+template <>
+inline ErrorSpec DefaultSpecGenerator<F64, 1>(double) {
+ return ErrorSpec{0.0001, 0.0001};
+}
+
+template <>
+inline ErrorSpec DefaultSpecGenerator<F32, 1>(float) {
+ return ErrorSpec{0.0001, 0.0001};
+}
+
+template <>
+inline ErrorSpec DefaultSpecGenerator<F16, 1>(Eigen::half) {
+ return ErrorSpec{0.001, 0.001};
+}
+
+template <>
+inline ErrorSpec DefaultSpecGenerator<BF16, 1>(bfloat16) {
+ return ErrorSpec{0.002, 0.02};
+}
+
+template <>
+inline ErrorSpec DefaultSpecGenerator<F64, 2>(double, double) {
+ return ErrorSpec{0.001, 0.001};
+}
+
+template <>
+inline ErrorSpec DefaultSpecGenerator<F32, 2>(float, float) {
+ return ErrorSpec{0.001, 0.001};
+}
+
+template <>
+inline ErrorSpec DefaultSpecGenerator<F16, 2>(Eigen::half, Eigen::half) {
+ return ErrorSpec{0.001, 0.001};
+}
+
+template <>
+inline ErrorSpec DefaultSpecGenerator<BF16, 2>(bfloat16, bfloat16) {
+ return ErrorSpec{0.002, 0.02};
+}
+
+template <PrimitiveType T, size_t N>
+typename ErrorSpecGenWrapper<T, N>::type GetDefaultSpecGenerator() {
+ return DefaultSpecGenerator<T, N>;
+}
+
+} // namespace exhaustive_op_test
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_
diff --git a/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc b/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc
index 9f14774..d008ffe 100644
--- a/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc
+++ b/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc
@@ -22,6 +22,7 @@
#endif
namespace xla {
+namespace exhaustive_op_test {
using Eigen::half;
@@ -158,7 +159,13 @@
}
template <PrimitiveType T>
-using ExhaustiveUnaryTest = ExhaustiveOpTestBase<T, 1>;
+class ExhaustiveUnaryTest : public ExhaustiveOpTestBase<T, 1> {
+ public:
+ using typename ExhaustiveOpTestBase<T, 1>::ErrorSpecGen;
+ static ErrorSpecGen GetDefaultSpecGenerator() {
+ return exhaustive_op_test::GetDefaultSpecGenerator<T, 1>();
+ }
+};
// Exhaustive test for unary operations for <= 32bit floating point types.
//
@@ -977,4 +984,5 @@
::testing::ValuesIn(
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000))));
+} // namespace exhaustive_op_test
} // namespace xla