Roll-forward of: [Grappler] Extend the RemoveStackStridedSliceSameAxis optimization to support Slice.

`tf.slice(tf.pack([x, y, z]), ...)` will always have rank one greater than `x,
y, z`. The previous version had a bug where it would incorrectly return the
original `x, y, z` tensor directly, if insufficient shape information were
available. The fix is to explicitly track whether the rank should be expanded or not, based on the slicing operation used.

PiperOrigin-RevId: 297961546
Change-Id: I80b10bd0cd39e783f8a9a41e9d0d5a3dcc250b82
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index d256a66..3281f97 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -48,6 +48,7 @@
 #include "tensorflow/core/lib/hash/hash.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/tensor_coding.h"
 #include "tensorflow/core/util/device_name_utils.h"
 #include "tensorflow/core/util/saved_tensor_slice_util.h"
@@ -3095,6 +3096,7 @@
 //    x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i:i+1,...]
 // with
 //    expand_dims(a_i, axis=k)
+// where the slice operator can be StridedSlice or Slice.
 //
 // TODO(ebrevdo): Extend to also replace operations of the form
 //    concat((a_0, a_1, ..., ), axis=k)[:, ..., s_i:s_{i+1}, ...]
@@ -3103,17 +3105,16 @@
 // when
 //    s_i = cumsum(shape(a)[k] for a in (a_0, ...,))[i]
 // and slicing is in the k'th axis.
-class RemoveStackStridedSliceSameAxis : public ArithmeticOptimizerStage {
+class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage {
  public:
-  explicit RemoveStackStridedSliceSameAxis(
-      const GraphOptimizerContext& ctx,
-      const ArithmeticOptimizerContext& ctx_ext)
+  explicit RemoveStackSliceSameAxis(const GraphOptimizerContext& ctx,
+                                    const ArithmeticOptimizerContext& ctx_ext)
       : ArithmeticOptimizerStage("RemoveStackStridedSliceSameAxis", ctx,
                                  ctx_ext) {}
-  ~RemoveStackStridedSliceSameAxis() override = default;
+  ~RemoveStackSliceSameAxis() override = default;
 
   bool IsSupported(const NodeDef* node) const override {
-    return IsStridedSlice(*node);
+    return IsStridedSlice(*node) || IsSlice(*node);
   }
 
   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
@@ -3131,14 +3132,16 @@
         CheckInputs(node, pack, &pack_output_shape, &pack_axis, &return_early));
     if (return_early) return Status::OK();
 
-    int slice_start_value;
+    int64 slice_start_value;
     bool found;
+    bool must_expand_dims;
     TF_RETURN_IF_ERROR(GetSliceAxis(node, pack, pack_output_shape, pack_axis,
-                                    &slice_start_value, &found));
+                                    &slice_start_value, &found,
+                                    &must_expand_dims));
     if (!found) return Status::OK();
 
     return RewriteGraph(node, pack, slice_start_value, pack_axis,
-                        simplified_node_name);
+                        must_expand_dims, simplified_node_name);
   }
 
  protected:
