Rotated boxes support for GPU GenerateProposals op (#15470)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15470

On top of D13509114 and D13017791. Pretty straight-forward.

Reviewed By: newstzpz

Differential Revision: D13536671

fbshipit-source-id: ff65981b70c63773ccc9aef3ff28e3c9508f6716
diff --git a/caffe2/operators/generate_proposals_op.cu b/caffe2/operators/generate_proposals_op.cu
index fe89a66..ec92bd0 100644
--- a/caffe2/operators/generate_proposals_op.cu
+++ b/caffe2/operators/generate_proposals_op.cu
@@ -6,6 +6,8 @@
 #include "caffe2/operators/generate_proposals_op_util_nms.h"
 #include "caffe2/operators/generate_proposals_op_util_nms_gpu.h"
 
+using caffe2::utils::RotatedBox;
+
 namespace caffe2 {
 namespace {
 __global__ void GeneratePreNMSUprightBoxesKernel(
@@ -15,9 +17,7 @@
     const float4* d_anchors,
     const int H,
     const int W,
-    const int K, // K = H*W
     const int A,
-    const int KA, // KA = K*A
     const float feat_stride,
     const float min_size,
     const float* d_img_info_vec,
@@ -28,6 +28,8 @@
     const int prenms_nboxes, // leading dimension of out_boxes
     float* d_inout_scores,
     char* d_boxes_keep_flags) {
+  const int K = H * W;
+  const int KA = K * A;
   CUDA_2D_KERNEL_LOOP(ibox, nboxes_to_generate, image_index, num_images) {
     // box_conv_index : # of the same box, but indexed in
     // the scores from the conv layer, of shape (A,H,W)
@@ -48,7 +50,7 @@
     const int w = remaining; // dW = 1
 
     // Loading the anchor a
-    // float is a struct with float x,y,z,w
+    // float4 is a struct with float x,y,z,w
     const float4 anchor = d_anchors[a];
     // x1,y1,x2,y2 :coordinates of anchor a, shifted for position (h,w)
     const float shift_w = feat_stride * w;
@@ -131,7 +133,148 @@
   }
 }
 
-__global__ void WriteOutput(
+__global__ void GeneratePreNMSRotatedBoxesKernel(
+    const int* d_sorted_scores_keys,
+    const int nboxes_to_generate,
+    const float* d_bbox_deltas,
+    const RotatedBox* d_anchors,
+    const int H,
+    const int W,
+    const int A,
+    const float feat_stride,
+    const float min_size,
+    const float* d_img_info_vec,
+    const int num_images,
+    const float bbox_xform_clip,
+    const bool angle_bound_on,
+    const int angle_bound_lo,
+    const int angle_bound_hi,
+    const bool clip_angle_thresh,
+    RotatedBox* d_out_boxes,
+    const int prenms_nboxes, // leading dimension of out_boxes
+    float* d_inout_scores,
+    char* d_boxes_keep_flags) {
+  constexpr float PI = 3.14159265358979323846;
+  const int K = H * W;
+  const int KA = K * A;
+  CUDA_2D_KERNEL_LOOP(ibox, nboxes_to_generate, image_index, num_images) {
+    // box_conv_index : # of the same box, but indexed in
+    // the scores from the conv layer, of shape (A,H,W)
+    // the num_images dimension was already removed
+    // box_conv_index = a*K + h*W + w
+    const int box_conv_index = d_sorted_scores_keys[image_index * KA + ibox];
+
+    // We want to decompose box_conv_index in (a,h,w)
+    // such as box_conv_index = a*K + h*W + w
+    // (avoiding modulos in the process)
+    int remaining = box_conv_index;
+    const int dA = K; // stride of A
+    const int a = remaining / dA;
+    remaining -= a * dA;
+    const int dH = W; // stride of H
+    const int h = remaining / dH;
+    remaining -= h * dH;
+    const int w = remaining; // dW = 1
+
+    // Loading the anchor a and applying shifts.
+    // RotatedBox in [ctr_x, ctr_y, w, h, angle] format.
+    // Zero shift for width, height and angle.
+    RotatedBox box = d_anchors[a];
+    box.x_ctr += feat_stride * w; // x_ctr shifted for w
+    box.y_ctr += feat_stride * h; // y_ctr shifted for h
+
+    // TODO use fast math when possible
+
+    // Deltas for that box
+    // Deltas of shape (num_images,5*A,K)
+    // We're going to compute 5 scattered reads
+    // better than the alternative, ie transposing the complete deltas
+    // array first
+    int deltas_idx = image_index * (KA * 5) + a * 5 * K + h * W + w;
+    // Stride of K between each dimension
+    RotatedBox delta;
+    delta.x_ctr = d_bbox_deltas[deltas_idx + K * 0];
+    delta.y_ctr = d_bbox_deltas[deltas_idx + K * 1];
+    delta.w = d_bbox_deltas[deltas_idx + K * 2];
+    delta.h = d_bbox_deltas[deltas_idx + K * 3];
+    delta.a = d_bbox_deltas[deltas_idx + K * 4];
+
+    // Upper bound on dw,dh
+    delta.w = fmin(delta.w, bbox_xform_clip);
+    delta.h = fmin(delta.h, bbox_xform_clip);
+
+    // Convert back to degrees
+    delta.a *= 180.f / PI;
+
+    // Applying the deltas
+    box.x_ctr += delta.x_ctr * box.w;
+    box.y_ctr += delta.y_ctr * box.h;
+    box.w *= expf(delta.w);
+    box.h *= expf(delta.h);
+    box.a += delta.a;
+
+    if (angle_bound_on) {
+      // Normalize angle to be within [angle_bound_lo, angle_bound_hi].
+      // Deltas are guaranteed to be <= period / 2 while computing training
+      // targets by bbox_transform_inv.
+      const float period = angle_bound_hi - angle_bound_lo;
+      // CAFFE_ENFORCE(period > 0 && period % 180 == 0);
+      if (box.a < angle_bound_lo) {
+        box.a += period;
+      } else if (box.a > angle_bound_hi) {
+        box.a -= period;
+      }
+    }
+
+    // Clipping box to image.
+    // Only clip boxes that are almost upright (with a tolerance of
+    // clip_angle_thresh) for backward compatibility with horizontal boxes.
+    const float img_height = d_img_info_vec[3 * image_index + 0];
+    const float img_width = d_img_info_vec[3 * image_index + 1];
+    const float min_size_scaled =
+        min_size * d_img_info_vec[3 * image_index + 2];
+    if (fabs(box.a) <= clip_angle_thresh) {
+      // Convert from [x_ctr, y_ctr, w, h] to [x1, y1, x2, y2]
+      float x1 = box.x_ctr - (box.w - 1.f) / 2.f;
+      float y1 = box.y_ctr - (box.h - 1.f) / 2.f;
+      float x2 = x1 + box.w - 1.f;
+      float y2 = y1 + box.h - 1.f;
+
+      // Clip
+      x1 = fmax(fmin(x1, img_width - 1.0f), 0.0f);
+      y1 = fmax(fmin(y1, img_height - 1.0f), 0.0f);
+      x2 = fmax(fmin(x2, img_width - 1.0f), 0.0f);
+      y2 = fmax(fmin(y2, img_height - 1.0f), 0.0f);
+
+      // Convert back to [x_ctr, y_ctr, w, h]
+      box.x_ctr = (x1 + x2) / 2.f;
+      box.y_ctr = (y1 + y2) / 2.f;
+      box.w = x2 - x1 + 1.f;
+      box.h = y2 - y1 + 1.f;
+    }
+
+    // Filter boxes.
+    // Removing boxes with one dim < min_size or center outside the image.
+    bool keep_box = (fmin(box.w, box.h) >= min_size_scaled) &&
+        (box.x_ctr < img_width) && (box.y_ctr < img_height);
+
+    // We are not deleting the box right now even if !keep_box
+    // we want to keep the relative order of the elements stable
+    // we'll do it in such a way later
+    // d_boxes_keep_flags size: (num_images,prenms_nboxes)
+    // d_out_boxes size: (num_images,prenms_nboxes)
+    const int out_index = image_index * prenms_nboxes + ibox;
+    d_boxes_keep_flags[out_index] = keep_box;
+    d_out_boxes[out_index] = box;
+
+    // d_inout_scores size: (num_images,KA)
+    if (!keep_box) {
+      d_inout_scores[image_index * KA + ibox] = FLT_MIN; // for NMS
+    }
+  }
+}
+
+__global__ void WriteUprightBoxesOutput(
     const float4* d_image_boxes,
     const float* d_image_scores,
     const int* d_image_boxes_keep_list,
@@ -155,6 +298,31 @@
   }
 }
 
+__global__ void WriteRotatedBoxesOutput(
+    const RotatedBox* d_image_boxes,
+    const float* d_image_scores,
+    const int* d_image_boxes_keep_list,
+    const int nboxes,
+    const int image_index,
+    float* d_image_out_rois,
+    float* d_image_out_rois_probs) {
+  CUDA_1D_KERNEL_LOOP(i, nboxes) {
+    const int ibox = d_image_boxes_keep_list[i];
+    const RotatedBox box = d_image_boxes[ibox];
+    const float score = d_image_scores[ibox];
+    // Scattered memory accesses
+    // postnms_nboxes is small anyway
+    d_image_out_rois_probs[i] = score;
+    const int base_idx = 6 * i;
+    d_image_out_rois[base_idx + 0] = image_index;
+    d_image_out_rois[base_idx + 1] = box.x_ctr;
+    d_image_out_rois[base_idx + 2] = box.y_ctr;
+    d_image_out_rois[base_idx + 3] = box.w;
+    d_image_out_rois[base_idx + 4] = box.h;
+    d_image_out_rois[base_idx + 5] = box.a;
+  }
+}
+
 __global__ void InitializeDataKernel(
     const int num_images,
     const int KA,
@@ -191,11 +359,10 @@
   const auto A = scores.dim(1);
   const auto H = scores.dim(2);
   const auto W = scores.dim(3);
-  const auto box_dim_conv = anchors.dim(1);
+  const auto box_dim = anchors.dim(1);
 
-  CAFFE_ENFORCE(box_dim_conv == 4); // only upright boxes in GPU version for now
+  CAFFE_ENFORCE(box_dim == 4 || box_dim == 5);
 
-  constexpr int box_dim = 4;
   const int K = H * W;
   const int conv_layer_nboxes = K * A;
   // Getting data members ready
@@ -301,40 +468,65 @@
   const float* d_bbox_deltas = bbox_deltas.data<float>();
   const float* d_anchors = anchors.data<float>();
   const float* d_im_info_vec = im_info_tensor.data<float>();
-  float4* d_boxes =
-      reinterpret_cast<float4*>(dev_boxes_.template mutable_data<float>());
+  float* d_boxes = dev_boxes_.template mutable_data<float>();
   ;
   char* d_boxes_keep_flags =
       dev_boxes_keep_flags_.template mutable_data<char>();
 
-  GeneratePreNMSUprightBoxesKernel<<<
-      (CAFFE_GET_BLOCKS(nboxes_to_generate), num_images),
-      CAFFE_CUDA_NUM_THREADS, // blockDim.y == 1
-      0,
-      context_.cuda_stream()>>>(
-      d_sorted_conv_layer_indexes,
-      nboxes_to_generate,
-      d_bbox_deltas,
-      reinterpret_cast<const float4*>(d_anchors),
-      H,
-      W,
-      K,
-      A,
-      K * A,
-      feat_stride_,
-      rpn_min_size_,
-      d_im_info_vec,
-      num_images,
-      utils::BBOX_XFORM_CLIP_DEFAULT,
-      correct_transform_coords_,
-      d_boxes,
-      nboxes_to_generate,
-      d_sorted_scores,
-      d_boxes_keep_flags);
+  if (box_dim == 4) {
+    GeneratePreNMSUprightBoxesKernel<<<
+        (CAFFE_GET_BLOCKS(nboxes_to_generate), num_images),
+        CAFFE_CUDA_NUM_THREADS, // blockDim.y == 1
+        0,
+        context_.cuda_stream()>>>(
+        d_sorted_conv_layer_indexes,
+        nboxes_to_generate,
+        d_bbox_deltas,
+        reinterpret_cast<const float4*>(d_anchors),
+        H,
+        W,
+        A,
+        feat_stride_,
+        rpn_min_size_,
+        d_im_info_vec,
+        num_images,
+        utils::BBOX_XFORM_CLIP_DEFAULT,
+        correct_transform_coords_,
+        reinterpret_cast<float4*>(d_boxes),
+        nboxes_to_generate,
+        d_sorted_scores,
+        d_boxes_keep_flags);
+  } else {
+    GeneratePreNMSRotatedBoxesKernel<<<
+        (CAFFE_GET_BLOCKS(nboxes_to_generate), num_images),
+        CAFFE_CUDA_NUM_THREADS, // blockDim.y == 1
+        0,
+        context_.cuda_stream()>>>(
+        d_sorted_conv_layer_indexes,
+        nboxes_to_generate,
+        d_bbox_deltas,
+        reinterpret_cast<const RotatedBox*>(d_anchors),
+        H,
+        W,
+        A,
+        feat_stride_,
+        rpn_min_size_,
+        d_im_info_vec,
+        num_images,
+        utils::BBOX_XFORM_CLIP_DEFAULT,
+        angle_bound_on_,
+        angle_bound_lo_,
+        angle_bound_hi_,
+        clip_angle_thresh_,
+        reinterpret_cast<RotatedBox*>(d_boxes),
+        nboxes_to_generate,
+        d_sorted_scores,
+        d_boxes_keep_flags);
+  }
   const int nboxes_generated = nboxes_to_generate;
   dev_image_prenms_boxes_.Resize(box_dim * nboxes_generated);
-  float4* d_image_prenms_boxes = reinterpret_cast<float4*>(
-      dev_image_prenms_boxes_.template mutable_data<float>());
+  float* d_image_prenms_boxes =
+      dev_image_prenms_boxes_.template mutable_data<float>();
   dev_image_prenms_scores_.Resize(nboxes_generated);
   float* d_image_prenms_scores =
       dev_image_prenms_scores_.template mutable_data<float>();
@@ -342,8 +534,9 @@
   int* d_image_boxes_keep_list =
       dev_image_boxes_keep_list_.template mutable_data<int>();
 
+  const int roi_cols = box_dim + 1;
   const int max_postnms_nboxes = std::min(nboxes_generated, rpn_post_nms_topN_);
-  dev_postnms_rois_.Resize(5 * num_images * max_postnms_nboxes);
+  dev_postnms_rois_.Resize(roi_cols * num_images * max_postnms_nboxes);
   dev_postnms_rois_probs_.Resize(num_images * max_postnms_nboxes);
   float* d_postnms_rois = dev_postnms_rois_.template mutable_data<float>();
   float* d_postnms_rois_probs =
@@ -357,26 +550,39 @@
   int nrois_in_output = 0;
   for (int image_index = 0; image_index < num_images; ++image_index) {
     // Sub matrices for current image
-    const float4* d_image_boxes = &d_boxes[image_index * nboxes_generated];
+    const float* d_image_boxes =
+        &d_boxes[image_index * nboxes_generated * box_dim];
     const float* d_image_sorted_scores = &d_sorted_scores[image_index * K * A];
     char* d_image_boxes_keep_flags =
         &d_boxes_keep_flags[image_index * nboxes_generated];
 
-    float* d_image_postnms_rois = &d_postnms_rois[5 * nrois_in_output];
+    float* d_image_postnms_rois = &d_postnms_rois[roi_cols * nrois_in_output];
     float* d_image_postnms_rois_probs = &d_postnms_rois_probs[nrois_in_output];
 
     // Moving valid boxes (ie the ones with d_boxes_keep_flags[ibox] == true)
     // to the output tensors
 
-    cub::DeviceSelect::Flagged(
-        d_cub_select_temp_storage,
-        cub_select_temp_storage_bytes,
-        d_image_boxes,
-        d_image_boxes_keep_flags,
-        d_image_prenms_boxes,
-        d_prenms_nboxes,
-        nboxes_generated,
-        context_.cuda_stream());
+    if (box_dim == 4) {
+      cub::DeviceSelect::Flagged(
+          d_cub_select_temp_storage,
+          cub_select_temp_storage_bytes,
+          reinterpret_cast<const float4*>(d_image_boxes),
+          d_image_boxes_keep_flags,
+          reinterpret_cast<float4*>(d_image_prenms_boxes),
+          d_prenms_nboxes,
+          nboxes_generated,
+          context_.cuda_stream());
+    } else {
+      cub::DeviceSelect::Flagged(
+          d_cub_select_temp_storage,
+          cub_select_temp_storage_bytes,
+          reinterpret_cast<const RotatedBox*>(d_image_boxes),
+          d_image_boxes_keep_flags,
+          reinterpret_cast<RotatedBox*>(d_image_prenms_boxes),
+          d_prenms_nboxes,
+          nboxes_generated,
+          context_.cuda_stream());
+    }
 
     cub::DeviceSelect::Flagged(
         d_cub_select_temp_storage,
@@ -391,18 +597,19 @@
     host_prenms_nboxes_.CopyFrom(dev_prenms_nboxes_);
 
     // We know prenms_boxes <= topN_prenms, because nboxes_generated <=
-    // topN_prenms Calling NMS on the generated boxes
+    // topN_prenms. Calling NMS on the generated boxes
     const int prenms_nboxes = *h_prenms_nboxes;
     int nkeep;
-    utils::nms_gpu_upright(
-        reinterpret_cast<const float*>(d_image_prenms_boxes),
+    utils::nms_gpu(
+        d_image_prenms_boxes,
         prenms_nboxes,
         rpn_nms_thresh_,
         d_image_boxes_keep_list,
         &nkeep,
         dev_nms_mask_,
         host_nms_mask_,
-        &context_);
+        &context_,
+        box_dim);
 
     // All operations done after previous sort were keeping the relative order
     // of the elements the elements are still sorted keep topN <=> truncate the
@@ -411,24 +618,39 @@
 
     // Moving the out boxes to the output tensors,
     // adding the image_index dimension on the fly
-    WriteOutput<<<
-        CAFFE_GET_BLOCKS(postnms_nboxes),
-        CAFFE_CUDA_NUM_THREADS,
-        0,
-        context_.cuda_stream()>>>(
-        d_image_prenms_boxes,
-        d_image_prenms_scores,
-        d_image_boxes_keep_list,
-        postnms_nboxes,
-        image_index,
-        d_image_postnms_rois,
-        d_image_postnms_rois_probs);
+    if (box_dim == 4) {
+      WriteUprightBoxesOutput<<<
+          CAFFE_GET_BLOCKS(postnms_nboxes),
+          CAFFE_CUDA_NUM_THREADS,
+          0,
+          context_.cuda_stream()>>>(
+          reinterpret_cast<const float4*>(d_image_prenms_boxes),
+          d_image_prenms_scores,
+          d_image_boxes_keep_list,
+          postnms_nboxes,
+          image_index,
+          d_image_postnms_rois,
+          d_image_postnms_rois_probs);
+    } else {
+      WriteRotatedBoxesOutput<<<
+          CAFFE_GET_BLOCKS(postnms_nboxes),
+          CAFFE_CUDA_NUM_THREADS,
+          0,
+          context_.cuda_stream()>>>(
+          reinterpret_cast<const RotatedBox*>(d_image_prenms_boxes),
+          d_image_prenms_scores,
+          d_image_boxes_keep_list,
+          postnms_nboxes,
+          image_index,
+          d_image_postnms_rois,
+          d_image_postnms_rois_probs);
+    }
 
     nrois_in_output += postnms_nboxes;
   }
 
   // Using a buffer because we cannot call ShrinkTo
-  out_rois->Resize(nrois_in_output, 5);
+  out_rois->Resize(nrois_in_output, roi_cols);
   out_rois_probs->Resize(nrois_in_output);
   float* d_out_rois = out_rois->template mutable_data<float>();
   float* d_out_rois_probs = out_rois_probs->template mutable_data<float>();
@@ -436,7 +658,7 @@
   CUDA_CHECK(cudaMemcpyAsync(
       d_out_rois,
       d_postnms_rois,
-      nrois_in_output * 5 * sizeof(float),
+      nrois_in_output * roi_cols * sizeof(float),
       cudaMemcpyDeviceToDevice,
       context_.cuda_stream()));
   CUDA_CHECK(cudaMemcpyAsync(
diff --git a/caffe2/operators/generate_proposals_op_gpu_test.cc b/caffe2/operators/generate_proposals_op_gpu_test.cc
index 36ad5aa..817f345 100644
--- a/caffe2/operators/generate_proposals_op_gpu_test.cc
+++ b/caffe2/operators/generate_proposals_op_gpu_test.cc
@@ -6,6 +6,7 @@
 
 #include "caffe2/core/context.h"
 #include "caffe2/core/context_gpu.h"
+#include "caffe2/operators/generate_proposals_op_util_boxes.h"
 
 #ifdef CAFFE2_USE_OPENCV
 #include <opencv2/opencv.hpp>
@@ -236,4 +237,410 @@
       0,
       1e-4);
 }
+
+#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
+TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0GPU) {
+  // Similar to TestRealDownSampledGPU but for rotated boxes with angle info.
+  if (!HasCudaGPU())
+    return;
+
+  const float angle = 0;
+  const float delta_angle = 0;
+  const float clip_angle_thresh = 1.0;
+  const int box_dim = 5;
+
+  Workspace ws;
+  OperatorDef def;
+  def.set_name("test");
+  def.set_type("GenerateProposals");
+  def.add_input("scores");
+  def.add_input("bbox_deltas");
+  def.add_input("im_info");
+  def.add_input("anchors");
+  def.add_output("rois");
+  def.add_output("rois_probs");
+  def.mutable_device_option()->set_device_type(PROTO_CUDA);
+  const int img_count = 2;
+  const int A = 2;
+  const int H = 4;
+  const int W = 5;
+
+  vector<float> scores{
+      5.44218998e-03f, 1.19207997e-03f, 1.12379994e-03f, 1.17181998e-03f,
+      1.20544003e-03f, 6.17993006e-04f, 1.05261997e-05f, 8.91025957e-06f,
+      9.29536981e-09f, 6.09605013e-05f, 4.72735002e-04f, 1.13482002e-10f,
+      1.50015003e-05f, 4.45032993e-06f, 3.21612994e-08f, 8.02662980e-04f,
+      1.40488002e-04f, 3.12508007e-07f, 3.02616991e-06f, 1.97759000e-08f,
+      2.66913995e-02f, 5.26766013e-03f, 5.05053019e-03f, 5.62100019e-03f,
+      5.37420018e-03f, 5.26280981e-03f, 2.48894998e-04f, 1.06842002e-04f,
+      3.92931997e-06f, 1.79388002e-03f, 4.79440019e-03f, 3.41609990e-07f,
+      5.20430971e-04f, 3.34090000e-05f, 2.19159006e-07f, 2.28786003e-03f,
+      5.16703985e-05f, 4.04523007e-06f, 1.79227004e-06f, 5.32449000e-08f};
+  vector<float> bbx{
+      -1.65040009e-02f, -1.84051003e-02f, -1.85930002e-02f, -2.08263006e-02f,
+      -1.83814000e-02f, -2.89172009e-02f, -3.89706008e-02f, -7.52277970e-02f,
+      -1.54091999e-01f, -2.55433004e-02f, -1.77490003e-02f, -1.10340998e-01f,
+      -4.20190990e-02f, -2.71421000e-02f, 6.89801015e-03f,  5.71171008e-02f,
+      -1.75665006e-01f, 2.30021998e-02f,  3.08554992e-02f,  -1.39333997e-02f,
+      3.40579003e-01f,  3.91070992e-01f,  3.91624004e-01f,  3.92527014e-01f,
+      3.91445011e-01f,  3.79328012e-01f,  4.26631987e-01f,  3.64892989e-01f,
+      2.76894987e-01f,  5.13985991e-01f,  3.79999995e-01f,  1.80457994e-01f,
+      4.37402993e-01f,  4.18545991e-01f,  2.51549989e-01f,  4.48318988e-01f,
+      1.68564007e-01f,  4.65440989e-01f,  4.21891987e-01f,  4.45928007e-01f,
+      3.27155995e-03f,  3.71480011e-03f,  3.60032008e-03f,  4.27092984e-03f,
+      3.74579988e-03f,  5.95752988e-03f,  -3.14473989e-03f, 3.52022005e-03f,
+      -1.88564006e-02f, 1.65188999e-03f,  1.73791999e-03f,  -3.56074013e-02f,
+      -1.66615995e-04f, 3.14146001e-03f,  -1.11830998e-02f, -5.35363983e-03f,
+      6.49790000e-03f,  -9.27671045e-03f, -2.83346009e-02f, -1.61233004e-02f,
+      -2.15505004e-01f, -2.19910994e-01f, -2.20872998e-01f, -2.12831005e-01f,
+      -2.19145000e-01f, -2.27687001e-01f, -3.43973994e-01f, -2.75869995e-01f,
+      -3.19516987e-01f, -2.50418007e-01f, -2.48537004e-01f, -5.08224010e-01f,
+      -2.28724003e-01f, -2.82402009e-01f, -3.75815988e-01f, -2.86352992e-01f,
+      -5.28333001e-02f, -4.43836004e-01f, -4.55134988e-01f, -4.34897989e-01f,
+      -5.65053988e-03f, -9.25739005e-04f, -1.06790999e-03f, -2.37016007e-03f,
+      -9.71166010e-04f, -8.90910998e-03f, -1.17592998e-02f, -2.08992008e-02f,
+      -4.94231991e-02f, 6.63906988e-03f,  3.20469006e-03f,  -6.44695014e-02f,
+      -3.11607006e-03f, 2.02738005e-03f,  1.48096997e-02f,  4.39785011e-02f,
+      -8.28424022e-02f, 3.62076014e-02f,  2.71668993e-02f,  1.38250999e-02f,
+      6.76669031e-02f,  1.03252999e-01f,  1.03255004e-01f,  9.89722982e-02f,
+      1.03646003e-01f,  4.79663983e-02f,  1.11014001e-01f,  9.31736007e-02f,
+      1.15768999e-01f,  1.04014002e-01f,  -8.90677981e-03f, 1.13103002e-01f,
+      1.33085996e-01f,  1.25405997e-01f,  1.50051996e-01f,  -1.13038003e-01f,
+      7.01059997e-02f,  1.79651007e-01f,  1.41055003e-01f,  1.62841007e-01f,
+      -1.00247003e-02f, -8.17587040e-03f, -8.32176022e-03f, -8.90108012e-03f,
+      -8.13035015e-03f, -1.77263003e-02f, -3.69572006e-02f, -3.51580009e-02f,
+      -5.92143014e-02f, -1.80795006e-02f, -5.46086021e-03f, -4.10550982e-02f,
+      -1.83081999e-02f, -2.15411000e-02f, -1.17953997e-02f, 3.33894007e-02f,
+      -5.29635996e-02f, -6.97528012e-03f, -3.15250992e-03f, -3.27355005e-02f,
+      1.29676998e-01f,  1.16080999e-01f,  1.15947001e-01f,  1.21797003e-01f,
+      1.16089001e-01f,  1.44875005e-01f,  1.15617000e-01f,  1.31586999e-01f,
+      1.74735002e-02f,  1.21973999e-01f,  1.31596997e-01f,  2.48907991e-02f,
+      6.18605018e-02f,  1.12855002e-01f,  -6.99798986e-02f, 9.58312973e-02f,
+      1.53593004e-01f,  -8.75087008e-02f, -4.92327996e-02f, -3.32239009e-02f};
+
+  // Add angle in bbox deltas
+  int num_boxes = scores.size();
+  CHECK_EQ(bbx.size() / 4, num_boxes);
+  vector<float> bbx_with_angle(num_boxes * box_dim);
+  // bbx (deltas) is in shape (A * 4, H, W). Insert angle delta
+  // at each spatial location for each anchor.
+  int i = 0, j = 0;
+  for (int a = 0; a < A; ++a) {
+    for (int k = 0; k < 4 * H * W; ++k) {
+      bbx_with_angle[i++] = bbx[j++];
+    }
+    for (int k = 0; k < H * W; ++k) {
+      bbx_with_angle[i++] = delta_angle;
+    }
+  }
+
+  vector<float> im_info{60, 80, 0.166667f};
+  // vector<float> anchors{-38, -16, 53, 31, -120, -120, 135, 135};
+  // Anchors in [x_ctr, y_ctr, w, h, angle] format
+  vector<float> anchors{7.5, 7.5, 92, 48, angle, 7.5, 7.5, 256, 256, angle};
+
+  // Doubling everything related to images, to simulate
+  // num_images = 2
+  scores.insert(scores.begin(), scores.begin(), scores.end());
+  bbx_with_angle.insert(
+      bbx_with_angle.begin(), bbx_with_angle.begin(), bbx_with_angle.end());
+  im_info.insert(im_info.begin(), im_info.begin(), im_info.end());
+
+  // Results should exactly be the same as TestRealDownSampledGPU since
+  // angle = 0 for all boxes and clip_angle_thresh > 0 (which means
+  // all horizontal boxes will be clipped to maintain backward compatibility).
+  ERMatXf rois_gt_xyxy(18, 5);
+  rois_gt_xyxy << 0, 0, 0, 79, 59, 0, 0, 5.0005703f, 51.6324f, 42.6950f, 0,
+      24.13628387f, 7.51243401f, 79, 45.0663f, 0, 0, 7.50924301f, 67.4779f,
+      45.0336, 0, 0, 23.09477997f, 50.61448669f, 59, 0, 0, 39.52141571f,
+      51.44710541f, 59, 0, 23.57396317f, 29.98791885f, 79, 59, 0, 0,
+      41.90219116f, 79, 59, 0, 0, 23.30098343f, 78.2413f, 58.7287f, 1, 0, 0, 79,
+      59, 1, 0, 5.0005703f, 51.6324f, 42.6950f, 1, 24.13628387f, 7.51243401f,
+      79, 45.0663f, 1, 0, 7.50924301f, 67.4779f, 45.0336, 1, 0, 23.09477997f,
+      50.61448669f, 59, 1, 0, 39.52141571f, 51.44710541f, 59, 1, 23.57396317f,
+      29.98791885f, 79, 59, 1, 0, 41.90219116f, 79, 59, 1, 0, 23.30098343f,
+      78.2413f, 58.7287f;
+  ERMatXf rois_gt(rois_gt_xyxy.rows(), 6);
+  // Batch ID
+  rois_gt.block(0, 0, rois_gt.rows(), 1) =
+      rois_gt_xyxy.block(0, 0, rois_gt.rows(), 0);
+  // rois_gt in [x_ctr, y_ctr, w, h] format
+  rois_gt.block(0, 1, rois_gt.rows(), 4) = utils::bbox_xyxy_to_ctrwh(
+      rois_gt_xyxy.block(0, 1, rois_gt.rows(), 4).array());
+  // Angle
+  rois_gt.block(0, 5, rois_gt.rows(), 1) =
+      ERMatXf::Constant(rois_gt.rows(), 1, angle);
+
+  vector<float> rois_probs_gt{2.66913995e-02f,
+                              5.44218998e-03f,
+                              1.20544003e-03f,
+                              1.19207997e-03f,
+                              6.17993006e-04f,
+                              4.72735002e-04f,
+                              6.09605013e-05f,
+                              1.50015003e-05f,
+                              8.91025957e-06f};
+  // Doubling everything related to images, to simulate
+  // num_images = 2
+  rois_probs_gt.insert(
+      rois_probs_gt.begin(), rois_probs_gt.begin(), rois_probs_gt.end());
+
+  AddInput<CUDAContext>(
+      vector<int64_t>{img_count, A, H, W}, scores, "scores", &ws);
+  AddInput<CUDAContext>(
+      vector<int64_t>{img_count, box_dim * A, H, W},
+      bbx_with_angle,
+      "bbox_deltas",
+      &ws);
+  AddInput<CUDAContext>(vector<int64_t>{img_count, 3}, im_info, "im_info", &ws);
+  AddInput<CUDAContext>(vector<int64_t>{A, box_dim}, anchors, "anchors", &ws);
+
+  def.add_arg()->CopyFrom(MakeArgument("spatial_scale", 1.0f / 16.0f));
+  def.add_arg()->CopyFrom(MakeArgument("pre_nms_topN", 6000));
+  def.add_arg()->CopyFrom(MakeArgument("post_nms_topN", 300));
+  def.add_arg()->CopyFrom(MakeArgument("nms_thresh", 0.7f));
+  def.add_arg()->CopyFrom(MakeArgument("min_size", 16.0f));
+  def.add_arg()->CopyFrom(MakeArgument("correct_transform_coords", true));
+  def.add_arg()->CopyFrom(MakeArgument("clip_angle_thresh", clip_angle_thresh));
+
+  unique_ptr<OperatorBase> op(CreateOperator(def, &ws));
+  EXPECT_NE(nullptr, op.get());
+  EXPECT_TRUE(op->Run());
+
+  // test rois
+  Blob* rois_blob = ws.GetBlob("rois");
+  EXPECT_NE(nullptr, rois_blob);
+  auto& rois_gpu = rois_blob->Get<TensorCUDA>();
+  Tensor rois{CPU};
+  rois.CopyFrom(rois_gpu);
+
+  EXPECT_EQ(rois.dims(), (vector<int64_t>{rois_gt.rows(), rois_gt.cols()}));
+  auto rois_data =
+      Eigen::Map<const ERMatXf>(rois.data<float>(), rois.dim(0), rois.dim(1));
+  EXPECT_NEAR((rois_data.matrix() - rois_gt).cwiseAbs().maxCoeff(), 0, 1e-4);
+
+  // test rois_probs
+  Blob* rois_probs_blob = ws.GetBlob("rois_probs");
+  EXPECT_NE(nullptr, rois_probs_blob);
+  auto& rois_probs_gpu = rois_probs_blob->Get<TensorCUDA>();
+  Tensor rois_probs{CPU};
+  rois_probs.CopyFrom(rois_probs_gpu);
+  EXPECT_EQ(
+      rois_probs.dims(), (vector<int64_t>{int64_t(rois_probs_gt.size())}));
+  auto rois_probs_data =
+      ConstEigenVectorArrayMap<float>(rois_probs.data<float>(), rois.dim(0));
+  EXPECT_NEAR(
+      (rois_probs_data.matrix() - utils::AsEArrXt(rois_probs_gt).matrix())
+          .cwiseAbs()
+          .maxCoeff(),
+      0,
+      1e-4);
+}
+
+TEST(GenerateProposalsTest, TestRealDownSampledRotatedGPU) {
+  // Similar to TestRealDownSampledGPU but for rotated boxes with angle info.
+  if (!HasCudaGPU())
+    return;
+
+  const float angle = 45.0;
+  const float delta_angle = 0.174533; // 0.174533 radians -> 10 degrees
+  const float expected_angle = 55.0;
+  const float clip_angle_thresh = 1.0;
+  const int box_dim = 5;
+
+  Workspace ws;
+  OperatorDef def;
+  def.set_name("test");
+  def.set_type("GenerateProposals");
+  def.add_input("scores");
+  def.add_input("bbox_deltas");
+  def.add_input("im_info");
+  def.add_input("anchors");
+  def.add_output("rois");
+  def.add_output("rois_probs");
+  def.mutable_device_option()->set_device_type(PROTO_CUDA);
+  const int img_count = 2;
+  const int A = 2;
+  const int H = 4;
+  const int W = 5;
+
+  vector<float> scores{
+      5.44218998e-03f, 1.19207997e-03f, 1.12379994e-03f, 1.17181998e-03f,
+      1.20544003e-03f, 6.17993006e-04f, 1.05261997e-05f, 8.91025957e-06f,
+      9.29536981e-09f, 6.09605013e-05f, 4.72735002e-04f, 1.13482002e-10f,
+      1.50015003e-05f, 4.45032993e-06f, 3.21612994e-08f, 8.02662980e-04f,
+      1.40488002e-04f, 3.12508007e-07f, 3.02616991e-06f, 1.97759000e-08f,
+      2.66913995e-02f, 5.26766013e-03f, 5.05053019e-03f, 5.62100019e-03f,
+      5.37420018e-03f, 5.26280981e-03f, 2.48894998e-04f, 1.06842002e-04f,
+      3.92931997e-06f, 1.79388002e-03f, 4.79440019e-03f, 3.41609990e-07f,
+      5.20430971e-04f, 3.34090000e-05f, 2.19159006e-07f, 2.28786003e-03f,
+      5.16703985e-05f, 4.04523007e-06f, 1.79227004e-06f, 5.32449000e-08f};
+  vector<float> bbx{
+      -1.65040009e-02f, -1.84051003e-02f, -1.85930002e-02f, -2.08263006e-02f,
+      -1.83814000e-02f, -2.89172009e-02f, -3.89706008e-02f, -7.52277970e-02f,
+      -1.54091999e-01f, -2.55433004e-02f, -1.77490003e-02f, -1.10340998e-01f,
+      -4.20190990e-02f, -2.71421000e-02f, 6.89801015e-03f,  5.71171008e-02f,
+      -1.75665006e-01f, 2.30021998e-02f,  3.08554992e-02f,  -1.39333997e-02f,
+      3.40579003e-01f,  3.91070992e-01f,  3.91624004e-01f,  3.92527014e-01f,
+      3.91445011e-01f,  3.79328012e-01f,  4.26631987e-01f,  3.64892989e-01f,
+      2.76894987e-01f,  5.13985991e-01f,  3.79999995e-01f,  1.80457994e-01f,
+      4.37402993e-01f,  4.18545991e-01f,  2.51549989e-01f,  4.48318988e-01f,
+      1.68564007e-01f,  4.65440989e-01f,  4.21891987e-01f,  4.45928007e-01f,
+      3.27155995e-03f,  3.71480011e-03f,  3.60032008e-03f,  4.27092984e-03f,
+      3.74579988e-03f,  5.95752988e-03f,  -3.14473989e-03f, 3.52022005e-03f,
+      -1.88564006e-02f, 1.65188999e-03f,  1.73791999e-03f,  -3.56074013e-02f,
+      -1.66615995e-04f, 3.14146001e-03f,  -1.11830998e-02f, -5.35363983e-03f,
+      6.49790000e-03f,  -9.27671045e-03f, -2.83346009e-02f, -1.61233004e-02f,
+      -2.15505004e-01f, -2.19910994e-01f, -2.20872998e-01f, -2.12831005e-01f,
+      -2.19145000e-01f, -2.27687001e-01f, -3.43973994e-01f, -2.75869995e-01f,
+      -3.19516987e-01f, -2.50418007e-01f, -2.48537004e-01f, -5.08224010e-01f,
+      -2.28724003e-01f, -2.82402009e-01f, -3.75815988e-01f, -2.86352992e-01f,
+      -5.28333001e-02f, -4.43836004e-01f, -4.55134988e-01f, -4.34897989e-01f,
+      -5.65053988e-03f, -9.25739005e-04f, -1.06790999e-03f, -2.37016007e-03f,
+      -9.71166010e-04f, -8.90910998e-03f, -1.17592998e-02f, -2.08992008e-02f,
+      -4.94231991e-02f, 6.63906988e-03f,  3.20469006e-03f,  -6.44695014e-02f,
+      -3.11607006e-03f, 2.02738005e-03f,  1.48096997e-02f,  4.39785011e-02f,
+      -8.28424022e-02f, 3.62076014e-02f,  2.71668993e-02f,  1.38250999e-02f,
+      6.76669031e-02f,  1.03252999e-01f,  1.03255004e-01f,  9.89722982e-02f,
+      1.03646003e-01f,  4.79663983e-02f,  1.11014001e-01f,  9.31736007e-02f,
+      1.15768999e-01f,  1.04014002e-01f,  -8.90677981e-03f, 1.13103002e-01f,
+      1.33085996e-01f,  1.25405997e-01f,  1.50051996e-01f,  -1.13038003e-01f,
+      7.01059997e-02f,  1.79651007e-01f,  1.41055003e-01f,  1.62841007e-01f,
+      -1.00247003e-02f, -8.17587040e-03f, -8.32176022e-03f, -8.90108012e-03f,
+      -8.13035015e-03f, -1.77263003e-02f, -3.69572006e-02f, -3.51580009e-02f,
+      -5.92143014e-02f, -1.80795006e-02f, -5.46086021e-03f, -4.10550982e-02f,
+      -1.83081999e-02f, -2.15411000e-02f, -1.17953997e-02f, 3.33894007e-02f,
+      -5.29635996e-02f, -6.97528012e-03f, -3.15250992e-03f, -3.27355005e-02f,
+      1.29676998e-01f,  1.16080999e-01f,  1.15947001e-01f,  1.21797003e-01f,
+      1.16089001e-01f,  1.44875005e-01f,  1.15617000e-01f,  1.31586999e-01f,
+      1.74735002e-02f,  1.21973999e-01f,  1.31596997e-01f,  2.48907991e-02f,
+      6.18605018e-02f,  1.12855002e-01f,  -6.99798986e-02f, 9.58312973e-02f,
+      1.53593004e-01f,  -8.75087008e-02f, -4.92327996e-02f, -3.32239009e-02f};
+
+  // Add angle in bbox deltas
+  int num_boxes = scores.size();
+  CHECK_EQ(bbx.size() / 4, num_boxes);
+  vector<float> bbx_with_angle(num_boxes * box_dim);
+  // bbx (deltas) is in shape (A * 4, H, W). Insert angle delta
+  // at each spatial location for each anchor.
+  int i = 0, j = 0;
+  for (int a = 0; a < A; ++a) {
+    for (int k = 0; k < 4 * H * W; ++k) {
+      bbx_with_angle[i++] = bbx[j++];
+    }
+    for (int k = 0; k < H * W; ++k) {
+      bbx_with_angle[i++] = delta_angle;
+    }
+  }
+
+  vector<float> im_info{60, 80, 0.166667f};
+  // vector<float> anchors{-38, -16, 53, 31, -120, -120, 135, 135};
+  vector<float> anchors{8, 8, 92, 48, angle, 8, 8, 256, 256, angle};
+
+  // Doubling everything related to images, to simulate
+  // num_images = 2
+  scores.insert(scores.begin(), scores.begin(), scores.end());
+  bbx_with_angle.insert(
+      bbx_with_angle.begin(), bbx_with_angle.begin(), bbx_with_angle.end());
+  im_info.insert(im_info.begin(), im_info.begin(), im_info.end());
+
+  ERMatXf rois_gt(26, 6);
+  rois_gt <<
+      0, 6.55346, 25.3227, 253.447, 291.446, expected_angle,
+      0, 55.3932, 33.3369, 253.731, 289.158, expected_angle,
+      0, 6.48163, 24.3478, 92.3015, 38.6944, expected_angle,
+      0, 70.3089, 26.7894, 92.3453, 38.5539, expected_angle,
+      0, 22.3067, 26.7714, 92.3424, 38.5243, expected_angle,
+      0, 54.084, 26.8413, 92.3938, 38.798, expected_angle,
+      0, 38.2894, 26.798, 92.3318, 38.4873, expected_angle,
+      0, 5.33962, 42.2077, 92.5497, 38.2259, expected_angle,
+      0, 6.36709, 58.24, 92.16, 37.4372, expected_angle,
+      0, 69.65, 48.6713, 92.1521, 37.3668, expected_angle,
+      0, 20.4147, 44.4783, 91.7111, 34.0295, expected_angle,
+      0, 33.079, 41.5149, 92.3244, 36.4278, expected_angle,
+      0, 41.8235, 37.291, 90.2815, 34.872, expected_angle,
+      1, 6.55346, 25.3227, 253.447, 291.446, expected_angle,
+      1, 55.3932, 33.3369, 253.731, 289.158, expected_angle,
+      1, 6.48163, 24.3478, 92.3015, 38.6944, expected_angle,
+      1, 70.3089, 26.7894, 92.3453, 38.5539, expected_angle,
+      1, 22.3067, 26.7714, 92.3424, 38.5243, expected_angle,
+      1, 54.084, 26.8413, 92.3938, 38.798, expected_angle,
+      1, 38.2894, 26.798, 92.3318, 38.4873, expected_angle,
+      1, 5.33962, 42.2077, 92.5497, 38.2259, expected_angle,
+      1, 6.36709, 58.24, 92.16, 37.4372, expected_angle,
+      1, 69.65, 48.6713, 92.1521, 37.3668, expected_angle,
+      1, 20.4147, 44.4783, 91.7111, 34.0295, expected_angle,
+      1, 33.079, 41.5149, 92.3244, 36.4278, expected_angle,
+      1, 41.8235, 37.291, 90.2815, 34.872, expected_angle;
+
+  vector<float> rois_probs_gt{2.66913995e-02f,
+                              5.621e-03f,
+                              5.44218998e-03f,
+                              1.20544003e-03f,
+                              1.19207997e-03f,
+                              1.17182e-03f,
+                              1.1238e-03f,
+                              6.17993006e-04f,
+                              4.72735002e-04f,
+                              6.09605013e-05f,
+                              1.50015003e-05f,
+                              8.91025957e-06f,
+                              9.29537e-09f};
+  // Doubling everything related to images, to simulate
+  // num_images = 2
+  rois_probs_gt.insert(
+      rois_probs_gt.begin(), rois_probs_gt.begin(), rois_probs_gt.end());
+
+  AddInput<CUDAContext>(
+      vector<int64_t>{img_count, A, H, W}, scores, "scores", &ws);
+  AddInput<CUDAContext>(
+      vector<int64_t>{img_count, box_dim * A, H, W},
+      bbx_with_angle,
+      "bbox_deltas",
+      &ws);
+  AddInput<CUDAContext>(vector<int64_t>{img_count, 3}, im_info, "im_info", &ws);
+  AddInput<CUDAContext>(vector<int64_t>{A, box_dim}, anchors, "anchors", &ws);
+
+  def.add_arg()->CopyFrom(MakeArgument("spatial_scale", 1.0f / 16.0f));
+  def.add_arg()->CopyFrom(MakeArgument("pre_nms_topN", 6000));
+  def.add_arg()->CopyFrom(MakeArgument("post_nms_topN", 300));
+  def.add_arg()->CopyFrom(MakeArgument("nms_thresh", 0.7f));
+  def.add_arg()->CopyFrom(MakeArgument("min_size", 16.0f));
+  def.add_arg()->CopyFrom(MakeArgument("correct_transform_coords", true));
+  def.add_arg()->CopyFrom(MakeArgument("clip_angle_thresh", clip_angle_thresh));
+
+  unique_ptr<OperatorBase> op(CreateOperator(def, &ws));
+  EXPECT_NE(nullptr, op.get());
+  EXPECT_TRUE(op->Run());
+
+  // test rois
+  Blob* rois_blob = ws.GetBlob("rois");
+  EXPECT_NE(nullptr, rois_blob);
+  auto& rois_gpu = rois_blob->Get<TensorCUDA>();
+  Tensor rois{CPU};
+  rois.CopyFrom(rois_gpu);
+  EXPECT_EQ(rois.dims(), (vector<int64_t>{26, 6}));
+  auto rois_data =
+      Eigen::Map<const ERMatXf>(rois.data<float>(), rois.size(0), rois.size(1));
+  EXPECT_NEAR((rois_data.matrix() - rois_gt).cwiseAbs().maxCoeff(), 0, 1e-3);
+
+  // test rois_probs
+  Blob* rois_probs_blob = ws.GetBlob("rois_probs");
+  EXPECT_NE(nullptr, rois_probs_blob);
+  auto& rois_probs_gpu = rois_probs_blob->Get<TensorCUDA>();
+  Tensor rois_probs{CPU};
+  rois_probs.CopyFrom(rois_probs_gpu);
+  EXPECT_EQ(
+      rois_probs.dims(), (vector<int64_t>{int64_t(rois_probs_gt.size())}));
+  auto rois_probs_data =
+      ConstEigenVectorArrayMap<float>(rois_probs.data<float>(), rois.size(0));
+  EXPECT_NEAR(
+      (rois_probs_data.matrix() - utils::AsEArrXt(rois_probs_gt).matrix())
+          .cwiseAbs()
+          .maxCoeff(),
+      0,
+      1e-4);
+}
+#endif // CV_MAJOR_VERSION >= 3
+
 } // namespace caffe2
diff --git a/caffe2/operators/generate_proposals_op_test.cc b/caffe2/operators/generate_proposals_op_test.cc
index 4d76075..004402b 100644
--- a/caffe2/operators/generate_proposals_op_test.cc
+++ b/caffe2/operators/generate_proposals_op_test.cc
@@ -4,6 +4,8 @@
 #include "caffe2/core/flags.h"
 #include "caffe2/core/macros.h"
 
+#include "caffe2/operators/generate_proposals_op_util_boxes.h"
+
 #ifdef CAFFE2_USE_OPENCV
 #include <opencv2/opencv.hpp>
 #endif // CAFFE2_USE_OPENCV
@@ -142,22 +144,6 @@
   }
 }
 
-namespace {
-
-template <class Derived>
-ERMatXf boxes_xyxy_to_xywh(const Eigen::MatrixBase<Derived>& boxes) {
-  CAFFE_ENFORCE_EQ(boxes.cols(), 4);
-  ERMatXf res(boxes.rows(), 4);
-  auto ones = ERMatXf::Constant(boxes.rows(), 1, 1.0);
-  res.col(0) = (boxes.col(0) + boxes.col(2)) / 2.0; // ctr_x = (x1 + x2)/2
-  res.col(1) = (boxes.col(1) + boxes.col(3)) / 2.0; // ctr_y = (y1 + y2)/2
-  res.col(2) = boxes.col(2) - boxes.col(0) + ones; // w = x2 - x1 + 1
-  res.col(3) = boxes.col(3) - boxes.col(1) + ones; // h = y2 - y1 + 1
-  return res;
-}
-
-} // namespace
-
 TEST(GenerateProposalsTest, TestComputeAllAnchorsRotated) {
   // Similar to TestComputeAllAnchors but for rotated boxes with angle info.
   ERMatXf anchors_xyxy(3, 4);
@@ -165,7 +151,7 @@
 
   // Convert to RRPN format and add angles
   ERMatXf anchors(3, 5);
-  anchors.block(0, 0, 3, 4) = boxes_xyxy_to_xywh(anchors_xyxy);
+  anchors.block(0, 0, 3, 4) = utils::bbox_xyxy_to_ctrwh(anchors_xyxy.array());
   std::vector<float> angles{0.0, 45.0, -120.0};
   for (int i = 0; i < anchors.rows(); ++i) {
     anchors(i, 4) = angles[i % angles.size()];
@@ -188,7 +174,8 @@
 
   // Convert gt to RRPN format and add angles
   ERMatXf all_anchors_gt(36, 5);
-  all_anchors_gt.block(0, 0, 36, 4) = boxes_xyxy_to_xywh(all_anchors_gt_xyxy);
+  all_anchors_gt.block(0, 0, 36, 4) =
+      utils::bbox_xyxy_to_ctrwh(all_anchors_gt_xyxy.array());
   for (int i = 0; i < all_anchors_gt.rows(); ++i) {
     all_anchors_gt(i, 4) = angles[i % angles.size()];
   }
@@ -213,7 +200,7 @@
 
   // Convert to RRPN format and add angles
   ERMatXf anchors(3, 5);
-  anchors.block(0, 0, 3, 4) = boxes_xyxy_to_xywh(anchors_xyxy);
+  anchors.block(0, 0, 3, 4) = utils::bbox_xyxy_to_ctrwh(anchors_xyxy.array());
   std::vector<float> angles{0.0, 45.0, -120.0};
   for (int i = 0; i < anchors.rows(); ++i) {
     anchors(i, 4) = angles[i % angles.size()];
@@ -433,9 +420,10 @@
 #if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
 TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0) {
   // Similar to TestRealDownSampled but for rotated boxes with angle info.
-  float angle = 0;
-  float delta_angle = 0;
-  float clip_angle_thresh = 1.0;
+  const float angle = 0;
+  const float delta_angle = 0;
+  const float clip_angle_thresh = 1.0;
+  const int box_dim = 5;
 
   Workspace ws;
   OperatorDef def;
@@ -508,7 +496,7 @@
   // Add angle in bbox deltas
   int num_boxes = scores.size();
   CHECK_EQ(bbx.size() / 4, num_boxes);
-  vector<float> bbx_with_angle(num_boxes * 5);
+  vector<float> bbx_with_angle(num_boxes * box_dim);
   // bbx (deltas) is in shape (A * 4, H, W). Insert angle delta
   // at each spatial location for each anchor.
   int i = 0, j = 0;
@@ -535,13 +523,13 @@
       45.0336, 0, 0, 23.09477997f, 50.61448669f, 59, 0, 0, 39.52141571f,
       51.44710541f, 59, 0, 23.57396317f, 29.98791885f, 79, 59, 0, 0,
       41.90219116f, 79, 59, 0, 0, 23.30098343f, 78.2413f, 58.7287f;
-  ERMatXf rois_gt(9, 6);
+  ERMatXf rois_gt(rois_gt_xyxy.rows(), 6);
   // Batch ID
   rois_gt.block(0, 0, rois_gt.rows(), 1) =
-      ERMatXf::Constant(rois_gt.rows(), 1, 0.0);
+      rois_gt_xyxy.block(0, 0, rois_gt.rows(), 0);
   // rois_gt in [x_ctr, y_ctr, w, h] format
-  rois_gt.block(0, 1, rois_gt.rows(), 4) =
-      boxes_xyxy_to_xywh(rois_gt_xyxy.block(0, 1, rois_gt.rows(), 4));
+  rois_gt.block(0, 1, rois_gt.rows(), 4) = utils::bbox_xyxy_to_ctrwh(
+      rois_gt_xyxy.block(0, 1, rois_gt.rows(), 4).array());
   // Angle
   rois_gt.block(0, 5, rois_gt.rows(), 1) =
       ERMatXf::Constant(rois_gt.rows(), 1, angle);
@@ -557,12 +545,12 @@
 
   AddInput(vector<int64_t>{img_count, A, H, W}, scores, "scores", &ws);
   AddInput(
-      vector<int64_t>{img_count, 5 * A, H, W},
+      vector<int64_t>{img_count, box_dim * A, H, W},
       bbx_with_angle,
       "bbox_deltas",
       &ws);
   AddInput(vector<int64_t>{img_count, 3}, im_info, "im_info", &ws);
-  AddInput(vector<int64_t>{A, 5}, anchors, "anchors", &ws);
+  AddInput(vector<int64_t>{A, box_dim}, anchors, "anchors", &ws);
 
   def.add_arg()->CopyFrom(MakeArgument("spatial_scale", 1.0f / 16.0f));
   def.add_arg()->CopyFrom(MakeArgument("pre_nms_topN", 6000));
@@ -603,10 +591,11 @@
 
 TEST(GenerateProposalsTest, TestRealDownSampledRotated) {
   // Similar to TestRealDownSampled but for rotated boxes with angle info.
-  float angle = 45.0;
-  float delta_angle = 0.174533; // 0.174533 radians -> 10 degrees
-  float expected_angle = 55.0;
-  float clip_angle_thresh = 1.0;
+  const float angle = 45.0;
+  const float delta_angle = 0.174533; // 0.174533 radians -> 10 degrees
+  const float expected_angle = 55.0;
+  const float clip_angle_thresh = 1.0;
+  const int box_dim = 5;
 
   Workspace ws;
   OperatorDef def;
@@ -679,7 +668,7 @@
   // Add angle in bbox deltas
   int num_boxes = scores.size();
   CHECK_EQ(bbx.size() / 4, num_boxes);
-  vector<float> bbx_with_angle(num_boxes * 5);
+  vector<float> bbx_with_angle(num_boxes * box_dim);
   // bbx (deltas) is in shape (A * 4, H, W). Insert angle delta
   // at each spatial location for each anchor.
   {
@@ -700,12 +689,12 @@
 
   AddInput(vector<int64_t>{img_count, A, H, W}, scores, "scores", &ws);
   AddInput(
-      vector<int64_t>{img_count, 5 * A, H, W},
+      vector<int64_t>{img_count, box_dim * A, H, W},
       bbx_with_angle,
       "bbox_deltas",
       &ws);
   AddInput(vector<int64_t>{img_count, 3}, im_info, "im_info", &ws);
-  AddInput(vector<int64_t>{A, 5}, anchors, "anchors", &ws);
+  AddInput(vector<int64_t>{A, box_dim}, anchors, "anchors", &ws);
 
   def.add_arg()->CopyFrom(MakeArgument("spatial_scale", 1.0f / 16.0f));
   def.add_arg()->CopyFrom(MakeArgument("pre_nms_topN", 6000));
diff --git a/caffe2/operators/generate_proposals_op_util_nms_gpu.cu b/caffe2/operators/generate_proposals_op_util_nms_gpu.cu
index 584c2f4..0cf157c 100644
--- a/caffe2/operators/generate_proposals_op_util_nms_gpu.cu
+++ b/caffe2/operators/generate_proposals_op_util_nms_gpu.cu
@@ -195,10 +195,6 @@
 }
 
 namespace {
-struct RotatedBox {
-  float x_ctr, y_ctr, w, h, a;
-};
-
 struct Point {
   float x, y;
 };
diff --git a/caffe2/operators/generate_proposals_op_util_nms_gpu.h b/caffe2/operators/generate_proposals_op_util_nms_gpu.h
index 04ba8f5..da7a840 100644
--- a/caffe2/operators/generate_proposals_op_util_nms_gpu.h
+++ b/caffe2/operators/generate_proposals_op_util_nms_gpu.h
@@ -33,6 +33,10 @@
     TensorCPU& host_delete_mask,
     CUDAContext* context);
 
+struct RotatedBox {
+  float x_ctr, y_ctr, w, h, a;
+};
+
 // Same as nms_gpu_upright, but for rotated boxes with angle info.
 // d_desc_sorted_boxes : pixel coordinates of proposed bounding boxes
 //    size: (N,5), format: [x_ct; y_ctr; width; height; angle]