Set as much partial shape as we can infer statically within the gradient impl of the gather op.

PiperOrigin-RevId: 279964455
Change-Id: I82610bdac0b6affe848f2b7db957bb806439d1e0
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc
index cec7ca6..807e105 100644
--- a/tensorflow/core/common_runtime/shape_refiner.cc
+++ b/tensorflow/core/common_runtime/shape_refiner.cc
@@ -463,6 +463,40 @@
     // Source tensor is a vector of length 0, so the shape it
     // represents is as scalar.
     *result = target_context->Scalar();
+  } else if (src_op == "Cast") {
+    // First try to evaluate the current tensor, as it might be a valid cast of
+    // a float.
+    Tensor t;
+    bool evaluated = false;
+    if (EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t).ok()) {
+      if (evaluated &&
+          target_context->MakeShapeFromTensor(&t, src_shape, result).ok()) {
+        return Status::OK();
+      }
+    }
+
+    // Then try to infer partial shape from the input to the cast tensor.
+    ShapeHandle pre_cast_shape;
+    if (!ConstantPartialShape(target_context, input_edge->src(), 0,
+                              &pre_cast_shape)
+             .ok()) {
+      TF_RETURN_IF_ERROR(
+          target_context->MakeShapeFromTensor(nullptr, src_shape, result));
+    }
+    if (!target_context->RankKnown(pre_cast_shape)) {
+      // Failed to evaluate. Treat the output as completely unknown.
+      *result = target_context->UnknownShape();
+      return Status::OK();
+    }
+    auto* dest_type = input_edge->src()->attrs().Find("DstT");
+    if (dest_type == nullptr || dest_type->value_case() != AttrValue::kType ||
+        (dest_type->type() != DT_INT32 && dest_type->type() != DT_INT64)) {
+      // Casting to a weird type. Do not attempt to infer across it.
+      *result = target_context->MakeShape(std::vector<DimensionHandle>(
+          target_context->Rank(pre_cast_shape), target_context->UnknownDim()));
+      return Status::OK();
+    }
+    *result = pre_cast_shape;
   } else if (src_op == "Shape") {
     *result = src_context->input(0);
   } else if (src_op == "ShapeN") {
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index f4f3fa7..74517c7 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -6684,6 +6684,24 @@
     tags = ["no_windows"],
 )
 
+py_test(
+    name = "ops/array_ops_test",
+    srcs = ["ops/array_ops_test.py"],
+    python_version = "PY3",
+    srcs_version = "PY2AND3",
+    deps = [
+        ":array_ops",
+        ":client_testlib",
+        ":constant_op",
+        ":dtypes",
+        ":framework_ops",
+        ":framework_test_lib",
+        ":gradients",
+        ":math_ops",
+        ":random_ops",
+    ],
+)
+
 cuda_py_test(
     name = "accumulate_n_benchmark",
     size = "medium",
diff --git a/tensorflow/python/ops/array_ops_test.py b/tensorflow/python/ops/array_ops_test.py
new file mode 100644
index 0000000..1bdd9d7
--- /dev/null
+++ b/tensorflow/python/ops/array_ops_test.py
@@ -0,0 +1,77 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for array operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+
+class ArrayOpTest(test.TestCase):
+
+  @test_util.deprecated_graph_mode_only
+  def testGatherGradHasPartialStaticShape(self):
+    # Create a tensor with an unknown dim 1.
+    x = random_ops.random_normal([4, 10, 10])
+    x = array_ops.gather(
+        x,
+        array_ops.reshape(array_ops.where_v2(x[0, :, 0] > 0.5), [-1]),
+        axis=1)
+    self.assertAllEqual(x.shape.as_list(), [4, None, 10])
+
+    a = array_ops.gather(array_ops.gather(x, [0, 1]), [0, 1])
+    b = array_ops.gather(array_ops.gather(x, [2, 3], axis=2), [0, 1])
+    grad_a = ops.convert_to_tensor(gradients.gradients(a, x)[0])
+    grad_b = ops.convert_to_tensor(gradients.gradients(b, x)[0])
+
+    # We make sure that the representation of the shapes are correct; the shape
+    # equality check will always eval to false due to the shapes being partial.
+    self.assertAllEqual(grad_a.shape.as_list(), [None, None, 10])
+    self.assertAllEqual(grad_b.shape.as_list(), [4, None, 10])
+
+  @test_util.deprecated_graph_mode_only
+  def testReshapeShapeInference(self):
+    # Create a tensor with an unknown dim 1.
+    x = random_ops.random_normal([4, 10, 10])
+    x = array_ops.gather(
+        x,
+        array_ops.reshape(array_ops.where_v2(x[0, :, 0] > 0.5), [-1]),
+        axis=1)
+    self.assertAllEqual(x.shape.as_list(), [4, None, 10])
+    a = array_ops.reshape(x, array_ops.shape(x))
+    self.assertAllEqual(a.shape.as_list(), [4, None, 10])
+    b = array_ops.reshape(x, math_ops.cast(array_ops.shape(x), dtypes.int64))
+    self.assertAllEqual(b.shape.as_list(), [4, None, 10])
+
+    # We do not shape-infer across a tf.cast into anything that's not tf.int32
+    # or tf.int64, since they might end up mangling the shape.
+    c = array_ops.reshape(
+        x,
+        math_ops.cast(
+            math_ops.cast(array_ops.shape(x), dtypes.float32), dtypes.int32))
+    self.assertAllEqual(c.shape.as_list(), [None, None, None])
+
+
+if __name__ == "__main__":
+  test.main()