@@ -3171,8 +3174,113 @@
 
   Status GetSliceAxis(const NodeDef* node, const NodeDef* pack,
                       const PartialTensorShape& pack_output_shape,
-                      int pack_axis, int* slice_start_value, bool* found) {
+                      int pack_axis, int64* slice_start_value, bool* found,
+                      bool* must_expand_dims) {
     *found = false;
+    if (IsSlice(*node)) {
+      *must_expand_dims = true;
+      return GetSimpleSliceAxis(node, pack, pack_output_shape, pack_axis,
+                                slice_start_value, found);
+    } else {
+      return GetStridedSliceAxis(node, pack, pack_output_shape, pack_axis,
+                                 slice_start_value, found, must_expand_dims);
+    }
+  }
+
+  Status GetSimpleSliceAxis(const NodeDef* node, const NodeDef* pack,
+                            const PartialTensorShape& pack_output_shape,
+                            int pack_axis, int64* slice_start_value,
+                            bool* found) {
+    NodeDef* slice_begin;
+    NodeDef* slice_size;
+    TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &slice_begin));
+    TF_RETURN_IF_ERROR(GetInputNode(node->input(2), &slice_size));
+    for (const auto* n : {slice_begin, slice_size}) {
+      if (!IsReallyConstant(*n)) return Status::OK();
+    }
+
+    Tensor slice_begin_t;
+    Tensor slice_size_t;
+    TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value"));
+    if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) {
+      return Status::OK();
+    }
+    TF_RETURN_IF_ERROR(CheckAttrExists(*slice_size, "value"));
+    if (!slice_size_t.FromProto(slice_size->attr().at("value").tensor())) {
+      return Status::OK();
+    }
+
+    auto copy_tensor_values_to_vector =
+        [node](const Tensor& t, gtl::InlinedVector<int64, 4>* vec) {
+          if (t.dtype() == DT_INT32) {
+            auto t_flat = t.flat<int32>();
+            vec->assign(&t_flat(0), &t_flat(t.NumElements()));
+          } else if (t.dtype() == DT_INT64) {
+            auto t_flat = t.flat<int64>();
+            vec->assign(&t_flat(0), &t_flat(t.NumElements()));
+          } else {
+            return errors::InvalidArgument("Node ", node->name(),
+                                           " has invalid type for Index attr: ",
+                                           DataTypeString(t.dtype()));
+          }
+          return Status::OK();
+        };
+
+    gtl::InlinedVector<int64, 4> slice_begin_vec;
+    gtl::InlinedVector<int64, 4> slice_size_vec;
+    TF_RETURN_IF_ERROR(
+        copy_tensor_values_to_vector(slice_begin_t, &slice_begin_vec));
+    TF_RETURN_IF_ERROR(
+        copy_tensor_values_to_vector(slice_size_t, &slice_size_vec));
+
+    if (slice_begin_vec.size() != slice_size_vec.size()) {
+      return errors::InvalidArgument("Node ", node->name(),
+                                     " has mismatched lengths for begin (",
+                                     slice_begin_vec.size(), ") and size (",
+                                     slice_size_vec.size(), ") vectors.");
+    }
+    if (!pack_output_shape.unknown_rank() &&
+        slice_begin_vec.size() != pack_output_shape.dims()) {
+      return Status::OK();
+    }
+    if (pack_axis >= slice_begin_vec.size()) {
+      return errors::InvalidArgument(
+          "Input to node ", node->name(), " had pack_axis ", pack_axis,
+          " but rank was ", slice_begin_vec.size(), ".");
+    }
+
+    *slice_start_value = slice_begin_vec[pack_axis];
+    if (slice_size_vec[pack_axis] != 1) {
+      // Not slicing a single value out.
+      return Status::OK();
+    }
+
+    for (size_t i = 0; i < slice_begin_vec.size(); ++i) {
+      if (i != pack_axis) {
+        if (slice_begin_vec[i] != 0 ||
+            !(slice_size_vec[i] == -1 ||
+              slice_size_vec[i] == pack_output_shape.dim_size(i))) {
+          // Not slicing on the same axis as the Pack op.
+          return Status::OK();
+        }
+      }
+    }
+
+    if (*slice_start_value < 0 || *slice_start_value >= pack->input_size()) {
+      return errors::InvalidArgument(
+          "Node ", node->name(), " requested invalid slice index ",
+          *slice_start_value, " on axis ", pack_axis,
+          " from tensor of shape: ", pack_output_shape.DebugString());
+    }
+
+    *found = true;  // slice_start_value is valid.
+    return Status::OK();
+  }
+
+  Status GetStridedSliceAxis(const NodeDef* node, const NodeDef* pack,
+                             const PartialTensorShape& pack_output_shape,
+                             int pack_axis, int64* slice_start_value,
+                             bool* found, bool* must_expand_dims) {
     TF_RETURN_IF_ERROR(
         CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask",
                                 "new_axis_mask", "shrink_axis_mask"}));
@@ -3286,13 +3394,22 @@
           " from tensor of shape: ", pack_output_shape.DebugString());
     }
 
