[TF2XLA] Support dynamic slice size in strided slice op.
- Add two side outputs in ValidateStridedSliceOp to help analyze dynamic dimensions.
- Correctly set strided slice op's dynamic size if the slice size (slice end) is dynamic
PiperOrigin-RevId: 327552472
Change-Id: Ia85e7bc377c432e5032f49278754659452ec9f86
diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
index 807c061..d7a8e67 100644
--- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
@@ -16,7 +16,6 @@
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -29,26 +28,13 @@
: XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
+ const TensorShape input_shape = context->InputShape(0);
TensorShape output_shape;
OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape));
- auto output_status_or =
- BroadcastTo(context->Input(0), output_shape.dim_sizes());
- OP_REQUIRES_OK(context, output_status_or.status());
- auto output = output_status_or.ValueOrDie();
- std::vector<bool> dynamic_dims;
- OP_REQUIRES_OK(
- context, context->ResolveInputDynamismIntoPredVector(1, &dynamic_dims));
- for (int64 dim = 0; dim < dynamic_dims.size(); ++dim) {
- if (dynamic_dims[dim]) {
- output = xla::SetDimensionSize(
- output,
- xla::Reshape(xla::Slice(context->Input(1), {dim}, {dim + 1}, {1}),
- {}),
- dim);
- }
- }
- context->SetOutput(0, output);
+ auto output = BroadcastTo(context->Input(0), output_shape.dim_sizes());
+ OP_REQUIRES_OK(context, output.status());
+ context->SetOutput(0, output.ValueOrDie());
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index 72cb746..784b790 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -15,9 +15,6 @@
#include "tensorflow/core/util/strided_slice_op.h"
-#include <vector>
-
-#include "absl/algorithm/container.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
@@ -26,7 +23,6 @@
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
-#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/ops_util.h"
#include "tensorflow/core/framework/register_types.h"
@@ -37,7 +33,6 @@
namespace tensorflow {
namespace {
-using errors::InvalidArgument;
class StridedSliceOp : public XlaOpKernel {
public:
@@ -53,7 +48,7 @@
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);
const TensorShape begin_shape = ctx->InputShape("begin");
- VLOG(0) << "strided slice";
+
OP_REQUIRES(
ctx, begin_shape.dims() == 1,
errors::InvalidArgument("'begin' input has to be a rank 1 vector"));
@@ -83,24 +78,20 @@
TensorShape final_shape;
PartialTensorShape dummy_processing_shape, partial_final_shape;
bool dummy = false;
- absl::InlinedVector<int64, 4> output_to_sparse_mapping;
- absl::InlinedVector<int64, 4> output_to_processing_mapping;
- OP_REQUIRES_OK(
- ctx,
- ValidateStridedSliceOp(
- begin_is_constant ? &begin_tensor : nullptr,
- end_is_constant ? &end_tensor : nullptr, strides_tensor,
- input_shape, begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
- shrink_axis_mask_, &dummy_processing_shape, &partial_final_shape,
- &dummy, &dummy, &dummy, &begin, &end, &strides,
- &output_to_sparse_mapping, &output_to_processing_mapping));
+ OP_REQUIRES_OK(ctx, ValidateStridedSliceOp(
+ begin_is_constant ? &begin_tensor : nullptr,
+ end_is_constant ? &end_tensor : nullptr,
+ strides_tensor, input_shape, begin_mask_, end_mask_,
+ ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
+ &dummy_processing_shape, &partial_final_shape,
+ &dummy, &dummy, &dummy, &begin, &end, &strides));
- OP_REQUIRES(
- ctx, partial_final_shape.AsTensorShape(&final_shape),
- InvalidArgument("XLA can't deduce compile time constant output "
- "shape for strided slice: ",
- partial_final_shape.DebugString(),
- ", output shape must be a compile-time constant"));
+ OP_REQUIRES(ctx, partial_final_shape.AsTensorShape(&final_shape),
+ errors::InvalidArgument(
+ "XLA can't deduce compile time constant output "
+ "shape for strided slice: ",
+ partial_final_shape.DebugString(),
+ ", output shape must be a compile-time constant"));
xla::XlaOp slice = ctx->Input(0);
if (begin_is_constant && end_is_constant) {
@@ -128,84 +119,69 @@
auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0));
OP_REQUIRES_OK(ctx, operand_shape_or.status());
xla::Shape xla_shape = operand_shape_or.ValueOrDie();
- std::vector<bool> begins_are_dynamic;
- OP_REQUIRES_OK(
- ctx, ctx->ResolveInputDynamismIntoPredVector(1, &begins_are_dynamic));
- std::vector<bool> ends_are_dynamic;
- OP_REQUIRES_OK(
- ctx, ctx->ResolveInputDynamismIntoPredVector(2, &ends_are_dynamic));
- bool begins_are_static = absl::c_all_of(
- begins_are_dynamic, [](bool dynamic) { return !dynamic; });
- OP_REQUIRES(ctx, begins_are_static,
- errors::InvalidArgument(
- "XLA can't use dynamic begin values for slice."));
- bool ends_are_static = absl::c_all_of(
- ends_are_dynamic, [](bool dynamic) { return !dynamic; });
- // Static output shape, return a static slice.
- slice = xla::Reshape(slice, final_shape.dim_sizes());
- if (xla_shape.is_static() && ends_are_static) {
+ if (xla_shape.is_static()) {
+ // Static output shape, return a static slice.
+ slice = xla::Reshape(slice, final_shape.dim_sizes());
+ ctx->SetOutput(0, slice);
+ return;
+ }
+ auto input_dim_sizes = input_shape.dim_sizes();
+
+ for (int64 i = 0; i < xla_shape.rank(); ++i) {
+ if (xla_shape.is_dynamic_dimension(i)) {
+ input_dim_sizes[i] = -1;
+ }
+ }
+ PartialTensorShape input_partial_shape(input_dim_sizes);
+ partial_final_shape.Clear();
+ end.clear();
+ strides.clear();
+ begin.clear();
+ // Run shape inferenference again with partial shape.
+ OP_REQUIRES_OK(ctx, ValidateStridedSliceOp(
+ &begin_tensor, &end_tensor, strides_tensor,
+ input_partial_shape, begin_mask_, end_mask_,
+ ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
+ &dummy_processing_shape, &partial_final_shape,
+ &dummy, &dummy, &dummy, &begin, &end, &strides));
+ if (partial_final_shape.AsTensorShape(&final_shape)) {
+ // Static output shape, return a static slice.
+ slice = xla::Reshape(slice, final_shape.dim_sizes());
ctx->SetOutput(0, slice);
return;
}
- for (int64 i = 0; i < final_shape.dims(); ++i) {
- int64 input_index = output_to_processing_mapping[i];
- if (input_index == -1) {
- continue;
- }
- bool input_is_dynamic = xla_shape.is_dynamic_dimension(input_index);
-
- int64 sparse_index = output_to_sparse_mapping[i];
- bool end_is_dynamic =
- sparse_index == -1 ? false : ends_are_dynamic[sparse_index];
- bool backward_slice = sparse_index == -1
- ? false
- : end_literal.Get<int32>({sparse_index}) < 0;
- if ((input_is_dynamic && backward_slice) || end_is_dynamic) {
+ // We consider slicing a dynamic tensor t with negative indices as a
+ // dynamic sized slice. E.g., t[: -n], the result length is shape(t) - n
+ for (int64 i = 0; i < partial_final_shape.dims(); ++i) {
+ bool dynamic_dim = partial_final_shape.dim_size(i) - 1;
+ bool backward_slice = end[i] < 0;
+ if (dynamic_dim && backward_slice) {
OP_REQUIRES(
- ctx, strides[input_index] == 1,
+ ctx, strides[i] == 1,
errors::InvalidArgument("XLA has not implemented dynamic "
"sized slice with non-trival stride yet. "
"Please file a bug against XLA"));
+
+ OP_REQUIRES(ctx, begin[i] >= 0,
+ errors::InvalidArgument(
+ "XLA has not implemented dynamic "
+ "sized slice with negative begin index %lld. "
+ "Please file a bug against XLA",
+ begin[i]));
// If there is a dynamic dimension, properly set dimension size of
// the result.
- auto operand_size = xla::GetDimensionSize(ctx->Input(0), input_index);
- if (backward_slice) {
- // We consider slicing a dynamic tensor t with negative indices as
- // a dynamic sized slice. E.g., t[: -n], the result length is
- // shape(t) - n.
- OP_REQUIRES(ctx, !end_is_dynamic,
- errors::InvalidArgument(
- "XLA has not implemented dynamic "
- "sized slice with dynamic negative index %lld. "));
- operand_size = xla::Add(
- operand_size,
- xla::ConstantR0<int32>(ctx->builder(),
- end_literal.Get<int32>({sparse_index})));
- } else {
- // The end of slice with dynamic slice size is the min of operand
- // shape and slice size. E.g., t[:end_size], result size is
- // min(shape(t), end_size).
- xla::XlaOp end_size;
- if (end_is_dynamic) {
- end_size = xla::Reshape(xla::Slice(ctx->Input(2), {sparse_index},
- {sparse_index + 1}, {1}),
- {});
- } else {
- end_size =
- xla::ConstantR0<int32>(ctx->builder(), end[input_index]);
- }
- operand_size = xla::Min(operand_size, end_size);
- }
+ auto operand_size = xla::GetDimensionSize(ctx->Input(0), i);
+
+ operand_size = xla::Add(
+ operand_size, xla::ConstantR0<int32>(ctx->builder(), end[i]));
slice = xla::SetDimensionSize(
slice,
- xla::Sub(operand_size, xla::ConstantR0<int32>(
- ctx->builder(), begin[input_index])),
+ xla::Sub(operand_size,
+ xla::ConstantR0<int32>(ctx->builder(), begin[i])),
i);
}
}
- ctx->SetOutput(0, slice);
- return;
} else {
// When output shape is fully defined, it must be a size one slice:
//
@@ -263,9 +239,9 @@
std::vector<int64> output_shape_dim_sizes;
slice = xla::DynamicSlice(slice, start_indices, slice_sizes);
- slice = xla::Reshape(slice, final_shape.dim_sizes());
- ctx->SetOutput(0, slice);
}
+ slice = xla::Reshape(slice, final_shape.dim_sizes());
+ ctx->SetOutput(0, slice);
}
private:
diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc
index 1cf9a8c..0df810a 100644
--- a/tensorflow/core/util/strided_slice_op.cc
+++ b/tensorflow/core/util/strided_slice_op.cc
@@ -59,11 +59,6 @@
// is obtained from canonical end-begin. Otherwise, if it is a kNewAxis,
// it will be 1. A shrunk dimension is skipped.
gtl::InlinedVector<int32, 4> final_shape_gather_indices;
- // This vector has the same size as final_shape_gather_indices, but it
- // remembers the sparse index that a dimension comes from, instead of dense
- // index. A -1 in this vector means there the index is not from the sparse
- // input.
- gtl::InlinedVector<int32, 4> final_shape_gather_indices_sparse;
// The dense indexed shrink mask is which processing dimensions
// should be shrunk. For example, if foo.shape = (10,10,10,10)
// foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
@@ -113,11 +108,9 @@
dense->begin_mask |= (1 << full_index);
dense->end_mask |= (1 << full_index);
dense->final_shape_gather_indices.push_back(full_index);
- dense->final_shape_gather_indices_sparse.push_back(-1);
}
} else if ((1 << i) & sparse.new_axis_mask) {
dense->final_shape_gather_indices.push_back(kNewAxis);
- dense->final_shape_gather_indices_sparse.push_back(-1);
} else {
if (full_index == dense->begin.size()) {
return errors::InvalidArgument("Index out of range using input dim ",
@@ -145,13 +138,9 @@
// axis (now in dense form) so we can ignore dense->end below.
if (sparse.shrink_axis_mask & (1 << i)) {
dense->final_shape_gather_indices.push_back(kShrinkAxis);
- dense->final_shape_gather_indices_sparse.push_back(-1);
dense->shrink_axis_mask |= (1 << full_index);
} else {
dense->final_shape_gather_indices.push_back(full_index);
- // Remember that where in the sparse shape the dense dim comes
- // from.
- dense->final_shape_gather_indices_sparse.push_back(i);
}
full_index++;
}
@@ -168,9 +157,7 @@
PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
- gtl::InlinedVector<int64, 4>* strides,
- gtl::InlinedVector<int64, 4>* output_to_sparse_mapping,
- gtl::InlinedVector<int64, 4>* output_to_processing_mapping) {
+ gtl::InlinedVector<int64, 4>* strides) {
const bool begin_is_wrong =
begin_tensor != nullptr &&
!(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
@@ -375,34 +362,11 @@
// slices like foo[3,...] will reduce dimension by 1.
// This cannot be done earlier, because it depends on Step 3.
final_shape->Clear();
- if (output_to_sparse_mapping != nullptr) {
- output_to_sparse_mapping->clear();
- }
-
- if (output_to_processing_mapping != nullptr) {
- output_to_processing_mapping->clear();
- }
- for (int64 dense_dim = 0;
- dense_dim < dense_spec.final_shape_gather_indices.size(); ++dense_dim) {
- int64 gather_index = dense_spec.final_shape_gather_indices[dense_dim];
- int64 sparse_index =
- dense_spec.final_shape_gather_indices_sparse[dense_dim];
+ for (auto gather_index : dense_spec.final_shape_gather_indices) {
if (gather_index >= 0) {
final_shape->AddDim(processing_shape->dim_size(gather_index));
- if (output_to_sparse_mapping != nullptr) {
- output_to_sparse_mapping->push_back(sparse_index);
- }
- if (output_to_processing_mapping != nullptr) {
- output_to_processing_mapping->push_back(gather_index);
- }
} else if (gather_index == kNewAxis) {
final_shape->AddDim(1);
- if (output_to_sparse_mapping != nullptr) {
- output_to_sparse_mapping->push_back(-1);
- }
- if (output_to_processing_mapping != nullptr) {
- output_to_processing_mapping->push_back(-1);
- }
}
}
return Status::OK();
@@ -415,17 +379,14 @@
int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
- gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
- gtl::InlinedVector<int64, 4>* output_to_sparse_mapping,
- gtl::InlinedVector<int64, 4>* output_to_processing_mapping) {
+ gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides) {
// Validate with PartialTensorShape output
PartialTensorShape partial_processing_shape, partial_final_shape;
TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec,
end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask,
&partial_processing_shape, &partial_final_shape, is_identity,
- is_simple_slice, slice_dim0, begin, end, strides,
- output_to_sparse_mapping, output_to_processing_mapping));
+ is_simple_slice, slice_dim0, begin, end, strides));
// Verify that the output shapes are fully known
if (!partial_processing_shape.AsTensorShape(processing_shape) ||
diff --git a/tensorflow/core/util/strided_slice_op.h b/tensorflow/core/util/strided_slice_op.h
index 9e49477..25ecccd 100644
--- a/tensorflow/core/util/strided_slice_op.h
+++ b/tensorflow/core/util/strided_slice_op.h
@@ -40,17 +40,6 @@
// some dimensions of <processing_shape> and/or <final_shape> may be unknown
// (-1). Any validation that can be done without complete information is
// performed.
-//
-// This function changes the orders of dimensions, output_to_sparse_mapping and
-// output_to_processing_mapping are used to track the order change.
-//
-// output_to_sparse_mapping[i] represents output[i]'s the corresponding dim
-// index in the begin_tensor. If
-// output_to_sparse_mapping[i] is -1, it means the dimension doesn't show up in
-// sparse_mapping.
-//
-// output_to_processing_mapping is similar to output_to_sparse_mapping, but for
-// processing_shape.
Status ValidateStridedSliceOp(
const Tensor* begin_tensor, const Tensor* end_tensor,
const Tensor& strides_tensor, const PartialTensorShape& input_shape,
@@ -59,9 +48,7 @@
PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
- gtl::InlinedVector<int64, 4>* strides,
- gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr,
- gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr);
+ gtl::InlinedVector<int64, 4>* strides);
// Same as above, but the outputs are TensorShape, not PartialTensorShape
Status ValidateStridedSliceOp(
@@ -71,9 +58,7 @@
int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
- gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
- gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr,
- gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr);
+ gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides);
} // namespace tensorflow