support different class modes for bbox in box_with_nms_limit_op

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19820

Reviewed By: newstzpz

Differential Revision: D15112955

fbshipit-source-id: a757622a32cff7159c39735607103138dbbafc24
diff --git a/caffe2/operators/box_with_nms_limit_op.cc b/caffe2/operators/box_with_nms_limit_op.cc
index 9fcf55f..b780bc2 100644
--- a/caffe2/operators/box_with_nms_limit_op.cc
+++ b/caffe2/operators/box_with_nms_limit_op.cc
@@ -13,18 +13,18 @@
 
   // tscores: (num_boxes, num_classes), 0 for background
   if (tscores.dim() == 4) {
-    CAFFE_ENFORCE_EQ(tscores.size(2), 1, tscores.size(2));
-    CAFFE_ENFORCE_EQ(tscores.size(3), 1, tscores.size(3));
+    CAFFE_ENFORCE_EQ(tscores.size(2), 1);
+    CAFFE_ENFORCE_EQ(tscores.size(3), 1);
   } else {
-    CAFFE_ENFORCE_EQ(tscores.dim(), 2, tscores.dim());
+    CAFFE_ENFORCE_EQ(tscores.dim(), 2);
   }
   CAFFE_ENFORCE(tscores.template IsType<float>(), tscores.dtype().name());
   // tboxes: (num_boxes, num_classes * box_dim)
   if (tboxes.dim() == 4) {
-    CAFFE_ENFORCE_EQ(tboxes.size(2), 1, tboxes.size(2));
-    CAFFE_ENFORCE_EQ(tboxes.size(3), 1, tboxes.size(3));
+    CAFFE_ENFORCE_EQ(tboxes.size(2), 1);
+    CAFFE_ENFORCE_EQ(tboxes.size(3), 1);
   } else {
-    CAFFE_ENFORCE_EQ(tboxes.dim(), 2, tboxes.dim());
+    CAFFE_ENFORCE_EQ(tboxes.dim(), 2);
   }
   CAFFE_ENFORCE(tboxes.template IsType<float>(), tboxes.dtype().name());
 
@@ -32,7 +32,8 @@
   int num_classes = tscores.size(1);
 
   CAFFE_ENFORCE_EQ(N, tboxes.size(0));
-  CAFFE_ENFORCE_EQ(num_classes * box_dim, tboxes.size(1));
+  int num_boxes_classes = get_box_cls_index(num_classes - 1) + 1;
+  CAFFE_ENFORCE_EQ(num_boxes_classes * box_dim, tboxes.size(1));
 
   int batch_size = 1;
   vector<float> batch_splits_default(1, tscores.size(0));
@@ -82,12 +83,13 @@
     // skip j = 0, because it's the background class
     int total_keep_count = 0;
     for (int j = 1; j < num_classes; j++) {
-      auto cur_scores = scores.col(j);
+      auto cur_scores = scores.col(get_score_cls_index(j));
       auto inds = utils::GetArrayIndices(cur_scores > score_thres_);
-      auto cur_boxes = boxes.block(0, j * box_dim, boxes.rows(), box_dim);
+      auto cur_boxes =
+          boxes.block(0, get_box_cls_index(j) * box_dim, boxes.rows(), box_dim);
 
       if (soft_nms_enabled_) {
-        auto cur_soft_nms_scores = soft_nms_scores.col(j);
+        auto cur_soft_nms_scores = soft_nms_scores.col(get_score_cls_index(j));
         keeps[j] = utils::soft_nms_cpu(
             &cur_soft_nms_scores,
             cur_boxes,
@@ -173,8 +175,9 @@
 
     int cur_out_idx = 0;
     for (int j = 1; j < num_classes; j++) {
-      auto cur_scores = scores.col(j);
-      auto cur_boxes = boxes.block(0, j * box_dim, boxes.rows(), box_dim);
+      auto cur_scores = scores.col(get_score_cls_index(j));
+      auto cur_boxes =
+          boxes.block(0, get_box_cls_index(j) * box_dim, boxes.rows(), box_dim);
       auto& cur_keep = keeps[j];
       Eigen::Map<EArrXf> cur_out_scores(
           out_scores->template mutable_data<float>() + cur_start_idx +
@@ -195,7 +198,8 @@
       utils::GetSubArrayRows(
           cur_boxes, utils::AsEArrXt(cur_keep), &cur_out_boxes);
       for (int k = 0; k < cur_keep.size(); k++) {
-        cur_out_classes[k] = static_cast<float>(j);
+        cur_out_classes[k] =
+            static_cast<float>(j - !output_classes_include_bg_cls_);
       }
 
       cur_out_idx += cur_keep.size();
@@ -309,7 +313,10 @@
       "str soft_nms_method, "
       "float soft_nms_sigma, "
       "float soft_nms_min_score_thres, "
-      "bool rotated"
+      "bool rotated, "
+      "bool cls_agnostic_bbox_reg, "
+      "bool input_boxes_include_bg_cls, "
+      "bool output_classes_include_bg_cls "
     ") -> ("
       "Tensor scores, "
       "Tensor boxes, "
diff --git a/caffe2/operators/box_with_nms_limit_op.h b/caffe2/operators/box_with_nms_limit_op.h
index 885c62b..090993f 100644
--- a/caffe2/operators/box_with_nms_limit_op.h
+++ b/caffe2/operators/box_with_nms_limit_op.h
@@ -35,11 +35,26 @@
         soft_nms_min_score_thres_(this->template GetSingleArgument<float>(
             "soft_nms_min_score_thres",
             0.001)),
-        rotated_(this->template GetSingleArgument<bool>("rotated", false)) {
+        rotated_(this->template GetSingleArgument<bool>("rotated", false)),
+        cls_agnostic_bbox_reg_(this->template GetSingleArgument<bool>(
+            "cls_agnostic_bbox_reg",
+            false)),
+        input_boxes_include_bg_cls_(this->template GetSingleArgument<bool>(
+            "input_boxes_include_bg_cls",
+            true)),
+        output_classes_include_bg_cls_(this->template GetSingleArgument<bool>(
+            "output_classes_include_bg_cls",
+            true)) {
     CAFFE_ENFORCE(
         soft_nms_method_str_ == "linear" || soft_nms_method_str_ == "gaussian",
         "Unexpected soft_nms_method");
     soft_nms_method_ = (soft_nms_method_str_ == "linear") ? 1 : 2;
+
+    // When input `boxes` doesn't inlcude background class, the score will skip
+    // background class and start with foreground classes directly, and put the
+    // background class in the end, i.e. score[:, 0:NUM_CLASSES-1] represents
+    // foreground classes and score[:,NUM_CLASSES] represents background class.
+    input_scores_fg_cls_starting_id_ = (int)input_boxes_include_bg_cls_;
   }
 
   ~BoxWithNMSLimitOp() {}
@@ -65,6 +80,35 @@
   // Set for RRPN case to handle rotated boxes. Inputs should be in format
   // [ctr_x, ctr_y, width, height, angle (in degrees)].
   bool rotated_{false};
+  // MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG
+  bool cls_agnostic_bbox_reg_{false};
+  // Whether input `boxes` includes background class. If true, boxes will have
+  // shape of (N, (num_fg_class+1) * 4or5), otherwise (N, num_fg_class * 4or5)
+  bool input_boxes_include_bg_cls_{true};
+  // Whether output `classes` includes background class. If true, index 0 will
+  // represent background, and valid outputs start from 1.
+  bool output_classes_include_bg_cls_{true};
+  // The index where foreground starts in scoures. Eg. if 0 represents
+  // background class then foreground class starts with 1.
+  int input_scores_fg_cls_starting_id_{1};
+
+  // Map a class id (starting with background and then foreground) from (0, 1,
+  // ..., NUM_FG_CLASSES) to it's matching value in box
+  inline int get_box_cls_index(int bg_fg_cls_id) {
+    if (cls_agnostic_bbox_reg_) {
+      return 0;
+    } else if (!input_boxes_include_bg_cls_) {
+      return bg_fg_cls_id - 1;
+    } else {
+      return bg_fg_cls_id;
+    }
+  }
+
+  // Map a class id (starting with background and then foreground) from (0, 1,
+  // ..., NUM_FG_CLASSES) to it's matching value in score
+  inline int get_score_cls_index(int bg_fg_cls_id) {
+    return bg_fg_cls_id - 1 + input_scores_fg_cls_starting_id_;
+  }
 };
 
 } // namespace caffe2
diff --git a/caffe2/python/operator_test/box_with_nms_limit_op_test.py b/caffe2/python/operator_test/box_with_nms_limit_op_test.py
index 52155c0..cd86990 100644
--- a/caffe2/python/operator_test/box_with_nms_limit_op_test.py
+++ b/caffe2/python/operator_test/box_with_nms_limit_op_test.py
@@ -120,12 +120,32 @@
 
         self.assertReferenceChecks(gc, op, [scores, boxes], ref)
 
-    @given(num_classes=st.integers(2, 10), **HU_CONFIG)
-    def test_multiclass(self, num_classes, gc):
+    @given(
+        num_classes=st.integers(2, 10),
+        cls_agnostic_bbox_reg=st.booleans(),
+        input_boxes_include_bg_cls=st.booleans(),
+        output_classes_include_bg_cls=st.booleans(),
+        **HU_CONFIG
+    )
+    def test_multiclass(
+        self,
+        num_classes,
+        cls_agnostic_bbox_reg,
+        input_boxes_include_bg_cls,
+        output_classes_include_bg_cls,
+        gc
+    ):
         in_centers = [(0, 0), (20, 20), (50, 50)]
         in_scores = [0.7, 0.85, 0.6]
         boxes, scores = gen_multiple_boxes(in_centers, in_scores, 10, num_classes)
 
+        if not input_boxes_include_bg_cls:
+            # remove backgound class
+            boxes = boxes[:, 4:]
+        if cls_agnostic_bbox_reg:
+            # only leave one class
+            boxes = boxes[:, :4]
+
         gt_centers = [(20, 20), (0, 0), (50, 50)]
         gt_scores = [0.85, 0.7, 0.6]
         gt_boxes, gt_scores = gen_multiple_boxes(gt_centers, gt_scores, 1, 1)
@@ -133,12 +153,22 @@
         gt_classes = np.tile(
             np.array(range(1, num_classes), dtype=np.float32),
             (gt_boxes.shape[0], 1)).T.flatten()
+        if not output_classes_include_bg_cls:
+            # remove backgound class
+            gt_classes -= 1
         gt_boxes = np.tile(gt_boxes, (num_classes - 1, 1))
         gt_scores = np.tile(gt_scores, (num_classes - 1, 1)).flatten()
 
         op = get_op(
             2, 3,
-            {"score_thresh": 0.5, "nms": 0.9, "detections_per_im": 100}
+            {
+                "score_thresh": 0.5,
+                "nms": 0.9,
+                "detections_per_im": 100,
+                "cls_agnostic_bbox_reg": cls_agnostic_bbox_reg,
+                "input_boxes_include_bg_cls": input_boxes_include_bg_cls,
+                "output_classes_include_bg_cls": output_classes_include_bg_cls
+            }
         )
 
         def ref(*args, **kwargs):
diff --git a/caffe2/python/operator_test/torch_integration_test.py b/caffe2/python/operator_test/torch_integration_test.py
index 893d314..4e811a1 100644
--- a/caffe2/python/operator_test/torch_integration_test.py
+++ b/caffe2/python/operator_test/torch_integration_test.py
@@ -203,6 +203,9 @@
             soft_nms_sigma=0.5,
             soft_nms_min_score_thres=0.001,
             rotated=rotated,
+            cls_agnostic_bbox_reg=False,
+            input_boxes_include_bg_cls=True,
+            output_classes_include_bg_cls=True,
         )
 
         for o, o_ref in zip(outputs, output_refs):
diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py
index 589ce03..7813966 100644
--- a/test/onnx/test_pytorch_onnx_caffe2.py
+++ b/test/onnx/test_pytorch_onnx_caffe2.py
@@ -1513,6 +1513,9 @@
                     soft_nms_sigma=0.5,
                     soft_nms_min_score_thres=0.001,
                     rotated=rotated,
+                    cls_agnostic_bbox_reg=False,
+                    input_boxes_include_bg_cls=True,
+                    output_classes_include_bg_cls=True,
                 )
                 return a, b, c, d