+    if (shrink_axis_mask == 0) {
+      *must_expand_dims = true;
+    } else if (shrink_axis_mask == (1 << slice_axis)) {
+      *must_expand_dims = false;
+    } else {
+      // Shrinking on a different axis from the one that we are slicing on.
+      return Status::OK();
+    }
+
     *found = true;  // slice_start_value is valid.
     return Status::OK();
   }
 
   Status RewriteGraph(const NodeDef* node, const NodeDef* pack,
-                      int slice_start_value, int pack_axis,
-                      string* simplified_node_name) {
+                      int64 slice_start_value, int pack_axis,
+                      bool must_expand_dims, string* simplified_node_name) {
     const string& input_slice = pack->input(slice_start_value);
 
     const OpInfo::TensorProperties* input_slice_properties;
@@ -3306,7 +3423,7 @@
     PartialTensorShape output_shape(output_properties->shape());
     NodeDef* output =
         AddEmptyNode(OptimizedNodeName(ParseNodeScopeAndName(node->name())));
-    if (input_slice_shape.IsCompatibleWith(output_shape)) {
+    if (!must_expand_dims) {
       output->set_op("Identity");
       output->set_device(node->device());
       SetDataTypeToAttr(output_properties->dtype(), "T", output);
@@ -3449,7 +3566,6 @@
   if (node.device().find("SPU") != string::npos) {
     return false;
   }
-  // Workaround for Assert and Print mistakenly being labeled as stateful.
   if (IsAssert(node) || IsPrint(node)) {
     return true;
   }
@@ -3621,8 +3737,8 @@
     pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext);
   if (options_.unary_ops_composition)
     pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext);
-  if (options_.remove_stack_strided_slice_same_axis)
-    pipeline.AddStage<RemoveStackStridedSliceSameAxis>(ctx, ctx_ext);
+  if (options_.remove_stack_slice_same_axis)
+    pipeline.AddStage<RemoveStackSliceSameAxis>(ctx, ctx_ext);
   if (options_.fuse_squared_diff)
     pipeline.AddStage<FuseSquaredDiffStage>(ctx, ctx_ext);
 
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index e7c847c..76aca8b 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -84,7 +84,7 @@
     bool convert_log_softmax = true;
     bool convert_expm1 = true;
     bool unary_ops_composition = true;
-    bool remove_stack_strided_slice_same_axis = true;
+    bool remove_stack_slice_same_axis = true;
 
     // Choose which arithmetic optimizer stages will be enabled for a given
     // optimization level by default.
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 4c70343..a421daa 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
 
+#include "absl/strings/match.h"
 #include "absl/strings/str_cat.h"
 #include "tensorflow/cc/ops/array_ops.h"
 #include "tensorflow/cc/ops/math_ops.h"
@@ -4002,7 +4003,7 @@
 
   GraphDef output;
   ArithmeticOptimizer optimizer;
-  EnableOnlyRemoveStackStridedSliceSameAxis(&optimizer);
+  EnableOnlyRemoveStackSliceSameAxis(&optimizer);
   OptimizeAndPrune(&optimizer, &item, &output);
 
   for (const auto& node : output.node()) {
@@ -4050,6 +4051,122 @@
                                  tensors_expected[fExpandedC]);
 }
 
