| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // XLA specific pooling ops. |
| |
| #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" |
| #include "tensorflow/compiler/tf2xla/shape_util.h" |
| #include "tensorflow/compiler/tf2xla/type_util.h" |
| #include "tensorflow/compiler/tf2xla/xla_helpers.h" |
| #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" |
| #include "tensorflow/compiler/tf2xla/xla_op_registry.h" |
| #include "tensorflow/compiler/xla/client/lib/arithmetic.h" |
| #include "tensorflow/compiler/xla/client/lib/constants.h" |
| #include "tensorflow/compiler/xla/client/lib/pooling.h" |
| #include "tensorflow/compiler/xla/client/xla_builder.h" |
| #include "tensorflow/compiler/xla/client/xla_computation.h" |
| #include "tensorflow/compiler/xla/literal.h" |
| #include "tensorflow/compiler/xla/util.h" |
| #include "tensorflow/core/framework/bounds_check.h" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/register_types.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/platform/errors.h" |
| #include "tensorflow/core/util/tensor_format.h" |
| |
| namespace tensorflow { |
| namespace { |
| |
| // Superclass of pooling ops. |
| class PoolingOp : public XlaOpKernel { |
| public: |
| PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims, |
| const DataType reduction_type) |
| : XlaOpKernel(ctx), |
| num_spatial_dims_(num_spatial_dims), |
| reduction_type_(reduction_type) { |
| if (ctx->num_inputs() == 1) { |
| std::vector<int32> ksize_int; |
| std::vector<int32> stride_int; |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int)); |
| OP_REQUIRES(ctx, ksize_int.size() == num_dims(), |
| errors::InvalidArgument("Sliding window ksize field must " |
| "specify ", |
| num_dims(), " dimensions")); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int)); |
| OP_REQUIRES(ctx, stride_int.size() == num_dims(), |
| errors::InvalidArgument("Sliding window stride field must " |
| "specify ", |
| num_dims(), " dimensions")); |
| for (int i = 0; i < num_dims(); ++i) { |
| ksize_.push_back(ksize_int[i]); |
| stride_.push_back(stride_int[i]); |
| } |
| } |
| Padding padding; |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding)); |
| padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame; |
| |
| OP_REQUIRES_OK( |
| ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_)); |
| } |
| |
| int num_dims() const { return num_spatial_dims_ + 2; } |
| |
| protected: |
| xla::StatusOr<std::vector<int64>> GetKernelSize(XlaOpKernelContext* ctx) { |
| if (ctx->num_inputs() == 1) { |
| return ksize_; |
| } |
| const TensorShape ksize_shape = ctx->InputShape(1); |
| // Validate input sizes. |
| if (!TensorShapeUtils::IsVector(ksize_shape)) { |
| return errors::InvalidArgument("ksize must be a vector, not shape ", |
| ksize_shape.DebugString()); |
| } |
| if (ksize_shape.num_elements() != num_dims()) { |
| return errors::InvalidArgument( |
| "Sliding window ksize field must " |
| "specify ", |
| num_dims(), " dimensions"); |
| } |
| std::vector<int64> ksize; |
| auto status = ctx->ConstantInputAsIntVector(1, &ksize); |
| if (!status.ok()) { |
| return status; |
| } |
| return ksize; |
| } |
| |
| xla::StatusOr<std::vector<int64>> GetStride(XlaOpKernelContext* ctx) { |
| if (ctx->num_inputs() == 1) { |
| return stride_; |
| } |
| const TensorShape stride_shape = ctx->InputShape(2); |
| // Validate input sizes. |
| if (!TensorShapeUtils::IsVector(stride_shape)) { |
| return errors::InvalidArgument("stride must be a vector, not shape ", |
| stride_shape.DebugString()); |
| } |
| if (stride_shape.num_elements() != num_dims()) { |
| return errors::InvalidArgument( |
| "Sliding window stride field must " |
| "specify ", |
| num_dims(), " dimensions"); |
| } |
| std::vector<int64> stride; |
| auto status = ctx->ConstantInputAsIntVector(2, &stride); |
| if (!status.ok()) { |
| return status; |
| } |
| return stride; |
| } |
| |
| protected: |
| const int num_spatial_dims_; |
| std::vector<int64> ksize_; |
| std::vector<int64> stride_; |
| xla::Padding padding_; |
| TensorFormat data_format_ = FORMAT_NHWC; |
| DataType reduction_type_; |
| xla::PrimitiveType xla_reduction_type_; |
| }; |
| |
| // Converts the tensor data format to the one required by the XLA pooling |
| // library. |
| xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format, |
| int num_spatial_dims) { |
| int num_dims = num_spatial_dims + 2; |
| int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format); |
| int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format); |
| absl::InlinedVector<int64, 4> spatial_dimensions(num_spatial_dims); |
| for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) { |
| spatial_dimensions[spatial_dim] = |
| GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim); |
| } |
| return xla::TensorFormat(/*batch_dimension=*/batch_dimension, |
| /*feature_dimension=*/feature_dimension, |
| /*spatial_dimensions=*/spatial_dimensions); |
| } |
| |
| class MaxPoolOp : public PoolingOp { |
| public: |
| MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) |
| : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, |
| /*reduction_type=*/ctx->input_type(0)) { |
| string data_format_str; |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); |
| OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), |
| errors::InvalidArgument("Invalid data format")); |
| OP_REQUIRES( |
| ctx, |
| data_format_ != FORMAT_NCHW_VECT_C && |
| data_format_ != FORMAT_NHWC_VECT_W, |
| errors::Unimplemented("XLA does not support the VECT_* data formats. " |
| "Returning unimplemented from MaxPool to keep " |
| "Tensorflow's intended optimized MaxPool here.")); |
| } |
| |
| void Compile(XlaOpKernelContext* ctx) override { |
| auto ksize_or_error = GetKernelSize(ctx); |
| OP_REQUIRES_OK(ctx, ksize_or_error.status()); |
| std::vector<int64> ksize = ksize_or_error.ValueOrDie(); |
| |
| auto stride_or_error = GetStride(ctx); |
| OP_REQUIRES_OK(ctx, stride_or_error.status()); |
| std::vector<int64> stride = stride_or_error.ValueOrDie(); |
| |
| const TensorShape input_shape = ctx->InputShape(0); |
| OP_REQUIRES(ctx, input_shape.dims() == num_dims(), |
| errors::InvalidArgument("Input to ", type_string(), |
| " operator must have ", num_dims(), |
| " dimensions")); |
| |
| auto pooling = |
| xla::MaxPool(ctx->Input(0), ksize, stride, padding_, |
| XlaTensorFormat(data_format_, input_shape.dims() - 2)); |
| ctx->SetOutput(0, pooling); |
| } |
| }; |
| |
| class MaxPool2DOp : public MaxPoolOp { |
| public: |
| explicit MaxPool2DOp(OpKernelConstruction* ctx) |
| : MaxPoolOp(ctx, /*num_spatial_dims=*/2) {} |
| }; |
| REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp); |
| REGISTER_XLA_OP(Name("MaxPoolV2") |
| .CompileTimeConstantInput("ksize") |
| .CompileTimeConstantInput("strides"), |
| MaxPool2DOp); |
| |
| class MaxPool3DOp : public MaxPoolOp { |
| public: |
| explicit MaxPool3DOp(OpKernelConstruction* ctx) |
| : MaxPoolOp(ctx, /*num_spatial_dims=*/3) {} |
| }; |
| REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp); |
| |
| class AvgPoolOp : public PoolingOp { |
| public: |
| AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) |
| : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, |
| /*reduction_type=*/ |
| XlaHelpers::SumAccumulationType(ctx->input_type(0))) { |
| string data_format_str; |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); |
| OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), |
| errors::InvalidArgument("Invalid data format")); |
| } |
| |
| void Compile(XlaOpKernelContext* ctx) override { |
| auto ksize_or_error = GetKernelSize(ctx); |
| OP_REQUIRES_OK(ctx, ksize_or_error.status()); |
| std::vector<int64> ksize = ksize_or_error.ValueOrDie(); |
| |
| auto stride_or_error = GetStride(ctx); |
| OP_REQUIRES_OK(ctx, stride_or_error.status()); |
| std::vector<int64> stride = stride_or_error.ValueOrDie(); |
| |
| const TensorShape input_shape = ctx->InputShape(0); |
| OP_REQUIRES(ctx, input_shape.dims() == num_dims(), |
| errors::InvalidArgument("Input to ", type_string(), |
| " operator must have ", num_dims(), |
| " dimensions")); |
| |
| auto xla_data_format = |
| XlaTensorFormat(data_format_, input_shape.dims() - 2); |
| auto spatial_padding = MakeSpatialPadding( |
| input_shape.dim_sizes(), ksize, stride, padding_, xla_data_format); |
| |
| // Convert the input to the reduction type. |
| auto converted_input = |
| ConvertElementType(ctx->Input(0), xla_reduction_type_); |
| auto pooling = |
| xla::AvgPool(converted_input, ksize, stride, spatial_padding, |
| xla_data_format, padding_ == xla::Padding::kValid); |
| // Convert the pooling result back to the input type before returning it. |
| ctx->SetOutput(0, ConvertElementType(pooling, ctx->input_xla_type(0))); |
| } |
| }; |
| |
| class AvgPool2DOp : public AvgPoolOp { |
| public: |
| explicit AvgPool2DOp(OpKernelConstruction* ctx) |
| : AvgPoolOp(ctx, /*num_spatial_dims=*/2) {} |
| }; |
| REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp); |
| |
| class AvgPool3DOp : public AvgPoolOp { |
| public: |
| explicit AvgPool3DOp(OpKernelConstruction* ctx) |
| : AvgPoolOp(ctx, /*num_spatial_dims=*/3) {} |
| }; |
| REGISTER_XLA_OP(Name("AvgPool3D"), AvgPool3DOp); |
| |
| // The operation to compute MaxPool gradients. |
| // It takes three inputs: |
| // - The original input tensor |
| // - The original output tensor |
| // - Backprop tensor for output |
| // It produces one output: backprop tensor for input. |
| class MaxPoolGradOp : public XlaOpKernel { |
| public: |
| MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims) |
| : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { |
| if (ctx->num_inputs() == 3) { |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); |
| } |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); |
| } |
| |
| int num_dims() const { return num_spatial_dims_ + 2; } |
| |
| void Compile(XlaOpKernelContext* ctx) override { |
| if (ctx->num_inputs() != 3) { |
| OP_REQUIRES( |
| ctx, ctx->num_inputs() == 5, |
| errors::InvalidArgument("Must supply ksize and stride arguments.")); |
| const TensorShape ksize_shape = ctx->InputShape(3); |
| // Validate input sizes. |
| OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), |
| errors::InvalidArgument("ksize must be a vector, not shape ", |
| ksize_shape.DebugString())); |
| OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_)); |
| |
| const TensorShape stride_shape = ctx->InputShape(4); |
| // Validate input sizes. |
| OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), |
| errors::InvalidArgument("stride must be a vector, not shape ", |
| stride_shape.DebugString())); |
| OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_)); |
| } |
| |
| OP_REQUIRES(ctx, ksize_.size() == num_dims(), |
| errors::InvalidArgument("Sliding window ksize field must " |
| "specify ", |
| num_dims(), " dimensions")); |
| OP_REQUIRES(ctx, stride_.size() == num_dims(), |
| errors::InvalidArgument("Sliding window strides field must " |
| "specify ", |
| num_dims(), " dimensions")); |
| |
| const TensorShape tensor_in_shape = ctx->InputShape(0); |
| const TensorShape tensor_out_shape = ctx->InputShape(1); |
| const TensorShape out_backprop_shape = ctx->InputShape(2); |
| |
| // For maxpooling, tensor_in should have num_dims() dimensions. |
| OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(), |
| errors::InvalidArgument("tensor_in must be ", num_dims(), |
| "-dimensional")); |
| OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(), |
| errors::InvalidArgument("tensor_out must be ", num_dims(), |
| "-dimensional")); |
| // For maxpooling, out_backprop should have num_dims() dimensions. |
| OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(), |
| errors::InvalidArgument("out_backprop must be ", num_dims(), |
| "-dimensional")); |
| |
| // TODO(phawkins): The XLA version doesn't need tensor_out. Investigate |
| // whether this is a good time/space tradeoff. |
| auto input = ctx->Input(0); |
| auto out_backprop = ctx->Input(2); |
| |
| xla::Padding xla_padding = |
| (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; |
| |
| // Create a MaxPool operation to check the expected resulting shape, and |
| // then throw away the operation because we don't actually need it here. |
| TensorShape expected_out_shape; |
| auto pooling = |
| xla::MaxPool(ctx->Input(0), ksize_, stride_, xla_padding, |
| XlaTensorFormat(data_format_, tensor_in_shape.dims() - 2)); |
| auto status_or_shape = pooling.builder()->GetShape(pooling); |
| OP_REQUIRES_OK(ctx, status_or_shape.status()); |
| OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(status_or_shape.ValueOrDie(), |
| &expected_out_shape)); |
| OP_REQUIRES(ctx, expected_out_shape == out_backprop_shape, |
| errors::Unimplemented("The output dimensions do not match the " |
| "other input values.")); |
| |
| xla::PrimitiveType element_type; |
| OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type)); |
| xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2)); |
| auto select = CreateScalarGeComputation(element_type, ctx->builder()); |
| auto scatter = CreateScalarAddComputation(element_type, ctx->builder()); |
| xla::XlaOp gradients = |
| xla::SelectAndScatter(input, select, ksize_, stride_, xla_padding, |
| out_backprop, init_value, scatter); |
| |
| ctx->SetOutput(0, gradients); |
| } |
| |
| protected: |
| const int num_spatial_dims_; |
| std::vector<int64> ksize_; |
| std::vector<int64> stride_; |
| Padding padding_; |
| TensorFormat data_format_ = FORMAT_NHWC; |
| }; |
| |
| class MaxPool2DGradOp : public MaxPoolGradOp { |
| public: |
| explicit MaxPool2DGradOp(OpKernelConstruction* ctx) |
| : MaxPoolGradOp(ctx, /*num_spatial_dims=*/2) { |
| string data_format; |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); |
| OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), |
| errors::InvalidArgument("Invalid data format")); |
| } |
| }; |
| REGISTER_XLA_OP(Name("MaxPoolGrad"), MlirXlaOpKernel); |
| REGISTER_XLA_OP(Name("MaxPoolGradV2") |
| .CompileTimeConstantInput("ksize") |
| .CompileTimeConstantInput("strides"), |
| MaxPool2DGradOp); |
| |
| class MaxPool3DGradOp : public MaxPoolGradOp { |
| public: |
| explicit MaxPool3DGradOp(OpKernelConstruction* ctx) |
| : MaxPoolGradOp(ctx, /*num_spatial_dims=*/3) {} |
| }; |
| REGISTER_XLA_OP(Name("MaxPool3DGrad"), MaxPool3DGradOp); |
| |
| // Average-pooling gradient |
| class AvgPoolGradOp : public XlaOpKernel { |
| public: |
| AvgPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims) |
| : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_)); |
| OP_REQUIRES(ctx, ksize_.size() == num_dims(), |
| errors::InvalidArgument("Sliding window ksize field must " |
| "specify ", |
| num_dims(), " dimensions")); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); |
| OP_REQUIRES(ctx, stride_.size() == num_dims(), |
| errors::InvalidArgument("Sliding window strides field must " |
| "specify ", |
| num_dims(), " dimensions")); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); |
| OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1, |
| errors::Unimplemented( |
| "Pooling is not yet supported on the batch dimension.")); |
| |
| string data_format; |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); |
| OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), |
| errors::InvalidArgument("Invalid data format")); |
| } |
| |
| int num_dims() const { return num_spatial_dims_ + 2; } |
| |
| void Compile(XlaOpKernelContext* ctx) override { |
| TensorShape gradients_shape; |
| OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &gradients_shape)); |
| |
| const TensorShape out_backprop_shape = ctx->InputShape(1); |
| |
| // For avgpooling, tensor_in_shape should have num_dims() dimensions. |
| OP_REQUIRES(ctx, gradients_shape.dims() == num_dims(), |
| errors::InvalidArgument("orig_input_shape must be ", num_dims(), |
| "-dimensional")); |
| |
| // For avgpooling, out_backprop should have num_dims() dimensions. |
| OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(), |
| errors::InvalidArgument("out_backprop must be ", num_dims(), |
| "-dimensional")); |
| |
| auto out_backprop = ctx->Input(1); |
| std::vector<int64> stride_int64s(stride_.begin(), stride_.end()); |
| xla::Padding xla_padding = |
| (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; |
| xla::PrimitiveType xla_reduction_type; |
| auto reduction_type = XlaHelpers::SumAccumulationType(ctx->input_type(1)); |
| OP_REQUIRES_OK( |
| ctx, DataTypeToPrimitiveType(reduction_type, &xla_reduction_type)); |
| auto converted_out_backprop = |
| xla::ConvertElementType(out_backprop, xla_reduction_type); |
| auto xla_data_format = |
| XlaTensorFormat(data_format_, gradients_shape.dims() - 2); |
| auto padding_values = |
| MakeSpatialPadding(gradients_shape.dim_sizes(), ksize_, stride_int64s, |
| xla_padding, xla_data_format); |
| auto in_backprop = |
| xla::AvgPoolGrad(converted_out_backprop, gradients_shape.dim_sizes(), |
| ksize_, stride_int64s, padding_values, xla_data_format, |
| /*counts_include_padding=*/padding_ == VALID); |
| // Convert the pooling result back to the input type before returning it. |
| xla::PrimitiveType xla_out_backprop_type; |
| OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1), |
| &xla_out_backprop_type)); |
| ctx->SetOutput(0, |
| xla::ConvertElementType(in_backprop, xla_out_backprop_type)); |
| } |
| |
| protected: |
| const int num_spatial_dims_; |
| std::vector<int64> ksize_; |
| std::vector<int32> stride_; |
| Padding padding_; |
| TensorFormat data_format_ = FORMAT_NHWC; |
| }; |
| |
| class AvgPool2DGradOp : public AvgPoolGradOp { |
| public: |
| explicit AvgPool2DGradOp(OpKernelConstruction* ctx) |
| : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) {} |
| }; |
| REGISTER_XLA_OP( |
| Name("AvgPoolGrad").CompileTimeConstantInput("orig_input_shape"), |
| AvgPool2DGradOp); |
| |
| class AvgPool3DGradOp : public AvgPoolGradOp { |
| public: |
| explicit AvgPool3DGradOp(OpKernelConstruction* ctx) |
| : AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {} |
| }; |
| REGISTER_XLA_OP( |
| Name("AvgPool3DGrad").CompileTimeConstantInput("orig_input_shape"), |
| AvgPool3DGradOp); |
| |
| class MaxPoolGradGradOp : public XlaOpKernel { |
| public: |
| MaxPoolGradGradOp(OpKernelConstruction* ctx, int num_spatial_dims) |
| : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { |
| if (ctx->num_inputs() == 3) { |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); |
| } |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); |
| } |
| |
| int num_dims() const { return num_spatial_dims_ + 2; } |
| |
| void Compile(XlaOpKernelContext* ctx) override { |
| if (ctx->num_inputs() != 3) { |
| OP_REQUIRES( |
| ctx, ctx->num_inputs() == 5, |
| errors::InvalidArgument("Must supply ksize and stride arguments.")); |
| const TensorShape ksize_shape = ctx->InputShape(3); |
| // Validate input sizes. |
| OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), |
| errors::InvalidArgument("ksize must be a vector, not shape ", |
| ksize_shape.DebugString())); |
| OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_)); |
| |
| const TensorShape stride_shape = ctx->InputShape(4); |
| // Validate input sizes. |
| OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), |
| errors::InvalidArgument("stride must be a vector, not shape ", |
| stride_shape.DebugString())); |
| OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_)); |
| } |
| |
| OP_REQUIRES(ctx, ksize_.size() == num_dims(), |
| errors::InvalidArgument("Sliding window ksize field must " |
| "specify ", |
| num_dims(), " dimensions")); |
| OP_REQUIRES(ctx, stride_.size() == num_dims(), |
| errors::InvalidArgument("Sliding window strides field must " |
| "specify ", |
| num_dims(), " dimensions")); |
| |
| const TensorShape tensor_in_shape = ctx->InputShape(0); |
| const TensorShape tensor_out_shape = ctx->InputShape(1); |
| const TensorShape out_backprop_shape = ctx->InputShape(2); |
| |
| // For maxpooling, tensor_in should have num_dims() dimensions. |
| OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(), |
| errors::InvalidArgument("tensor_in must be ", num_dims(), |
| "-dimensional")); |
| OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(), |
| errors::InvalidArgument("tensor_out must be ", num_dims(), |
| "-dimensional")); |
| // For maxpooling, out_backprop should have num_dims() dimensions. |
| OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(), |
| errors::InvalidArgument("out_backprop must be ", num_dims(), |
| "-dimensional")); |
| |
| // What we want to compute: |
| // Given y = MaxPool(x), and xs_grad = MaxPoolGrad(x, y, ys_grad) |
| // MaxPoolGradGrad computes {ys_grad}_grad given x, y, and {xs_grad}_grad. |
| // |
| // In the regular TF op, this amounts to selecting for each window the |
| // incoming backprop value from xs_grad_grad that corresponds to the maximal |
| // value in the corresponding window of x. |
| // |
| // TODO(b/73062247): What we really want is a ReduceWindow with different |
| // arrays for index selection vs return value selection--a select-to-gather. |
| // |
| // Here, we implement a bitwise hack: we use the hi 16 bits of input for |
| // separate max pooling alongside each of the hi and lo 16 bits of |
| // out_backprop packed into 16 lo bits, which we then glue back together at |
| // the end to get a full 32 bits of gradient. |
| // |
| // This could select the wrong backprop value for two x values that are |
| // equally maximal up to the first 16 bits, in which case we are taking the |
| // latter. |
| // |
| // Note that in principle we could use 32 separate maxpools to recover each |
| // of 32 bits of the gradient while preserving 31 bits of input for the max |
| // pooling criteria; here, we just truncate to the first 16 bits of input. |
| |
| auto input = ctx->Input(0); |
| auto out_backprop = ctx->Input(2); |
| |
| auto b = ctx->builder(); |
| |
| auto sixteen = xla::ConstantR0<uint32>(b, 16); |
| // in (f32) -> round to 7 mantissa bits (bf16)-> 16-high-bit u32. |
| // |
| // NOTE: Use a ReducePrecision operation instead of a cast to BF16 and back |
| // to F32 since the XLA compiler may ignore narrowing casts to floating |
| // point types if the debug option xla_allow_excess_precision is set. |
| auto in_hi = xla::BitcastConvertType( |
| xla::ReducePrecision(input, /*exponent_bits=*/8, /*mantissa_bits=*/7), |
| xla::U32); |
| auto bp_int = xla::BitcastConvertType(out_backprop, xla::U32); |
| auto bp_hi = xla::ShiftRightLogical(bp_int, sixteen); |
| auto bp_lo = |
| xla::ShiftRightLogical(xla::ShiftLeft(bp_int, sixteen), sixteen); |
| auto in_hi_bp_hi = xla::Add(in_hi, bp_hi); // Want an unsigned add. |
| auto in_hi_bp_lo = xla::Add(in_hi, bp_lo); // Want an unsigned add. |
| |
| auto init_value = xla::MinValue(b, xla::F32); |
| // We will reduce by taking the maximal value up to 16 bits (ignoring the lo |
| // 16 bits of packed-in hi/lo backprop value). |
| auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits"); |
| { |
| // F32 parameters to satisfy lowering type restriction for reduce opcode. |
| const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {}); |
| auto lhs = xla::Parameter(rb.get(), 0, scalar, "lhs"); |
| auto rhs = xla::Parameter(rb.get(), 1, scalar, "rhs"); |
| auto sixteen = xla::ConstantR0<int32>(rb.get(), 16); |
| auto lhs_criteria = |
| xla::ShiftLeft(xla::ShiftRightLogical( |
| xla::BitcastConvertType(lhs, xla::S32), sixteen), |
| sixteen); |
| auto rhs_criteria = |
| xla::ShiftLeft(xla::ShiftRightLogical( |
| xla::BitcastConvertType(rhs, xla::S32), sixteen), |
| sixteen); |
| // Must use a F32 comparison, because S32 would not work for negatives. |
| xla::Select(xla::Ge(xla::BitcastConvertType(lhs_criteria, xla::F32), |
| xla::BitcastConvertType(rhs_criteria, xla::F32)), |
| lhs, rhs); |
| } |
| auto reduce = rb->BuildAndNoteError(); |
| xla::Padding xla_padding = |
| (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; |
| auto pooled_hi = |
| xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_hi, xla::F32), |
| init_value, reduce, ksize_, stride_, xla_padding); |
| auto pooled_lo = |
| xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_lo, xla::F32), |
| init_value, reduce, ksize_, stride_, xla_padding); |
| auto grads_hi = |
| xla::ShiftLeft(xla::BitcastConvertType(pooled_hi, xla::U32), sixteen); |
| auto grads_lo = xla::ShiftRightLogical( |
| xla::ShiftLeft(xla::BitcastConvertType(pooled_lo, xla::U32), sixteen), |
| sixteen); |
| auto grads = xla::Add(grads_hi, grads_lo); // Want an unsigned add. |
| |
| xla::PrimitiveType element_type; |
| OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type)); |
| ctx->SetOutput(0, xla::BitcastConvertType(grads, element_type)); |
| } |
| |
| protected: |
| const int num_spatial_dims_; |
| std::vector<int64> ksize_; |
| std::vector<int64> stride_; |
| Padding padding_; |
| TensorFormat data_format_ = FORMAT_NHWC; |
| }; |
| |
| class MaxPool2DGradGradOp : public MaxPoolGradGradOp { |
| public: |
| explicit MaxPool2DGradGradOp(OpKernelConstruction* ctx) |
| : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/2) { |
| string data_format; |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); |
| OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), |
| errors::InvalidArgument("Invalid data format")); |
| } |
| }; |
| REGISTER_XLA_OP(Name("MaxPoolGradGrad").TypeConstraint("T", DT_FLOAT), |
| MaxPool2DGradGradOp); |
| REGISTER_XLA_OP(Name("MaxPoolGradGradV2") |
| .TypeConstraint("T", DT_FLOAT) |
| .CompileTimeConstantInput("ksize") |
| .CompileTimeConstantInput("strides"), |
| MaxPool2DGradGradOp); |
| |
| class MaxPool3DGradGradOp : public MaxPoolGradGradOp { |
| public: |
| explicit MaxPool3DGradGradOp(OpKernelConstruction* ctx) |
| : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/3) { |
| string data_format; |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); |
| OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), |
| errors::InvalidArgument("Invalid data format")); |
| } |
| }; |
| REGISTER_XLA_OP(Name("MaxPool3DGradGrad").TypeConstraint("T", DT_FLOAT), |
| MaxPool3DGradGradOp); |
| |
| } // anonymous namespace |
| } // namespace tensorflow |