[TF2XLA] Support dynamic slice size in strided slice op.

- Add two side outputs in ValidateStridedSliceOp to help analyze dynamic dimensions.
- Correctly set strided slice op's dynamic size if the slice size (slice end) is dynamic

PiperOrigin-RevId: 327552472
Change-Id: Ia85e7bc377c432e5032f49278754659452ec9f86
diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
index 807c061..d7a8e67 100644
--- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
@@ -16,7 +16,6 @@
 #include "tensorflow/compiler/tf2xla/lib/broadcast.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/types.h"
 
@@ -29,26 +28,13 @@
       : XlaOpKernel(context) {}
 
   void Compile(XlaOpKernelContext* context) override {
+    const TensorShape input_shape = context->InputShape(0);
     TensorShape output_shape;
     OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape));
-    auto output_status_or =
-        BroadcastTo(context->Input(0), output_shape.dim_sizes());
-    OP_REQUIRES_OK(context, output_status_or.status());
-    auto output = output_status_or.ValueOrDie();
-    std::vector<bool> dynamic_dims;
-    OP_REQUIRES_OK(
-        context, context->ResolveInputDynamismIntoPredVector(1, &dynamic_dims));
-    for (int64 dim = 0; dim < dynamic_dims.size(); ++dim) {
-      if (dynamic_dims[dim]) {
-        output = xla::SetDimensionSize(
-            output,
-            xla::Reshape(xla::Slice(context->Input(1), {dim}, {dim + 1}, {1}),
-                         {}),
-            dim);
-      }
-    }
 
-    context->SetOutput(0, output);
+    auto output = BroadcastTo(context->Input(0), output_shape.dim_sizes());
+    OP_REQUIRES_OK(context, output.status());
+    context->SetOutput(0, output.ValueOrDie());
   }
 };
 
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index 72cb746..784b790 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -15,9 +15,6 @@
 
 #include "tensorflow/core/util/strided_slice_op.h"
 
-#include <vector>
-
-#include "absl/algorithm/container.h"
 #include "absl/types/span.h"
 #include "tensorflow/compiler/tf2xla/literal_util.h"
 #include "tensorflow/compiler/tf2xla/type_util.h"
@@ -26,7 +23,6 @@
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/compiler/xla/client/lib/constants.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
-#include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/ops_util.h"
 #include "tensorflow/core/framework/register_types.h"