+TEST_F(ArithmeticOptimizerTest, RemoveStackSimpleSliceSameAxis) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  auto a_in =
+      ops::Const(s.WithOpName("a_in"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
+  auto b_in =
+      ops::Const(s.WithOpName("b_in"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
+  auto c_in =
+      ops::Const(s.WithOpName("c_in"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
+  auto a = ops::PlaceholderWithDefault(s.WithOpName("a"), a_in,
+                                       PartialTensorShape({-1, -1}));
+  auto b = ops::PlaceholderWithDefault(s.WithOpName("b"), b_in,
+                                       PartialTensorShape({-1, -1}));
+  auto c = ops::PlaceholderWithDefault(s.WithOpName("c"), c_in,
+                                       PartialTensorShape({-1, -1}));
+  // stacked = tf.stack((a, b, c), axis=1).
+  // stacked.shape == [2, 3, 2] (a, b, c are stacked along new axis 1)
+  auto stacked =
+      ops::Stack(s.WithOpName("stacked"), {a.output, b.output, c.output},
+                 ops::Stack::Axis(1));
+  auto expanded_a = ops::ExpandDims(s.WithOpName("expanded_a"), a, {1});
+  auto expanded_b = ops::ExpandDims(s.WithOpName("expanded_b"), b, {1});
+  auto expanded_c = ops::ExpandDims(s.WithOpName("expanded_c"), c, {1});
+  auto begin_a = ops::Const(s.WithOpName("begin_a"), {0, 0, 0}, {3});
+  auto begin_b = ops::Const(s.WithOpName("begin_b"), {0, 1, 0}, {3});
+  auto begin_c = ops::Const(s.WithOpName("begin_c"), {0, 2, 0}, {3});
+  auto sizes_to_end = ops::Const(s.WithOpName("size"), {-1, 1, -1}, {3});
+
+  // stacked[:, 0:1, :]
+  auto pa_slice = ops::Identity(
+      s.WithOpName("pa_slice_out"),
+      ops::Slice(s.WithOpName("pa_slice"), stacked, begin_a, sizes_to_end));
+
+  // stacked[:, 1:2, :]
+  auto pb_slice = ops::Identity(
+      s.WithOpName("pb_slice_out"),
+      ops::Slice(s.WithOpName("pb_slice"), stacked, begin_b, sizes_to_end));
+
+  // stacked[:, 2:3, :]
+  auto pc_slice = ops::Identity(
+      s.WithOpName("pc_slice_out"),
+      ops::Slice(s.WithOpName("pc_slice"), stacked, begin_c, sizes_to_end));
+
+  GrapplerItem item;
+  item.fetch = {"a",
+                "b",
+                "c",
+                "pa_slice_out",
+                "pb_slice_out",
+                "pc_slice_out",
+                "expanded_a",
+                "expanded_b",
+                "expanded_c"};
+  enum FetchItem {
+    fA,
+    fB,
+    fC,
+    fASliceOut,
+    fBSliceOut,
+    fCSliceOut,
+    fExpandedA,
+    fExpandedB,
+    fExpandedC,
+  };
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+  // stacked[:, 0:1, :] == a.
+  test::ExpectTensorEqual<float>(tensors_expected[fASliceOut],
+                                 tensors_expected[fExpandedA]);
+  // stacked[:, 1:2, :] == b.
+  test::ExpectTensorEqual<float>(tensors_expected[fBSliceOut],
+                                 tensors_expected[fExpandedB]);
+  // stacked[:, 2:3, :] == c.
+  test::ExpectTensorEqual<float>(tensors_expected[fCSliceOut],
+                                 tensors_expected[fExpandedC]);
+
+  GraphDef output;
+  ArithmeticOptimizer optimizer;
+  EnableOnlyRemoveStackSliceSameAxis(&optimizer);
+  OptimizeAndPrune(&optimizer, &item, &output);
+
+  const string kExpandDimsNamePrefix(
+      "ArithmeticOptimizer/RemoveStackStridedSliceSameAxis_p");
+
+  for (const auto& node : output.node()) {
+    if (node.name() == "pa_slice_out") {
+      ASSERT_EQ(node.input_size(), 1);
+      EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "a_slice"));
+    } else if (node.name() == "pb_slice_out") {
+      ASSERT_EQ(node.input_size(), 1);
+      EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "b_slice"));
+    } else if (node.name() == "pc_slice_out") {
+      ASSERT_EQ(node.input_size(), 1);
+      EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "c_slice"));
+    } else if (absl::StartsWith(node.name(), kExpandDimsNamePrefix)) {
+      EXPECT_EQ(node.op(), "ExpandDims");
+      // The input is "a", "b", or "c", as appropriate.
+      EXPECT_EQ(node.input(0),
+                node.name().substr(kExpandDimsNamePrefix.size(), 1));
+    }
+  }
+
+  auto tensors = EvaluateNodes(output, item.fetch);
+
+  // stacked[:, 0:1, :] == a.
+  test::ExpectTensorEqual<float>(tensors[fASliceOut],
+                                 tensors_expected[fExpandedA]);
+
+  // stacked[:, 1:2, :] == b.
+  test::ExpectTensorEqual<float>(tensors[fBSliceOut],
+                                 tensors_expected[fExpandedB]);
+  // stacked[:, 2:3, :] == c.
+  test::ExpectTensorEqual<float>(tensors[fCSliceOut],
+                                 tensors_expected[fExpandedC]);
+}
+
 TEST_F(ArithmeticOptimizerTest, SimplifyAggregationBFloat16) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
   Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h
index d3ad437..4d3ba97 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h
@@ -202,10 +202,9 @@
     optimizer->options_.unary_ops_composition = true;
   }
 
-  void EnableOnlyRemoveStackStridedSliceSameAxis(
-      ArithmeticOptimizer* optimizer) {
+  void EnableOnlyRemoveStackSliceSameAxis(ArithmeticOptimizer* optimizer) {
     DisableAllStages(optimizer);
-    optimizer->options_.remove_stack_strided_slice_same_axis = true;
+    optimizer->options_.remove_stack_slice_same_axis = true;
   }
 
  private: