Removed legacy support
diff --git a/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt b/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt
index e4405be..648b23e 100644
--- a/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt
@@ -3,7 +3,7 @@
   in_arg {
     name: "scores"
     description: <<END
-A 4-D float tensor of shape `[num_images, height, width, num_achors]` containing scores of the boxes for given anchors.
+A 4-D float tensor of shape `[num_images, height, width, num_achors]` containing scores of the boxes for given anchors, can be unsorted.
 END
   }
   in_arg {
@@ -63,12 +63,6 @@
 An integer. Maximum number of rois in the output.
 END
   }
-  attr {
-    name: "use_detectron_offset"
-    description: <<END
-Boolean to decide whether apply +1 offset used in detectron like networks.
-END
-  }
   summary: "This op produces Region of Interests from given bounding boxes(bbox_deltas) encoded wrt anchors according to eq.2 in arXiv:1506.01497"
   description: <<END
       The op selects top `pre_nms_topn` scoring boxes, decodes them with respect to anchors,
@@ -81,6 +75,6 @@
       `anchors`: A 1D tensor of shape [4 x Num Anchors], representing the anchors.
       Outputs:
       `rois`: output RoIs, a 3D tensor of shape [Batch, post_nms_topn, 4], padded by 0 if less than post_nms_topn candidates found.
-      `roi_probabilities`: probability scores of each roi in 'rois', a 2D tensor of shape [Batch,post_nms_topn], padded with 0 if needed.
+      `roi_probabilities`: probability scores of each roi in 'rois', a 2D tensor of shape [Batch,post_nms_topn], padded with 0 if needed, sorted by scores.
 END
 }
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 9f2e44d..fa2f14e 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -2838,7 +2838,7 @@
 
 tf_kernel_library(
     name = "generate_box_proposals_op",
-    prefix = "generate_box_proposals_op",
+    gpu_srcs = ["generate_box_proposals_op.cu.cc"],
     deps = [":non_max_suppression_op_gpu"] + if_cuda(["@cub_archive//:cub"]),
 )
 