@@ -37,7 +33,6 @@
 
 namespace tensorflow {
 namespace {
-using errors::InvalidArgument;
 
 class StridedSliceOp : public XlaOpKernel {
  public:
@@ -53,7 +48,7 @@
   void Compile(XlaOpKernelContext* ctx) override {
     const TensorShape input_shape = ctx->InputShape(0);
     const TensorShape begin_shape = ctx->InputShape("begin");
-    VLOG(0) << "strided slice";
+
     OP_REQUIRES(
         ctx, begin_shape.dims() == 1,
         errors::InvalidArgument("'begin' input has to be a rank 1 vector"));
@@ -83,24 +78,20 @@
     TensorShape final_shape;
     PartialTensorShape dummy_processing_shape, partial_final_shape;
     bool dummy = false;
-    absl::InlinedVector<int64, 4> output_to_sparse_mapping;
-    absl::InlinedVector<int64, 4> output_to_processing_mapping;
-    OP_REQUIRES_OK(
-        ctx,
-        ValidateStridedSliceOp(
-            begin_is_constant ? &begin_tensor : nullptr,
-            end_is_constant ? &end_tensor : nullptr, strides_tensor,
-            input_shape, begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
-            shrink_axis_mask_, &dummy_processing_shape, &partial_final_shape,
-            &dummy, &dummy, &dummy, &begin, &end, &strides,
-            &output_to_sparse_mapping, &output_to_processing_mapping));
+    OP_REQUIRES_OK(ctx, ValidateStridedSliceOp(
+                            begin_is_constant ? &begin_tensor : nullptr,
+                            end_is_constant ? &end_tensor : nullptr,
+                            strides_tensor, input_shape, begin_mask_, end_mask_,
+                            ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
+                            &dummy_processing_shape, &partial_final_shape,
+                            &dummy, &dummy, &dummy, &begin, &end, &strides));
 
-    OP_REQUIRES(
-        ctx, partial_final_shape.AsTensorShape(&final_shape),
-        InvalidArgument("XLA can't deduce compile time constant output "
-                        "shape for strided slice: ",
-                        partial_final_shape.DebugString(),
-                        ", output shape must be a compile-time constant"));
+    OP_REQUIRES(ctx, partial_final_shape.AsTensorShape(&final_shape),
+                errors::InvalidArgument(
+                    "XLA can't deduce compile time constant output "
+                    "shape for strided slice: ",
+                    partial_final_shape.DebugString(),
+                    ", output shape must be a compile-time constant"));
 
     xla::XlaOp slice = ctx->Input(0);
     if (begin_is_constant && end_is_constant) {
@@ -128,84 +119,69 @@
       auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0));
       OP_REQUIRES_OK(ctx, operand_shape_or.status());
       xla::Shape xla_shape = operand_shape_or.ValueOrDie();
-      std::vector<bool> begins_are_dynamic;
-      OP_REQUIRES_OK(
-          ctx, ctx->ResolveInputDynamismIntoPredVector(1, &begins_are_dynamic));
-      std::vector<bool> ends_are_dynamic;
-      OP_REQUIRES_OK(
-          ctx, ctx->ResolveInputDynamismIntoPredVector(2, &ends_are_dynamic));
-      bool begins_are_static = absl::c_all_of(
-          begins_are_dynamic, [](bool dynamic) { return !dynamic; });
-      OP_REQUIRES(ctx, begins_are_static,
-                  errors::InvalidArgument(
-                      "XLA can't use dynamic begin values for slice."));
-      bool ends_are_static = absl::c_all_of(
-          ends_are_dynamic, [](bool dynamic) { return !dynamic; });
-      // Static output shape, return a static slice.
-      slice = xla::Reshape(slice, final_shape.dim_sizes());
-      if (xla_shape.is_static() && ends_are_static) {
+      if (xla_shape.is_static()) {
+        // Static output shape, return a static slice.
+        slice = xla::Reshape(slice, final_shape.dim_sizes());
+        ctx->SetOutput(0, slice);
+        return;
+      }
+      auto input_dim_sizes = input_shape.dim_sizes();
+
+      for (int64 i = 0; i < xla_shape.rank(); ++i) {
+        if (xla_shape.is_dynamic_dimension(i)) {
+          input_dim_sizes[i] = -1;
+        }
+      }
+      PartialTensorShape input_partial_shape(input_dim_sizes);
+      partial_final_shape.Clear();
+      end.clear();
+      strides.clear();
+      begin.clear();
+      // Run shape inferenference again with partial shape.
+      OP_REQUIRES_OK(ctx, ValidateStridedSliceOp(
+                              &begin_tensor, &end_tensor, strides_tensor,
+                              input_partial_shape, begin_mask_, end_mask_,
+                              ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
+                              &dummy_processing_shape, &partial_final_shape,
+                              &dummy, &dummy, &dummy, &begin, &end, &strides));
+      if (partial_final_shape.AsTensorShape(&final_shape)) {
+        // Static output shape, return a static slice.
+        slice = xla::Reshape(slice, final_shape.dim_sizes());
         ctx->SetOutput(0, slice);
         return;
       }
 
-      for (int64 i = 0; i < final_shape.dims(); ++i) {
-        int64 input_index = output_to_processing_mapping[i];
-        if (input_index == -1) {
-          continue;
-        }
-        bool input_is_dynamic = xla_shape.is_dynamic_dimension(input_index);
-
-        int64 sparse_index = output_to_sparse_mapping[i];
-        bool end_is_dynamic =
-            sparse_index == -1 ? false : ends_are_dynamic[sparse_index];
-        bool backward_slice = sparse_index == -1
-                                  ? false
-                                  : end_literal.Get<int32>({sparse_index}) < 0;
-        if ((input_is_dynamic && backward_slice) || end_is_dynamic) {
+      // We consider slicing a dynamic tensor t with negative indices as a
+      // dynamic sized slice. E.g., t[: -n], the result length is shape(t) - n
+      for (int64 i = 0; i < partial_final_shape.dims(); ++i) {
+        bool dynamic_dim = partial_final_shape.dim_size(i) - 1;
+        bool backward_slice = end[i] < 0;
+        if (dynamic_dim && backward_slice) {
           OP_REQUIRES(
-              ctx, strides[input_index] == 1,
+              ctx, strides[i] == 1,
               errors::InvalidArgument("XLA has not implemented dynamic "
                                       "sized slice with non-trival stride yet. "
                                       "Please file a bug against XLA"));
+
+          OP_REQUIRES(ctx, begin[i] >= 0,
+                      errors::InvalidArgument(
+                          "XLA has not implemented dynamic "
+                          "sized slice with negative begin index %lld. "
+                          "Please file a bug against XLA",
+                          begin[i]));
           // If there is a dynamic dimension, properly set dimension size of
           // the result.
-          auto operand_size = xla::GetDimensionSize(ctx->Input(0), input_index);
-          if (backward_slice) {
-            // We consider slicing a dynamic tensor t with negative indices as
-            // a dynamic sized slice. E.g., t[: -n], the result length is
-            // shape(t) - n.
-            OP_REQUIRES(ctx, !end_is_dynamic,
-                        errors::InvalidArgument(
-                            "XLA has not implemented dynamic "
-                            "sized slice with dynamic negative index %lld. "));
-            operand_size = xla::Add(
-                operand_size,
-                xla::ConstantR0<int32>(ctx->builder(),
-                                       end_literal.Get<int32>({sparse_index})));
-          } else {
-            // The end of slice with dynamic slice size is the min of operand
-            // shape and slice size. E.g., t[:end_size], result size is
-            // min(shape(t), end_size).
-            xla::XlaOp end_size;
-            if (end_is_dynamic) {
-              end_size = xla::Reshape(xla::Slice(ctx->Input(2), {sparse_index},
-                                                 {sparse_index + 1}, {1}),
-                                      {});
-            } else {
-              end_size =
-                  xla::ConstantR0<int32>(ctx->builder(), end[input_index]);
-            }
-            operand_size = xla::Min(operand_size, end_size);
-          }
+          auto operand_size = xla::GetDimensionSize(ctx->Input(0), i);
+
+          operand_size = xla::Add(
+              operand_size, xla::ConstantR0<int32>(ctx->builder(), end[i]));
           slice = xla::SetDimensionSize(
               slice,
-              xla::Sub(operand_size, xla::ConstantR0<int32>(
-                                         ctx->builder(), begin[input_index])),
+              xla::Sub(operand_size,
+                       xla::ConstantR0<int32>(ctx->builder(), begin[i])),
               i);
         }
       }
-      ctx->SetOutput(0, slice);
-      return;
     } else {
       // When output shape is fully defined, it must be a size one slice:
       //
@@ -263,9 +239,9 @@
 
       std::vector<int64> output_shape_dim_sizes;
       slice = xla::DynamicSlice(slice, start_indices, slice_sizes);
-      slice = xla::Reshape(slice, final_shape.dim_sizes());
-      ctx->SetOutput(0, slice);
     }
+    slice = xla::Reshape(slice, final_shape.dim_sizes());
+    ctx->SetOutput(0, slice);
   }
 
  private:
diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc
index 1cf9a8c..0df810a 100644
--- a/tensorflow/core/util/strided_slice_op.cc
+++ b/tensorflow/core/util/strided_slice_op.cc
@@ -59,11 +59,6 @@
   // is obtained from canonical end-begin. Otherwise, if it is a kNewAxis,
   // it will be 1. A shrunk dimension is skipped.
   gtl::InlinedVector<int32, 4> final_shape_gather_indices;
-  // This vector has the same size as final_shape_gather_indices, but it
-  // remembers the sparse index that a dimension comes from, instead of dense
-  // index. A -1 in this vector means there the index is not from the sparse
-  // input.
-  gtl::InlinedVector<int32, 4> final_shape_gather_indices_sparse;
   // The dense indexed shrink mask is which processing dimensions
   // should be shrunk. For example, if foo.shape = (10,10,10,10)
   // foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
@@ -113,11 +108,9 @@
           dense->begin_mask |= (1 << full_index);
           dense->end_mask |= (1 << full_index);
           dense->final_shape_gather_indices.push_back(full_index);
-          dense->final_shape_gather_indices_sparse.push_back(-1);
         }
       } else if ((1 << i) & sparse.new_axis_mask) {
         dense->final_shape_gather_indices.push_back(kNewAxis);
-        dense->final_shape_gather_indices_sparse.push_back(-1);
       } else {
         if (full_index == dense->begin.size()) {
           return errors::InvalidArgument("Index out of range using input dim ",
@@ -145,13 +138,9 @@
         // axis (now in dense form) so we can ignore dense->end below.
         if (sparse.shrink_axis_mask & (1 << i)) {
           dense->final_shape_gather_indices.push_back(kShrinkAxis);
-          dense->final_shape_gather_indices_sparse.push_back(-1);
           dense->shrink_axis_mask |= (1 << full_index);
         } else {
           dense->final_shape_gather_indices.push_back(full_index);
-          // Remember that where in the sparse shape the dense dim comes
-          // from.
-          dense->final_shape_gather_indices_sparse.push_back(i);
         }
         full_index++;
       }
@@ -168,9 +157,7 @@
     PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
     bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
     gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
-    gtl::InlinedVector<int64, 4>* strides,
-    gtl::InlinedVector<int64, 4>* output_to_sparse_mapping,
-    gtl::InlinedVector<int64, 4>* output_to_processing_mapping) {
+    gtl::InlinedVector<int64, 4>* strides) {
   const bool begin_is_wrong =
       begin_tensor != nullptr &&
       !(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
@@ -375,34 +362,11 @@
   // slices like foo[3,...] will reduce dimension by 1.
   // This cannot be done earlier, because it depends on Step 3.
   final_shape->Clear();
-  if (output_to_sparse_mapping != nullptr) {
-    output_to_sparse_mapping->clear();
-  }
-
-  if (output_to_processing_mapping != nullptr) {
-    output_to_processing_mapping->clear();
-  }
-  for (int64 dense_dim = 0;
-       dense_dim < dense_spec.final_shape_gather_indices.size(); ++dense_dim) {
-    int64 gather_index = dense_spec.final_shape_gather_indices[dense_dim];
-    int64 sparse_index =
-        dense_spec.final_shape_gather_indices_sparse[dense_dim];
+  for (auto gather_index : dense_spec.final_shape_gather_indices) {
     if (gather_index >= 0) {
       final_shape->AddDim(processing_shape->dim_size(gather_index));
-      if (output_to_sparse_mapping != nullptr) {
-        output_to_sparse_mapping->push_back(sparse_index);
-      }
-      if (output_to_processing_mapping != nullptr) {
-        output_to_processing_mapping->push_back(gather_index);
-      }
     } else if (gather_index == kNewAxis) {
       final_shape->AddDim(1);
-      if (output_to_sparse_mapping != nullptr) {
-        output_to_sparse_mapping->push_back(-1);
-      }
-      if (output_to_processing_mapping != nullptr) {
-        output_to_processing_mapping->push_back(-1);
-      }
     }
   }
   return Status::OK();
@@ -415,17 +379,14 @@
     int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
     TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
     bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
-    gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
-    gtl::InlinedVector<int64, 4>* output_to_sparse_mapping,
-    gtl::InlinedVector<int64, 4>* output_to_processing_mapping) {
+    gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides) {
   // Validate with PartialTensorShape output
   PartialTensorShape partial_processing_shape, partial_final_shape;
   TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
       begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec,
       end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask,
       &partial_processing_shape, &partial_final_shape, is_identity,
-      is_simple_slice, slice_dim0, begin, end, strides,
-      output_to_sparse_mapping, output_to_processing_mapping));
+      is_simple_slice, slice_dim0, begin, end, strides));
 
   // Verify that the output shapes are fully known
   if (!partial_processing_shape.AsTensorShape(processing_shape) ||
diff --git a/tensorflow/core/util/strided_slice_op.h b/tensorflow/core/util/strided_slice_op.h
index 9e49477..25ecccd 100644
--- a/tensorflow/core/util/strided_slice_op.h
+++ b/tensorflow/core/util/strided_slice_op.h
@@ -40,17 +40,6 @@
 // some dimensions of <processing_shape> and/or <final_shape> may be unknown
 // (-1). Any validation that can be done without complete information is
 // performed.
-//
-// This function changes the orders of dimensions, output_to_sparse_mapping and
-// output_to_processing_mapping are used to track the order change.
-//
-// output_to_sparse_mapping[i] represents output[i]'s the corresponding dim
-// index in the begin_tensor. If
-// output_to_sparse_mapping[i] is -1, it means the dimension doesn't show up in
-// sparse_mapping.
-//
-// output_to_processing_mapping is similar to output_to_sparse_mapping, but for
-// processing_shape.
 Status ValidateStridedSliceOp(
     const Tensor* begin_tensor, const Tensor* end_tensor,
     const Tensor& strides_tensor, const PartialTensorShape& input_shape,
@@ -59,9 +48,7 @@
     PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
     bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
     gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
-    gtl::InlinedVector<int64, 4>* strides,
-    gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr,
-    gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr);
+    gtl::InlinedVector<int64, 4>* strides);
 
 // Same as above, but the outputs are TensorShape, not PartialTensorShape
 Status ValidateStridedSliceOp(
@@ -71,9 +58,7 @@
     int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
     TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
     bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
-    gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
-    gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr,
-    gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr);
+    gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides);
 
 }  // namespace tensorflow