Set as much partial shape as we can infer statically within the gradient impl of the gather op.
PiperOrigin-RevId: 279964455
Change-Id: I82610bdac0b6affe848f2b7db957bb806439d1e0
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc
index cec7ca6..807e105 100644
--- a/tensorflow/core/common_runtime/shape_refiner.cc
+++ b/tensorflow/core/common_runtime/shape_refiner.cc
@@ -463,6 +463,40 @@
// Source tensor is a vector of length 0, so the shape it
// represents is as scalar.
*result = target_context->Scalar();
+ } else if (src_op == "Cast") {
+ // First try to evaluate the current tensor, as it might be a valid cast of
+ // a float.
+ Tensor t;
+ bool evaluated = false;
+ if (EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t).ok()) {
+ if (evaluated &&
+ target_context->MakeShapeFromTensor(&t, src_shape, result).ok()) {
+ return Status::OK();
+ }
+ }
+
+ // Then try to infer partial shape from the input to the cast tensor.
+ ShapeHandle pre_cast_shape;
+ if (!ConstantPartialShape(target_context, input_edge->src(), 0,
+ &pre_cast_shape)
+ .ok()) {
+ TF_RETURN_IF_ERROR(
+ target_context->MakeShapeFromTensor(nullptr, src_shape, result));
+ }
+ if (!target_context->RankKnown(pre_cast_shape)) {
+ // Failed to evaluate. Treat the output as completely unknown.
+ *result = target_context->UnknownShape();
+ return Status::OK();
+ }
+ auto* dest_type = input_edge->src()->attrs().Find("DstT");
+ if (dest_type == nullptr || dest_type->value_case() != AttrValue::kType ||
+ (dest_type->type() != DT_INT32 && dest_type->type() != DT_INT64)) {
+ // Casting to a weird type. Do not attempt to infer across it.
+ *result = target_context->MakeShape(std::vector<DimensionHandle>(
+ target_context->Rank(pre_cast_shape), target_context->UnknownDim()));
+ return Status::OK();
+ }
+ *result = pre_cast_shape;
} else if (src_op == "Shape") {
*result = src_context->input(0);
} else if (src_op == "ShapeN") {
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index f4f3fa7..74517c7 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -6684,6 +6684,24 @@
tags = ["no_windows"],
)
+py_test(
+ name = "ops/array_ops_test",
+ srcs = ["ops/array_ops_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":array_ops",
+ ":client_testlib",
+ ":constant_op",
+ ":dtypes",
+ ":framework_ops",
+ ":framework_test_lib",
+ ":gradients",
+ ":math_ops",
+ ":random_ops",
+ ],
+)
+
cuda_py_test(
name = "accumulate_n_benchmark",
size = "medium",
diff --git a/tensorflow/python/ops/array_ops_test.py b/tensorflow/python/ops/array_ops_test.py
new file mode 100644
index 0000000..1bdd9d7
--- /dev/null
+++ b/tensorflow/python/ops/array_ops_test.py
@@ -0,0 +1,77 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for array operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+
+class ArrayOpTest(test.TestCase):
+
+ @test_util.deprecated_graph_mode_only
+ def testGatherGradHasPartialStaticShape(self):
+ # Create a tensor with an unknown dim 1.
+ x = random_ops.random_normal([4, 10, 10])
+ x = array_ops.gather(
+ x,
+ array_ops.reshape(array_ops.where_v2(x[0, :, 0] > 0.5), [-1]),
+ axis=1)
+ self.assertAllEqual(x.shape.as_list(), [4, None, 10])
+
+ a = array_ops.gather(array_ops.gather(x, [0, 1]), [0, 1])
+ b = array_ops.gather(array_ops.gather(x, [2, 3], axis=2), [0, 1])
+ grad_a = ops.convert_to_tensor(gradients.gradients(a, x)[0])
+ grad_b = ops.convert_to_tensor(gradients.gradients(b, x)[0])
+
+ # We make sure that the representation of the shapes are correct; the shape
+ # equality check will always eval to false due to the shapes being partial.
+ self.assertAllEqual(grad_a.shape.as_list(), [None, None, 10])
+ self.assertAllEqual(grad_b.shape.as_list(), [4, None, 10])
+
+ @test_util.deprecated_graph_mode_only
+ def testReshapeShapeInference(self):
+ # Create a tensor with an unknown dim 1.
+ x = random_ops.random_normal([4, 10, 10])
+ x = array_ops.gather(
+ x,
+ array_ops.reshape(array_ops.where_v2(x[0, :, 0] > 0.5), [-1]),
+ axis=1)
+ self.assertAllEqual(x.shape.as_list(), [4, None, 10])
+ a = array_ops.reshape(x, array_ops.shape(x))
+ self.assertAllEqual(a.shape.as_list(), [4, None, 10])
+ b = array_ops.reshape(x, math_ops.cast(array_ops.shape(x), dtypes.int64))
+ self.assertAllEqual(b.shape.as_list(), [4, None, 10])
+
+ # We do not shape-infer across a tf.cast into anything that's not tf.int32
+ # or tf.int64, since they might end up mangling the shape.
+ c = array_ops.reshape(
+ x,
+ math_ops.cast(
+ math_ops.cast(array_ops.shape(x), dtypes.float32), dtypes.int32))
+ self.assertAllEqual(c.shape.as_list(), [None, None, None])
+
+
+if __name__ == "__main__":
+ test.main()