GetItem should return the python `float` type instead of `bfloat16` to match numpy convention.
PiperOrigin-RevId: 451537539
diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc
index 3aa5233..ab7e2f1 100644
--- a/tensorflow/python/lib/core/bfloat16.cc
+++ b/tensorflow/python/lib/core/bfloat16.cc
@@ -429,7 +429,7 @@
// character is unique.
/*type=*/'E',
/*byteorder=*/'=',
- /*flags=*/NPY_NEEDS_PYAPI | NPY_USE_SETITEM,
+ /*flags=*/NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM,
/*type_num=*/0,
/*elsize=*/sizeof(bfloat16),
/*alignment=*/alignof(bfloat16),
@@ -447,7 +447,7 @@
PyObject* NPyBfloat16_GetItem(void* data, void* arr) {
bfloat16 x;
memcpy(&x, data, sizeof(bfloat16));
- return PyFloat_FromDouble(static_cast<float>(x));
+ return PyBfloat16_FromBfloat16(x).release();
}
int NPyBfloat16_SetItem(PyObject* item, void* data, void* arr) {
diff --git a/tensorflow/python/lib/core/bfloat16_test.py b/tensorflow/python/lib/core/bfloat16_test.py
index 53ac95b..73ebe5c 100644
--- a/tensorflow/python/lib/core/bfloat16_test.py
+++ b/tensorflow/python/lib/core/bfloat16_test.py
@@ -115,9 +115,6 @@
self.assertEqual("-inf", repr(bfloat16(float("-inf"))))
self.assertEqual("nan", repr(bfloat16(float("nan"))))
- def testItem(self):
- self.assertIsInstance(bfloat16(0).item(), float)
-
def testHashZero(self):
"""Tests that negative zero and zero hash to the same value."""
self.assertEqual(hash(bfloat16(-0.0)), hash(bfloat16(0.0)))