support different class modes for bbox in box_with_nms_limit_op
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19820
Reviewed By: newstzpz
Differential Revision: D15112955
fbshipit-source-id: a757622a32cff7159c39735607103138dbbafc24
diff --git a/caffe2/operators/box_with_nms_limit_op.cc b/caffe2/operators/box_with_nms_limit_op.cc
index 9fcf55f..b780bc2 100644
--- a/caffe2/operators/box_with_nms_limit_op.cc
+++ b/caffe2/operators/box_with_nms_limit_op.cc
@@ -13,18 +13,18 @@
// tscores: (num_boxes, num_classes), 0 for background
if (tscores.dim() == 4) {
- CAFFE_ENFORCE_EQ(tscores.size(2), 1, tscores.size(2));
- CAFFE_ENFORCE_EQ(tscores.size(3), 1, tscores.size(3));
+ CAFFE_ENFORCE_EQ(tscores.size(2), 1);
+ CAFFE_ENFORCE_EQ(tscores.size(3), 1);
} else {
- CAFFE_ENFORCE_EQ(tscores.dim(), 2, tscores.dim());
+ CAFFE_ENFORCE_EQ(tscores.dim(), 2);
}
CAFFE_ENFORCE(tscores.template IsType<float>(), tscores.dtype().name());
// tboxes: (num_boxes, num_classes * box_dim)
if (tboxes.dim() == 4) {
- CAFFE_ENFORCE_EQ(tboxes.size(2), 1, tboxes.size(2));
- CAFFE_ENFORCE_EQ(tboxes.size(3), 1, tboxes.size(3));
+ CAFFE_ENFORCE_EQ(tboxes.size(2), 1);
+ CAFFE_ENFORCE_EQ(tboxes.size(3), 1);
} else {
- CAFFE_ENFORCE_EQ(tboxes.dim(), 2, tboxes.dim());
+ CAFFE_ENFORCE_EQ(tboxes.dim(), 2);
}
CAFFE_ENFORCE(tboxes.template IsType<float>(), tboxes.dtype().name());
@@ -32,7 +32,8 @@
int num_classes = tscores.size(1);
CAFFE_ENFORCE_EQ(N, tboxes.size(0));
- CAFFE_ENFORCE_EQ(num_classes * box_dim, tboxes.size(1));
+ int num_boxes_classes = get_box_cls_index(num_classes - 1) + 1;
+ CAFFE_ENFORCE_EQ(num_boxes_classes * box_dim, tboxes.size(1));
int batch_size = 1;
vector<float> batch_splits_default(1, tscores.size(0));
@@ -82,12 +83,13 @@
// skip j = 0, because it's the background class
int total_keep_count = 0;
for (int j = 1; j < num_classes; j++) {
- auto cur_scores = scores.col(j);
+ auto cur_scores = scores.col(get_score_cls_index(j));
auto inds = utils::GetArrayIndices(cur_scores > score_thres_);
- auto cur_boxes = boxes.block(0, j * box_dim, boxes.rows(), box_dim);
+ auto cur_boxes =
+ boxes.block(0, get_box_cls_index(j) * box_dim, boxes.rows(), box_dim);
if (soft_nms_enabled_) {
- auto cur_soft_nms_scores = soft_nms_scores.col(j);
+ auto cur_soft_nms_scores = soft_nms_scores.col(get_score_cls_index(j));
keeps[j] = utils::soft_nms_cpu(
&cur_soft_nms_scores,
cur_boxes,
@@ -173,8 +175,9 @@
int cur_out_idx = 0;
for (int j = 1; j < num_classes; j++) {
- auto cur_scores = scores.col(j);
- auto cur_boxes = boxes.block(0, j * box_dim, boxes.rows(), box_dim);
+ auto cur_scores = scores.col(get_score_cls_index(j));
+ auto cur_boxes =
+ boxes.block(0, get_box_cls_index(j) * box_dim, boxes.rows(), box_dim);
auto& cur_keep = keeps[j];
Eigen::Map<EArrXf> cur_out_scores(
out_scores->template mutable_data<float>() + cur_start_idx +
@@ -195,7 +198,8 @@
utils::GetSubArrayRows(
cur_boxes, utils::AsEArrXt(cur_keep), &cur_out_boxes);
for (int k = 0; k < cur_keep.size(); k++) {
- cur_out_classes[k] = static_cast<float>(j);
+ cur_out_classes[k] =
+ static_cast<float>(j - !output_classes_include_bg_cls_);
}
cur_out_idx += cur_keep.size();
@@ -309,7 +313,10 @@
"str soft_nms_method, "
"float soft_nms_sigma, "
"float soft_nms_min_score_thres, "
- "bool rotated"
+ "bool rotated, "
+ "bool cls_agnostic_bbox_reg, "
+ "bool input_boxes_include_bg_cls, "
+ "bool output_classes_include_bg_cls "
") -> ("
"Tensor scores, "
"Tensor boxes, "
diff --git a/caffe2/operators/box_with_nms_limit_op.h b/caffe2/operators/box_with_nms_limit_op.h
index 885c62b..090993f 100644
--- a/caffe2/operators/box_with_nms_limit_op.h
+++ b/caffe2/operators/box_with_nms_limit_op.h
@@ -35,11 +35,26 @@
soft_nms_min_score_thres_(this->template GetSingleArgument<float>(
"soft_nms_min_score_thres",
0.001)),
- rotated_(this->template GetSingleArgument<bool>("rotated", false)) {
+ rotated_(this->template GetSingleArgument<bool>("rotated", false)),
+ cls_agnostic_bbox_reg_(this->template GetSingleArgument<bool>(
+ "cls_agnostic_bbox_reg",
+ false)),
+ input_boxes_include_bg_cls_(this->template GetSingleArgument<bool>(
+ "input_boxes_include_bg_cls",
+ true)),
+ output_classes_include_bg_cls_(this->template GetSingleArgument<bool>(
+ "output_classes_include_bg_cls",
+ true)) {
CAFFE_ENFORCE(
soft_nms_method_str_ == "linear" || soft_nms_method_str_ == "gaussian",
"Unexpected soft_nms_method");
soft_nms_method_ = (soft_nms_method_str_ == "linear") ? 1 : 2;
+
+ // When input `boxes` doesn't inlcude background class, the score will skip
+ // background class and start with foreground classes directly, and put the
+ // background class in the end, i.e. score[:, 0:NUM_CLASSES-1] represents
+ // foreground classes and score[:,NUM_CLASSES] represents background class.
+ input_scores_fg_cls_starting_id_ = (int)input_boxes_include_bg_cls_;
}
~BoxWithNMSLimitOp() {}
@@ -65,6 +80,35 @@
// Set for RRPN case to handle rotated boxes. Inputs should be in format
// [ctr_x, ctr_y, width, height, angle (in degrees)].
bool rotated_{false};
+ // MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG
+ bool cls_agnostic_bbox_reg_{false};
+ // Whether input `boxes` includes background class. If true, boxes will have
+ // shape of (N, (num_fg_class+1) * 4or5), otherwise (N, num_fg_class * 4or5)
+ bool input_boxes_include_bg_cls_{true};
+ // Whether output `classes` includes background class. If true, index 0 will
+ // represent background, and valid outputs start from 1.
+ bool output_classes_include_bg_cls_{true};
+ // The index where foreground starts in scoures. Eg. if 0 represents
+ // background class then foreground class starts with 1.
+ int input_scores_fg_cls_starting_id_{1};
+
+ // Map a class id (starting with background and then foreground) from (0, 1,
+ // ..., NUM_FG_CLASSES) to it's matching value in box
+ inline int get_box_cls_index(int bg_fg_cls_id) {
+ if (cls_agnostic_bbox_reg_) {
+ return 0;
+ } else if (!input_boxes_include_bg_cls_) {
+ return bg_fg_cls_id - 1;
+ } else {
+ return bg_fg_cls_id;
+ }
+ }
+
+ // Map a class id (starting with background and then foreground) from (0, 1,
+ // ..., NUM_FG_CLASSES) to it's matching value in score
+ inline int get_score_cls_index(int bg_fg_cls_id) {
+ return bg_fg_cls_id - 1 + input_scores_fg_cls_starting_id_;
+ }
};
} // namespace caffe2
diff --git a/caffe2/python/operator_test/box_with_nms_limit_op_test.py b/caffe2/python/operator_test/box_with_nms_limit_op_test.py
index 52155c0..cd86990 100644
--- a/caffe2/python/operator_test/box_with_nms_limit_op_test.py
+++ b/caffe2/python/operator_test/box_with_nms_limit_op_test.py
@@ -120,12 +120,32 @@
self.assertReferenceChecks(gc, op, [scores, boxes], ref)
- @given(num_classes=st.integers(2, 10), **HU_CONFIG)
- def test_multiclass(self, num_classes, gc):
+ @given(
+ num_classes=st.integers(2, 10),
+ cls_agnostic_bbox_reg=st.booleans(),
+ input_boxes_include_bg_cls=st.booleans(),
+ output_classes_include_bg_cls=st.booleans(),
+ **HU_CONFIG
+ )
+ def test_multiclass(
+ self,
+ num_classes,
+ cls_agnostic_bbox_reg,
+ input_boxes_include_bg_cls,
+ output_classes_include_bg_cls,
+ gc
+ ):
in_centers = [(0, 0), (20, 20), (50, 50)]
in_scores = [0.7, 0.85, 0.6]
boxes, scores = gen_multiple_boxes(in_centers, in_scores, 10, num_classes)
+ if not input_boxes_include_bg_cls:
+ # remove backgound class
+ boxes = boxes[:, 4:]
+ if cls_agnostic_bbox_reg:
+ # only leave one class
+ boxes = boxes[:, :4]
+
gt_centers = [(20, 20), (0, 0), (50, 50)]
gt_scores = [0.85, 0.7, 0.6]
gt_boxes, gt_scores = gen_multiple_boxes(gt_centers, gt_scores, 1, 1)
@@ -133,12 +153,22 @@
gt_classes = np.tile(
np.array(range(1, num_classes), dtype=np.float32),
(gt_boxes.shape[0], 1)).T.flatten()
+ if not output_classes_include_bg_cls:
+ # remove backgound class
+ gt_classes -= 1
gt_boxes = np.tile(gt_boxes, (num_classes - 1, 1))
gt_scores = np.tile(gt_scores, (num_classes - 1, 1)).flatten()
op = get_op(
2, 3,
- {"score_thresh": 0.5, "nms": 0.9, "detections_per_im": 100}
+ {
+ "score_thresh": 0.5,
+ "nms": 0.9,
+ "detections_per_im": 100,
+ "cls_agnostic_bbox_reg": cls_agnostic_bbox_reg,
+ "input_boxes_include_bg_cls": input_boxes_include_bg_cls,
+ "output_classes_include_bg_cls": output_classes_include_bg_cls
+ }
)
def ref(*args, **kwargs):
diff --git a/caffe2/python/operator_test/torch_integration_test.py b/caffe2/python/operator_test/torch_integration_test.py
index 893d314..4e811a1 100644
--- a/caffe2/python/operator_test/torch_integration_test.py
+++ b/caffe2/python/operator_test/torch_integration_test.py
@@ -203,6 +203,9 @@
soft_nms_sigma=0.5,
soft_nms_min_score_thres=0.001,
rotated=rotated,
+ cls_agnostic_bbox_reg=False,
+ input_boxes_include_bg_cls=True,
+ output_classes_include_bg_cls=True,
)
for o, o_ref in zip(outputs, output_refs):
diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py
index 589ce03..7813966 100644
--- a/test/onnx/test_pytorch_onnx_caffe2.py
+++ b/test/onnx/test_pytorch_onnx_caffe2.py
@@ -1513,6 +1513,9 @@
soft_nms_sigma=0.5,
soft_nms_min_score_thres=0.001,
rotated=rotated,
+ cls_agnostic_bbox_reg=False,
+ input_boxes_include_bg_cls=True,
+ output_classes_include_bg_cls=True,
)
return a, b, c, d