| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/util/strided_slice_op.h" |
| |
| #include <vector> |
| |
| #include "absl/algorithm/container.h" |
| #include "absl/types/span.h" |
| #include "tensorflow/compiler/tf2xla/literal_util.h" |
| #include "tensorflow/compiler/tf2xla/type_util.h" |
| #include "tensorflow/compiler/tf2xla/xla_helpers.h" |
| #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" |
| #include "tensorflow/compiler/tf2xla/xla_op_registry.h" |
| #include "tensorflow/compiler/xla/client/lib/constants.h" |
| #include "tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h" |
| #include "tensorflow/compiler/xla/client/value_inference.h" |
| #include "tensorflow/compiler/xla/client/xla_builder.h" |
| #include "tensorflow/compiler/xla/literal.h" |
| #include "tensorflow/compiler/xla/util.h" |
| #include "tensorflow/compiler/xla/xla_data.pb.h" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/ops_util.h" |
| #include "tensorflow/core/framework/register_types.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/framework/tensor_shape.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/platform/errors.h" |
| #include "tensorflow/core/platform/mem.h" |
| |
| namespace tensorflow { |
| namespace { |
| using errors::InvalidArgument; |
| |
| class StridedSliceOp : public XlaOpKernel { |
| public: |
| explicit StridedSliceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_)); |
| } |
| |
| void EmitDynamicSlice(XlaOpKernelContext* ctx, |
| const absl::InlinedVector<int64_t, 4>& strides, |
| TensorShape processing_shape, TensorShape final_shape, |
| PartialTensorShape partial_processing_shape, |
| PartialTensorShape partial_final_shape, |
| const StridedSliceShapeSpec& shape_spec, |
| const std::vector<bool>& begins_are_dynamic, |
| const std::vector<bool>& ends_are_dynamic) { |
| const TensorShape input_shape = ctx->InputShape(0); |
| xla::XlaOp slice = ctx->Input(0); |
| for (int64_t i = 0; i < ctx->InputShape("begin").dims(); ++i) { |
| OP_REQUIRES(ctx, strides[i] == 1, |
| errors::InvalidArgument( |
| "Strides have to be one when inputs are not constant.")); |
| } |
| // Infer static output shape, reconcile unknown dimension with input dim |
| // size. |
| for (int64_t i = 0; i < partial_final_shape.dims(); ++i) { |
| if (partial_final_shape.dim_size(i) == -1) { |
| // Use input shape to update unknown dimension of partial shape -- if a |
| // dimension is unknown, we use input shape as bound. |
| partial_final_shape.set_dim( |
| i, |
| input_shape.dim_size(shape_spec.output_to_processing_mapping[i])); |
| } |
| } |
| |
| OP_REQUIRES( |
| ctx, partial_final_shape.AsTensorShape(&final_shape), |
| InvalidArgument("XLA can't deduce compile time constant output " |
| "shape for strided slice: ", |
| partial_final_shape.DebugString(), |
| ", output shape must be a compile-time constant")); |
| for (int64_t i = 0; i < partial_processing_shape.dims(); ++i) { |
| if (partial_processing_shape.dim_size(i) == -1) { |
| // Use input shape to update unknown dimension of partial shape -- if a |
| // dimension is unknown, we use input shape as bound. |
| partial_processing_shape.set_dim(i, input_shape.dim_size(i)); |
| } |
| } |
| OP_REQUIRES( |
| ctx, partial_processing_shape.AsTensorShape(&processing_shape), |
| InvalidArgument("XLA can't deduce compile time constant processing " |
| "shape for strided slice: ", |
| partial_processing_shape.DebugString(), |
| ", output shape must be a compile-time constant")); |
| // When inputs are not compile time constants, shape inference can only |
| // inference size 1 slice. |
| std::vector<int64_t> slice_sizes(input_shape.dims(), 1); |
| // If there is dynamic begin/end (and if the dimension is not shrunk), we |
| // need to use dynamic shape infrastructure -- we slice the output with |
| // full size, then call SetDimensionSize on the output. However, if we |
| // slice with the full size at a non-zero dimension we may get OOB access. |
| // To avoid that, we first pad the input to 2x before calling slice. |
| xla::PaddingConfig padding_config; |
| bool need_padding = false; |
| std::vector<bool> result_dims_are_dynamic; |
| for (int64_t i = 0; i < input_shape.dims(); ++i) { |
| int64_t sparse_index = shape_spec.processing_to_sparse_mapping[i]; |
| bool shrink_axis_set = (1 << i) & shape_spec.shrink_axis_dense_mask; |
| auto* dims = padding_config.add_dimensions(); |
| dims->set_edge_padding_low(0); |
| |
| dims->set_interior_padding(0); |
| if ((begins_are_dynamic[sparse_index] || |
| ends_are_dynamic[sparse_index]) && |
| !shrink_axis_set) { |
| // Need to slice this dimension so pad first. |
| dims->set_edge_padding_high(input_shape.dim_size(i)); |
| need_padding = true; |
| result_dims_are_dynamic.push_back(true); |
| } else { |
| dims->set_edge_padding_high(0); |
| result_dims_are_dynamic.push_back(false); |
| } |
| } |
| |
| if (need_padding) { |
| // Pad input to 2x to avoid OOB access. |
| slice = xla::Pad(slice, xla::Zero(ctx->builder(), ctx->input_xla_type(0)), |
| padding_config); |
| for (int64 i = 0; i < result_dims_are_dynamic.size(); ++i) { |
| if (result_dims_are_dynamic[i]) { |
| slice = xla::RemoveDynamicDimension(slice, i); |
| } |
| } |
| } |
| std::vector<xla::XlaOp> start_indices; |
| std::vector<xla::XlaOp> slice_sizes_dynamic; |
| xla::Shape input_xla_shape = ctx->InputXlaShape(0).ValueOrDie(); |
| for (int64_t i = 0; i < input_shape.dims(); ++i) { |
| bool begin_mask = (1 << i) & shape_spec.begin_dense_mask; |
| bool end_mask = (1 << i) & shape_spec.end_dense_mask; |
| auto zero = xla::Zero(ctx->builder(), ctx->InputXlaType("begin")); |
| xla::XlaOp begin_index, end_index; |
| int64_t sparse_index = shape_spec.processing_to_sparse_mapping[i]; |
| bool xla_input_is_dynamic = input_xla_shape.is_dynamic_dimension(i); |
| xla::XlaOp dim_size; |
| if (xla_input_is_dynamic) { |
| dim_size = xla::GetDimensionSize(ctx->Input(0), i); |
| OP_REQUIRES(ctx, ctx->InputXlaType("begin") == xla::S32, |
| errors::InvalidArgument("'begin shape has to be int32 when " |
| "indices to slice op are dynamic")); |
| } else { |
| dim_size = |
| xla::ConstantR0WithType(ctx->builder(), ctx->InputXlaType("begin"), |
| input_xla_shape.dimensions(i)); |
| } |
| |
| auto scalar_must_be_non_negative = [ctx](xla::XlaOp value) -> bool { |
| // Check if the lower-bound of a value is always >= 0 |
| auto lower_bound = ctx->value_inference().AnalyzeConstant( |
| value, xla::ValueInferenceMode::kLowerBound); |
| if (!lower_bound.ok() || !lower_bound->AllValid()) { |
| // Can't infer a lower bound. |
| return false; |
| } |
| return lower_bound->Get<int32>({}) >= 0; |
| }; |
| if (begin_mask) { |
| begin_index = zero; |
| } else { |
| begin_index = xla::Slice(ctx->Input("begin"), {sparse_index}, |
| {sparse_index + 1}, {1}); |
| begin_index = xla::Reshape(begin_index, {}); |
| if (!scalar_must_be_non_negative(begin_index)) { |
| // begin could be negative. |
| auto index_negative = xla::Lt(begin_index, zero); |
| auto wrapped_index = xla::Add(dim_size, begin_index); |
| // Wrap negative indices around. |
| begin_index = xla::Select(index_negative, wrapped_index, begin_index); |
| } |
| } |
| start_indices.push_back(begin_index); |
| if (end_mask) { |
| end_index = dim_size; |
| } else { |
| end_index = xla::Slice(ctx->Input("end"), {sparse_index}, |
| {sparse_index + 1}, {1}); |
| end_index = xla::Reshape(end_index, {}); |
| if (!scalar_must_be_non_negative(end_index)) { |
| // end could be negative. |
| auto index_negative = xla::Lt(end_index, zero); |
| auto wrapped_index = xla::Add(dim_size, end_index); |
| end_index = xla::Select(index_negative, wrapped_index, end_index); |
| } |
| } |
| slice_sizes_dynamic.push_back( |
| xla::Max(xla::Sub(end_index, begin_index), zero)); |
| } |
| |
| slice = |
| xla::DynamicSlice(slice, start_indices, processing_shape.dim_sizes()); |
| // new_axis_mask_, ellipsis_mask_ and shrink_axis_mask_ may add or remove |
| // size 1 dims of a shape. |
| slice = xla::Reshape(slice, final_shape.dim_sizes()); |
| for (int64_t i = 0; i < final_shape.dims(); ++i) { |
| int64 processing_shape_dim = shape_spec.output_to_processing_mapping[i]; |
| // If processing_shape_dim is -1, it means the output dimension was newly |
| // added by new_axis_mask_, which doesn't show up in input. |
| if (processing_shape_dim != -1 && |
| result_dims_are_dynamic[processing_shape_dim]) { |
| // We gave a generous bound (same as input) to the output, try reset |
| // the bound if a tighter one can be found. |
| auto status = xla::SetDimensionSizeWithRebound( |
| &ctx->value_inference(), slice, |
| slice_sizes_dynamic[processing_shape_dim], i); |
| OP_REQUIRES_OK(ctx, status.status()); |
| slice = status.ValueOrDie(); |
| } |
| } |
| ctx->SetOutput(0, slice); |
| } |
| |
| void Compile(XlaOpKernelContext* ctx) override { |
| const TensorShape input_shape = ctx->InputShape(0); |
| const TensorShape begin_shape = ctx->InputShape("begin"); |
| OP_REQUIRES( |
| ctx, begin_shape.dims() == 1, |
| errors::InvalidArgument("'begin' input has to be a rank 1 vector")); |
| |
| absl::InlinedVector<int64_t, 4> begin; |
| absl::InlinedVector<int64_t, 4> end; |
| absl::InlinedVector<int64_t, 4> strides; |
| |
| xla::Literal begin_literal, end_literal, strides_literal; |
| bool begin_is_constant = ctx->ConstantInput(1, &begin_literal).ok(); |
| bool end_is_constant = ctx->ConstantInput(2, &end_literal).ok(); |
| |
| // Strides have to be static. |
| OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal)); |
| |
| Tensor begin_tensor, end_tensor, strides_tensor; |
| if (begin_is_constant) { |
| OP_REQUIRES_OK( |
| ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor)); |
| } |
| if (end_is_constant) { |
| OP_REQUIRES_OK( |
| ctx, LiteralToHostTensor(end_literal, index_type_, &end_tensor)); |
| } |
| OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, |
| &strides_tensor)); |
| TensorShape processing_shape, final_shape; |
| PartialTensorShape partial_processing_shape, partial_final_shape; |
| bool dummy = false; |
| StridedSliceShapeSpec shape_spec; |
| OP_REQUIRES_OK( |
| ctx, |
| ValidateStridedSliceOp( |
| begin_is_constant ? &begin_tensor : nullptr, |
| end_is_constant ? &end_tensor : nullptr, strides_tensor, |
| input_shape, begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, |
| shrink_axis_mask_, &partial_processing_shape, &partial_final_shape, |
| &dummy, &dummy, &dummy, &begin, &end, &strides, &shape_spec)); |
| |
| xla::XlaOp slice = ctx->Input(0); |
| std::vector<bool> begins_are_dynamic; |
| OP_REQUIRES_OK( |
| ctx, ctx->ResolveInputDynamismIntoPredVector(1, &begins_are_dynamic)); |
| std::vector<bool> ends_are_dynamic; |
| OP_REQUIRES_OK( |
| ctx, ctx->ResolveInputDynamismIntoPredVector(2, &ends_are_dynamic)); |
| if (begin_is_constant && end_is_constant) { |
| OP_REQUIRES( |
| ctx, partial_final_shape.AsTensorShape(&final_shape), |
| InvalidArgument("XLA can't deduce compile time constant output " |
| "shape for strided slice: ", |
| partial_final_shape.DebugString(), |
| ", output shape must be a compile-time constant")); |
| absl::InlinedVector<int64_t, 4> dimensions_to_reverse; |
| absl::InlinedVector<int64_t, 4> slice_begin, slice_end, slice_strides; |
| for (int i = 0; i < begin.size(); ++i) { |
| if (strides[i] > 0) { |
| slice_begin.push_back(begin[i]); |
| slice_end.push_back(std::max(end[i], begin[i])); |
| slice_strides.push_back(strides[i]); |
| } else { |
| // Negative stride: swap begin and end, add 1 because the interval |
| // is semi-open, and mark the dimension to be reversed. |
| slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1); |
| slice_end.push_back(std::max(input_shape.dim_size(i) - end[i] - 1, |
| input_shape.dim_size(i) - begin[i] - 1)); |
| slice_strides.push_back(-strides[i]); |
| dimensions_to_reverse.push_back(i); |
| } |
| } |
| if (!dimensions_to_reverse.empty()) { |
| slice = xla::Rev(slice, dimensions_to_reverse); |
| } |
| slice = xla::Slice(slice, slice_begin, slice_end, slice_strides); |
| auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0)); |
| OP_REQUIRES_OK(ctx, operand_shape_or.status()); |
| xla::Shape xla_shape = operand_shape_or.ValueOrDie(); |
| |
| bool begins_are_static = absl::c_all_of( |
| begins_are_dynamic, [](bool dynamic) { return !dynamic; }); |
| OP_REQUIRES(ctx, begins_are_static, |
| errors::InvalidArgument( |
| "XLA can't use dynamic begin values for slice.")); |
| bool ends_are_static = absl::c_all_of( |
| ends_are_dynamic, [](bool dynamic) { return !dynamic; }); |
| // Static output shape, return a static slice. |
| slice = xla::Reshape(slice, final_shape.dim_sizes()); |
| if (xla_shape.is_static() && ends_are_static) { |
| ctx->SetOutput(0, slice); |
| return; |
| } |
| |
| for (int64_t i = 0; i < final_shape.dims(); ++i) { |
| int64_t input_index = shape_spec.output_to_processing_mapping[i]; |
| if (input_index == -1) { |
| continue; |
| } |
| bool input_is_dynamic = xla_shape.is_dynamic_dimension(input_index); |
| |
| int64_t sparse_index = shape_spec.output_to_sparse_mapping[i]; |
| bool end_is_dynamic = |
| sparse_index == -1 ? false : ends_are_dynamic[sparse_index]; |
| bool backward_slice = sparse_index == -1 |
| ? false |
| : end_literal.Get<int32>({sparse_index}) < 0; |
| if (input_is_dynamic || end_is_dynamic) { |
| OP_REQUIRES( |
| ctx, strides[input_index] == 1, |
| errors::InvalidArgument("XLA has not implemented dynamic " |
| "sized slice with non-trival stride yet. " |
| "Please file a bug against XLA")); |
| // If there is a dynamic dimension, properly set dimension size of |
| // the result. |
| auto operand_size = xla::GetDimensionSize(ctx->Input(0), input_index); |
| if (backward_slice) { |
| // We consider slicing a dynamic tensor t with negative indices as |
| // a dynamic sized slice. E.g., t[: -n], the result length is |
| // shape(t) - n. |
| OP_REQUIRES(ctx, !end_is_dynamic, |
| errors::InvalidArgument( |
| "XLA has not implemented dynamic " |
| "sized slice with dynamic negative index %lld. ")); |
| operand_size = xla::Add( |
| operand_size, |
| xla::ConstantR0<int32>(ctx->builder(), |
| end_literal.Get<int32>({sparse_index}))); |
| } else { |
| // The end of slice with dynamic slice size is the min of operand |
| // shape and slice size. E.g., t[:end_size], result size is |
| // min(shape(t), end_size). |
| xla::XlaOp end_size; |
| if (end_is_dynamic) { |
| end_size = xla::Reshape(xla::Slice(ctx->Input(2), {sparse_index}, |
| {sparse_index + 1}, {1}), |
| {}); |
| } else { |
| end_size = |
| xla::ConstantR0<int32>(ctx->builder(), end[input_index]); |
| } |
| operand_size = xla::Min(operand_size, end_size); |
| } |
| slice = xla::SetDimensionSize( |
| slice, |
| xla::Sub(operand_size, xla::ConstantR0<int32>( |
| ctx->builder(), begin[input_index])), |
| i); |
| } |
| } |
| ctx->SetOutput(0, slice); |
| return; |
| } else { |
| EmitDynamicSlice(ctx, strides, processing_shape, final_shape, |
| partial_processing_shape, partial_final_shape, |
| shape_spec, begins_are_dynamic, ends_are_dynamic); |
| } |
| } |
| |
| private: |
| int32 begin_mask_, end_mask_; |
| int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; |
| DataType index_type_; |
| }; |
| |
| REGISTER_XLA_OP(Name("StridedSlice") |
| .CompileTimeConstantInput("begin") |
| .CompileTimeConstantInput("end") |
| .CompileTimeConstantInput("strides"), |
| StridedSliceOp); |
| |
| class StridedSliceGradOp : public XlaOpKernel { |
| public: |
| explicit StridedSliceGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_)); |
| } |
| |
| // When the begin / end is unknown, compile the gradient into dynamic update |
| // slice into a broadcasted 0s. |
| // |
| // Broadcasted 0 |
| // +----------------------+ |
| // | +----+ | |
| // |<-begin->|grad|<-end->| <== Dynamic update grad into 0s. |
| // | +----+ | |
| // +----------------------+ |
| void CompileAsDynamicUpdateSlice(XlaOpKernelContext* ctx, |
| const TensorShape& input_shape, |
| const xla::Literal& strides_literal) { |
| bool dummy = false; |
| Tensor strides_tensor; |
| PartialTensorShape processing_shape, final_shape; |
| absl::InlinedVector<int64_t, 4> begin; |
| absl::InlinedVector<int64_t, 4> end; |
| absl::InlinedVector<int64_t, 4> strides; |
| StridedSliceShapeSpec shape_spec; |
| OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, |
| &strides_tensor)); |
| OP_REQUIRES_OK( |
| ctx, ValidateStridedSliceOp( |
| nullptr, nullptr, strides_tensor, input_shape, begin_mask_, |
| end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, |
| &processing_shape, &final_shape, &dummy, &dummy, &dummy, |
| &begin, &end, &strides, &shape_spec)); |
| for (int64_t i = 0; i < processing_shape.dims(); ++i) { |
| OP_REQUIRES( |
| ctx, strides[i] == 1, |
| errors::InvalidArgument("Strides in strided slice grad have to be " |
| "one when inputs are not constant.")); |
| } |
| |
| xla::XlaOp grad = ctx->Input(4); |
| xla::Shape grad_shape = ctx->InputXlaShape(4).ValueOrDie(); |
| VLOG(1) << "xla grad shape" << grad_shape; |
| VLOG(1) << "xla final_shape" << final_shape; |
| VLOG(1) << "input_shape" << input_shape.DebugString(); |
| auto input_sizes = input_shape.dim_sizes(); |
| // For unknown output dim the bound of the output shape is input. Pad and |
| // double the size of input shape to leave enough buffer to avoid OOB |
| // dynamic update slice. |
| auto input_sizes_padded = input_shape.dim_sizes(); |
| bool need_padding = false; |
| for (int64_t i = 0; i < processing_shape.dims(); ++i) { |
| if (processing_shape.dim_size(i) == -1) { |
| input_sizes_padded[i] *= 2; |
| need_padding = true; |
| } |
| } |
| for (int64_t i = 0; i < grad_shape.rank(); ++i) { |
| // Use grad shape, which is known, to update unknown processing shape. |
| // Grad shape is the output of the ValidateStridedSliceOp function in |
| // forward pass, thus we use output_to_processing_mapping. |
| if (shape_spec.output_to_processing_mapping[i] != -1) { |
| processing_shape.set_dim(shape_spec.output_to_processing_mapping[i], |
| grad_shape.dimensions(i)); |
| } |
| } |
| |
| std::vector<xla::XlaOp> begins; |
| begins.reserve(processing_shape.dims()); |
| for (int64_t i = 0; i < input_shape.dims(); ++i) { |
| bool begin_mask = (1 << i) & shape_spec.begin_dense_mask; |
| // Similarly, use processing_to_sparse_mapping to find out corresponding |
| // begin dim of the gradient, as indices for dynamic update slice. |
| int64_t begin_dim = shape_spec.processing_to_sparse_mapping[i]; |
| xla::XlaOp begin_index; |
| auto zero = xla::Zero(ctx->builder(), ctx->InputXlaType("begin")); |
| if (begin_mask) { |
| begin_index = zero; |
| } else { |
| xla::XlaOp dim_size = xla::Slice(ctx->Input(0), {i}, {i + 1}, {1}); |
| dim_size = xla::Reshape(dim_size, {}); |
| begin_index = |
| xla::Slice(ctx->Input(1), {begin_dim}, {begin_dim + 1}, {1}); |
| begin_index = xla::Reshape(begin_index, {}); |
| auto index_negative = xla::Lt(begin_index, zero); |
| auto wrapped_index = xla::Add(dim_size, begin_index); |
| // Wrap negative indices around. |
| begin_index = xla::Select(index_negative, wrapped_index, begin_index); |
| } |
| begins.push_back(begin_index); |
| } |
| auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0)); |
| |
| zero = xla::Broadcast(zero, input_sizes_padded); |
| grad = xla::Reshape(grad, processing_shape.dim_sizes()); |
| grad = xla::DynamicUpdateSlice(zero, grad, begins); |
| if (need_padding) { |
| // We padded the input shape to avoid OOB when DUS. Now slice out the |
| // padding in the final result. |
| std::vector<int64_t> strides(input_shape.dims(), 1); |
| std::vector<int64_t> start_indices(input_shape.dims(), 0); |
| grad = xla::Slice(grad, start_indices, input_sizes, strides); |
| } |
| ctx->SetOutput(0, grad); |
| } |
| void Compile(XlaOpKernelContext* ctx) override { |
| TensorShape processing_shape, final_shape; |
| absl::InlinedVector<int64_t, 4> begin; |
| absl::InlinedVector<int64_t, 4> end; |
| absl::InlinedVector<int64_t, 4> strides; |
| |
| TensorShape input_shape; |
| OP_REQUIRES_OK( |
| ctx, ctx->ConstantInputAsShape(0, &input_shape, |
| xla::ValueInferenceMode::kUpperBound)); |
| xla::Literal begin_literal, end_literal, strides_literal; |
| |
| bool begin_is_constant = ctx->ConstantInput(1, &begin_literal).ok(); |
| bool end_is_constant = ctx->ConstantInput(2, &end_literal).ok(); |
| OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal)); |
| if (!(begin_is_constant && end_is_constant)) { |
| CompileAsDynamicUpdateSlice(ctx, input_shape, strides_literal); |
| return; |
| } |
| Tensor begin_tensor, end_tensor, strides_tensor; |
| OP_REQUIRES_OK( |
| ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor)); |
| OP_REQUIRES_OK(ctx, |
| LiteralToHostTensor(end_literal, index_type_, &end_tensor)); |
| OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, |
| &strides_tensor)); |
| |
| bool dummy = false; |
| OP_REQUIRES_OK( |
| ctx, ValidateStridedSliceOp( |
| &begin_tensor, &end_tensor, strides_tensor, input_shape, |
| begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, |
| shrink_axis_mask_, &processing_shape, &final_shape, &dummy, |
| &dummy, &dummy, &begin, &end, &strides)); |
| |
| // Check to make sure dy is consistent with the original slice |
| const TensorShape dy_shape = ctx->InputShape(4); |
| OP_REQUIRES( |
| ctx, final_shape == dy_shape, |
| errors::InvalidArgument("shape of dy was ", dy_shape.DebugString(), |
| " instead of ", final_shape.DebugString())); |
| |
| OP_REQUIRES( |
| ctx, input_shape.dims() == processing_shape.dims(), |
| errors::Internal( |
| "input shape and processing shape must have same number of dims")); |
| |
| auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0)); |
| |
| xla::XlaOp grad = ctx->Input(4); |
| |
| // Undo any new/shrink axes. |
| grad = xla::Reshape(grad, processing_shape.dim_sizes()); |
| |
| // Pad the input gradients. |
| absl::InlinedVector<int64_t, 4> dimensions_to_reverse; |
| xla::PaddingConfig padding_config; |
| |
| for (int i = 0; i < processing_shape.dims(); ++i) { |
| auto* dims = padding_config.add_dimensions(); |
| if (strides[i] > 0) { |
| dims->set_edge_padding_low(begin[i]); |
| dims->set_interior_padding(strides[i] - 1); |
| |
| // Pad the upper dimension up to the expected input shape. (It's |
| // not sufficient simply to use "end[i]" to compute the padding in |
| // cases where the stride does not divide evenly into the interval |
| // between begin[i] and end[i].) |
| int64_t size = |
| dims->edge_padding_low() + processing_shape.dim_size(i) + |
| (processing_shape.dim_size(i) - 1) * dims->interior_padding(); |
| dims->set_edge_padding_high(input_shape.dim_size(i) - size); |
| } else { |
| dimensions_to_reverse.push_back(i); |
| dims->set_edge_padding_high(input_shape.dim_size(i) - begin[i] - 1); |
| dims->set_interior_padding(-strides[i] - 1); |
| |
| // Pad the lower dimension up to the expected input shape. |
| int64_t size = |
| dims->edge_padding_high() + processing_shape.dim_size(i) + |
| (processing_shape.dim_size(i) - 1) * dims->interior_padding(); |
| dims->set_edge_padding_low(input_shape.dim_size(i) - size); |
| } |
| } |
| if (!dimensions_to_reverse.empty()) { |
| grad = xla::Rev(grad, dimensions_to_reverse); |
| } |
| grad = xla::Pad(grad, zero, padding_config); |
| |
| xla::XlaOp dynamic_shape = ctx->Input(0); |
| xla::Shape grad_shape = ctx->builder()->GetShape(grad).ValueOrDie(); |
| std::vector<bool> dynamic_input; |
| OP_REQUIRES_OK(ctx, |
| ctx->ResolveInputDynamismIntoPredVector(0, &dynamic_input)); |
| // Input of strided_slice_op has to have the same shape as output. |
| DCHECK_EQ(grad_shape.rank(), input_shape.dims()); |
| for (int64_t dim = 0; dim < input_shape.dims(); ++dim) { |
| DCHECK_EQ(grad_shape.dimensions(dim), input_shape.dim_size(dim)); |
| if (dynamic_input[dim]) { |
| // Input is a dynamic dimension, set the same dynamic dimension size in |
| // the output. |
| auto dim_size = xla::Slice(dynamic_shape, {dim}, {dim + 1}, {1}); |
| dim_size = xla::ConvertElementType(dim_size, xla::S32); |
| auto dim_size_scalar = xla::Reshape(dim_size, {}); |
| grad = xla::SetDimensionSize(grad, dim_size_scalar, dim); |
| } else if (grad_shape.is_dynamic_dimension(dim)) { |
| // Input is static but output is dynamic, respect input and remove any |
| // dynamic dim in the output. |
| grad = xla::RemoveDynamicDimension(grad, dim); |
| } |
| } |
| |
| ctx->SetOutput(0, grad); |
| } |
| |
| private: |
| int32 begin_mask_, end_mask_; |
| int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; |
| DataType index_type_; |
| }; |
| |
| REGISTER_XLA_OP(Name("StridedSliceGrad") |
| .CompileTimeConstantInput("shape") |
| .CompileTimeConstantInput("begin") |
| .CompileTimeConstantInput("end") |
| .CompileTimeConstantInput("strides"), |
| StridedSliceGradOp); |
| |
| class StridedSliceAssignOp : public XlaOpKernel { |
| public: |
| explicit StridedSliceAssignOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); |
| } |
| |
| void Compile(XlaOpKernelContext* ctx) override { |
| TensorShape final_shape; |
| absl::InlinedVector<int64_t, 4> begin; |
| absl::InlinedVector<int64_t, 4> end; |
| absl::InlinedVector<int64_t, 4> strides; |
| |
| xla::Literal begin_literal, end_literal, strides_literal; |
| OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); |
| OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal)); |
| OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal)); |
| |
| Tensor begin_tensor, end_tensor, strides_tensor; |
| OP_REQUIRES_OK( |
| ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor)); |
| OP_REQUIRES_OK(ctx, |
| LiteralToHostTensor(end_literal, index_type_, &end_tensor)); |
| OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, |
| &strides_tensor)); |
| |
| TensorShape lhs_shape; |
| xla::XlaOp lhs; |
| if (ctx->input_type(0) == DT_RESOURCE) { |
| OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs)); |
| } else { |
| lhs_shape = ctx->InputShape(0); |
| lhs = ctx->Input(0); |
| } |
| |
| const TensorShape rhs_shape = ctx->InputShape(4); |
| |
| TensorShape dummy_processing_shape; |
| bool dummy = false; |
| OP_REQUIRES_OK(ctx, |
| ValidateStridedSliceOp( |
| &begin_tensor, &end_tensor, strides_tensor, lhs_shape, |
| begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, |
| shrink_axis_mask_, &dummy_processing_shape, &final_shape, |
| &dummy, &dummy, &dummy, &begin, &end, &strides)); |
| |
| if (final_shape.num_elements() == 0 && rhs_shape.num_elements() == 0) { |
| // DynamicUpdateSlice does not allow 0-element updates. We should probably |
| // check that rhs_shape can be broadcast to final_shape, but that is |
| // probably better handled when implementing broadcasting more generally. |
| return; |
| } |
| |
| // TODO(aselle): This check is too strong, we only should need |
| // input_shape to be broadcastable to final_shape |
| OP_REQUIRES(ctx, final_shape == rhs_shape, |
| errors::Unimplemented( |
| "sliced l-value shape ", final_shape.DebugString(), |
| " does not match r-value shape ", rhs_shape.DebugString(), |
| ". Automatic broadcasting not yet implemented.")); |
| |
| xla::XlaOp rhs = ctx->Input(4); |
| |
| absl::InlinedVector<int64_t, 4> dimensions_to_reverse; |
| absl::InlinedVector<xla::XlaOp, 4> slice_begin; |
| absl::InlinedVector<int64_t, 4> slice_dims; |
| for (int i = 0; i < begin.size(); ++i) { |
| // TODO(b/121179231): implement strides != 1 |
| OP_REQUIRES( |
| ctx, strides[i] == 1 || strides[i] == -1, |
| errors::Unimplemented("Strides != 1 or -1 are not yet implemented")); |
| if (strides[i] > 0) { |
| slice_begin.push_back( |
| xla::ConstantR0<int64_t>(ctx->builder(), begin[i])); |
| slice_dims.push_back(end[i] - begin[i]); |
| } else { |
| // Negative stride: swap begin and end, add 1 because the interval |
| // is semi-open, and mark the dimension to be reversed. |
| slice_begin.push_back( |
| xla::ConstantR0<int64_t>(ctx->builder(), end[i] + 1)); |
| slice_dims.push_back(begin[i] - end[i]); |
| dimensions_to_reverse.push_back(i); |
| } |
| } |
| |
| if (!dimensions_to_reverse.empty()) { |
| rhs = xla::Rev(rhs, dimensions_to_reverse); |
| } |
| rhs = xla::Reshape(rhs, slice_dims); |
| |
| lhs = xla::DynamicUpdateSlice(lhs, rhs, slice_begin); |
| |
| if (ctx->input_type(0) == DT_RESOURCE) { |
| OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs)); |
| } else { |
| ctx->SetOutput(0, lhs); |
| } |
| } |
| |
| private: |
| int32 begin_mask_, end_mask_; |
| int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; |
| DataType index_type_; |
| DataType dtype_; |
| }; |
| |
| REGISTER_XLA_OP(Name("ResourceStridedSliceAssign") |
| .CompileTimeConstantInput("begin") |
| .CompileTimeConstantInput("end") |
| .CompileTimeConstantInput("strides"), |
| StridedSliceAssignOp); |
| |
| REGISTER_XLA_OP(Name("TensorStridedSliceUpdate") |
| .CompileTimeConstantInput("begin") |
| .CompileTimeConstantInput("end") |
| .CompileTimeConstantInput("strides"), |
| StridedSliceAssignOp); |
| |
| } // namespace |
| } // namespace tensorflow |