| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "absl/types/span.h" |
| #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" |
| #include "tensorflow/compiler/tf2xla/lib/util.h" |
| #include "tensorflow/compiler/tf2xla/xla_helpers.h" |
| #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" |
| #include "tensorflow/compiler/tf2xla/xla_op_registry.h" |
| #include "tensorflow/compiler/xla/client/lib/arithmetic.h" |
| #include "tensorflow/compiler/xla/client/lib/comparators.h" |
| #include "tensorflow/compiler/xla/client/lib/constants.h" |
| #include "tensorflow/compiler/xla/client/lib/loops.h" |
| #include "tensorflow/compiler/xla/client/lib/sorting.h" |
| #include "tensorflow/compiler/xla/client/xla_builder.h" |
| #include "tensorflow/compiler/xla/shape_util.h" |
| #include "tensorflow/compiler/xla/xla_data.pb.h" |
| #include "tensorflow/core/framework/tensor_shape.h" |
| #include "tensorflow/core/framework/types.pb.h" |
| |
| namespace tensorflow { |
| namespace { |
| |
| // Converts 'input' from RGB format to HSV format. |
| // 'shape' is the shape of the red/green/blue tensors. |
| std::array<xla::XlaOp, 3> RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b, |
| const std::array<xla::XlaOp, 3>& rgb, |
| DataType dtype, const TensorShape& shape) { |
| auto zero = XlaHelpers::Zero(b, dtype); |
| auto one = XlaHelpers::One(b, dtype); |
| |
| auto red = rgb[0]; |
| auto green = rgb[1]; |
| auto blue = rgb[2]; |
| auto value = xla::Max(xla::Max(red, green), blue); |
| auto minimum = xla::Min(xla::Min(red, green), blue); |
| auto range = xla::Sub(value, minimum); |
| |
| auto zeros = xla::Broadcast(zero, shape.dim_sizes()); |
| auto saturation = |
| xla::Select(xla::Gt(value, zero), xla::Div(range, value), zeros); |
| |
| auto norm = xla::Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range); |
| |
| auto hue = |
| xla::Select(xla::Eq(green, value), |
| xla::Add(xla::Mul(norm, xla::Sub(blue, red)), |
| XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)), |
| xla::Add(xla::Mul(norm, xla::Sub(red, green)), |
| XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0))); |
| hue = xla::Select(xla::Eq(red, value), xla::Mul(norm, xla::Sub(green, blue)), |
| hue); |
| hue = xla::Select(xla::Gt(range, zero), hue, zeros); |
| hue = xla::Select(xla::Lt(hue, zero), xla::Add(hue, one), hue); |
| return {hue, saturation, value}; |
| } |
| |
| // Converts 'input' from HSV format to RGB format. |
| std::array<xla::XlaOp, 3> HSVToRGB(xla::XlaBuilder* b, |
| const std::array<xla::XlaOp, 3>& hsv, |
| DataType dtype) { |
| xla::XlaOp hue = hsv[0]; |
| xla::XlaOp saturation = hsv[1]; |
| xla::XlaOp value = hsv[2]; |
| auto zero = XlaHelpers::Zero(b, dtype); |
| auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0); |
| auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0); |
| auto three = XlaHelpers::FloatLiteral(b, dtype, 3.0); |
| auto four = XlaHelpers::FloatLiteral(b, dtype, 4.0); |
| auto six = XlaHelpers::FloatLiteral(b, dtype, 6.0); |
| |
| auto dh = xla::Mul(hue, six); |
| auto dr = xla::Clamp(zero, xla::Sub(xla::Abs(xla::Sub(dh, three)), one), one); |
| auto dg = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, two))), one); |
| auto db = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, four))), one); |
| auto one_minus_s = xla::Sub(one, saturation); |
| |
| auto red = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dr)), value); |
| auto green = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dg)), value); |
| auto blue = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, db)), value); |
| return {red, green, blue}; |
| } |
| |
| class RGBToHSVOp : public XlaOpKernel { |
| public: |
| explicit RGBToHSVOp(OpKernelConstruction* context) : XlaOpKernel(context) {} |
| |
| void Compile(XlaOpKernelContext* context) override { |
| const TensorShape input_shape = context->InputShape(0); |
| OP_REQUIRES(context, input_shape.dims() >= 1, |
| errors::InvalidArgument("input must be at least 1D", |
| input_shape.DebugString())); |
| int channel_dim = input_shape.dims() - 1; |
| int64 channels = input_shape.dim_size(channel_dim); |
| OP_REQUIRES( |
| context, channels == 3, |
| errors::FailedPrecondition("input must have 3 channels but input has ", |
| channels, " channels.")); |
| |
| xla::XlaBuilder* b = context->builder(); |
| xla::XlaOp input = context->Input(0); |
| |
| xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0, |
| /*limit_index=*/1, /*stride=*/1, |
| /*dimno=*/channel_dim); |
| xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1, |
| /*limit_index=*/2, /*stride=*/1, |
| /*dimno=*/channel_dim); |
| xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2, |
| /*limit_index=*/3, /*stride=*/1, |
| /*dimno=*/channel_dim); |
| TensorShape channel_shape = input_shape; |
| channel_shape.set_dim(channel_dim, 1); |
| auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), |
| channel_shape); |
| |
| context->SetOutput(0, xla::ConcatInDim(b, hsv, channel_dim)); |
| } |
| }; |
| REGISTER_XLA_OP(Name("RGBToHSV"), RGBToHSVOp); |
| |
| class HSVToRGBOp : public XlaOpKernel { |
| public: |
| explicit HSVToRGBOp(OpKernelConstruction* context) : XlaOpKernel(context) {} |
| |
| void Compile(XlaOpKernelContext* context) override { |
| const TensorShape input_shape = context->InputShape(0); |
| OP_REQUIRES(context, input_shape.dims() >= 1, |
| errors::InvalidArgument("input must be at least 1D", |
| input_shape.DebugString())); |
| int channel_dim = input_shape.dims() - 1; |
| int64 channels = input_shape.dim_size(channel_dim); |
| OP_REQUIRES( |
| context, channels == 3, |
| errors::FailedPrecondition("input must have 3 channels but input has ", |
| channels, " channels.")); |
| |
| xla::XlaBuilder* b = context->builder(); |
| xla::XlaOp input = context->Input(0); |
| xla::XlaOp hue = xla::SliceInDim(input, /*start_index=*/0, |
| /*limit_index=*/1, /*stride=*/1, |
| /*dimno=*/channel_dim); |
| xla::XlaOp saturation = xla::SliceInDim(input, /*start_index=*/1, |
| /*limit_index=*/2, /*stride=*/1, |
| /*dimno=*/channel_dim); |
| xla::XlaOp value = xla::SliceInDim(input, /*start_index=*/2, |
| /*limit_index=*/3, /*stride=*/1, |
| /*dimno=*/channel_dim); |
| |
| auto rgb = HSVToRGB(context->builder(), {hue, saturation, value}, |
| context->input_type(0)); |
| |
| context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim)); |
| } |
| }; |
| REGISTER_XLA_OP(Name("HSVToRGB"), HSVToRGBOp); |
| |
| class AdjustContrastOpV2 : public XlaOpKernel { |
| public: |
| explicit AdjustContrastOpV2(OpKernelConstruction* context) |
| : XlaOpKernel(context) {} |
| |
| void Compile(XlaOpKernelContext* context) override { |
| const TensorShape& input_shape = context->InputShape(0); |
| const TensorShape& factor_shape = context->InputShape(1); |
| OP_REQUIRES(context, input_shape.dims() >= 3, |
| errors::InvalidArgument("input must be at least 3-D, got shape", |
| input_shape.DebugString())); |
| int height_dim = input_shape.dims() - 3; |
| int width_dim = input_shape.dims() - 2; |
| int channel_dim = input_shape.dims() - 1; |
| const int64 height = input_shape.dim_size(height_dim); |
| const int64 width = input_shape.dim_size(width_dim); |
| |
| OP_REQUIRES(context, TensorShapeUtils::IsScalar(factor_shape), |
| errors::InvalidArgument("contrast_factor must be scalar: ", |
| factor_shape.DebugString())); |
| |
| xla::XlaBuilder* b = context->builder(); |
| DataType type = context->input_type(0); |
| |
| xla::XlaOp input = context->Input(0); |
| xla::XlaOp factor = XlaHelpers::ConvertElementType(context->Input(1), type); |
| |
| const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); |
| auto converted = XlaHelpers::ConvertElementType(input, accumulation_type); |
| auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), |
| *context->GetOrCreateAdd(accumulation_type), |
| {height_dim, width_dim}); |
| |
| auto output = xla::Div( |
| reduce, XlaHelpers::FloatLiteral(b, accumulation_type, height * width)); |
| output = XlaHelpers::ConvertElementType(output, type); |
| |
| std::vector<int64> broadcast_dims(input_shape.dims() - 2); |
| std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); |
| broadcast_dims.back() = channel_dim; |
| output = |
| xla::Add(xla::Mul(input, factor), |
| xla::Mul(output, xla::Sub(XlaHelpers::One(b, type), factor)), |
| broadcast_dims); |
| context->SetOutput(0, output); |
| } |
| }; |
| REGISTER_XLA_OP(Name("AdjustContrastv2"), AdjustContrastOpV2); |
| |
| class AdjustSaturationOp : public XlaOpKernel { |
| public: |
| explicit AdjustSaturationOp(OpKernelConstruction* context) |
| : XlaOpKernel(context) {} |
| |
| void Compile(XlaOpKernelContext* context) override { |
| const TensorShape& input_shape = context->InputShape(0); |
| const TensorShape& scale_shape = context->InputShape(1); |
| OP_REQUIRES(context, input_shape.dims() >= 3, |
| errors::InvalidArgument("input must be at least 3-D, got shape", |
| input_shape.DebugString())); |
| OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_shape), |
| errors::InvalidArgument("scale must be scalar: ", |
| scale_shape.DebugString())); |
| const int channel_dim = input_shape.dims() - 1; |
| const int64 channels = input_shape.dim_size(channel_dim); |
| OP_REQUIRES( |
| context, channels == 3, |
| errors::InvalidArgument("input must have 3 channels but instead has ", |
| channels, " channels.")); |
| |
| xla::XlaBuilder* b = context->builder(); |
| xla::XlaOp input = |
| XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT); |
| xla::XlaOp scale = |
| XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT); |
| |
| DataType type = context->input_type(0); |
| |
| xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0, |
| /*limit_index=*/1, /*stride=*/1, |
| /*dimno=*/channel_dim); |
| xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1, |
| /*limit_index=*/2, /*stride=*/1, |
| /*dimno=*/channel_dim); |
| xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2, |
| /*limit_index=*/3, /*stride=*/1, |
| /*dimno=*/channel_dim); |
| TensorShape channel_shape = input_shape; |
| channel_shape.set_dim(channel_dim, 1); |
| auto hsv = |
| RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape); |
| |
| hsv[1] = xla::Clamp(XlaHelpers::Zero(b, DT_FLOAT), xla::Mul(hsv[1], scale), |
| XlaHelpers::One(b, DT_FLOAT)); |
| |
| auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT); |
| |
| auto output = XlaHelpers::ConvertElementType( |
| xla::ConcatInDim(b, rgb, channel_dim), type); |
| context->SetOutput(0, output); |
| } |
| }; |
| REGISTER_XLA_OP(Name("AdjustSaturation"), AdjustSaturationOp); |
| |
| class AdjustHueOp : public XlaOpKernel { |
| public: |
| explicit AdjustHueOp(OpKernelConstruction* context) : XlaOpKernel(context) {} |
| |
| void Compile(XlaOpKernelContext* context) override { |
| const TensorShape& input_shape = context->InputShape(0); |
| const TensorShape& delta_shape = context->InputShape(1); |
| OP_REQUIRES(context, input_shape.dims() >= 3, |
| errors::InvalidArgument("input must be at least 3-D, got shape", |
| input_shape.DebugString())); |
| OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta_shape), |
| errors::InvalidArgument("delta must be scalar: ", |
| delta_shape.DebugString())); |
| const int channel_dim = input_shape.dims() - 1; |
| const int64 channels = input_shape.dim_size(channel_dim); |
| OP_REQUIRES( |
| context, channels == 3, |
| errors::InvalidArgument("input must have 3 channels but instead has ", |
| channels, " channels.")); |
| |
| xla::XlaBuilder* b = context->builder(); |
| xla::XlaOp input = |
| XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT); |
| xla::XlaOp delta = |
| XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT); |
| |
| DataType type = context->input_type(0); |
| |
| xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0, |
| /*limit_index=*/1, /*stride=*/1, |
| /*dimno=*/channel_dim); |
| xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1, |
| /*limit_index=*/2, /*stride=*/1, |
| /*dimno=*/channel_dim); |
| xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2, |
| /*limit_index=*/3, /*stride=*/1, |
| /*dimno=*/channel_dim); |
| TensorShape channel_shape = input_shape; |
| channel_shape.set_dim(channel_dim, 1); |
| auto hsv = |
| RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape); |
| |
| auto zero = XlaHelpers::Zero(b, DT_FLOAT); |
| auto one = XlaHelpers::One(b, DT_FLOAT); |
| |
| auto& hue = hsv[0]; |
| hue = xla::Rem(xla::Add(hsv[0], delta), one); |
| hue = |
| xla::Select(xla::Lt(hue, zero), xla::Rem(xla::Add(one, hue), one), hue); |
| |
| auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT); |
| |
| auto output = XlaHelpers::ConvertElementType( |
| xla::ConcatInDim(b, rgb, channel_dim), type); |
| context->SetOutput(0, output); |
| } |
| }; |
| REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp); |
| |
| struct WhileCondFn { |
| const int64 num_boxes; |
| const int64 output_size; |
| |
| explicit WhileCondFn(int64 num_boxes, int64 output_size) |
| : num_boxes(num_boxes), output_size(output_size) {} |
| |
| xla::StatusOr<xla::XlaOp> operator()(absl::Span<const xla::XlaOp> values, |
| xla::XlaBuilder* cond_builder) const { |
| xla::XlaOp row_idx = values[0]; |
| xla::XlaOp row_in_bounds = |
| xla::Lt(row_idx, xla::ConstantR0<int32>(cond_builder, num_boxes)); |
| xla::XlaOp num_outputs_so_far = values[1]; |
| xla::XlaOp results_not_full = xla::Lt( |
| num_outputs_so_far, xla::ConstantR0<int32>(cond_builder, output_size)); |
| return xla::And(row_in_bounds, results_not_full); |
| } |
| }; |
| |
| // Process the boxes one-by-one using the iou matrix mask. |
| // This implementation uses a correct, but greedy, sequential algorithm |
| // to ensure that suppressed boxes cannot themselves suppress other |
| // boxes. |
| struct SuppressBodyFn { |
| const int64 num_boxes; |
| |
| explicit SuppressBodyFn(int64 num_boxes) : num_boxes(num_boxes) {} |
| |
| xla::StatusOr<std::vector<xla::XlaOp>> operator()( |
| absl::Span<const xla::XlaOp> values, xla::XlaBuilder* builder) const { |
| auto row_idx = values[0]; |
| auto num_outputs_so_far = values[1]; |
| auto iou_mask = values[2]; |
| auto included_iou = values[3]; |
| auto zero = xla::ConstantR0<int32>(builder, 0); |
| // Determine if current elem is active using a slice. |
| // TODO(b/118437727): The only reason we need an explicit vector is because |
| // some old GCCs can't deduce the right type for MakeConstSpan, and |
| // providing a single-value initializer list directly uses the wrong |
| // overload. Delete this once the deprecated overload is gone. |
| std::vector<xla::XlaOp> row_idx_vector = {row_idx}; |
| auto active_elem = xla::DynamicSlice(included_iou, row_idx_vector, {1}); |
| active_elem = xla::Reshape(active_elem, {}); |
| // Increment output count iff current elem is not suppressed. |
| num_outputs_so_far = xla::Select( |
| active_elem, num_outputs_so_far + xla::ConstantR0<int32>(builder, 1), |
| num_outputs_so_far); |
| // Slice out the row_idx. |
| auto row_iou = xla::DynamicSlice(iou_mask, {row_idx, zero}, {1, num_boxes}); |
| // Remove the diagonal from consideration. An elem cannot suppress |
| // itself. |
| row_iou = xla::DynamicUpdateSlice( |
| row_iou, xla::ConstantR2FromArray2D<bool>(builder, {{false}}), |
| {zero, row_idx}); |
| // Create a suppression by inverting polarity. |
| row_iou = xla::Reshape(row_iou, {num_boxes}); |
| auto supp_mask = xla::Not(row_iou); |
| // Update mask iff current elem is not suppressed. |
| included_iou = xla::Select(xla::Broadcast(active_elem, {num_boxes}), |
| xla::And(included_iou, supp_mask), included_iou); |
| row_idx = row_idx + xla::ConstantR0<int32>(builder, 1); |
| return std::vector<xla::XlaOp>{row_idx, num_outputs_so_far, iou_mask, |
| included_iou}; |
| } |
| }; |
| |
| class NonMaxSuppressionOp : public XlaOpKernel { |
| public: |
| explicit NonMaxSuppressionOp(OpKernelConstruction* context) |
| : XlaOpKernel(context) { |
| OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size", |
| &pad_to_max_output_size_)); |
| } |
| |
| void Compile(XlaOpKernelContext* context) override { |
| // TODO(b/111646731): Improve scalability of this op, using blocking. |
| const TensorShape& boxes_shape = context->InputShape("boxes"); |
| OP_REQUIRES(context, TensorShapeUtils::IsMatrix(boxes_shape), |
| errors::InvalidArgument("boxes must be 2-D, currently: ", |
| boxes_shape.DebugString())); |
| const int64 num_boxes = boxes_shape.dim_size(0); |
| OP_REQUIRES(context, boxes_shape.dim_size(1) == 4, |
| errors::InvalidArgument("boxes must have 4 columns", |
| boxes_shape.DebugString())); |
| const TensorShape& scores_shape = context->InputShape("scores"); |
| OP_REQUIRES(context, TensorShapeUtils::IsVector(scores_shape), |
| errors::InvalidArgument("scores must be 1-D, currently: ", |
| scores_shape.DebugString())); |
| OP_REQUIRES( |
| context, scores_shape.dim_size(0) == num_boxes, |
| errors::InvalidArgument("scores size must equal number of boxes", |
| scores_shape.DebugString())); |
| OP_REQUIRES(context, pad_to_max_output_size_, |
| errors::Unimplemented( |
| "XLA compilation requires pad_to_max_output_size == True")); |
| OP_REQUIRES(context, num_boxes <= kint32max, |
| errors::InvalidArgument("XLA compilation requires number of " |
| "boxes to be <= kint32max, got ", |
| num_boxes)); |
| xla::PrimitiveType boxes_xla_type = context->InputXlaType("boxes"); |
| xla::PrimitiveType scores_xla_type = context->InputXlaType("scores"); |
| const xla::XlaOp boxes_input = context->Input("boxes"); |
| const xla::XlaOp scores_input = context->Input("scores"); |
| int64 output_size; |
| OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &output_size)); |
| OP_REQUIRES( |
| context, output_size >= 0, |
| errors::InvalidArgument("Need output_size >= 0, got ", output_size)); |
| OP_REQUIRES(context, output_size <= kint32max, |
| errors::InvalidArgument("Need output_size <= kint32Max, got ", |
| output_size)); |
| const xla::XlaOp score_thresh = context->Input("score_threshold"); |
| const xla::XlaOp iou_thresh = context->Input("iou_threshold"); |
| xla::XlaBuilder* const builder = context->builder(); |
| |
| // Choose a more convenient layout. |
| const xla::XlaOp boxes = xla::Transpose(boxes_input, {1, 0}); |
| const xla::XlaOp boxes_sorted = xla::GetTupleElement( |
| xla::Sort({xla::Broadcast(scores_input, {4}), boxes}, |
| xla::CreateScalarGtComputation( |
| {scores_xla_type, boxes_xla_type}, builder), |
| /*dimension=*/1), |
| 1); |
| // Track the mapping of indices into sorted domain. |
| const xla::XlaOp iota_indices = xla::Iota(builder, xla::S32, num_boxes); |
| const xla::XlaOp indices_sort = xla::Sort( |
| {scores_input, iota_indices}, |
| xla::CreateScalarGtComputation({scores_xla_type, xla::S32}, builder)); |
| const xla::XlaOp indices_sorted = xla::GetTupleElement(indices_sort, 1); |
| const xla::XlaOp scores = xla::GetTupleElement(indices_sort, 0); |
| |
| // Shapes are henceforth [1, num_boxes]. 'c_y0' denotes 'coordinate' y0. |
| const xla::XlaOp c_y0 = xla::Reshape(xla::SliceInDim(boxes_sorted, |
| /*start_index=*/0, |
| /*limit_index=*/1, |
| /*stride=*/1, |
| /*dimno=*/0), |
| {num_boxes}); |
| const xla::XlaOp c_x0 = xla::Reshape(xla::SliceInDim(boxes_sorted, |
| /*start_index=*/1, |
| /*limit_index=*/2, |
| /*stride=*/1, |
| /*dimno=*/0), |
| {num_boxes}); |
| const xla::XlaOp c_y1 = xla::Reshape(xla::SliceInDim(boxes_sorted, |
| /*start_index=*/2, |
| /*limit_index=*/3, |
| /*stride=*/1, |
| /*dimno=*/0), |
| {num_boxes}); |
| const xla::XlaOp c_x1 = xla::Reshape(xla::SliceInDim(boxes_sorted, |
| /*start_index=*/3, |
| /*limit_index=*/4, |
| /*stride=*/1, |
| /*dimno=*/0), |
| {num_boxes}); |
| |
| xla::XlaOp y1 = xla::Select(xla::Le(c_y0, c_y1), c_y0, c_y1); |
| xla::XlaOp y2 = xla::Select(xla::Le(c_y0, c_y1), c_y1, c_y0); |
| xla::XlaOp x1 = xla::Select(xla::Le(c_x0, c_x1), c_x0, c_x1); |
| xla::XlaOp x2 = xla::Select(xla::Le(c_x0, c_x1), c_x1, c_x0); |
| xla::XlaOp area = (y2 - y1) * (x2 - x1); |
| |
| // Shapes are henceforth [1, num_boxes]. |
| y1 = xla::Broadcast(y1, {1}); |
| y2 = xla::Broadcast(y2, {1}); |
| x1 = xla::Broadcast(x1, {1}); |
| x2 = xla::Broadcast(x2, {1}); |
| area = xla::Broadcast(area, {1}); |
| |
| // Shapes are henceforth [num_boxes, num_boxes]. |
| xla::XlaOp i_xmin = xla::Max(x1, xla::Transpose(x1, {1, 0})); |
| xla::XlaOp i_ymin = xla::Max(y1, xla::Transpose(y1, {1, 0})); |
| xla::XlaOp i_xmax = xla::Min(x2, xla::Transpose(x2, {1, 0})); |
| xla::XlaOp i_ymax = xla::Min(y2, xla::Transpose(y2, {1, 0})); |
| auto square_zero = xla::ZerosLike(i_xmin); |
| |
| xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) * |
| xla::Max(i_ymax - i_ymin, square_zero); |
| xla::XlaOp u_area = area + xla::Transpose(area, {1, 0}) - i_area; |
| xla::XlaOp iou = i_area / u_area; |
| |
| xla::XlaOp iou_thresh_mask = xla::Gt(iou, iou_thresh + square_zero); |
| xla::XlaOp included_iou = |
| xla::Broadcast(xla::ConstantR0<bool>(builder, true), {num_boxes}); |
| |
| std::vector<xla::XlaOp> init_values; |
| init_values.reserve(4); |
| init_values.push_back(xla::ConstantR0<int32>(builder, 0)); // col_idx |
| init_values.push_back(xla::ConstantR0<int32>(builder, 0)); // num_outputs |
| init_values.push_back(iou_thresh_mask); |
| init_values.push_back(included_iou); |
| |
| auto suppress_loop_result = |
| xla::WhileLoopHelper(WhileCondFn(num_boxes, output_size), |
| SuppressBodyFn(num_boxes), init_values, |
| "suppress_loop", builder) |
| .ValueOrDie(); |
| |
| xla::XlaOp included_score = |
| xla::Gt(scores, xla::Broadcast(score_thresh, {num_boxes})); |
| xla::XlaOp included = xla::And(included_score, suppress_loop_result[3]); |
| |
| // Only consider boxes over which we have iterated. This allows for accurate |
| // counting. DynamicSlice would require knowledge of the size of the output. |
| auto valid_elem = xla::Lt( |
| iota_indices, xla::Broadcast(suppress_loop_result[0], {num_boxes})); |
| included = xla::And(included, valid_elem); |
| |
| xla::XlaOp neg_inf = |
| xla::Broadcast(xla::MinValue(builder, xla::F32), {num_boxes}); |
| xla::XlaOp scores_included = xla::Select(included, scores, neg_inf); |
| xla::XlaOp output_tuple = TopK(scores_included, output_size); |
| xla::XlaOp selected_indices_sorted = xla::GetTupleElement(output_tuple, 1); |
| // Calculate num_valid. |
| // Note: num_valid cannot be taken from the loop outputs, because outputs |
| // can be suppressed by score threshold. |
| xla::XlaOp ones_included = xla::Select( |
| included, |
| xla::Broadcast(xla::ConstantR0<int32>(builder, 1), {num_boxes}), |
| xla::Broadcast(xla::ConstantR0<int32>(builder, 0), {num_boxes})); |
| // num_valid is scalar. Value should be bound by output_size. |
| xla::XlaOp num_valid_total = xla::Reduce( |
| ones_included, |
| /*init_value=*/xla::ConstantR0<int>(builder, 0), |
| /*computation=*/CreateScalarAddComputation(xla::S32, builder), |
| /*dimensions_to_reduce=*/{0}); |
| xla::XlaOp num_valid = |
| xla::Min(num_valid_total, xla::ConstantR0<int32>(builder, output_size)); |
| |
| // Re-index into the original scores input tensor, using a Gather. |
| // Boxes were suppressed in the sorted domain. |
| xla::XlaOp selected_indices; |
| DataType gather_type = context->expected_output_dtype(0); |
| OP_REQUIRES_OK( |
| context, |
| XlaGather(indices_sorted, scores_shape, selected_indices_sorted, |
| TensorShape({output_size}), |
| /*axis=*/0, |
| /*indices_are_nd=*/false, |
| /*dtype=*/gather_type, DT_INT32, builder, &selected_indices)); |
| |
| context->SetOutput(0, selected_indices); |
| context->SetOutput(1, num_valid); |
| } |
| |
| private: |
| bool pad_to_max_output_size_; |
| }; |
| |
| REGISTER_XLA_OP( |
| Name("NonMaxSuppressionV4").CompileTimeConstantInput("max_output_size"), |
| NonMaxSuppressionOp); |
| |
| } // namespace |
| } // namespace tensorflow |