Rotated boxes support for GPU GenerateProposals op (#15470)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15470
On top of D13509114 and D13017791. Pretty straight-forward.
Reviewed By: newstzpz
Differential Revision: D13536671
fbshipit-source-id: ff65981b70c63773ccc9aef3ff28e3c9508f6716
diff --git a/caffe2/operators/generate_proposals_op.cu b/caffe2/operators/generate_proposals_op.cu
index fe89a66..ec92bd0 100644
--- a/caffe2/operators/generate_proposals_op.cu
+++ b/caffe2/operators/generate_proposals_op.cu
@@ -6,6 +6,8 @@
#include "caffe2/operators/generate_proposals_op_util_nms.h"
#include "caffe2/operators/generate_proposals_op_util_nms_gpu.h"
+using caffe2::utils::RotatedBox;
+
namespace caffe2 {
namespace {
__global__ void GeneratePreNMSUprightBoxesKernel(
@@ -15,9 +17,7 @@
const float4* d_anchors,
const int H,
const int W,
- const int K, // K = H*W
const int A,
- const int KA, // KA = K*A
const float feat_stride,
const float min_size,
const float* d_img_info_vec,
@@ -28,6 +28,8 @@
const int prenms_nboxes, // leading dimension of out_boxes
float* d_inout_scores,
char* d_boxes_keep_flags) {
+ const int K = H * W;
+ const int KA = K * A;
CUDA_2D_KERNEL_LOOP(ibox, nboxes_to_generate, image_index, num_images) {
// box_conv_index : # of the same box, but indexed in
// the scores from the conv layer, of shape (A,H,W)
@@ -48,7 +50,7 @@
const int w = remaining; // dW = 1
// Loading the anchor a
- // float is a struct with float x,y,z,w
+ // float4 is a struct with float x,y,z,w
const float4 anchor = d_anchors[a];
// x1,y1,x2,y2 :coordinates of anchor a, shifted for position (h,w)
const float shift_w = feat_stride * w;
@@ -131,7 +133,148 @@
}
}
-__global__ void WriteOutput(
+__global__ void GeneratePreNMSRotatedBoxesKernel(
+ const int* d_sorted_scores_keys,
+ const int nboxes_to_generate,
+ const float* d_bbox_deltas,
+ const RotatedBox* d_anchors,
+ const int H,
+ const int W,
+ const int A,
+ const float feat_stride,
+ const float min_size,
+ const float* d_img_info_vec,
+ const int num_images,
+ const float bbox_xform_clip,
+ const bool angle_bound_on,
+ const int angle_bound_lo,
+ const int angle_bound_hi,
+ const bool clip_angle_thresh,
+ RotatedBox* d_out_boxes,
+ const int prenms_nboxes, // leading dimension of out_boxes
+ float* d_inout_scores,
+ char* d_boxes_keep_flags) {
+ constexpr float PI = 3.14159265358979323846;
+ const int K = H * W;
+ const int KA = K * A;
+ CUDA_2D_KERNEL_LOOP(ibox, nboxes_to_generate, image_index, num_images) {
+ // box_conv_index : # of the same box, but indexed in
+ // the scores from the conv layer, of shape (A,H,W)
+ // the num_images dimension was already removed
+ // box_conv_index = a*K + h*W + w
+ const int box_conv_index = d_sorted_scores_keys[image_index * KA + ibox];
+
+ // We want to decompose box_conv_index in (a,h,w)
+ // such as box_conv_index = a*K + h*W + w
+ // (avoiding modulos in the process)
+ int remaining = box_conv_index;
+ const int dA = K; // stride of A
+ const int a = remaining / dA;
+ remaining -= a * dA;
+ const int dH = W; // stride of H
+ const int h = remaining / dH;
+ remaining -= h * dH;
+ const int w = remaining; // dW = 1
+
+ // Loading the anchor a and applying shifts.
+ // RotatedBox in [ctr_x, ctr_y, w, h, angle] format.
+ // Zero shift for width, height and angle.
+ RotatedBox box = d_anchors[a];
+ box.x_ctr += feat_stride * w; // x_ctr shifted for w
+ box.y_ctr += feat_stride * h; // y_ctr shifted for h
+
+ // TODO use fast math when possible
+
+ // Deltas for that box
+ // Deltas of shape (num_images,5*A,K)
+ // We're going to compute 5 scattered reads
+ // better than the alternative, ie transposing the complete deltas
+ // array first
+ int deltas_idx = image_index * (KA * 5) + a * 5 * K + h * W + w;
+ // Stride of K between each dimension
+ RotatedBox delta;
+ delta.x_ctr = d_bbox_deltas[deltas_idx + K * 0];
+ delta.y_ctr = d_bbox_deltas[deltas_idx + K * 1];
+ delta.w = d_bbox_deltas[deltas_idx + K * 2];
+ delta.h = d_bbox_deltas[deltas_idx + K * 3];
+ delta.a = d_bbox_deltas[deltas_idx + K * 4];
+
+ // Upper bound on dw,dh
+ delta.w = fmin(delta.w, bbox_xform_clip);
+ delta.h = fmin(delta.h, bbox_xform_clip);
+
+ // Convert back to degrees
+ delta.a *= 180.f / PI;
+
+ // Applying the deltas
+ box.x_ctr += delta.x_ctr * box.w;
+ box.y_ctr += delta.y_ctr * box.h;
+ box.w *= expf(delta.w);
+ box.h *= expf(delta.h);
+ box.a += delta.a;
+
+ if (angle_bound_on) {
+ // Normalize angle to be within [angle_bound_lo, angle_bound_hi].
+ // Deltas are guaranteed to be <= period / 2 while computing training
+ // targets by bbox_transform_inv.
+ const float period = angle_bound_hi - angle_bound_lo;
+ // CAFFE_ENFORCE(period > 0 && period % 180 == 0);
+ if (box.a < angle_bound_lo) {
+ box.a += period;
+ } else if (box.a > angle_bound_hi) {
+ box.a -= period;
+ }
+ }
+
+ // Clipping box to image.
+ // Only clip boxes that are almost upright (with a tolerance of
+ // clip_angle_thresh) for backward compatibility with horizontal boxes.
+ const float img_height = d_img_info_vec[3 * image_index + 0];
+ const float img_width = d_img_info_vec[3 * image_index + 1];
+ const float min_size_scaled =
+ min_size * d_img_info_vec[3 * image_index + 2];
+ if (fabs(box.a) <= clip_angle_thresh) {
+ // Convert from [x_ctr, y_ctr, w, h] to [x1, y1, x2, y2]
+ float x1 = box.x_ctr - (box.w - 1.f) / 2.f;
+ float y1 = box.y_ctr - (box.h - 1.f) / 2.f;
+ float x2 = x1 + box.w - 1.f;
+ float y2 = y1 + box.h - 1.f;
+
+ // Clip
+ x1 = fmax(fmin(x1, img_width - 1.0f), 0.0f);
+ y1 = fmax(fmin(y1, img_height - 1.0f), 0.0f);
+ x2 = fmax(fmin(x2, img_width - 1.0f), 0.0f);
+ y2 = fmax(fmin(y2, img_height - 1.0f), 0.0f);
+
+ // Convert back to [x_ctr, y_ctr, w, h]
+ box.x_ctr = (x1 + x2) / 2.f;
+ box.y_ctr = (y1 + y2) / 2.f;
+ box.w = x2 - x1 + 1.f;
+ box.h = y2 - y1 + 1.f;
+ }
+
+ // Filter boxes.
+ // Removing boxes with one dim < min_size or center outside the image.
+ bool keep_box = (fmin(box.w, box.h) >= min_size_scaled) &&
+ (box.x_ctr < img_width) && (box.y_ctr < img_height);
+
+ // We are not deleting the box right now even if !keep_box
+ // we want to keep the relative order of the elements stable
+ // we'll do it in such a way later
+ // d_boxes_keep_flags size: (num_images,prenms_nboxes)
+ // d_out_boxes size: (num_images,prenms_nboxes)
+ const int out_index = image_index * prenms_nboxes + ibox;
+ d_boxes_keep_flags[out_index] = keep_box;
+ d_out_boxes[out_index] = box;
+
+ // d_inout_scores size: (num_images,KA)
+ if (!keep_box) {
+ d_inout_scores[image_index * KA + ibox] = FLT_MIN; // for NMS
+ }
+ }
+}
+
+__global__ void WriteUprightBoxesOutput(
const float4* d_image_boxes,
const float* d_image_scores,
const int* d_image_boxes_keep_list,
@@ -155,6 +298,31 @@
}
}
+__global__ void WriteRotatedBoxesOutput(
+ const RotatedBox* d_image_boxes,
+ const float* d_image_scores,
+ const int* d_image_boxes_keep_list,
+ const int nboxes,
+ const int image_index,
+ float* d_image_out_rois,
+ float* d_image_out_rois_probs) {
+ CUDA_1D_KERNEL_LOOP(i, nboxes) {
+ const int ibox = d_image_boxes_keep_list[i];
+ const RotatedBox box = d_image_boxes[ibox];
+ const float score = d_image_scores[ibox];
+ // Scattered memory accesses
+ // postnms_nboxes is small anyway
+ d_image_out_rois_probs[i] = score;
+ const int base_idx = 6 * i;
+ d_image_out_rois[base_idx + 0] = image_index;
+ d_image_out_rois[base_idx + 1] = box.x_ctr;
+ d_image_out_rois[base_idx + 2] = box.y_ctr;
+ d_image_out_rois[base_idx + 3] = box.w;
+ d_image_out_rois[base_idx + 4] = box.h;
+ d_image_out_rois[base_idx + 5] = box.a;
+ }
+}
+
__global__ void InitializeDataKernel(
const int num_images,
const int KA,
@@ -191,11 +359,10 @@
const auto A = scores.dim(1);
const auto H = scores.dim(2);
const auto W = scores.dim(3);
- const auto box_dim_conv = anchors.dim(1);
+ const auto box_dim = anchors.dim(1);
- CAFFE_ENFORCE(box_dim_conv == 4); // only upright boxes in GPU version for now
+ CAFFE_ENFORCE(box_dim == 4 || box_dim == 5);
- constexpr int box_dim = 4;
const int K = H * W;
const int conv_layer_nboxes = K * A;
// Getting data members ready
@@ -301,40 +468,65 @@
const float* d_bbox_deltas = bbox_deltas.data<float>();
const float* d_anchors = anchors.data<float>();
const float* d_im_info_vec = im_info_tensor.data<float>();
- float4* d_boxes =
- reinterpret_cast<float4*>(dev_boxes_.template mutable_data<float>());
+ float* d_boxes = dev_boxes_.template mutable_data<float>();
;
char* d_boxes_keep_flags =
dev_boxes_keep_flags_.template mutable_data<char>();
- GeneratePreNMSUprightBoxesKernel<<<
- (CAFFE_GET_BLOCKS(nboxes_to_generate), num_images),
- CAFFE_CUDA_NUM_THREADS, // blockDim.y == 1
- 0,
- context_.cuda_stream()>>>(
- d_sorted_conv_layer_indexes,
- nboxes_to_generate,
- d_bbox_deltas,
- reinterpret_cast<const float4*>(d_anchors),
- H,
- W,
- K,
- A,
- K * A,
- feat_stride_,
- rpn_min_size_,
- d_im_info_vec,
- num_images,
- utils::BBOX_XFORM_CLIP_DEFAULT,
- correct_transform_coords_,
- d_boxes,
- nboxes_to_generate,
- d_sorted_scores,
- d_boxes_keep_flags);
+ if (box_dim == 4) {
+ GeneratePreNMSUprightBoxesKernel<<<
+ (CAFFE_GET_BLOCKS(nboxes_to_generate), num_images),
+ CAFFE_CUDA_NUM_THREADS, // blockDim.y == 1
+ 0,
+ context_.cuda_stream()>>>(
+ d_sorted_conv_layer_indexes,
+ nboxes_to_generate,
+ d_bbox_deltas,
+ reinterpret_cast<const float4*>(d_anchors),
+ H,
+ W,
+ A,
+ feat_stride_,
+ rpn_min_size_,
+ d_im_info_vec,
+ num_images,
+ utils::BBOX_XFORM_CLIP_DEFAULT,
+ correct_transform_coords_,
+ reinterpret_cast<float4*>(d_boxes),
+ nboxes_to_generate,
+ d_sorted_scores,
+ d_boxes_keep_flags);
+ } else {
+ GeneratePreNMSRotatedBoxesKernel<<<
+ (CAFFE_GET_BLOCKS(nboxes_to_generate), num_images),
+ CAFFE_CUDA_NUM_THREADS, // blockDim.y == 1
+ 0,
+ context_.cuda_stream()>>>(
+ d_sorted_conv_layer_indexes,
+ nboxes_to_generate,
+ d_bbox_deltas,
+ reinterpret_cast<const RotatedBox*>(d_anchors),
+ H,
+ W,
+ A,
+ feat_stride_,
+ rpn_min_size_,
+ d_im_info_vec,
+ num_images,
+ utils::BBOX_XFORM_CLIP_DEFAULT,
+ angle_bound_on_,
+ angle_bound_lo_,
+ angle_bound_hi_,
+ clip_angle_thresh_,
+ reinterpret_cast<RotatedBox*>(d_boxes),
+ nboxes_to_generate,
+ d_sorted_scores,
+ d_boxes_keep_flags);
+ }
const int nboxes_generated = nboxes_to_generate;
dev_image_prenms_boxes_.Resize(box_dim * nboxes_generated);
- float4* d_image_prenms_boxes = reinterpret_cast<float4*>(
- dev_image_prenms_boxes_.template mutable_data<float>());
+ float* d_image_prenms_boxes =
+ dev_image_prenms_boxes_.template mutable_data<float>();
dev_image_prenms_scores_.Resize(nboxes_generated);
float* d_image_prenms_scores =
dev_image_prenms_scores_.template mutable_data<float>();
@@ -342,8 +534,9 @@
int* d_image_boxes_keep_list =
dev_image_boxes_keep_list_.template mutable_data<int>();
+ const int roi_cols = box_dim + 1;
const int max_postnms_nboxes = std::min(nboxes_generated, rpn_post_nms_topN_);
- dev_postnms_rois_.Resize(5 * num_images * max_postnms_nboxes);
+ dev_postnms_rois_.Resize(roi_cols * num_images * max_postnms_nboxes);
dev_postnms_rois_probs_.Resize(num_images * max_postnms_nboxes);
float* d_postnms_rois = dev_postnms_rois_.template mutable_data<float>();
float* d_postnms_rois_probs =
@@ -357,26 +550,39 @@
int nrois_in_output = 0;
for (int image_index = 0; image_index < num_images; ++image_index) {
// Sub matrices for current image
- const float4* d_image_boxes = &d_boxes[image_index * nboxes_generated];
+ const float* d_image_boxes =
+ &d_boxes[image_index * nboxes_generated * box_dim];
const float* d_image_sorted_scores = &d_sorted_scores[image_index * K * A];
char* d_image_boxes_keep_flags =
&d_boxes_keep_flags[image_index * nboxes_generated];
- float* d_image_postnms_rois = &d_postnms_rois[5 * nrois_in_output];
+ float* d_image_postnms_rois = &d_postnms_rois[roi_cols * nrois_in_output];
float* d_image_postnms_rois_probs = &d_postnms_rois_probs[nrois_in_output];
// Moving valid boxes (ie the ones with d_boxes_keep_flags[ibox] == true)
// to the output tensors
- cub::DeviceSelect::Flagged(
- d_cub_select_temp_storage,
- cub_select_temp_storage_bytes,
- d_image_boxes,
- d_image_boxes_keep_flags,
- d_image_prenms_boxes,
- d_prenms_nboxes,
- nboxes_generated,
- context_.cuda_stream());
+ if (box_dim == 4) {
+ cub::DeviceSelect::Flagged(
+ d_cub_select_temp_storage,
+ cub_select_temp_storage_bytes,
+ reinterpret_cast<const float4*>(d_image_boxes),
+ d_image_boxes_keep_flags,
+ reinterpret_cast<float4*>(d_image_prenms_boxes),
+ d_prenms_nboxes,
+ nboxes_generated,
+ context_.cuda_stream());
+ } else {
+ cub::DeviceSelect::Flagged(
+ d_cub_select_temp_storage,
+ cub_select_temp_storage_bytes,
+ reinterpret_cast<const RotatedBox*>(d_image_boxes),
+ d_image_boxes_keep_flags,
+ reinterpret_cast<RotatedBox*>(d_image_prenms_boxes),
+ d_prenms_nboxes,
+ nboxes_generated,
+ context_.cuda_stream());
+ }
cub::DeviceSelect::Flagged(
d_cub_select_temp_storage,
@@ -391,18 +597,19 @@
host_prenms_nboxes_.CopyFrom(dev_prenms_nboxes_);
// We know prenms_boxes <= topN_prenms, because nboxes_generated <=
- // topN_prenms Calling NMS on the generated boxes
+ // topN_prenms. Calling NMS on the generated boxes
const int prenms_nboxes = *h_prenms_nboxes;
int nkeep;
- utils::nms_gpu_upright(
- reinterpret_cast<const float*>(d_image_prenms_boxes),
+ utils::nms_gpu(
+ d_image_prenms_boxes,
prenms_nboxes,
rpn_nms_thresh_,
d_image_boxes_keep_list,
&nkeep,
dev_nms_mask_,
host_nms_mask_,
- &context_);
+ &context_,
+ box_dim);
// All operations done after previous sort were keeping the relative order
// of the elements the elements are still sorted keep topN <=> truncate the
@@ -411,24 +618,39 @@
// Moving the out boxes to the output tensors,
// adding the image_index dimension on the fly
- WriteOutput<<<
- CAFFE_GET_BLOCKS(postnms_nboxes),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- d_image_prenms_boxes,
- d_image_prenms_scores,
- d_image_boxes_keep_list,
- postnms_nboxes,
- image_index,
- d_image_postnms_rois,
- d_image_postnms_rois_probs);
+ if (box_dim == 4) {
+ WriteUprightBoxesOutput<<<
+ CAFFE_GET_BLOCKS(postnms_nboxes),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context_.cuda_stream()>>>(
+ reinterpret_cast<const float4*>(d_image_prenms_boxes),
+ d_image_prenms_scores,
+ d_image_boxes_keep_list,
+ postnms_nboxes,
+ image_index,
+ d_image_postnms_rois,
+ d_image_postnms_rois_probs);
+ } else {
+ WriteRotatedBoxesOutput<<<
+ CAFFE_GET_BLOCKS(postnms_nboxes),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context_.cuda_stream()>>>(
+ reinterpret_cast<const RotatedBox*>(d_image_prenms_boxes),
+ d_image_prenms_scores,
+ d_image_boxes_keep_list,
+ postnms_nboxes,
+ image_index,
+ d_image_postnms_rois,
+ d_image_postnms_rois_probs);
+ }
nrois_in_output += postnms_nboxes;
}
// Using a buffer because we cannot call ShrinkTo
- out_rois->Resize(nrois_in_output, 5);
+ out_rois->Resize(nrois_in_output, roi_cols);
out_rois_probs->Resize(nrois_in_output);
float* d_out_rois = out_rois->template mutable_data<float>();
float* d_out_rois_probs = out_rois_probs->template mutable_data<float>();
@@ -436,7 +658,7 @@
CUDA_CHECK(cudaMemcpyAsync(
d_out_rois,
d_postnms_rois,
- nrois_in_output * 5 * sizeof(float),
+ nrois_in_output * roi_cols * sizeof(float),
cudaMemcpyDeviceToDevice,
context_.cuda_stream()));
CUDA_CHECK(cudaMemcpyAsync(
diff --git a/caffe2/operators/generate_proposals_op_gpu_test.cc b/caffe2/operators/generate_proposals_op_gpu_test.cc
index 36ad5aa..817f345 100644
--- a/caffe2/operators/generate_proposals_op_gpu_test.cc
+++ b/caffe2/operators/generate_proposals_op_gpu_test.cc
@@ -6,6 +6,7 @@
#include "caffe2/core/context.h"
#include "caffe2/core/context_gpu.h"
+#include "caffe2/operators/generate_proposals_op_util_boxes.h"
#ifdef CAFFE2_USE_OPENCV
#include <opencv2/opencv.hpp>
@@ -236,4 +237,410 @@
0,
1e-4);
}
+
+#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
+TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0GPU) {
+ // Similar to TestRealDownSampledGPU but for rotated boxes with angle info.
+ if (!HasCudaGPU())
+ return;
+
+ const float angle = 0;
+ const float delta_angle = 0;
+ const float clip_angle_thresh = 1.0;
+ const int box_dim = 5;
+
+ Workspace ws;
+ OperatorDef def;
+ def.set_name("test");
+ def.set_type("GenerateProposals");
+ def.add_input("scores");
+ def.add_input("bbox_deltas");
+ def.add_input("im_info");
+ def.add_input("anchors");
+ def.add_output("rois");
+ def.add_output("rois_probs");
+ def.mutable_device_option()->set_device_type(PROTO_CUDA);
+ const int img_count = 2;
+ const int A = 2;
+ const int H = 4;
+ const int W = 5;
+
+ vector<float> scores{
+ 5.44218998e-03f, 1.19207997e-03f, 1.12379994e-03f, 1.17181998e-03f,
+ 1.20544003e-03f, 6.17993006e-04f, 1.05261997e-05f, 8.91025957e-06f,
+ 9.29536981e-09f, 6.09605013e-05f, 4.72735002e-04f, 1.13482002e-10f,
+ 1.50015003e-05f, 4.45032993e-06f, 3.21612994e-08f, 8.02662980e-04f,
+ 1.40488002e-04f, 3.12508007e-07f, 3.02616991e-06f, 1.97759000e-08f,
+ 2.66913995e-02f, 5.26766013e-03f, 5.05053019e-03f, 5.62100019e-03f,
+ 5.37420018e-03f, 5.26280981e-03f, 2.48894998e-04f, 1.06842002e-04f,
+ 3.92931997e-06f, 1.79388002e-03f, 4.79440019e-03f, 3.41609990e-07f,
+ 5.20430971e-04f, 3.34090000e-05f, 2.19159006e-07f, 2.28786003e-03f,
+ 5.16703985e-05f, 4.04523007e-06f, 1.79227004e-06f, 5.32449000e-08f};
+ vector<float> bbx{
+ -1.65040009e-02f, -1.84051003e-02f, -1.85930002e-02f, -2.08263006e-02f,
+ -1.83814000e-02f, -2.89172009e-02f, -3.89706008e-02f, -7.52277970e-02f,
+ -1.54091999e-01f, -2.55433004e-02f, -1.77490003e-02f, -1.10340998e-01f,
+ -4.20190990e-02f, -2.71421000e-02f, 6.89801015e-03f, 5.71171008e-02f,
+ -1.75665006e-01f, 2.30021998e-02f, 3.08554992e-02f, -1.39333997e-02f,
+ 3.40579003e-01f, 3.91070992e-01f, 3.91624004e-01f, 3.92527014e-01f,
+ 3.91445011e-01f, 3.79328012e-01f, 4.26631987e-01f, 3.64892989e-01f,
+ 2.76894987e-01f, 5.13985991e-01f, 3.79999995e-01f, 1.80457994e-01f,
+ 4.37402993e-01f, 4.18545991e-01f, 2.51549989e-01f, 4.48318988e-01f,
+ 1.68564007e-01f, 4.65440989e-01f, 4.21891987e-01f, 4.45928007e-01f,
+ 3.27155995e-03f, 3.71480011e-03f, 3.60032008e-03f, 4.27092984e-03f,
+ 3.74579988e-03f, 5.95752988e-03f, -3.14473989e-03f, 3.52022005e-03f,
+ -1.88564006e-02f, 1.65188999e-03f, 1.73791999e-03f, -3.56074013e-02f,
+ -1.66615995e-04f, 3.14146001e-03f, -1.11830998e-02f, -5.35363983e-03f,
+ 6.49790000e-03f, -9.27671045e-03f, -2.83346009e-02f, -1.61233004e-02f,
+ -2.15505004e-01f, -2.19910994e-01f, -2.20872998e-01f, -2.12831005e-01f,
+ -2.19145000e-01f, -2.27687001e-01f, -3.43973994e-01f, -2.75869995e-01f,
+ -3.19516987e-01f, -2.50418007e-01f, -2.48537004e-01f, -5.08224010e-01f,
+ -2.28724003e-01f, -2.82402009e-01f, -3.75815988e-01f, -2.86352992e-01f,
+ -5.28333001e-02f, -4.43836004e-01f, -4.55134988e-01f, -4.34897989e-01f,
+ -5.65053988e-03f, -9.25739005e-04f, -1.06790999e-03f, -2.37016007e-03f,
+ -9.71166010e-04f, -8.90910998e-03f, -1.17592998e-02f, -2.08992008e-02f,
+ -4.94231991e-02f, 6.63906988e-03f, 3.20469006e-03f, -6.44695014e-02f,
+ -3.11607006e-03f, 2.02738005e-03f, 1.48096997e-02f, 4.39785011e-02f,
+ -8.28424022e-02f, 3.62076014e-02f, 2.71668993e-02f, 1.38250999e-02f,
+ 6.76669031e-02f, 1.03252999e-01f, 1.03255004e-01f, 9.89722982e-02f,
+ 1.03646003e-01f, 4.79663983e-02f, 1.11014001e-01f, 9.31736007e-02f,
+ 1.15768999e-01f, 1.04014002e-01f, -8.90677981e-03f, 1.13103002e-01f,
+ 1.33085996e-01f, 1.25405997e-01f, 1.50051996e-01f, -1.13038003e-01f,
+ 7.01059997e-02f, 1.79651007e-01f, 1.41055003e-01f, 1.62841007e-01f,
+ -1.00247003e-02f, -8.17587040e-03f, -8.32176022e-03f, -8.90108012e-03f,
+ -8.13035015e-03f, -1.77263003e-02f, -3.69572006e-02f, -3.51580009e-02f,
+ -5.92143014e-02f, -1.80795006e-02f, -5.46086021e-03f, -4.10550982e-02f,
+ -1.83081999e-02f, -2.15411000e-02f, -1.17953997e-02f, 3.33894007e-02f,
+ -5.29635996e-02f, -6.97528012e-03f, -3.15250992e-03f, -3.27355005e-02f,
+ 1.29676998e-01f, 1.16080999e-01f, 1.15947001e-01f, 1.21797003e-01f,
+ 1.16089001e-01f, 1.44875005e-01f, 1.15617000e-01f, 1.31586999e-01f,
+ 1.74735002e-02f, 1.21973999e-01f, 1.31596997e-01f, 2.48907991e-02f,
+ 6.18605018e-02f, 1.12855002e-01f, -6.99798986e-02f, 9.58312973e-02f,
+ 1.53593004e-01f, -8.75087008e-02f, -4.92327996e-02f, -3.32239009e-02f};
+
+ // Add angle in bbox deltas
+ int num_boxes = scores.size();
+ CHECK_EQ(bbx.size() / 4, num_boxes);
+ vector<float> bbx_with_angle(num_boxes * box_dim);
+ // bbx (deltas) is in shape (A * 4, H, W). Insert angle delta
+ // at each spatial location for each anchor.
+ int i = 0, j = 0;
+ for (int a = 0; a < A; ++a) {
+ for (int k = 0; k < 4 * H * W; ++k) {
+ bbx_with_angle[i++] = bbx[j++];
+ }
+ for (int k = 0; k < H * W; ++k) {
+ bbx_with_angle[i++] = delta_angle;
+ }
+ }
+
+ vector<float> im_info{60, 80, 0.166667f};
+ // vector<float> anchors{-38, -16, 53, 31, -120, -120, 135, 135};
+ // Anchors in [x_ctr, y_ctr, w, h, angle] format
+ vector<float> anchors{7.5, 7.5, 92, 48, angle, 7.5, 7.5, 256, 256, angle};
+
+ // Doubling everything related to images, to simulate
+ // num_images = 2
+ scores.insert(scores.begin(), scores.begin(), scores.end());
+ bbx_with_angle.insert(
+ bbx_with_angle.begin(), bbx_with_angle.begin(), bbx_with_angle.end());
+ im_info.insert(im_info.begin(), im_info.begin(), im_info.end());
+
+ // Results should exactly be the same as TestRealDownSampledGPU since
+ // angle = 0 for all boxes and clip_angle_thresh > 0 (which means
+ // all horizontal boxes will be clipped to maintain backward compatibility).
+ ERMatXf rois_gt_xyxy(18, 5);
+ rois_gt_xyxy << 0, 0, 0, 79, 59, 0, 0, 5.0005703f, 51.6324f, 42.6950f, 0,
+ 24.13628387f, 7.51243401f, 79, 45.0663f, 0, 0, 7.50924301f, 67.4779f,
+ 45.0336, 0, 0, 23.09477997f, 50.61448669f, 59, 0, 0, 39.52141571f,
+ 51.44710541f, 59, 0, 23.57396317f, 29.98791885f, 79, 59, 0, 0,
+ 41.90219116f, 79, 59, 0, 0, 23.30098343f, 78.2413f, 58.7287f, 1, 0, 0, 79,
+ 59, 1, 0, 5.0005703f, 51.6324f, 42.6950f, 1, 24.13628387f, 7.51243401f,
+ 79, 45.0663f, 1, 0, 7.50924301f, 67.4779f, 45.0336, 1, 0, 23.09477997f,
+ 50.61448669f, 59, 1, 0, 39.52141571f, 51.44710541f, 59, 1, 23.57396317f,
+ 29.98791885f, 79, 59, 1, 0, 41.90219116f, 79, 59, 1, 0, 23.30098343f,
+ 78.2413f, 58.7287f;
+ ERMatXf rois_gt(rois_gt_xyxy.rows(), 6);
+ // Batch ID
+ rois_gt.block(0, 0, rois_gt.rows(), 1) =
+ rois_gt_xyxy.block(0, 0, rois_gt.rows(), 0);
+ // rois_gt in [x_ctr, y_ctr, w, h] format
+ rois_gt.block(0, 1, rois_gt.rows(), 4) = utils::bbox_xyxy_to_ctrwh(
+ rois_gt_xyxy.block(0, 1, rois_gt.rows(), 4).array());
+ // Angle
+ rois_gt.block(0, 5, rois_gt.rows(), 1) =
+ ERMatXf::Constant(rois_gt.rows(), 1, angle);
+
+ vector<float> rois_probs_gt{2.66913995e-02f,
+ 5.44218998e-03f,
+ 1.20544003e-03f,
+ 1.19207997e-03f,
+ 6.17993006e-04f,
+ 4.72735002e-04f,
+ 6.09605013e-05f,
+ 1.50015003e-05f,
+ 8.91025957e-06f};
+ // Doubling everything related to images, to simulate
+ // num_images = 2
+ rois_probs_gt.insert(
+ rois_probs_gt.begin(), rois_probs_gt.begin(), rois_probs_gt.end());
+
+ AddInput<CUDAContext>(
+ vector<int64_t>{img_count, A, H, W}, scores, "scores", &ws);
+ AddInput<CUDAContext>(
+ vector<int64_t>{img_count, box_dim * A, H, W},
+ bbx_with_angle,
+ "bbox_deltas",
+ &ws);
+ AddInput<CUDAContext>(vector<int64_t>{img_count, 3}, im_info, "im_info", &ws);
+ AddInput<CUDAContext>(vector<int64_t>{A, box_dim}, anchors, "anchors", &ws);
+
+ def.add_arg()->CopyFrom(MakeArgument("spatial_scale", 1.0f / 16.0f));
+ def.add_arg()->CopyFrom(MakeArgument("pre_nms_topN", 6000));
+ def.add_arg()->CopyFrom(MakeArgument("post_nms_topN", 300));
+ def.add_arg()->CopyFrom(MakeArgument("nms_thresh", 0.7f));
+ def.add_arg()->CopyFrom(MakeArgument("min_size", 16.0f));
+ def.add_arg()->CopyFrom(MakeArgument("correct_transform_coords", true));
+ def.add_arg()->CopyFrom(MakeArgument("clip_angle_thresh", clip_angle_thresh));
+
+ unique_ptr<OperatorBase> op(CreateOperator(def, &ws));
+ EXPECT_NE(nullptr, op.get());
+ EXPECT_TRUE(op->Run());
+
+ // test rois
+ Blob* rois_blob = ws.GetBlob("rois");
+ EXPECT_NE(nullptr, rois_blob);
+ auto& rois_gpu = rois_blob->Get<TensorCUDA>();
+ Tensor rois{CPU};
+ rois.CopyFrom(rois_gpu);
+
+ EXPECT_EQ(rois.dims(), (vector<int64_t>{rois_gt.rows(), rois_gt.cols()}));
+ auto rois_data =
+ Eigen::Map<const ERMatXf>(rois.data<float>(), rois.dim(0), rois.dim(1));
+ EXPECT_NEAR((rois_data.matrix() - rois_gt).cwiseAbs().maxCoeff(), 0, 1e-4);
+
+ // test rois_probs
+ Blob* rois_probs_blob = ws.GetBlob("rois_probs");
+ EXPECT_NE(nullptr, rois_probs_blob);
+ auto& rois_probs_gpu = rois_probs_blob->Get<TensorCUDA>();
+ Tensor rois_probs{CPU};
+ rois_probs.CopyFrom(rois_probs_gpu);
+ EXPECT_EQ(
+ rois_probs.dims(), (vector<int64_t>{int64_t(rois_probs_gt.size())}));
+ auto rois_probs_data =
+ ConstEigenVectorArrayMap<float>(rois_probs.data<float>(), rois.dim(0));
+ EXPECT_NEAR(
+ (rois_probs_data.matrix() - utils::AsEArrXt(rois_probs_gt).matrix())
+ .cwiseAbs()
+ .maxCoeff(),
+ 0,
+ 1e-4);
+}
+
+TEST(GenerateProposalsTest, TestRealDownSampledRotatedGPU) {
+ // Similar to TestRealDownSampledGPU but for rotated boxes with angle info.
+ if (!HasCudaGPU())
+ return;
+
+ const float angle = 45.0;
+ const float delta_angle = 0.174533; // 0.174533 radians -> 10 degrees
+ const float expected_angle = 55.0;
+ const float clip_angle_thresh = 1.0;
+ const int box_dim = 5;
+
+ Workspace ws;
+ OperatorDef def;
+ def.set_name("test");
+ def.set_type("GenerateProposals");
+ def.add_input("scores");
+ def.add_input("bbox_deltas");
+ def.add_input("im_info");
+ def.add_input("anchors");
+ def.add_output("rois");
+ def.add_output("rois_probs");
+ def.mutable_device_option()->set_device_type(PROTO_CUDA);
+ const int img_count = 2;
+ const int A = 2;
+ const int H = 4;
+ const int W = 5;
+
+ vector<float> scores{
+ 5.44218998e-03f, 1.19207997e-03f, 1.12379994e-03f, 1.17181998e-03f,
+ 1.20544003e-03f, 6.17993006e-04f, 1.05261997e-05f, 8.91025957e-06f,
+ 9.29536981e-09f, 6.09605013e-05f, 4.72735002e-04f, 1.13482002e-10f,
+ 1.50015003e-05f, 4.45032993e-06f, 3.21612994e-08f, 8.02662980e-04f,
+ 1.40488002e-04f, 3.12508007e-07f, 3.02616991e-06f, 1.97759000e-08f,
+ 2.66913995e-02f, 5.26766013e-03f, 5.05053019e-03f, 5.62100019e-03f,
+ 5.37420018e-03f, 5.26280981e-03f, 2.48894998e-04f, 1.06842002e-04f,
+ 3.92931997e-06f, 1.79388002e-03f, 4.79440019e-03f, 3.41609990e-07f,
+ 5.20430971e-04f, 3.34090000e-05f, 2.19159006e-07f, 2.28786003e-03f,
+ 5.16703985e-05f, 4.04523007e-06f, 1.79227004e-06f, 5.32449000e-08f};
+ vector<float> bbx{
+ -1.65040009e-02f, -1.84051003e-02f, -1.85930002e-02f, -2.08263006e-02f,
+ -1.83814000e-02f, -2.89172009e-02f, -3.89706008e-02f, -7.52277970e-02f,
+ -1.54091999e-01f, -2.55433004e-02f, -1.77490003e-02f, -1.10340998e-01f,
+ -4.20190990e-02f, -2.71421000e-02f, 6.89801015e-03f, 5.71171008e-02f,
+ -1.75665006e-01f, 2.30021998e-02f, 3.08554992e-02f, -1.39333997e-02f,
+ 3.40579003e-01f, 3.91070992e-01f, 3.91624004e-01f, 3.92527014e-01f,
+ 3.91445011e-01f, 3.79328012e-01f, 4.26631987e-01f, 3.64892989e-01f,
+ 2.76894987e-01f, 5.13985991e-01f, 3.79999995e-01f, 1.80457994e-01f,
+ 4.37402993e-01f, 4.18545991e-01f, 2.51549989e-01f, 4.48318988e-01f,
+ 1.68564007e-01f, 4.65440989e-01f, 4.21891987e-01f, 4.45928007e-01f,
+ 3.27155995e-03f, 3.71480011e-03f, 3.60032008e-03f, 4.27092984e-03f,
+ 3.74579988e-03f, 5.95752988e-03f, -3.14473989e-03f, 3.52022005e-03f,
+ -1.88564006e-02f, 1.65188999e-03f, 1.73791999e-03f, -3.56074013e-02f,
+ -1.66615995e-04f, 3.14146001e-03f, -1.11830998e-02f, -5.35363983e-03f,
+ 6.49790000e-03f, -9.27671045e-03f, -2.83346009e-02f, -1.61233004e-02f,
+ -2.15505004e-01f, -2.19910994e-01f, -2.20872998e-01f, -2.12831005e-01f,
+ -2.19145000e-01f, -2.27687001e-01f, -3.43973994e-01f, -2.75869995e-01f,
+ -3.19516987e-01f, -2.50418007e-01f, -2.48537004e-01f, -5.08224010e-01f,
+ -2.28724003e-01f, -2.82402009e-01f, -3.75815988e-01f, -2.86352992e-01f,
+ -5.28333001e-02f, -4.43836004e-01f, -4.55134988e-01f, -4.34897989e-01f,
+ -5.65053988e-03f, -9.25739005e-04f, -1.06790999e-03f, -2.37016007e-03f,
+ -9.71166010e-04f, -8.90910998e-03f, -1.17592998e-02f, -2.08992008e-02f,
+ -4.94231991e-02f, 6.63906988e-03f, 3.20469006e-03f, -6.44695014e-02f,
+ -3.11607006e-03f, 2.02738005e-03f, 1.48096997e-02f, 4.39785011e-02f,
+ -8.28424022e-02f, 3.62076014e-02f, 2.71668993e-02f, 1.38250999e-02f,
+ 6.76669031e-02f, 1.03252999e-01f, 1.03255004e-01f, 9.89722982e-02f,
+ 1.03646003e-01f, 4.79663983e-02f, 1.11014001e-01f, 9.31736007e-02f,
+ 1.15768999e-01f, 1.04014002e-01f, -8.90677981e-03f, 1.13103002e-01f,
+ 1.33085996e-01f, 1.25405997e-01f, 1.50051996e-01f, -1.13038003e-01f,
+ 7.01059997e-02f, 1.79651007e-01f, 1.41055003e-01f, 1.62841007e-01f,
+ -1.00247003e-02f, -8.17587040e-03f, -8.32176022e-03f, -8.90108012e-03f,
+ -8.13035015e-03f, -1.77263003e-02f, -3.69572006e-02f, -3.51580009e-02f,
+ -5.92143014e-02f, -1.80795006e-02f, -5.46086021e-03f, -4.10550982e-02f,
+ -1.83081999e-02f, -2.15411000e-02f, -1.17953997e-02f, 3.33894007e-02f,
+ -5.29635996e-02f, -6.97528012e-03f, -3.15250992e-03f, -3.27355005e-02f,
+ 1.29676998e-01f, 1.16080999e-01f, 1.15947001e-01f, 1.21797003e-01f,
+ 1.16089001e-01f, 1.44875005e-01f, 1.15617000e-01f, 1.31586999e-01f,
+ 1.74735002e-02f, 1.21973999e-01f, 1.31596997e-01f, 2.48907991e-02f,
+ 6.18605018e-02f, 1.12855002e-01f, -6.99798986e-02f, 9.58312973e-02f,
+ 1.53593004e-01f, -8.75087008e-02f, -4.92327996e-02f, -3.32239009e-02f};
+
+ // Add angle in bbox deltas
+ int num_boxes = scores.size();
+ CHECK_EQ(bbx.size() / 4, num_boxes);
+ vector<float> bbx_with_angle(num_boxes * box_dim);
+ // bbx (deltas) is in shape (A * 4, H, W). Insert angle delta
+ // at each spatial location for each anchor.
+ int i = 0, j = 0;
+ for (int a = 0; a < A; ++a) {
+ for (int k = 0; k < 4 * H * W; ++k) {
+ bbx_with_angle[i++] = bbx[j++];
+ }
+ for (int k = 0; k < H * W; ++k) {
+ bbx_with_angle[i++] = delta_angle;
+ }
+ }
+
+ vector<float> im_info{60, 80, 0.166667f};
+ // vector<float> anchors{-38, -16, 53, 31, -120, -120, 135, 135};
+ vector<float> anchors{8, 8, 92, 48, angle, 8, 8, 256, 256, angle};
+
+ // Doubling everything related to images, to simulate
+ // num_images = 2
+ scores.insert(scores.begin(), scores.begin(), scores.end());
+ bbx_with_angle.insert(
+ bbx_with_angle.begin(), bbx_with_angle.begin(), bbx_with_angle.end());
+ im_info.insert(im_info.begin(), im_info.begin(), im_info.end());
+
+ ERMatXf rois_gt(26, 6);
+ rois_gt <<
+ 0, 6.55346, 25.3227, 253.447, 291.446, expected_angle,
+ 0, 55.3932, 33.3369, 253.731, 289.158, expected_angle,
+ 0, 6.48163, 24.3478, 92.3015, 38.6944, expected_angle,
+ 0, 70.3089, 26.7894, 92.3453, 38.5539, expected_angle,
+ 0, 22.3067, 26.7714, 92.3424, 38.5243, expected_angle,
+ 0, 54.084, 26.8413, 92.3938, 38.798, expected_angle,
+ 0, 38.2894, 26.798, 92.3318, 38.4873, expected_angle,
+ 0, 5.33962, 42.2077, 92.5497, 38.2259, expected_angle,
+ 0, 6.36709, 58.24, 92.16, 37.4372, expected_angle,
+ 0, 69.65, 48.6713, 92.1521, 37.3668, expected_angle,
+ 0, 20.4147, 44.4783, 91.7111, 34.0295, expected_angle,
+ 0, 33.079, 41.5149, 92.3244, 36.4278, expected_angle,
+ 0, 41.8235, 37.291, 90.2815, 34.872, expected_angle,
+ 1, 6.55346, 25.3227, 253.447, 291.446, expected_angle,
+ 1, 55.3932, 33.3369, 253.731, 289.158, expected_angle,
+ 1, 6.48163, 24.3478, 92.3015, 38.6944, expected_angle,
+ 1, 70.3089, 26.7894, 92.3453, 38.5539, expected_angle,
+ 1, 22.3067, 26.7714, 92.3424, 38.5243, expected_angle,
+ 1, 54.084, 26.8413, 92.3938, 38.798, expected_angle,
+ 1, 38.2894, 26.798, 92.3318, 38.4873, expected_angle,
+ 1, 5.33962, 42.2077, 92.5497, 38.2259, expected_angle,
+ 1, 6.36709, 58.24, 92.16, 37.4372, expected_angle,
+ 1, 69.65, 48.6713, 92.1521, 37.3668, expected_angle,
+ 1, 20.4147, 44.4783, 91.7111, 34.0295, expected_angle,
+ 1, 33.079, 41.5149, 92.3244, 36.4278, expected_angle,
+ 1, 41.8235, 37.291, 90.2815, 34.872, expected_angle;
+
+ vector<float> rois_probs_gt{2.66913995e-02f,
+ 5.621e-03f,
+ 5.44218998e-03f,
+ 1.20544003e-03f,
+ 1.19207997e-03f,
+ 1.17182e-03f,
+ 1.1238e-03f,
+ 6.17993006e-04f,
+ 4.72735002e-04f,
+ 6.09605013e-05f,
+ 1.50015003e-05f,
+ 8.91025957e-06f,
+ 9.29537e-09f};
+ // Doubling everything related to images, to simulate
+ // num_images = 2
+ rois_probs_gt.insert(
+ rois_probs_gt.begin(), rois_probs_gt.begin(), rois_probs_gt.end());
+
+ AddInput<CUDAContext>(
+ vector<int64_t>{img_count, A, H, W}, scores, "scores", &ws);
+ AddInput<CUDAContext>(
+ vector<int64_t>{img_count, box_dim * A, H, W},
+ bbx_with_angle,
+ "bbox_deltas",
+ &ws);
+ AddInput<CUDAContext>(vector<int64_t>{img_count, 3}, im_info, "im_info", &ws);
+ AddInput<CUDAContext>(vector<int64_t>{A, box_dim}, anchors, "anchors", &ws);
+
+ def.add_arg()->CopyFrom(MakeArgument("spatial_scale", 1.0f / 16.0f));
+ def.add_arg()->CopyFrom(MakeArgument("pre_nms_topN", 6000));
+ def.add_arg()->CopyFrom(MakeArgument("post_nms_topN", 300));
+ def.add_arg()->CopyFrom(MakeArgument("nms_thresh", 0.7f));
+ def.add_arg()->CopyFrom(MakeArgument("min_size", 16.0f));
+ def.add_arg()->CopyFrom(MakeArgument("correct_transform_coords", true));
+ def.add_arg()->CopyFrom(MakeArgument("clip_angle_thresh", clip_angle_thresh));
+
+ unique_ptr<OperatorBase> op(CreateOperator(def, &ws));
+ EXPECT_NE(nullptr, op.get());
+ EXPECT_TRUE(op->Run());
+
+ // test rois
+ Blob* rois_blob = ws.GetBlob("rois");
+ EXPECT_NE(nullptr, rois_blob);
+ auto& rois_gpu = rois_blob->Get<TensorCUDA>();
+ Tensor rois{CPU};
+ rois.CopyFrom(rois_gpu);
+ EXPECT_EQ(rois.dims(), (vector<int64_t>{26, 6}));
+ auto rois_data =
+ Eigen::Map<const ERMatXf>(rois.data<float>(), rois.size(0), rois.size(1));
+ EXPECT_NEAR((rois_data.matrix() - rois_gt).cwiseAbs().maxCoeff(), 0, 1e-3);
+
+ // test rois_probs
+ Blob* rois_probs_blob = ws.GetBlob("rois_probs");
+ EXPECT_NE(nullptr, rois_probs_blob);
+ auto& rois_probs_gpu = rois_probs_blob->Get<TensorCUDA>();
+ Tensor rois_probs{CPU};
+ rois_probs.CopyFrom(rois_probs_gpu);
+ EXPECT_EQ(
+ rois_probs.dims(), (vector<int64_t>{int64_t(rois_probs_gt.size())}));
+ auto rois_probs_data =
+ ConstEigenVectorArrayMap<float>(rois_probs.data<float>(), rois.size(0));
+ EXPECT_NEAR(
+ (rois_probs_data.matrix() - utils::AsEArrXt(rois_probs_gt).matrix())
+ .cwiseAbs()
+ .maxCoeff(),
+ 0,
+ 1e-4);
+}
+#endif // CV_MAJOR_VERSION >= 3
+
} // namespace caffe2
diff --git a/caffe2/operators/generate_proposals_op_test.cc b/caffe2/operators/generate_proposals_op_test.cc
index 4d76075..004402b 100644
--- a/caffe2/operators/generate_proposals_op_test.cc
+++ b/caffe2/operators/generate_proposals_op_test.cc
@@ -4,6 +4,8 @@
#include "caffe2/core/flags.h"
#include "caffe2/core/macros.h"
+#include "caffe2/operators/generate_proposals_op_util_boxes.h"
+
#ifdef CAFFE2_USE_OPENCV
#include <opencv2/opencv.hpp>
#endif // CAFFE2_USE_OPENCV
@@ -142,22 +144,6 @@
}
}
-namespace {
-
-template <class Derived>
-ERMatXf boxes_xyxy_to_xywh(const Eigen::MatrixBase<Derived>& boxes) {
- CAFFE_ENFORCE_EQ(boxes.cols(), 4);
- ERMatXf res(boxes.rows(), 4);
- auto ones = ERMatXf::Constant(boxes.rows(), 1, 1.0);
- res.col(0) = (boxes.col(0) + boxes.col(2)) / 2.0; // ctr_x = (x1 + x2)/2
- res.col(1) = (boxes.col(1) + boxes.col(3)) / 2.0; // ctr_y = (y1 + y2)/2
- res.col(2) = boxes.col(2) - boxes.col(0) + ones; // w = x2 - x1 + 1
- res.col(3) = boxes.col(3) - boxes.col(1) + ones; // h = y2 - y1 + 1
- return res;
-}
-
-} // namespace
-
TEST(GenerateProposalsTest, TestComputeAllAnchorsRotated) {
// Similar to TestComputeAllAnchors but for rotated boxes with angle info.
ERMatXf anchors_xyxy(3, 4);
@@ -165,7 +151,7 @@
// Convert to RRPN format and add angles
ERMatXf anchors(3, 5);
- anchors.block(0, 0, 3, 4) = boxes_xyxy_to_xywh(anchors_xyxy);
+ anchors.block(0, 0, 3, 4) = utils::bbox_xyxy_to_ctrwh(anchors_xyxy.array());
std::vector<float> angles{0.0, 45.0, -120.0};
for (int i = 0; i < anchors.rows(); ++i) {
anchors(i, 4) = angles[i % angles.size()];
@@ -188,7 +174,8 @@
// Convert gt to RRPN format and add angles
ERMatXf all_anchors_gt(36, 5);
- all_anchors_gt.block(0, 0, 36, 4) = boxes_xyxy_to_xywh(all_anchors_gt_xyxy);
+ all_anchors_gt.block(0, 0, 36, 4) =
+ utils::bbox_xyxy_to_ctrwh(all_anchors_gt_xyxy.array());
for (int i = 0; i < all_anchors_gt.rows(); ++i) {
all_anchors_gt(i, 4) = angles[i % angles.size()];
}
@@ -213,7 +200,7 @@
// Convert to RRPN format and add angles
ERMatXf anchors(3, 5);
- anchors.block(0, 0, 3, 4) = boxes_xyxy_to_xywh(anchors_xyxy);
+ anchors.block(0, 0, 3, 4) = utils::bbox_xyxy_to_ctrwh(anchors_xyxy.array());
std::vector<float> angles{0.0, 45.0, -120.0};
for (int i = 0; i < anchors.rows(); ++i) {
anchors(i, 4) = angles[i % angles.size()];
@@ -433,9 +420,10 @@
#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0) {
// Similar to TestRealDownSampled but for rotated boxes with angle info.
- float angle = 0;
- float delta_angle = 0;
- float clip_angle_thresh = 1.0;
+ const float angle = 0;
+ const float delta_angle = 0;
+ const float clip_angle_thresh = 1.0;
+ const int box_dim = 5;
Workspace ws;
OperatorDef def;
@@ -508,7 +496,7 @@
// Add angle in bbox deltas
int num_boxes = scores.size();
CHECK_EQ(bbx.size() / 4, num_boxes);
- vector<float> bbx_with_angle(num_boxes * 5);
+ vector<float> bbx_with_angle(num_boxes * box_dim);
// bbx (deltas) is in shape (A * 4, H, W). Insert angle delta
// at each spatial location for each anchor.
int i = 0, j = 0;
@@ -535,13 +523,13 @@
45.0336, 0, 0, 23.09477997f, 50.61448669f, 59, 0, 0, 39.52141571f,
51.44710541f, 59, 0, 23.57396317f, 29.98791885f, 79, 59, 0, 0,
41.90219116f, 79, 59, 0, 0, 23.30098343f, 78.2413f, 58.7287f;
- ERMatXf rois_gt(9, 6);
+ ERMatXf rois_gt(rois_gt_xyxy.rows(), 6);
// Batch ID
rois_gt.block(0, 0, rois_gt.rows(), 1) =
- ERMatXf::Constant(rois_gt.rows(), 1, 0.0);
+ rois_gt_xyxy.block(0, 0, rois_gt.rows(), 0);
// rois_gt in [x_ctr, y_ctr, w, h] format
- rois_gt.block(0, 1, rois_gt.rows(), 4) =
- boxes_xyxy_to_xywh(rois_gt_xyxy.block(0, 1, rois_gt.rows(), 4));
+ rois_gt.block(0, 1, rois_gt.rows(), 4) = utils::bbox_xyxy_to_ctrwh(
+ rois_gt_xyxy.block(0, 1, rois_gt.rows(), 4).array());
// Angle
rois_gt.block(0, 5, rois_gt.rows(), 1) =
ERMatXf::Constant(rois_gt.rows(), 1, angle);
@@ -557,12 +545,12 @@
AddInput(vector<int64_t>{img_count, A, H, W}, scores, "scores", &ws);
AddInput(
- vector<int64_t>{img_count, 5 * A, H, W},
+ vector<int64_t>{img_count, box_dim * A, H, W},
bbx_with_angle,
"bbox_deltas",
&ws);
AddInput(vector<int64_t>{img_count, 3}, im_info, "im_info", &ws);
- AddInput(vector<int64_t>{A, 5}, anchors, "anchors", &ws);
+ AddInput(vector<int64_t>{A, box_dim}, anchors, "anchors", &ws);
def.add_arg()->CopyFrom(MakeArgument("spatial_scale", 1.0f / 16.0f));
def.add_arg()->CopyFrom(MakeArgument("pre_nms_topN", 6000));
@@ -603,10 +591,11 @@
TEST(GenerateProposalsTest, TestRealDownSampledRotated) {
// Similar to TestRealDownSampled but for rotated boxes with angle info.
- float angle = 45.0;
- float delta_angle = 0.174533; // 0.174533 radians -> 10 degrees
- float expected_angle = 55.0;
- float clip_angle_thresh = 1.0;
+ const float angle = 45.0;
+ const float delta_angle = 0.174533; // 0.174533 radians -> 10 degrees
+ const float expected_angle = 55.0;
+ const float clip_angle_thresh = 1.0;
+ const int box_dim = 5;
Workspace ws;
OperatorDef def;
@@ -679,7 +668,7 @@
// Add angle in bbox deltas
int num_boxes = scores.size();
CHECK_EQ(bbx.size() / 4, num_boxes);
- vector<float> bbx_with_angle(num_boxes * 5);
+ vector<float> bbx_with_angle(num_boxes * box_dim);
// bbx (deltas) is in shape (A * 4, H, W). Insert angle delta
// at each spatial location for each anchor.
{
@@ -700,12 +689,12 @@
AddInput(vector<int64_t>{img_count, A, H, W}, scores, "scores", &ws);
AddInput(
- vector<int64_t>{img_count, 5 * A, H, W},
+ vector<int64_t>{img_count, box_dim * A, H, W},
bbx_with_angle,
"bbox_deltas",
&ws);
AddInput(vector<int64_t>{img_count, 3}, im_info, "im_info", &ws);
- AddInput(vector<int64_t>{A, 5}, anchors, "anchors", &ws);
+ AddInput(vector<int64_t>{A, box_dim}, anchors, "anchors", &ws);
def.add_arg()->CopyFrom(MakeArgument("spatial_scale", 1.0f / 16.0f));
def.add_arg()->CopyFrom(MakeArgument("pre_nms_topN", 6000));
diff --git a/caffe2/operators/generate_proposals_op_util_nms_gpu.cu b/caffe2/operators/generate_proposals_op_util_nms_gpu.cu
index 584c2f4..0cf157c 100644
--- a/caffe2/operators/generate_proposals_op_util_nms_gpu.cu
+++ b/caffe2/operators/generate_proposals_op_util_nms_gpu.cu
@@ -195,10 +195,6 @@
}
namespace {
-struct RotatedBox {
- float x_ctr, y_ctr, w, h, a;
-};
-
struct Point {
float x, y;
};
diff --git a/caffe2/operators/generate_proposals_op_util_nms_gpu.h b/caffe2/operators/generate_proposals_op_util_nms_gpu.h
index 04ba8f5..da7a840 100644
--- a/caffe2/operators/generate_proposals_op_util_nms_gpu.h
+++ b/caffe2/operators/generate_proposals_op_util_nms_gpu.h
@@ -33,6 +33,10 @@
TensorCPU& host_delete_mask,
CUDAContext* context);
+struct RotatedBox {
+ float x_ctr, y_ctr, w, h, a;
+};
+
// Same as nms_gpu_upright, but for rotated boxes with angle info.
// d_desc_sorted_boxes : pixel coordinates of proposed bounding boxes
// size: (N,5), format: [x_ct; y_ctr; width; height; angle]