diff --git a/tensorflow/core/kernels/generate_box_proposals_op.cu.cc b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc
index 2f6c4a1..da4a0a4 100644
--- a/tensorflow/core/kernels/generate_box_proposals_op.cu.cc
+++ b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc
@@ -38,6 +38,15 @@
 
 namespace tensorflow {
 typedef Eigen::GpuDevice GPUDevice;
+#define TF_RETURN_IF_CUDA_ERROR(result)                   \
+  do {                                                    \
+    cudaError_t error(result);                            \
+    if (!SE_PREDICT_TRUE(error == cudaSuccess)) {         \
+      return errors::Internal("Cuda call failed with ",   \
+                              cudaGetErrorString(error)); \
+    }                                                     \
+  } while (0)
+
 #define TF_OP_REQUIRES_CUDA_SUCCESS(context, result)                   \
   do {                                                                 \
     cudaError_t error(result);                                         \
@@ -50,34 +59,12 @@
 
 namespace {
 
-template <bool T>
-__device__ float AddLegacyOffset(float);
-template <>
-__device__ float AddLegacyOffset<true>(float a) {
-  return a + 1.;
-}
-template <>
-__device__ float AddLegacyOffset<false>(float a) {
-  return a;
-}
-
-template <bool T>
-__device__ float SubtractLegacyOffset(float);
-template <>
-__device__ float SubtractLegacyOffset<true>(float a) {
-  return a - 1.;
-}
-template <>
-__device__ float SubtractLegacyOffset<false>(float a) {
-  return a;
-}
 // Decode d_bbox_deltas with respect to anchors into absolute coordinates,
 // clipping if necessary.
 // prenms_nboxes maximum number of boxes per image to decode.
 // d_boxes_keep_flags mask for boxes to consider in NMS.
 // min_size is the lower bound of the shortest edge for the boxes to consider.
 // bbox_xform_clip is the upper bound of encoded width and height.
-template <bool T>
 __global__ void GeneratePreNMSUprightBoxesKernel(
     const Cuda2DLaunchConfig config, const int* d_sorted_scores_keys,
     const float4* d_bbox_deltas, const float4* d_anchors, const int height,
@@ -133,36 +120,35 @@
       dh = fmin(dh, bbox_xform_clip);
 
       // Applying the deltas
-      float width = AddLegacyOffset<T>(x2 - x1);
+      float width = x2 - x1;
       const float ctr_x = x1 + 0.5f * width;
       const float pred_ctr_x = ctr_x + width * dx;  // TODO fuse madd
       const float pred_w = width * expf(dw);
       x1 = pred_ctr_x - 0.5f * pred_w;
-      x2 = SubtractLegacyOffset<T>(pred_ctr_x + 0.5f * pred_w);
+      x2 = pred_ctr_x + 0.5f * pred_w;
 
-      float height = AddLegacyOffset<T>(y2 - y1);
+      float height = y2 - y1;
       const float ctr_y = y1 + 0.5f * height;
       const float pred_ctr_y = ctr_y + height * dy;
       const float pred_h = height * expf(dh);
       y1 = pred_ctr_y - 0.5f * pred_h;
-      y2 = SubtractLegacyOffset<T>(pred_ctr_y +
-                                   0.5f * pred_h);  // -1 if legacy_op
+      y2 = pred_ctr_y + 0.5f * pred_h;
 
       // Clipping box to image
       const float img_height = d_img_info_vec[5 * image_index + 0];
       const float img_width = d_img_info_vec[5 * image_index + 1];
       const float min_size_scaled =
           min_size * d_img_info_vec[5 * image_index + 2];
-      x1 = fmax(fmin(x1, SubtractLegacyOffset<T>(img_width)), 0.0f);
-      y1 = fmax(fmin(y1, SubtractLegacyOffset<T>(img_height)), 0.0f);
-      x2 = fmax(fmin(x2, SubtractLegacyOffset<T>(img_width)), 0.0f);
-      y2 = fmax(fmin(y2, SubtractLegacyOffset<T>(img_height)), 0.0f);
+      x1 = fmax(fmin(x1, img_width), 0.0f);
+      y1 = fmax(fmin(y1, img_height), 0.0f);
+      x2 = fmax(fmin(x2, img_width), 0.0f);
+      y2 = fmax(fmin(y2, img_height), 0.0f);
 
       // Filter boxes
       // Removing boxes with one dim < min_size
       // (center of box is in image, because of previous step)
-      width = AddLegacyOffset<T>(x2 - x1);  // may have changed
-      height = AddLegacyOffset<T>(y2 - y1);
+      width = x2 - x1;  // may have changed
+      height = y2 - y1;
       bool keep_box = fmin(width, height) >= min_size_scaled;
 
       // We are not deleting the box right now even if !keep_box
@@ -332,9 +318,6 @@
       tensorflow::OpKernelConstruction* context)
       : OpKernel(context) {
     OP_REQUIRES_OK(context, context->GetAttr("post_nms_topn", &post_nms_topn_));
-    // compatibility for detectron like networks. False for generic case
-    OP_REQUIRES_OK(context, context->GetAttr("use_detectron_offset",
-                                             &use_detectron_offset_));
     CHECK_GT(post_nms_topn_, 0);
     bbox_xform_clip_default_ = log(1000.0 / 16.);
   }
@@ -454,29 +437,16 @@
 
     // create box y1,x1,y2,x2 from box_deltas and anchors (decode the boxes) and
     // mark the boxes which are smaller that min_size ignored.
-    if (use_detectron_offset_) {
-      TF_CHECK_OK(GpuLaunchKernel(
-          GeneratePreNMSUprightBoxesKernel<true>, conf2d.block_count,
-          conf2d.thread_per_block, 0, d.stream(), conf2d,
-          d_sorted_conv_layer_indexes.flat<int>().data(),
-          reinterpret_cast<const float4*>(bbox_deltas.flat<float>().data()),
-          reinterpret_cast<const float4*>(anchors.flat<float>().data()), height,
-          width, num_anchors, min_size, image_info.flat<float>().data(),
-          bbox_xform_clip_default_,
-          reinterpret_cast<float4*>(dev_boxes.flat<float>().data()),
-          nboxes_to_generate, (char*)dev_boxes_keep_flags.flat<int8>().data()));
-    } else {
-      TF_CHECK_OK(GpuLaunchKernel(
-          GeneratePreNMSUprightBoxesKernel<false>, conf2d.block_count,
-          conf2d.thread_per_block, 0, d.stream(), conf2d,
-          d_sorted_conv_layer_indexes.flat<int>().data(),
-          reinterpret_cast<const float4*>(bbox_deltas.flat<float>().data()),
-          reinterpret_cast<const float4*>(anchors.flat<float>().data()), height,
-          width, num_anchors, min_size, image_info.flat<float>().data(),
-          bbox_xform_clip_default_,
-          reinterpret_cast<float4*>(dev_boxes.flat<float>().data()),
-          nboxes_to_generate, (char*)dev_boxes_keep_flags.flat<int8>().data()));
-    }
+    TF_CHECK_OK(GpuLaunchKernel(
+        GeneratePreNMSUprightBoxesKernel, conf2d.block_count,
+        conf2d.thread_per_block, 0, d.stream(), conf2d,
+        d_sorted_conv_layer_indexes.flat<int>().data(),
+        reinterpret_cast<const float4*>(bbox_deltas.flat<float>().data()),
+        reinterpret_cast<const float4*>(anchors.flat<float>().data()), height,
+        width, num_anchors, min_size, image_info.flat<float>().data(),
+        bbox_xform_clip_default_,
+        reinterpret_cast<float4*>(dev_boxes.flat<float>().data()),
+        nboxes_to_generate, (char*)dev_boxes_keep_flags.flat<int8>().data()));
     const int nboxes_generated = nboxes_to_generate;
     const int roi_cols = box_dim;
     const int max_postnms_nboxes = std::min(nboxes_generated, post_nms_topn_);
@@ -520,6 +490,8 @@
                                 &output_roi_probs));
     float* d_postnms_rois = (*output_rois).flat<float>().data();
     float* d_postnms_rois_probs = (*output_roi_probs).flat<float>().data();
+    cudaEvent_t copy_done;
+    cudaEventCreate(&copy_done);
 
     // Do  per-image nms
     for (int image_index = 0; image_index < num_images; ++image_index) {
@@ -559,14 +531,15 @@
                        d_image_prenms_scores, d_prenms_nboxes, nboxes_generated,
                        d.stream()));
       d.memcpyDeviceToHost(&h_prenms_nboxes, d_prenms_nboxes, sizeof(int));
