blob: bdc8657d7442aaf5502c8e9b6bddb51eda4140e3 [file] [log] [blame]
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test cases for the bfloat16 Python type."""
import collections
import copy
import itertools
import math
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
# pylint: disable=unused-import,g-bad-import-order
from tensorflow.python.framework import dtypes
from tensorflow.python.lib.core import _pywrap_bfloat16
from tensorflow.python.platform import test
bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
def numpy_assert_allclose(a, b, **kwargs):
a = a.astype(np.float32) if a.dtype == bfloat16 else a
b = b.astype(np.float32) if b.dtype == bfloat16 else b
return np.testing.assert_allclose(a, b, **kwargs)
def type_bits(x):
if x == bfloat16:
return 16
return np.finfo(x).bits
def promote_types(a, b):
num_bits_a = type_bits(a)
num_bits_b = type_bits(b)
# Pick the greater of the two types.
if num_bits_a < num_bits_b:
return b
if num_bits_b < num_bits_a:
return a
# Pick either type if both are equivalent.
if a == b:
return a
# The only possibility at this point is that the two types are bfloat16 and
# np.float16. We expect to have promoted to np.float32.
assert num_bits_a == 16
assert num_bits_b == 16
return np.float32
epsilon = float.fromhex("1.0p-7")
# Values that should round trip exactly to float and back.
FLOAT_VALUES = [
0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon,
-1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0,
float("inf"),
float("-inf"),
float("nan")
]
class Bfloat16Test(parameterized.TestCase):
"""Tests the non-numpy Python methods of the bfloat16 type."""
def testRoundTripToFloat(self):
for v in FLOAT_VALUES:
np.testing.assert_equal(v, float(bfloat16(v)))
def testRoundTripNumpyTypes(self):
for dtype in [np.float16, np.float32, np.float64, np.longdouble]:
np.testing.assert_equal(-3.75, dtype(bfloat16(dtype(-3.75))))
np.testing.assert_equal(1.5, float(bfloat16(dtype(1.5))))
np.testing.assert_equal(4.5, dtype(bfloat16(np.array(4.5, dtype))))
np.testing.assert_equal(
np.array([2, 5, -1], bfloat16), bfloat16(np.array([2, 5, -1], dtype)))
def testRoundTripToInt(self):
for v in [-256, -255, -34, -2, -1, 0, 1, 2, 10, 47, 128, 255, 256, 512]:
self.assertEqual(v, int(bfloat16(v)))
# pylint: disable=g-complex-comprehension
@parameterized.named_parameters(({
"testcase_name": "_" + dtype.__name__,
"dtype": dtype
} for dtype in [bfloat16, np.float16, np.float32, np.float64, np.longdouble]))
def testRoundTripToNumpy(self, dtype):
for v in FLOAT_VALUES:
np.testing.assert_equal(v, bfloat16(dtype(v)))
np.testing.assert_equal(v, dtype(bfloat16(dtype(v))))
np.testing.assert_equal(v, dtype(bfloat16(np.array(v, dtype))))
if dtype != bfloat16:
np.testing.assert_equal(
np.array(FLOAT_VALUES, dtype),
bfloat16(np.array(FLOAT_VALUES, dtype)).astype(dtype))
def testStr(self):
self.assertEqual("0", str(bfloat16(0.0)))
self.assertEqual("1", str(bfloat16(1.0)))
self.assertEqual("-3.5", str(bfloat16(-3.5)))
self.assertEqual("0.0078125", str(bfloat16(float.fromhex("1.0p-7"))))
self.assertEqual("inf", str(bfloat16(float("inf"))))
self.assertEqual("-inf", str(bfloat16(float("-inf"))))
self.assertEqual("nan", str(bfloat16(float("nan"))))
def testRepr(self):
self.assertEqual("0", repr(bfloat16(0)))
self.assertEqual("1", repr(bfloat16(1)))
self.assertEqual("-3.5", repr(bfloat16(-3.5)))
self.assertEqual("0.0078125", repr(bfloat16(float.fromhex("1.0p-7"))))
self.assertEqual("inf", repr(bfloat16(float("inf"))))
self.assertEqual("-inf", repr(bfloat16(float("-inf"))))
self.assertEqual("nan", repr(bfloat16(float("nan"))))
def testHash(self):
self.assertEqual(0, hash(bfloat16(0.0)))
self.assertEqual(0x3f80, hash(bfloat16(1.0)))
self.assertEqual(0x7fc0, hash(bfloat16(float("nan"))))
# Tests for Python operations
def testNegate(self):
for v in FLOAT_VALUES:
np.testing.assert_equal(-v, float(-bfloat16(v)))
def testAdd(self):
np.testing.assert_equal(0, float(bfloat16(0) + bfloat16(0)))
np.testing.assert_equal(1, float(bfloat16(1) + bfloat16(0)))
np.testing.assert_equal(0, float(bfloat16(1) + bfloat16(-1)))
np.testing.assert_equal(5.5, float(bfloat16(2) + bfloat16(3.5)))
np.testing.assert_equal(1.25, float(bfloat16(3.5) + bfloat16(-2.25)))
np.testing.assert_equal(
float("inf"), float(bfloat16(float("inf")) + bfloat16(-2.25)))
np.testing.assert_equal(
float("-inf"), float(bfloat16(float("-inf")) + bfloat16(-2.25)))
self.assertTrue(math.isnan(float(bfloat16(3.5) + bfloat16(float("nan")))))
def testAddScalarTypePromotion(self):
"""Tests type promotion against Numpy scalar values."""
types = [bfloat16, np.float16, np.float32, np.float64, np.longdouble]
for lhs_type in types:
for rhs_type in types:
expected_type = promote_types(lhs_type, rhs_type)
actual_type = type(lhs_type(3.5) + rhs_type(2.25))
self.assertEqual(expected_type, actual_type)
def testAddArrayTypePromotion(self):
self.assertEqual(np.float32,
type(bfloat16(3.5) + np.array(2.25, np.float32)))
self.assertEqual(np.float32,
type(np.array(3.5, np.float32) + bfloat16(2.25)))
def testSub(self):
np.testing.assert_equal(0, float(bfloat16(0) - bfloat16(0)))
np.testing.assert_equal(1, float(bfloat16(1) - bfloat16(0)))
np.testing.assert_equal(2, float(bfloat16(1) - bfloat16(-1)))
np.testing.assert_equal(-1.5, float(bfloat16(2) - bfloat16(3.5)))
np.testing.assert_equal(5.75, float(bfloat16(3.5) - bfloat16(-2.25)))
np.testing.assert_equal(
float("-inf"), float(bfloat16(-2.25) - bfloat16(float("inf"))))
np.testing.assert_equal(
float("inf"), float(bfloat16(-2.25) - bfloat16(float("-inf"))))
self.assertTrue(math.isnan(float(bfloat16(3.5) - bfloat16(float("nan")))))
def testMul(self):
np.testing.assert_equal(0, float(bfloat16(0) * bfloat16(0)))
np.testing.assert_equal(0, float(bfloat16(1) * bfloat16(0)))
np.testing.assert_equal(-1, float(bfloat16(1) * bfloat16(-1)))
np.testing.assert_equal(-7.875, float(bfloat16(3.5) * bfloat16(-2.25)))
np.testing.assert_equal(
float("-inf"), float(bfloat16(float("inf")) * bfloat16(-2.25)))
np.testing.assert_equal(
float("inf"), float(bfloat16(float("-inf")) * bfloat16(-2.25)))
self.assertTrue(math.isnan(float(bfloat16(3.5) * bfloat16(float("nan")))))
def testDiv(self):
self.assertTrue(math.isnan(float(bfloat16(0) / bfloat16(0))))
np.testing.assert_equal(float("inf"), float(bfloat16(1) / bfloat16(0)))
np.testing.assert_equal(-1, float(bfloat16(1) / bfloat16(-1)))
np.testing.assert_equal(-1.75, float(bfloat16(3.5) / bfloat16(-2)))
np.testing.assert_equal(
float("-inf"), float(bfloat16(float("inf")) / bfloat16(-2.25)))
np.testing.assert_equal(
float("inf"), float(bfloat16(float("-inf")) / bfloat16(-2.25)))
self.assertTrue(math.isnan(float(bfloat16(3.5) / bfloat16(float("nan")))))
def testLess(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v < w, bfloat16(v) < bfloat16(w))
def testLessEqual(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v <= w, bfloat16(v) <= bfloat16(w))
def testGreater(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v > w, bfloat16(v) > bfloat16(w))
def testGreaterEqual(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v >= w, bfloat16(v) >= bfloat16(w))
def testEqual(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v == w, bfloat16(v) == bfloat16(w))
def testNotEqual(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v != w, bfloat16(v) != bfloat16(w))
def testNan(self):
a = np.isnan(bfloat16(float("nan")))
self.assertTrue(a)
numpy_assert_allclose(np.array([1.0, a]), np.array([1.0, a]))
a = np.array([bfloat16(1.34375),
bfloat16(1.4375),
bfloat16(float("nan"))],
dtype=bfloat16)
b = np.array(
[bfloat16(1.3359375),
bfloat16(1.4375),
bfloat16(float("nan"))],
dtype=bfloat16)
numpy_assert_allclose(
a, b, rtol=0.1, atol=0.1, equal_nan=True, err_msg="", verbose=True)
def testSort(self):
values_to_sort = np.float32(FLOAT_VALUES)
sorted_f32 = np.sort(values_to_sort)
sorted_bf16 = np.sort(values_to_sort.astype(bfloat16)) # pylint: disable=too-many-function-args
np.testing.assert_equal(sorted_f32, np.float32(sorted_bf16))
def testArgmax(self):
values_to_sort = np.float32(bfloat16(np.float32(FLOAT_VALUES)))
argmax_f32 = np.argmax(values_to_sort)
argmax_bf16 = np.argmax(values_to_sort.astype(bfloat16)) # pylint: disable=too-many-function-args
np.testing.assert_equal(argmax_f32, argmax_bf16)
def testArgmaxOnNan(self):
"""Ensures we return the right thing for multiple NaNs."""
one_with_nans = np.array(
[1.0, float("nan"), float("nan")], dtype=np.float32)
np.testing.assert_equal(
np.argmax(one_with_nans.astype(bfloat16)), np.argmax(one_with_nans))
def testArgmaxOnNegativeInfinity(self):
"""Ensures we return the right thing for negative infinities."""
inf = np.array([float("-inf")], dtype=np.float32)
np.testing.assert_equal(np.argmax(inf.astype(bfloat16)), np.argmax(inf))
def testArgmin(self):
values_to_sort = np.float32(bfloat16(np.float32(FLOAT_VALUES)))
argmin_f32 = np.argmin(values_to_sort)
argmin_bf16 = np.argmin(values_to_sort.astype(bfloat16)) # pylint: disable=too-many-function-args
np.testing.assert_equal(argmin_f32, argmin_bf16)
def testArgminOnNan(self):
"""Ensures we return the right thing for multiple NaNs."""
one_with_nans = np.array(
[1.0, float("nan"), float("nan")], dtype=np.float32)
np.testing.assert_equal(
np.argmin(one_with_nans.astype(bfloat16)), np.argmin(one_with_nans))
def testArgminOnPositiveInfinity(self):
"""Ensures we return the right thing for positive infinities."""
inf = np.array([float("inf")], dtype=np.float32)
np.testing.assert_equal(np.argmin(inf.astype(bfloat16)), np.argmin(inf))
def testDtypeFromString(self):
assert np.dtype("bfloat16") == np.dtype(bfloat16)
BinaryOp = collections.namedtuple("BinaryOp", ["op"])
UNARY_UFUNCS = [
np.negative, np.positive, np.absolute, np.fabs, np.rint, np.sign,
np.conjugate, np.exp, np.exp2, np.expm1, np.log, np.log10, np.log1p,
np.log2, np.sqrt, np.square, np.cbrt, np.reciprocal, np.sin, np.cos, np.tan,
np.arcsin, np.arccos, np.arctan, np.sinh, np.cosh, np.tanh, np.arcsinh,
np.arccosh, np.arctanh, np.deg2rad, np.rad2deg, np.floor, np.ceil, np.trunc
]
BINARY_UFUNCS = [
np.add, np.subtract, np.multiply, np.divide, np.logaddexp, np.logaddexp2,
np.floor_divide, np.power, np.remainder, np.fmod, np.heaviside, np.arctan2,
np.hypot, np.maximum, np.minimum, np.fmax, np.fmin, np.copysign
]
BINARY_PREDICATE_UFUNCS = [
np.equal, np.not_equal, np.less, np.greater, np.less_equal,
np.greater_equal, np.logical_and, np.logical_or, np.logical_xor
]
class Bfloat16NumPyTest(parameterized.TestCase):
"""Tests the NumPy integration of the bfloat16 type."""
def testDtype(self):
self.assertEqual(bfloat16, np.dtype(bfloat16))
def testDeepCopyDoesNotAlterHash(self):
# For context, see https://github.com/google/jax/issues/4651. If the hash
# value of the type descriptor is not initialized correctly, a deep copy
# can change the type hash.
dtype = np.dtype(bfloat16)
h = hash(dtype)
_ = copy.deepcopy(dtype)
self.assertEqual(h, hash(dtype))
def testArray(self):
x = np.array([[1, 2, 3]], dtype=bfloat16)
self.assertEqual(bfloat16, x.dtype)
self.assertEqual("[[1 2 3]]", str(x))
np.testing.assert_equal(x, x)
numpy_assert_allclose(x, x)
self.assertTrue((x == x).all())
def testComparisons(self):
x = np.array([401408, 7, -32], dtype=np.float32)
bx = x.astype(bfloat16)
y = np.array([82432, 7, 0], dtype=np.float32)
by = y.astype(bfloat16)
np.testing.assert_equal(x == y, bx == by)
np.testing.assert_equal(x != y, bx != by)
np.testing.assert_equal(x < y, bx < by)
np.testing.assert_equal(x > y, bx > by)
np.testing.assert_equal(x <= y, bx <= by)
np.testing.assert_equal(x >= y, bx >= by)
def testEqual2(self):
a = np.array([401408], bfloat16)
b = np.array([82432], bfloat16)
self.assertFalse(a.__eq__(b))
def testCanCast(self):
allowed_casts = [
(np.bool_, bfloat16),
(np.int8, bfloat16),
(np.uint8, bfloat16),
(bfloat16, np.float32),
(bfloat16, np.float64),
(bfloat16, np.longdouble),
(bfloat16, np.complex64),
(bfloat16, np.complex128),
(bfloat16, np.clongdouble),
]
all_dtypes = [
np.float16, np.float32, np.float64, np.longdouble, np.int8, np.int16,
np.int32, np.int64, np.complex64, np.complex128, np.clongdouble,
np.uint8, np.uint16, np.uint32, np.uint64, np.intc, np.int_,
np.longlong, np.uintc, np.ulonglong
]
for d in all_dtypes:
self.assertEqual((bfloat16, d) in allowed_casts, np.can_cast(bfloat16, d))
self.assertEqual((d, bfloat16) in allowed_casts, np.can_cast(d, bfloat16))
def testCasts(self):
for dtype in [
np.float16, np.float32, np.float64, np.longdouble, np.int8, np.int16,
np.int32, np.int64, np.complex64, np.complex128, np.clongdouble,
np.uint8, np.uint16, np.uint32, np.uint64, np.intc, np.int_,
np.longlong, np.uintc, np.ulonglong
]:
x = np.array([[1, 2, 3]], dtype=dtype)
y = x.astype(bfloat16)
z = y.astype(dtype)
self.assertTrue(np.all(x == y))
self.assertEqual(bfloat16, y.dtype)
self.assertTrue(np.all(x == z))
self.assertEqual(dtype, z.dtype)
def testConformNumpyComplex(self):
for dtype in [np.complex64, np.complex128, np.clongdouble]:
x = np.array([1.1, 2.2 + 2.2j, 3.3], dtype=dtype)
y_np = x.astype(np.float32)
y_tf = x.astype(bfloat16)
numpy_assert_allclose(y_np, y_tf, atol=2e-2)
z_np = y_np.astype(dtype)
z_tf = y_tf.astype(dtype)
numpy_assert_allclose(z_np, z_tf, atol=2e-2)
def testArange(self):
np.testing.assert_equal(
np.arange(100, dtype=np.float32).astype(bfloat16),
np.arange(100, dtype=bfloat16))
np.testing.assert_equal(
np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16),
np.arange(-10.5, 7.8, 0.5, dtype=bfloat16))
np.testing.assert_equal(
np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16),
np.arange(-0., -7., -0.25, dtype=bfloat16))
np.testing.assert_equal(
np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16),
np.arange(-16384., 16384., 64., dtype=bfloat16))
# pylint: disable=g-complex-comprehension
@parameterized.named_parameters(({
"testcase_name": "_" + op.__name__,
"op": op
} for op in UNARY_UFUNCS))
def testUnaryUfunc(self, op):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7, 10).astype(bfloat16)
numpy_assert_allclose(
op(x).astype(np.float32), op(x.astype(np.float32)), rtol=1e-2)
@parameterized.named_parameters(({
"testcase_name": "_" + op.__name__,
"op": op
} for op in BINARY_UFUNCS))
def testBinaryUfunc(self, op):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7, 10).astype(bfloat16)
y = rng.randn(4, 1, 7, 10).astype(bfloat16)
numpy_assert_allclose(
op(x, y).astype(np.float32),
op(x.astype(np.float32), y.astype(np.float32)),
rtol=1e-2)
@parameterized.named_parameters(({
"testcase_name": "_" + op.__name__,
"op": op
} for op in BINARY_PREDICATE_UFUNCS))
def testBinaryPredicateUfunc(self, op):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
y = rng.randn(4, 1, 7).astype(bfloat16)
np.testing.assert_equal(
op(x, y), op(x.astype(np.float32), y.astype(np.float32)))
@parameterized.named_parameters(({
"testcase_name": "_" + op.__name__,
"op": op
} for op in [np.isfinite, np.isinf, np.isnan, np.signbit, np.logical_not]))
def testPredicateUfunc(self, op):
rng = np.random.RandomState(seed=42)
shape = (3, 7, 10)
posinf_flips = rng.rand(*shape) < 0.1
neginf_flips = rng.rand(*shape) < 0.1
nan_flips = rng.rand(*shape) < 0.1
vals = rng.randn(*shape)
vals = np.where(posinf_flips, np.inf, vals)
vals = np.where(neginf_flips, -np.inf, vals)
vals = np.where(nan_flips, np.nan, vals)
vals = vals.astype(bfloat16)
np.testing.assert_equal(op(vals), op(vals.astype(np.float32)))
def testDivmod(self):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
y = rng.randn(4, 1, 7).astype(bfloat16)
o1, o2 = np.divmod(x, y)
e1, e2 = np.divmod(x.astype(np.float32), y.astype(np.float32))
numpy_assert_allclose(o1, e1, rtol=1e-2)
numpy_assert_allclose(o2, e2, rtol=1e-2)
def testModf(self):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
o1, o2 = np.modf(x)
e1, e2 = np.modf(x.astype(np.float32))
numpy_assert_allclose(o1.astype(np.float32), e1, rtol=1e-2)
numpy_assert_allclose(o2.astype(np.float32), e2, rtol=1e-2)
def testLdexp(self):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
y = rng.randint(-50, 50, (1, 7))
numpy_assert_allclose(
np.ldexp(x, y).astype(np.float32),
np.ldexp(x.astype(np.float32), y),
rtol=1e-2,
atol=1e-6)
def testFrexp(self):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
mant1, exp1 = np.frexp(x)
mant2, exp2 = np.frexp(x.astype(np.float32))
np.testing.assert_equal(exp1, exp2)
numpy_assert_allclose(mant1, mant2, rtol=1e-2)
def testNextAfter(self):
one = np.array(1., dtype=bfloat16)
two = np.array(2., dtype=bfloat16)
zero = np.array(0., dtype=bfloat16)
nan = np.array(np.nan, dtype=bfloat16)
np.testing.assert_equal(np.nextafter(one, two) - one, epsilon)
np.testing.assert_equal(np.nextafter(one, zero) - one, -epsilon / 2)
np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True)
np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True)
np.testing.assert_equal(np.nextafter(one, one), one)
smallest_denormal = float.fromhex("1.0p-133")
np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal)
np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal)
for a, b in itertools.permutations([0., -0., nan], 2):
np.testing.assert_equal(
np.nextafter(
np.array(a, dtype=np.float32), np.array(b, dtype=np.float32)),
np.nextafter(
np.array(a, dtype=bfloat16), np.array(b, dtype=bfloat16)))
def testSpacing(self):
# Sweep a variety of binades to see that spacing gives the proper ULP.
# All subnormals have a fixed distance of 2^-133.
with self.subTest(name="Subnormals"):
for i in range(-133, -126):
power_of_two = bfloat16(2.0**i)
distance = float.fromhex("0x1p-133")
np.testing.assert_equal(np.spacing(power_of_two), distance)
np.testing.assert_equal(np.spacing(-power_of_two), -distance)
# Normals have a distance which depends on their binade.
with self.subTest(name="Normals"):
for i in range(-126, 127):
power_of_two = bfloat16(2.0**i)
distance = epsilon * power_of_two
np.testing.assert_equal(np.spacing(power_of_two), distance)
np.testing.assert_equal(np.spacing(-power_of_two), -distance)
inf = bfloat16(float("inf"))
nan = bfloat16(float("nan"))
# Check that spacing agrees with arithmetic involving nextafter.
with self.subTest(name="NextAfter"):
for x in FLOAT_VALUES:
x_bfloat16 = bfloat16(x)
spacing = np.spacing(x_bfloat16)
toward = np.copysign(inf, x_bfloat16)
nextup = np.nextafter(x_bfloat16, toward)
np.testing.assert_equal(spacing, nextup - x_bfloat16)
# Check that spacing for special values gives the correct answer.
with self.subTest(name="NonFinite"):
np.testing.assert_equal(np.spacing(nan), np.spacing(np.float32(nan)))
np.testing.assert_equal(np.spacing(inf), np.spacing(np.float32(inf)))
if __name__ == "__main__":
absltest.main()