Fixes `assertAllEqual()` function in framework/test_util.py such that the function has the originally intended behavior without breaking PY3 compatibility.
PiperOrigin-RevId: 313902146
Change-Id: I3f9337ee4b58fdeb01fc08d3f49cbec7d3022d3e
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index aa52bbd..1adec3d 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -2713,26 +2713,8 @@
x, y = a, b
msgs.append("not equal lhs = %r" % x)
msgs.append("not equal rhs = %r" % y)
-
- # Handle mixed string types as a result of PY2to3 migration. That is, the
- # mixing between bytes (b-prefix strings, PY2 default) and unicodes
- # (u-prefix strings, PY3 default).
- if six.PY3:
- if (a.dtype.kind != b.dtype.kind and
- {a.dtype.kind, b.dtype.kind}.issubset({"U", "S", "O"})):
- a_list = []
- b_list = []
- # OK to flatten `a` and `b` because they are guaranteed to have the
- # same shape.
- for out_list, flat_arr in [(a_list, a.flat), (b_list, b.flat)]:
- for item in flat_arr:
- if isinstance(item, str):
- out_list.append(item.encode("utf-8"))
- else:
- out_list.append(item)
- a = np.array(a_list)
- b = np.array(b_list)
-
+ # With Python 3, we need to make sure the dtype matches between a and b.
+ b = b.astype(a.dtype)
np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs))
@py_func_if_in_function