tf.numpy: minor fixes. Change trax to use tf's numpy code.
PiperOrigin-RevId: 314968872
Change-Id: Ib8682cc9fdf2c8d6626dafe05f410ded1c9ea3d9
diff --git a/tensorflow/python/ops/numpy_ops/BUILD b/tensorflow/python/ops/numpy_ops/BUILD
index e3a1ec2..28ea2b6 100644
--- a/tensorflow/python/ops/numpy_ops/BUILD
+++ b/tensorflow/python/ops/numpy_ops/BUILD
@@ -30,6 +30,7 @@
"//tensorflow/python:framework_ops",
"//tensorflow/python:indexed_slices",
"//tensorflow/python:linalg_ops",
+ "//tensorflow/python:manip_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:sort_ops",
"//tensorflow/python:tensor_util",
diff --git a/tensorflow/python/ops/numpy_ops/__init__.py b/tensorflow/python/ops/numpy_ops/__init__.py
index f18d75a..383206a 100644
--- a/tensorflow/python/ops/numpy_ops/__init__.py
+++ b/tensorflow/python/ops/numpy_ops/__init__.py
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Tensorflow numpy API."""
+# pylint: disable=g-direct-tensorflow-import
from __future__ import absolute_import
from __future__ import division
@@ -21,8 +22,14 @@
from tensorflow.python.ops.numpy_ops import np_random as random
# pylint: disable=wildcard-import
+
from tensorflow.python.ops.numpy_ops.np_array_ops import *
+# TODO(wangpeng): Move ShardedNdArray, convert_to_tensor, tensor_to_ndarray out
+# of here.
+from tensorflow.python.ops.numpy_ops.np_arrays import convert_to_tensor
from tensorflow.python.ops.numpy_ops.np_arrays import ndarray
+from tensorflow.python.ops.numpy_ops.np_arrays import ShardedNdArray
+from tensorflow.python.ops.numpy_ops.np_arrays import tensor_to_ndarray
from tensorflow.python.ops.numpy_ops.np_dtypes import *
from tensorflow.python.ops.numpy_ops.np_math_ops import *
from tensorflow.python.ops.numpy_ops.np_utils import finfo
diff --git a/tensorflow/python/ops/numpy_ops/np_array_ops.py b/tensorflow/python/ops/numpy_ops/np_array_ops.py
index aba7ce4..1616f8b 100644
--- a/tensorflow/python/ops/numpy_ops/np_array_ops.py
+++ b/tensorflow/python/ops/numpy_ops/np_array_ops.py
@@ -31,6 +31,7 @@
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import manip_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sort_ops
from tensorflow.python.ops.numpy_ops import np_arrays
@@ -881,7 +882,7 @@
else:
# Use float as the working dtype when a.dtype is exact (e.g. integer),
# because `decimals` can be negative.
- float_dtype = dtypes.default_float_type()
+ float_dtype = np_dtypes.default_float_type()
a = a.astype(float_dtype).data
factor = math_ops.cast(factor, float_dtype)
a = math_ops.multiply(a, factor)
@@ -1109,7 +1110,7 @@
setattr(np_arrays.ndarray, 'transpose', transpose)
-setattr(np_arrays.ndarray, 'reshape', reshape)
+setattr(np_arrays.ndarray, 'reshape', _reshape_method_wrapper)
setattr(np_arrays.ndarray, '__setitem__', _setitem)
@@ -1518,11 +1519,11 @@
a = asarray(a).data
if axis is not None:
- return np_utils.tensor_to_ndarray(array_ops.roll(a, shift, axis))
+ return np_utils.tensor_to_ndarray(manip_ops.roll(a, shift, axis))
# If axis is None, the roll happens as a 1-d tensor.
original_shape = array_ops.shape(a)
- a = array_ops.roll(array_ops.reshape(a, [-1]), shift, 0)
+ a = manip_ops.roll(array_ops.reshape(a, [-1]), shift, 0)
return np_utils.tensor_to_ndarray(array_ops.reshape(a, original_shape))
@@ -1538,7 +1539,7 @@
return flip(flip(m, ax1), ax2)
else:
perm = math_ops.range(m_rank)
- perm = array_ops.tensor_scatter_nd_update(perm, [[ax1], [ax2]], [ax2, ax1])
+ perm = array_ops.tensor_scatter_update(perm, [[ax1], [ax2]], [ax2, ax1])
if k == 1:
return transpose(flip(m, ax2), perm)
diff --git a/tensorflow/python/ops/numpy_ops/np_utils.py b/tensorflow/python/ops/numpy_ops/np_utils.py
index 598a814..0fa3ab8 100644
--- a/tensorflow/python/ops/numpy_ops/np_utils.py
+++ b/tensorflow/python/ops/numpy_ops/np_utils.py
@@ -36,6 +36,22 @@
tensor_to_ndarray = np_arrays.tensor_to_ndarray
+def _canonicalize_axis(axis, rank):
+ return _canonicalize_axes([axis], rank)[0]
+
+
+def _canonicalize_axes(axes, rank):
+ rank = _maybe_static(rank)
+
+ if isinstance(rank, ops.Tensor):
+ canonicalizer = (
+ lambda axis: cond(axis < 0, lambda: axis + rank, lambda: axis))
+ else:
+ canonicalizer = lambda axis: axis+rank if axis < 0 else axis
+
+ return [canonicalizer(axis) for axis in axes]
+
+
def _supports_signature():
return hasattr(inspect, 'signature')