tf.numpy: minor fixes. Change trax to use tf's numpy code.

PiperOrigin-RevId: 314968872
Change-Id: Ib8682cc9fdf2c8d6626dafe05f410ded1c9ea3d9
diff --git a/tensorflow/python/ops/numpy_ops/BUILD b/tensorflow/python/ops/numpy_ops/BUILD
index e3a1ec2..28ea2b6 100644
--- a/tensorflow/python/ops/numpy_ops/BUILD
+++ b/tensorflow/python/ops/numpy_ops/BUILD
@@ -30,6 +30,7 @@
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:indexed_slices",
         "//tensorflow/python:linalg_ops",
+        "//tensorflow/python:manip_ops",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:sort_ops",
         "//tensorflow/python:tensor_util",
diff --git a/tensorflow/python/ops/numpy_ops/__init__.py b/tensorflow/python/ops/numpy_ops/__init__.py
index f18d75a..383206a 100644
--- a/tensorflow/python/ops/numpy_ops/__init__.py
+++ b/tensorflow/python/ops/numpy_ops/__init__.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 # ==============================================================================
 """Tensorflow numpy API."""
+# pylint: disable=g-direct-tensorflow-import
 
 from __future__ import absolute_import
 from __future__ import division
@@ -21,8 +22,14 @@
 from tensorflow.python.ops.numpy_ops import np_random as random
 
 # pylint: disable=wildcard-import
+
 from tensorflow.python.ops.numpy_ops.np_array_ops import *
+# TODO(wangpeng): Move ShardedNdArray, convert_to_tensor, tensor_to_ndarray out
+# of here.
+from tensorflow.python.ops.numpy_ops.np_arrays import convert_to_tensor
 from tensorflow.python.ops.numpy_ops.np_arrays import ndarray
+from tensorflow.python.ops.numpy_ops.np_arrays import ShardedNdArray
+from tensorflow.python.ops.numpy_ops.np_arrays import tensor_to_ndarray
 from tensorflow.python.ops.numpy_ops.np_dtypes import *
 from tensorflow.python.ops.numpy_ops.np_math_ops import *
 from tensorflow.python.ops.numpy_ops.np_utils import finfo
diff --git a/tensorflow/python/ops/numpy_ops/np_array_ops.py b/tensorflow/python/ops/numpy_ops/np_array_ops.py
index aba7ce4..1616f8b 100644
--- a/tensorflow/python/ops/numpy_ops/np_array_ops.py
+++ b/tensorflow/python/ops/numpy_ops/np_array_ops.py
@@ -31,6 +31,7 @@
 from tensorflow.python.ops import clip_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import manip_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import sort_ops
 from tensorflow.python.ops.numpy_ops import np_arrays
@@ -881,7 +882,7 @@
   else:
     # Use float as the working dtype when a.dtype is exact (e.g. integer),
     # because `decimals` can be negative.
-    float_dtype = dtypes.default_float_type()
+    float_dtype = np_dtypes.default_float_type()
     a = a.astype(float_dtype).data
     factor = math_ops.cast(factor, float_dtype)
   a = math_ops.multiply(a, factor)
@@ -1109,7 +1110,7 @@
 
 
 setattr(np_arrays.ndarray, 'transpose', transpose)
-setattr(np_arrays.ndarray, 'reshape', reshape)
+setattr(np_arrays.ndarray, 'reshape', _reshape_method_wrapper)
 setattr(np_arrays.ndarray, '__setitem__', _setitem)
 
 
@@ -1518,11 +1519,11 @@
   a = asarray(a).data
 
   if axis is not None:
-    return np_utils.tensor_to_ndarray(array_ops.roll(a, shift, axis))
+    return np_utils.tensor_to_ndarray(manip_ops.roll(a, shift, axis))
 
   # If axis is None, the roll happens as a 1-d tensor.
   original_shape = array_ops.shape(a)
-  a = array_ops.roll(array_ops.reshape(a, [-1]), shift, 0)
+  a = manip_ops.roll(array_ops.reshape(a, [-1]), shift, 0)
   return np_utils.tensor_to_ndarray(array_ops.reshape(a, original_shape))
 
 
@@ -1538,7 +1539,7 @@
     return flip(flip(m, ax1), ax2)
   else:
     perm = math_ops.range(m_rank)
-    perm = array_ops.tensor_scatter_nd_update(perm, [[ax1], [ax2]], [ax2, ax1])
+    perm = array_ops.tensor_scatter_update(perm, [[ax1], [ax2]], [ax2, ax1])
 
     if k == 1:
       return transpose(flip(m, ax2), perm)
diff --git a/tensorflow/python/ops/numpy_ops/np_utils.py b/tensorflow/python/ops/numpy_ops/np_utils.py
index 598a814..0fa3ab8 100644
--- a/tensorflow/python/ops/numpy_ops/np_utils.py
+++ b/tensorflow/python/ops/numpy_ops/np_utils.py
@@ -36,6 +36,22 @@
 tensor_to_ndarray = np_arrays.tensor_to_ndarray
 
 
+def _canonicalize_axis(axis, rank):
+  return _canonicalize_axes([axis], rank)[0]
+
+
+def _canonicalize_axes(axes, rank):
+  rank = _maybe_static(rank)
+
+  if isinstance(rank, ops.Tensor):
+    canonicalizer = (
+        lambda axis: cond(axis < 0, lambda: axis + rank, lambda: axis))
+  else:
+    canonicalizer = lambda axis: axis+rank if axis < 0 else axis
+
+  return [canonicalizer(axis) for axis in axes]
+
+
 def _supports_signature():
   return hasattr(inspect, 'signature')