Removed legacy support
diff --git a/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt b/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt
index e4405be..648b23e 100644
--- a/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt
@@ -3,7 +3,7 @@
in_arg {
name: "scores"
description: <<END
-A 4-D float tensor of shape `[num_images, height, width, num_achors]` containing scores of the boxes for given anchors.
+A 4-D float tensor of shape `[num_images, height, width, num_achors]` containing scores of the boxes for given anchors, can be unsorted.
END
}
in_arg {
@@ -63,12 +63,6 @@
An integer. Maximum number of rois in the output.
END
}
- attr {
- name: "use_detectron_offset"
- description: <<END
-Boolean to decide whether apply +1 offset used in detectron like networks.
-END
- }
summary: "This op produces Region of Interests from given bounding boxes(bbox_deltas) encoded wrt anchors according to eq.2 in arXiv:1506.01497"
description: <<END
The op selects top `pre_nms_topn` scoring boxes, decodes them with respect to anchors,
@@ -81,6 +75,6 @@
`anchors`: A 1D tensor of shape [4 x Num Anchors], representing the anchors.
Outputs:
`rois`: output RoIs, a 3D tensor of shape [Batch, post_nms_topn, 4], padded by 0 if less than post_nms_topn candidates found.
- `roi_probabilities`: probability scores of each roi in 'rois', a 2D tensor of shape [Batch,post_nms_topn], padded with 0 if needed.
+ `roi_probabilities`: probability scores of each roi in 'rois', a 2D tensor of shape [Batch,post_nms_topn], padded with 0 if needed, sorted by scores.
END
}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 9f2e44d..fa2f14e 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -2838,7 +2838,7 @@
tf_kernel_library(
name = "generate_box_proposals_op",
- prefix = "generate_box_proposals_op",
+ gpu_srcs = ["generate_box_proposals_op.cu.cc"],
deps = [":non_max_suppression_op_gpu"] + if_cuda(["@cub_archive//:cub"]),
)
diff --git a/tensorflow/core/kernels/generate_box_proposals_op.cu.cc b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc
index 2f6c4a1..da4a0a4 100644
--- a/tensorflow/core/kernels/generate_box_proposals_op.cu.cc
+++ b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc
@@ -38,6 +38,15 @@
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
+#define TF_RETURN_IF_CUDA_ERROR(result) \
+ do { \
+ cudaError_t error(result); \
+ if (!SE_PREDICT_TRUE(error == cudaSuccess)) { \
+ return errors::Internal("Cuda call failed with ", \
+ cudaGetErrorString(error)); \
+ } \
+ } while (0)
+
#define TF_OP_REQUIRES_CUDA_SUCCESS(context, result) \
do { \
cudaError_t error(result); \
@@ -50,34 +59,12 @@
namespace {
-template <bool T>
-__device__ float AddLegacyOffset(float);
-template <>
-__device__ float AddLegacyOffset<true>(float a) {
- return a + 1.;
-}
-template <>
-__device__ float AddLegacyOffset<false>(float a) {
- return a;
-}
-
-template <bool T>
-__device__ float SubtractLegacyOffset(float);
-template <>
-__device__ float SubtractLegacyOffset<true>(float a) {
- return a - 1.;
-}
-template <>
-__device__ float SubtractLegacyOffset<false>(float a) {
- return a;
-}
// Decode d_bbox_deltas with respect to anchors into absolute coordinates,
// clipping if necessary.
// prenms_nboxes maximum number of boxes per image to decode.
// d_boxes_keep_flags mask for boxes to consider in NMS.
// min_size is the lower bound of the shortest edge for the boxes to consider.
// bbox_xform_clip is the upper bound of encoded width and height.
-template <bool T>
__global__ void GeneratePreNMSUprightBoxesKernel(
const Cuda2DLaunchConfig config, const int* d_sorted_scores_keys,
const float4* d_bbox_deltas, const float4* d_anchors, const int height,
@@ -133,36 +120,35 @@
dh = fmin(dh, bbox_xform_clip);
// Applying the deltas
- float width = AddLegacyOffset<T>(x2 - x1);
+ float width = x2 - x1;
const float ctr_x = x1 + 0.5f * width;
const float pred_ctr_x = ctr_x + width * dx; // TODO fuse madd
const float pred_w = width * expf(dw);
x1 = pred_ctr_x - 0.5f * pred_w;
- x2 = SubtractLegacyOffset<T>(pred_ctr_x + 0.5f * pred_w);
+ x2 = pred_ctr_x + 0.5f * pred_w;
- float height = AddLegacyOffset<T>(y2 - y1);
+ float height = y2 - y1;
const float ctr_y = y1 + 0.5f * height;
const float pred_ctr_y = ctr_y + height * dy;
const float pred_h = height * expf(dh);
y1 = pred_ctr_y - 0.5f * pred_h;
- y2 = SubtractLegacyOffset<T>(pred_ctr_y +
- 0.5f * pred_h); // -1 if legacy_op
+ y2 = pred_ctr_y + 0.5f * pred_h;
// Clipping box to image
const float img_height = d_img_info_vec[5 * image_index + 0];
const float img_width = d_img_info_vec[5 * image_index + 1];
const float min_size_scaled =
min_size * d_img_info_vec[5 * image_index + 2];
- x1 = fmax(fmin(x1, SubtractLegacyOffset<T>(img_width)), 0.0f);
- y1 = fmax(fmin(y1, SubtractLegacyOffset<T>(img_height)), 0.0f);
- x2 = fmax(fmin(x2, SubtractLegacyOffset<T>(img_width)), 0.0f);
- y2 = fmax(fmin(y2, SubtractLegacyOffset<T>(img_height)), 0.0f);
+ x1 = fmax(fmin(x1, img_width), 0.0f);
+ y1 = fmax(fmin(y1, img_height), 0.0f);
+ x2 = fmax(fmin(x2, img_width), 0.0f);
+ y2 = fmax(fmin(y2, img_height), 0.0f);
// Filter boxes
// Removing boxes with one dim < min_size
// (center of box is in image, because of previous step)
- width = AddLegacyOffset<T>(x2 - x1); // may have changed
- height = AddLegacyOffset<T>(y2 - y1);
+ width = x2 - x1; // may have changed
+ height = y2 - y1;
bool keep_box = fmin(width, height) >= min_size_scaled;
// We are not deleting the box right now even if !keep_box
@@ -332,9 +318,6 @@
tensorflow::OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("post_nms_topn", &post_nms_topn_));
- // compatibility for detectron like networks. False for generic case
- OP_REQUIRES_OK(context, context->GetAttr("use_detectron_offset",
- &use_detectron_offset_));
CHECK_GT(post_nms_topn_, 0);
bbox_xform_clip_default_ = log(1000.0 / 16.);
}
@@ -454,29 +437,16 @@
// create box y1,x1,y2,x2 from box_deltas and anchors (decode the boxes) and
// mark the boxes which are smaller that min_size ignored.
- if (use_detectron_offset_) {
- TF_CHECK_OK(GpuLaunchKernel(
- GeneratePreNMSUprightBoxesKernel<true>, conf2d.block_count,
- conf2d.thread_per_block, 0, d.stream(), conf2d,
- d_sorted_conv_layer_indexes.flat<int>().data(),
- reinterpret_cast<const float4*>(bbox_deltas.flat<float>().data()),
- reinterpret_cast<const float4*>(anchors.flat<float>().data()), height,
- width, num_anchors, min_size, image_info.flat<float>().data(),
- bbox_xform_clip_default_,
- reinterpret_cast<float4*>(dev_boxes.flat<float>().data()),
- nboxes_to_generate, (char*)dev_boxes_keep_flags.flat<int8>().data()));
- } else {
- TF_CHECK_OK(GpuLaunchKernel(
- GeneratePreNMSUprightBoxesKernel<false>, conf2d.block_count,
- conf2d.thread_per_block, 0, d.stream(), conf2d,
- d_sorted_conv_layer_indexes.flat<int>().data(),
- reinterpret_cast<const float4*>(bbox_deltas.flat<float>().data()),
- reinterpret_cast<const float4*>(anchors.flat<float>().data()), height,
- width, num_anchors, min_size, image_info.flat<float>().data(),
- bbox_xform_clip_default_,
- reinterpret_cast<float4*>(dev_boxes.flat<float>().data()),
- nboxes_to_generate, (char*)dev_boxes_keep_flags.flat<int8>().data()));
- }
+ TF_CHECK_OK(GpuLaunchKernel(
+ GeneratePreNMSUprightBoxesKernel, conf2d.block_count,
+ conf2d.thread_per_block, 0, d.stream(), conf2d,
+ d_sorted_conv_layer_indexes.flat<int>().data(),
+ reinterpret_cast<const float4*>(bbox_deltas.flat<float>().data()),
+ reinterpret_cast<const float4*>(anchors.flat<float>().data()), height,
+ width, num_anchors, min_size, image_info.flat<float>().data(),
+ bbox_xform_clip_default_,
+ reinterpret_cast<float4*>(dev_boxes.flat<float>().data()),
+ nboxes_to_generate, (char*)dev_boxes_keep_flags.flat<int8>().data()));
const int nboxes_generated = nboxes_to_generate;
const int roi_cols = box_dim;
const int max_postnms_nboxes = std::min(nboxes_generated, post_nms_topn_);
@@ -520,6 +490,8 @@
&output_roi_probs));
float* d_postnms_rois = (*output_rois).flat<float>().data();
float* d_postnms_rois_probs = (*output_roi_probs).flat<float>().data();
+ cudaEvent_t copy_done;
+ cudaEventCreate(©_done);
// Do per-image nms
for (int image_index = 0; image_index < num_images; ++image_index) {
@@ -559,14 +531,15 @@
d_image_prenms_scores, d_prenms_nboxes, nboxes_generated,
d.stream()));
d.memcpyDeviceToHost(&h_prenms_nboxes, d_prenms_nboxes, sizeof(int));
- d.synchronize();
+ TF_OP_REQUIRES_CUDA_SUCCESS(context,cudaEventRecord(copy_done, d.stream()));
+ TF_OP_REQUIRES_CUDA_SUCCESS(context,cudaEventSynchronize(copy_done));
// We know prenms_boxes <= topN_prenms, because nboxes_generated <=
// topN_prenms. Calling NMS on the generated boxes
const int prenms_nboxes = h_prenms_nboxes;
int nkeep;
OP_REQUIRES_OK(context,
NmsGpu(d_image_prenms_boxes, prenms_nboxes, nms_threshold,
- d_image_boxes_keep_list, &nkeep, context));
+ d_image_boxes_keep_list, &nkeep, context, post_nms_topn_));
// All operations done after previous sort were keeping the relative order
// of the elements the elements are still sorted keep topN <=> truncate
// the array
@@ -589,7 +562,6 @@
private:
int post_nms_topn_;
float bbox_xform_clip_default_;
- bool use_detectron_offset_;
};
REGISTER_KERNEL_BUILDER(
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 50945c3..57c032f 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -994,7 +994,6 @@
.Output("rois: float")
.Output("roi_probabilities: float")
.Attr("post_nms_topn: int = 300")
- .Attr("use_detectron_offset: bool = false")
.SetShapeFn([](InferenceContext* c) -> Status {
// make sure input tensors have are correct rank
ShapeHandle scores, images, bounding_boxes, anchors, nms_threshold,
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 685644e..6fd8b63 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -1478,7 +1478,7 @@
}
member_method {
name: "GenerateBoundingBoxProposals"
- argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'use_detectron_offset\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'False\', \'None\'], "
+ argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'None\'], "
}
member_method {
name: "GenerateVocabRemapping"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 685644e..6fd8b63 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -1478,7 +1478,7 @@
}
member_method {
name: "GenerateBoundingBoxProposals"
- argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'use_detectron_offset\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'False\', \'None\'], "
+ argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'None\'], "
}
member_method {
name: "GenerateVocabRemapping"