-      d.synchronize();
+      TF_OP_REQUIRES_CUDA_SUCCESS(context,cudaEventRecord(copy_done, d.stream()));
+      TF_OP_REQUIRES_CUDA_SUCCESS(context,cudaEventSynchronize(copy_done));
       // We know prenms_boxes <= topN_prenms, because nboxes_generated <=
       // topN_prenms. Calling NMS on the generated boxes
       const int prenms_nboxes = h_prenms_nboxes;
       int nkeep;
       OP_REQUIRES_OK(context,
                      NmsGpu(d_image_prenms_boxes, prenms_nboxes, nms_threshold,
-                            d_image_boxes_keep_list, &nkeep, context));
+                            d_image_boxes_keep_list, &nkeep, context, post_nms_topn_));
       // All operations done after previous sort were keeping the relative order
       // of the elements the elements are still sorted keep topN <=> truncate
       // the array
@@ -589,7 +562,6 @@
  private:
   int post_nms_topn_;
   float bbox_xform_clip_default_;
-  bool use_detectron_offset_;
 };
 
 REGISTER_KERNEL_BUILDER(
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 50945c3..57c032f 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -994,7 +994,6 @@
     .Output("rois: float")
     .Output("roi_probabilities: float")
     .Attr("post_nms_topn: int = 300")
-    .Attr("use_detectron_offset: bool = false")
     .SetShapeFn([](InferenceContext* c) -> Status {
       // make sure input tensors have are correct rank
       ShapeHandle scores, images, bounding_boxes, anchors, nms_threshold,
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 685644e..6fd8b63 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -1478,7 +1478,7 @@
   }
   member_method {
     name: "GenerateBoundingBoxProposals"
-    argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'use_detectron_offset\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'False\', \'None\'], "
+    argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'None\'], "
   }
   member_method {
     name: "GenerateVocabRemapping"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 685644e..6fd8b63 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -1478,7 +1478,7 @@
   }
   member_method {
     name: "GenerateBoundingBoxProposals"
-    argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'use_detectron_offset\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'False\', \'None\'], "
+    argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'None\'], "
   }
   member_method {
     name: "GenerateVocabRemapping"