blob: e7bf343cd70f2fab4c2a4515687eaac32c84ce53 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/comparators.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/lib/sorting.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
namespace tensorflow {
namespace {
// Converts 'input' from RGB format to HSV format.
// 'shape' is the shape of the red/green/blue tensors.
std::array<xla::XlaOp, 3> RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b,
const std::array<xla::XlaOp, 3>& rgb,
DataType dtype, const TensorShape& shape) {
auto zero = XlaHelpers::Zero(b, dtype);
auto one = XlaHelpers::One(b, dtype);
auto red = rgb[0];
auto green = rgb[1];
auto blue = rgb[2];
auto value = xla::Max(xla::Max(red, green), blue);
auto minimum = xla::Min(xla::Min(red, green), blue);
auto range = xla::Sub(value, minimum);
auto zeros = xla::Broadcast(zero, shape.dim_sizes());
auto saturation =
xla::Select(xla::Gt(value, zero), xla::Div(range, value), zeros);
auto norm = xla::Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range);
auto hue =
xla::Select(xla::Eq(green, value),
xla::Add(xla::Mul(norm, xla::Sub(blue, red)),
XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)),
xla::Add(xla::Mul(norm, xla::Sub(red, green)),
XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0)));
hue = xla::Select(xla::Eq(red, value), xla::Mul(norm, xla::Sub(green, blue)),
hue);
hue = xla::Select(xla::Gt(range, zero), hue, zeros);
hue = xla::Select(xla::Lt(hue, zero), xla::Add(hue, one), hue);
return {hue, saturation, value};
}
// Converts 'input' from HSV format to RGB format.
std::array<xla::XlaOp, 3> HSVToRGB(xla::XlaBuilder* b,
const std::array<xla::XlaOp, 3>& hsv,
DataType dtype) {
xla::XlaOp hue = hsv[0];
xla::XlaOp saturation = hsv[1];
xla::XlaOp value = hsv[2];
auto zero = XlaHelpers::Zero(b, dtype);
auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0);
auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
auto three = XlaHelpers::FloatLiteral(b, dtype, 3.0);
auto four = XlaHelpers::FloatLiteral(b, dtype, 4.0);
auto six = XlaHelpers::FloatLiteral(b, dtype, 6.0);
auto dh = xla::Mul(hue, six);
auto dr = xla::Clamp(zero, xla::Sub(xla::Abs(xla::Sub(dh, three)), one), one);
auto dg = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, two))), one);
auto db = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, four))), one);
auto one_minus_s = xla::Sub(one, saturation);
auto red = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dr)), value);
auto green = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dg)), value);
auto blue = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, db)), value);
return {red, green, blue};
}
class RGBToHSVOp : public XlaOpKernel {
public:
explicit RGBToHSVOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
const TensorShape input_shape = context->InputShape(0);
OP_REQUIRES(context, input_shape.dims() >= 1,
errors::InvalidArgument("input must be at least 1D",
input_shape.DebugString()));
int channel_dim = input_shape.dims() - 1;
int64 channels = input_shape.dim_size(channel_dim);
OP_REQUIRES(
context, channels == 3,
errors::FailedPrecondition("input must have 3 channels but input has ",
channels, " channels."));
xla::XlaBuilder* b = context->builder();
xla::XlaOp input = context->Input(0);
xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
/*limit_index=*/1, /*stride=*/1,
/*dimno=*/channel_dim);
xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
/*limit_index=*/2, /*stride=*/1,
/*dimno=*/channel_dim);
xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
/*limit_index=*/3, /*stride=*/1,
/*dimno=*/channel_dim);
TensorShape channel_shape = input_shape;
channel_shape.set_dim(channel_dim, 1);
auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0),
channel_shape);
context->SetOutput(0, xla::ConcatInDim(b, hsv, channel_dim));
}
};
REGISTER_XLA_OP(Name("RGBToHSV"), RGBToHSVOp);
class HSVToRGBOp : public XlaOpKernel {
public:
explicit HSVToRGBOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
const TensorShape input_shape = context->InputShape(0);
OP_REQUIRES(context, input_shape.dims() >= 1,
errors::InvalidArgument("input must be at least 1D",
input_shape.DebugString()));
int channel_dim = input_shape.dims() - 1;
int64 channels = input_shape.dim_size(channel_dim);
OP_REQUIRES(
context, channels == 3,
errors::FailedPrecondition("input must have 3 channels but input has ",
channels, " channels."));
xla::XlaBuilder* b = context->builder();
xla::XlaOp input = context->Input(0);
xla::XlaOp hue = xla::SliceInDim(input, /*start_index=*/0,
/*limit_index=*/1, /*stride=*/1,
/*dimno=*/channel_dim);
xla::XlaOp saturation = xla::SliceInDim(input, /*start_index=*/1,
/*limit_index=*/2, /*stride=*/1,
/*dimno=*/channel_dim);
xla::XlaOp value = xla::SliceInDim(input, /*start_index=*/2,
/*limit_index=*/3, /*stride=*/1,
/*dimno=*/channel_dim);
auto rgb = HSVToRGB(context->builder(), {hue, saturation, value},
context->input_type(0));
context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim));
}
};
REGISTER_XLA_OP(Name("HSVToRGB"), HSVToRGBOp);
class AdjustContrastOpV2 : public XlaOpKernel {
public:
explicit AdjustContrastOpV2(OpKernelConstruction* context)
: XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
const TensorShape& input_shape = context->InputShape(0);
const TensorShape& factor_shape = context->InputShape(1);
OP_REQUIRES(context, input_shape.dims() >= 3,
errors::InvalidArgument("input must be at least 3-D, got shape",
input_shape.DebugString()));
int height_dim = input_shape.dims() - 3;
int width_dim = input_shape.dims() - 2;
int channel_dim = input_shape.dims() - 1;
const int64 height = input_shape.dim_size(height_dim);
const int64 width = input_shape.dim_size(width_dim);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(factor_shape),
errors::InvalidArgument("contrast_factor must be scalar: ",
factor_shape.DebugString()));
xla::XlaBuilder* b = context->builder();
DataType type = context->input_type(0);
xla::XlaOp input = context->Input(0);
xla::XlaOp factor = XlaHelpers::ConvertElementType(context->Input(1), type);
const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
auto converted = XlaHelpers::ConvertElementType(input, accumulation_type);
auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
*context->GetOrCreateAdd(accumulation_type),
{height_dim, width_dim});
auto output = xla::Div(
reduce, XlaHelpers::FloatLiteral(b, accumulation_type, height * width));
output = XlaHelpers::ConvertElementType(output, type);
std::vector<int64> broadcast_dims(input_shape.dims() - 2);
std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
broadcast_dims.back() = channel_dim;
output =
xla::Add(xla::Mul(input, factor),
xla::Mul(output, xla::Sub(XlaHelpers::One(b, type), factor)),
broadcast_dims);
context->SetOutput(0, output);
}
};
REGISTER_XLA_OP(Name("AdjustContrastv2"), AdjustContrastOpV2);
class AdjustSaturationOp : public XlaOpKernel {
public:
explicit AdjustSaturationOp(OpKernelConstruction* context)
: XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
const TensorShape& input_shape = context->InputShape(0);
const TensorShape& scale_shape = context->InputShape(1);
OP_REQUIRES(context, input_shape.dims() >= 3,
errors::InvalidArgument("input must be at least 3-D, got shape",
input_shape.DebugString()));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_shape),
errors::InvalidArgument("scale must be scalar: ",
scale_shape.DebugString()));
const int channel_dim = input_shape.dims() - 1;
const int64 channels = input_shape.dim_size(channel_dim);
OP_REQUIRES(
context, channels == 3,
errors::InvalidArgument("input must have 3 channels but instead has ",
channels, " channels."));
xla::XlaBuilder* b = context->builder();
xla::XlaOp input =
XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT);
xla::XlaOp scale =
XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT);
DataType type = context->input_type(0);
xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
/*limit_index=*/1, /*stride=*/1,
/*dimno=*/channel_dim);
xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
/*limit_index=*/2, /*stride=*/1,
/*dimno=*/channel_dim);
xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
/*limit_index=*/3, /*stride=*/1,
/*dimno=*/channel_dim);
TensorShape channel_shape = input_shape;
channel_shape.set_dim(channel_dim, 1);
auto hsv =
RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape);
hsv[1] = xla::Clamp(XlaHelpers::Zero(b, DT_FLOAT), xla::Mul(hsv[1], scale),
XlaHelpers::One(b, DT_FLOAT));
auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT);
auto output = XlaHelpers::ConvertElementType(
xla::ConcatInDim(b, rgb, channel_dim), type);
context->SetOutput(0, output);
}
};
REGISTER_XLA_OP(Name("AdjustSaturation"), AdjustSaturationOp);
class AdjustHueOp : public XlaOpKernel {
public:
explicit AdjustHueOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
const TensorShape& input_shape = context->InputShape(0);
const TensorShape& delta_shape = context->InputShape(1);
OP_REQUIRES(context, input_shape.dims() >= 3,
errors::InvalidArgument("input must be at least 3-D, got shape",
input_shape.DebugString()));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta_shape),
errors::InvalidArgument("delta must be scalar: ",
delta_shape.DebugString()));
const int channel_dim = input_shape.dims() - 1;
const int64 channels = input_shape.dim_size(channel_dim);
OP_REQUIRES(
context, channels == 3,
errors::InvalidArgument("input must have 3 channels but instead has ",
channels, " channels."));
xla::XlaBuilder* b = context->builder();
xla::XlaOp input =
XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT);
xla::XlaOp delta =
XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT);
DataType type = context->input_type(0);
xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
/*limit_index=*/1, /*stride=*/1,
/*dimno=*/channel_dim);
xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
/*limit_index=*/2, /*stride=*/1,
/*dimno=*/channel_dim);
xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
/*limit_index=*/3, /*stride=*/1,
/*dimno=*/channel_dim);
TensorShape channel_shape = input_shape;
channel_shape.set_dim(channel_dim, 1);
auto hsv =
RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape);
auto zero = XlaHelpers::Zero(b, DT_FLOAT);
auto one = XlaHelpers::One(b, DT_FLOAT);
auto& hue = hsv[0];
hue = xla::Rem(xla::Add(hsv[0], delta), one);
hue =
xla::Select(xla::Lt(hue, zero), xla::Rem(xla::Add(one, hue), one), hue);
auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT);
auto output = XlaHelpers::ConvertElementType(
xla::ConcatInDim(b, rgb, channel_dim), type);
context->SetOutput(0, output);
}
};
REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp);
struct WhileCondFn {
const int64 num_boxes;
const int64 output_size;
explicit WhileCondFn(int64 num_boxes, int64 output_size)
: num_boxes(num_boxes), output_size(output_size) {}
xla::StatusOr<xla::XlaOp> operator()(absl::Span<const xla::XlaOp> values,
xla::XlaBuilder* cond_builder) const {
xla::XlaOp row_idx = values[0];
xla::XlaOp row_in_bounds =
xla::Lt(row_idx, xla::ConstantR0<int32>(cond_builder, num_boxes));
xla::XlaOp num_outputs_so_far = values[1];
xla::XlaOp results_not_full = xla::Lt(
num_outputs_so_far, xla::ConstantR0<int32>(cond_builder, output_size));
return xla::And(row_in_bounds, results_not_full);
}
};
// Process the boxes one-by-one using the iou matrix mask.
// This implementation uses a correct, but greedy, sequential algorithm
// to ensure that suppressed boxes cannot themselves suppress other
// boxes.
struct SuppressBodyFn {
const int64 num_boxes;
explicit SuppressBodyFn(int64 num_boxes) : num_boxes(num_boxes) {}
xla::StatusOr<std::vector<xla::XlaOp>> operator()(
absl::Span<const xla::XlaOp> values, xla::XlaBuilder* builder) const {
auto row_idx = values[0];
auto num_outputs_so_far = values[1];
auto iou_mask = values[2];
auto included_iou = values[3];
auto zero = xla::ConstantR0<int32>(builder, 0);
// Determine if current elem is active using a slice.
// TODO(b/118437727): The only reason we need an explicit vector is because
// some old GCCs can't deduce the right type for MakeConstSpan, and
// providing a single-value initializer list directly uses the wrong
// overload. Delete this once the deprecated overload is gone.
std::vector<xla::XlaOp> row_idx_vector = {row_idx};
auto active_elem = xla::DynamicSlice(included_iou, row_idx_vector, {1});
active_elem = xla::Reshape(active_elem, {});
// Increment output count iff current elem is not suppressed.
num_outputs_so_far = xla::Select(
active_elem, num_outputs_so_far + xla::ConstantR0<int32>(builder, 1),
num_outputs_so_far);
// Slice out the row_idx.
auto row_iou = xla::DynamicSlice(iou_mask, {row_idx, zero}, {1, num_boxes});
// Remove the diagonal from consideration. An elem cannot suppress
// itself.
row_iou = xla::DynamicUpdateSlice(
row_iou, xla::ConstantR2FromArray2D<bool>(builder, {{false}}),
{zero, row_idx});
// Create a suppression by inverting polarity.
row_iou = xla::Reshape(row_iou, {num_boxes});
auto supp_mask = xla::Not(row_iou);
// Update mask iff current elem is not suppressed.
included_iou = xla::Select(xla::Broadcast(active_elem, {num_boxes}),
xla::And(included_iou, supp_mask), included_iou);
row_idx = row_idx + xla::ConstantR0<int32>(builder, 1);
return std::vector<xla::XlaOp>{row_idx, num_outputs_so_far, iou_mask,
included_iou};
}
};
class NonMaxSuppressionOp : public XlaOpKernel {
public:
explicit NonMaxSuppressionOp(OpKernelConstruction* context)
: XlaOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size",
&pad_to_max_output_size_));
}
void Compile(XlaOpKernelContext* context) override {
// TODO(b/111646731): Improve scalability of this op, using blocking.
const TensorShape& boxes_shape = context->InputShape("boxes");
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(boxes_shape),
errors::InvalidArgument("boxes must be 2-D, currently: ",
boxes_shape.DebugString()));
const int64 num_boxes = boxes_shape.dim_size(0);
OP_REQUIRES(context, boxes_shape.dim_size(1) == 4,
errors::InvalidArgument("boxes must have 4 columns",
boxes_shape.DebugString()));
const TensorShape& scores_shape = context->InputShape("scores");
OP_REQUIRES(context, TensorShapeUtils::IsVector(scores_shape),
errors::InvalidArgument("scores must be 1-D, currently: ",
scores_shape.DebugString()));
OP_REQUIRES(
context, scores_shape.dim_size(0) == num_boxes,
errors::InvalidArgument("scores size must equal number of boxes",
scores_shape.DebugString()));
OP_REQUIRES(context, pad_to_max_output_size_,
errors::Unimplemented(
"XLA compilation requires pad_to_max_output_size == True"));
OP_REQUIRES(context, num_boxes <= kint32max,
errors::InvalidArgument("XLA compilation requires number of "
"boxes to be <= kint32max, got ",
num_boxes));
xla::PrimitiveType boxes_xla_type = context->InputXlaType("boxes");
xla::PrimitiveType scores_xla_type = context->InputXlaType("scores");
const xla::XlaOp boxes_input = context->Input("boxes");
const xla::XlaOp scores_input = context->Input("scores");
int64 output_size;
OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &output_size));
OP_REQUIRES(
context, output_size >= 0,
errors::InvalidArgument("Need output_size >= 0, got ", output_size));
OP_REQUIRES(context, output_size <= kint32max,
errors::InvalidArgument("Need output_size <= kint32Max, got ",
output_size));
const xla::XlaOp score_thresh = context->Input("score_threshold");
const xla::XlaOp iou_thresh = context->Input("iou_threshold");
xla::XlaBuilder* const builder = context->builder();
// Choose a more convenient layout.
const xla::XlaOp boxes = xla::Transpose(boxes_input, {1, 0});
const xla::XlaOp boxes_sorted = xla::GetTupleElement(
xla::Sort({xla::Broadcast(scores_input, {4}), boxes},
xla::CreateScalarGtComputation(
{scores_xla_type, boxes_xla_type}, builder),
/*dimension=*/1),
1);
// Track the mapping of indices into sorted domain.
const xla::XlaOp iota_indices = xla::Iota(builder, xla::S32, num_boxes);
const xla::XlaOp indices_sort = xla::Sort(
{scores_input, iota_indices},
xla::CreateScalarGtComputation({scores_xla_type, xla::S32}, builder));
const xla::XlaOp indices_sorted = xla::GetTupleElement(indices_sort, 1);
const xla::XlaOp scores = xla::GetTupleElement(indices_sort, 0);
// Shapes are henceforth [1, num_boxes]. 'c_y0' denotes 'coordinate' y0.
const xla::XlaOp c_y0 = xla::Reshape(xla::SliceInDim(boxes_sorted,
/*start_index=*/0,
/*limit_index=*/1,
/*stride=*/1,
/*dimno=*/0),
{num_boxes});
const xla::XlaOp c_x0 = xla::Reshape(xla::SliceInDim(boxes_sorted,
/*start_index=*/1,
/*limit_index=*/2,
/*stride=*/1,
/*dimno=*/0),
{num_boxes});
const xla::XlaOp c_y1 = xla::Reshape(xla::SliceInDim(boxes_sorted,
/*start_index=*/2,
/*limit_index=*/3,
/*stride=*/1,
/*dimno=*/0),
{num_boxes});
const xla::XlaOp c_x1 = xla::Reshape(xla::SliceInDim(boxes_sorted,
/*start_index=*/3,
/*limit_index=*/4,
/*stride=*/1,
/*dimno=*/0),
{num_boxes});
xla::XlaOp y1 = xla::Select(xla::Le(c_y0, c_y1), c_y0, c_y1);
xla::XlaOp y2 = xla::Select(xla::Le(c_y0, c_y1), c_y1, c_y0);
xla::XlaOp x1 = xla::Select(xla::Le(c_x0, c_x1), c_x0, c_x1);
xla::XlaOp x2 = xla::Select(xla::Le(c_x0, c_x1), c_x1, c_x0);
xla::XlaOp area = (y2 - y1) * (x2 - x1);
// Shapes are henceforth [1, num_boxes].
y1 = xla::Broadcast(y1, {1});
y2 = xla::Broadcast(y2, {1});
x1 = xla::Broadcast(x1, {1});
x2 = xla::Broadcast(x2, {1});
area = xla::Broadcast(area, {1});
// Shapes are henceforth [num_boxes, num_boxes].
xla::XlaOp i_xmin = xla::Max(x1, xla::Transpose(x1, {1, 0}));
xla::XlaOp i_ymin = xla::Max(y1, xla::Transpose(y1, {1, 0}));
xla::XlaOp i_xmax = xla::Min(x2, xla::Transpose(x2, {1, 0}));
xla::XlaOp i_ymax = xla::Min(y2, xla::Transpose(y2, {1, 0}));
auto square_zero = xla::ZerosLike(i_xmin);
xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) *
xla::Max(i_ymax - i_ymin, square_zero);
xla::XlaOp u_area = area + xla::Transpose(area, {1, 0}) - i_area;
xla::XlaOp iou = i_area / u_area;
xla::XlaOp iou_thresh_mask = xla::Gt(iou, iou_thresh + square_zero);
xla::XlaOp included_iou =
xla::Broadcast(xla::ConstantR0<bool>(builder, true), {num_boxes});
std::vector<xla::XlaOp> init_values;
init_values.reserve(4);
init_values.push_back(xla::ConstantR0<int32>(builder, 0)); // col_idx
init_values.push_back(xla::ConstantR0<int32>(builder, 0)); // num_outputs
init_values.push_back(iou_thresh_mask);
init_values.push_back(included_iou);
auto suppress_loop_result =
xla::WhileLoopHelper(WhileCondFn(num_boxes, output_size),
SuppressBodyFn(num_boxes), init_values,
"suppress_loop", builder)
.ValueOrDie();
xla::XlaOp included_score =
xla::Gt(scores, xla::Broadcast(score_thresh, {num_boxes}));
xla::XlaOp included = xla::And(included_score, suppress_loop_result[3]);
// Only consider boxes over which we have iterated. This allows for accurate
// counting. DynamicSlice would require knowledge of the size of the output.
auto valid_elem = xla::Lt(
iota_indices, xla::Broadcast(suppress_loop_result[0], {num_boxes}));
included = xla::And(included, valid_elem);
xla::XlaOp neg_inf =
xla::Broadcast(xla::MinValue(builder, xla::F32), {num_boxes});
xla::XlaOp scores_included = xla::Select(included, scores, neg_inf);
xla::XlaOp output_tuple = TopK(scores_included, output_size);
xla::XlaOp selected_indices_sorted = xla::GetTupleElement(output_tuple, 1);
// Calculate num_valid.
// Note: num_valid cannot be taken from the loop outputs, because outputs
// can be suppressed by score threshold.
xla::XlaOp ones_included = xla::Select(
included,
xla::Broadcast(xla::ConstantR0<int32>(builder, 1), {num_boxes}),
xla::Broadcast(xla::ConstantR0<int32>(builder, 0), {num_boxes}));
// num_valid is scalar. Value should be bound by output_size.
xla::XlaOp num_valid_total = xla::Reduce(
ones_included,
/*init_value=*/xla::ConstantR0<int>(builder, 0),
/*computation=*/CreateScalarAddComputation(xla::S32, builder),
/*dimensions_to_reduce=*/{0});
xla::XlaOp num_valid =
xla::Min(num_valid_total, xla::ConstantR0<int32>(builder, output_size));
// Re-index into the original scores input tensor, using a Gather.
// Boxes were suppressed in the sorted domain.
xla::XlaOp selected_indices;
DataType gather_type = context->expected_output_dtype(0);
OP_REQUIRES_OK(
context,
XlaGather(indices_sorted, scores_shape, selected_indices_sorted,
TensorShape({output_size}),
/*axis=*/0,
/*indices_are_nd=*/false,
/*dtype=*/gather_type, DT_INT32, builder, &selected_indices));
context->SetOutput(0, selected_indices);
context->SetOutput(1, num_valid);
}
private:
bool pad_to_max_output_size_;
};
REGISTER_XLA_OP(
Name("NonMaxSuppressionV4").CompileTimeConstantInput("max_output_size"),
NonMaxSuppressionOp);
} // namespace
} // namespace tensorflow