Clip horizontal bounding boxes during rotated detection for backward compatibility (#9403)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9403
In BBoxTransform and GenerateProposal ops, clip_boxes makes sure the bbox fits
within the images. For rotated boxes, this doesn't always make sense as there
could be multiple ways to clip a rotated box within an image boundary.
Moreover, clipping to a horizontal box means we leave out pixels of interest
potentially. Therefore, we clip only boxes with angle almost equal to 0 (with a
specified `angle_thresh` tolerance).
Reviewed By: pjh5
Differential Revision: D8828588
fbshipit-source-id: 39c1eafdb5d39d383780faa0a47e76149145e50c
diff --git a/caffe2/operators/bbox_transform_op.cc b/caffe2/operators/bbox_transform_op.cc
index 5dde4b1..0d2b5a3 100644
--- a/caffe2/operators/bbox_transform_op.cc
+++ b/caffe2/operators/bbox_transform_op.cc
@@ -54,6 +54,11 @@
"angle_bound_hi",
"int (default 90 degrees). If set, for rotated boxes, angle is "
"normalized to be within [angle_bound_lo, angle_bound_hi].")
+ .Arg(
+ "clip_angle_thresh",
+ "float (default 1.0 degrees). For RRPN, clip almost horizontal boxes "
+ "within this threshold of tolerance for backward compatibility. "
+ "Set to negative value for no clipping.")
.Input(
0,
"rois",
@@ -168,7 +173,8 @@
angle_bound_on_,
angle_bound_lo_,
angle_bound_hi_);
- EArrXXf clip_boxes = utils::clip_boxes(trans_boxes, img_h, img_w);
+ EArrXXf clip_boxes =
+ utils::clip_boxes(trans_boxes, img_h, img_w, clip_angle_thresh_);
// Do not apply scale for angle in rotated boxes
clip_boxes.leftCols(4) *= scale_after;
new_boxes.block(offset, k * box_dim, num_rois, box_dim) = clip_boxes;
diff --git a/caffe2/operators/bbox_transform_op.h b/caffe2/operators/bbox_transform_op.h
index e57d90e..8d76973 100644
--- a/caffe2/operators/bbox_transform_op.h
+++ b/caffe2/operators/bbox_transform_op.h
@@ -29,7 +29,9 @@
angle_bound_lo_(
OperatorBase::GetSingleArgument<int>("angle_bound_lo", -90)),
angle_bound_hi_(
- OperatorBase::GetSingleArgument<int>("angle_bound_hi", 90)) {
+ OperatorBase::GetSingleArgument<int>("angle_bound_hi", 90)),
+ clip_angle_thresh_(
+ OperatorBase::GetSingleArgument<float>("clip_angle_thresh", 1.0)) {
CAFFE_ENFORCE_EQ(
weights_.size(),
4,
@@ -59,6 +61,10 @@
bool angle_bound_on_{true};
int angle_bound_lo_{-90};
int angle_bound_hi_{90};
+ // For RRPN, clip almost horizontal boxes within this threshold of
+ // tolerance for backward compatibility. Set to negative value for
+ // no clipping.
+ float clip_angle_thresh_{1.0};
};
} // namespace caffe2
diff --git a/caffe2/operators/generate_proposals_op.cc b/caffe2/operators/generate_proposals_op.cc
index dff52aa..0b4f3a6 100644
--- a/caffe2/operators/generate_proposals_op.cc
+++ b/caffe2/operators/generate_proposals_op.cc
@@ -197,8 +197,8 @@
// 2. clip proposals to image (may result in proposals with zero area
// that will be removed in the next step)
- // TODO (viswanath): Should we clip rotated boxes as well?
- proposals = utils::clip_boxes(proposals, im_info[0], im_info[1]);
+ proposals =
+ utils::clip_boxes(proposals, im_info[0], im_info[1], clip_angle_thresh_);
// 3. remove predicted boxes with either height or width < min_size
auto keep = utils::filter_boxes(proposals, min_size, im_info);
@@ -342,6 +342,29 @@
.Arg("post_nms_topN", "(int) RPN_POST_NMS_TOP_N")
.Arg("nms_thresh", "(float) RPN_NMS_THRESH")
.Arg("min_size", "(float) RPN_MIN_SIZE")
+ .Arg(
+ "correct_transform_coords",
+ "bool (default false), Correct bounding box transform coordates,"
+ " see bbox_transform() in boxes.py "
+ "Set to true to match the detectron code, set to false for backward"
+ " compatibility")
+ .Arg(
+ "angle_bound_on",
+ "bool (default true). If set, for rotated boxes, angle is "
+ "normalized to be within [angle_bound_lo, angle_bound_hi].")
+ .Arg(
+ "angle_bound_lo",
+ "int (default -90 degrees). If set, for rotated boxes, angle is "
+ "normalized to be within [angle_bound_lo, angle_bound_hi].")
+ .Arg(
+ "angle_bound_hi",
+ "int (default 90 degrees). If set, for rotated boxes, angle is "
+ "normalized to be within [angle_bound_lo, angle_bound_hi].")
+ .Arg(
+ "clip_angle_thresh",
+ "float (default 1.0 degrees). For RRPN, clip almost horizontal boxes "
+ "within this threshold of tolerance for backward compatibility. "
+ "Set to negative value for no clipping.")
.Input(0, "scores", "Scores from conv layer, size (img_count, A, H, W)")
.Input(
1,
diff --git a/caffe2/operators/generate_proposals_op.h b/caffe2/operators/generate_proposals_op.h
index c1ae488..81f7d9a 100644
--- a/caffe2/operators/generate_proposals_op.h
+++ b/caffe2/operators/generate_proposals_op.h
@@ -84,7 +84,9 @@
angle_bound_lo_(
OperatorBase::GetSingleArgument<int>("angle_bound_lo", -90)),
angle_bound_hi_(
- OperatorBase::GetSingleArgument<int>("angle_bound_hi", 90)) {}
+ OperatorBase::GetSingleArgument<int>("angle_bound_hi", 90)),
+ clip_angle_thresh_(
+ OperatorBase::GetSingleArgument<float>("clip_angle_thresh", 1.0)) {}
~GenerateProposalsOp() {}
@@ -127,6 +129,10 @@
bool angle_bound_on_{true};
int angle_bound_lo_{-90};
int angle_bound_hi_{90};
+ // For RRPN, clip almost horizontal boxes within this threshold of
+ // tolerance for backward compatibility. Set to negative value for
+ // no clipping.
+ float clip_angle_thresh_{1.0};
};
} // namespace caffe2
diff --git a/caffe2/operators/generate_proposals_op_test.cc b/caffe2/operators/generate_proposals_op_test.cc
index d8e1021..3fb7ed9 100644
--- a/caffe2/operators/generate_proposals_op_test.cc
+++ b/caffe2/operators/generate_proposals_op_test.cc
@@ -320,6 +320,7 @@
// Similar to TestRealDownSampled but for rotated boxes with angle info.
float angle = 0;
float delta_angle = 0;
+ float clip_angle_thresh = 1.0;
Workspace ws;
OperatorDef def;
@@ -407,33 +408,37 @@
vector<float> im_info{60, 80, 0.166667f};
// vector<float> anchors{-38, -16, 53, 31, -120, -120, 135, 135};
- vector<float> anchors{8, 8, 92, 48, angle, 8, 8, 256, 256, angle};
+ // Anchors in [x_ctr, y_ctr, w, h, angle] format
+ vector<float> anchors{7.5, 7.5, 92, 48, angle, 7.5, 7.5, 256, 256, angle};
- // Although angle == 0, the results aren't exactly the same as
- // TestRealDownSampled because because clip_boxes() is not performed
- // for RRPN style boxes.
- ERMatXf rois_gt(13, 6);
- rois_gt << 0, 6.55346, 25.3227, 253.447, 291.446, 0, 0, 55.3932, 33.3369,
- 253.731, 289.158, 0, 0, 6.48163, 24.3478, 92.3015, 38.6944, 0, 0, 70.3089,
- 26.7894, 92.3453, 38.5539, 0, 0, 22.3067, 26.7714, 92.3424, 38.5243, 0, 0,
- 054.084, 26.8413, 92.3938, 38.798, 0, 0, 5.33962, 42.2077, 92.5497,
- 38.2259, 0, 0, 6.36709, 58.24, 92.16, 37.4372, 0, 0, 69.65, 48.6713,
- 92.1521, 37.3668, 0, 0, 20.4147, 44.4783, 91.7111, 34.0295, 0, 0, 033.079,
- 41.5149, 92.3244, 36.4278, 0, 0, 41.8235, 037.291, 90.2815, 034.872, 0, 0,
- 13.8486, 48.662, 88.7818, 28.875, 0;
- vector<float> rois_probs_gt{0.0266914,
- 0.005621,
- 0.00544219,
- 0.00120544,
- 0.00119208,
- 0.00117182,
- 0.000617993,
- 0.000472735,
- 6.09605e-05,
- 1.05262e-05,
- 8.91026e-06,
- 9.29537e-09,
- 1.13482e-10};
+ // Results should exactly be the same as TestRealDownSampled since
+ // angle = 0 for all boxes and clip_angle_thresh > 0 (which means
+ // all horizontal boxes will be clipped to maintain backward compatibility).
+ ERMatXf rois_gt_xyxy(9, 5);
+ rois_gt_xyxy << 0, 0, 0, 79, 59, 0, 0, 5.0005703f, 51.6324f, 42.6950f, 0,
+ 24.13628387f, 7.51243401f, 79, 45.0663f, 0, 0, 7.50924301f, 67.4779f,
+ 45.0336, 0, 0, 23.09477997f, 50.61448669f, 59, 0, 0, 39.52141571f,
+ 51.44710541f, 59, 0, 23.57396317f, 29.98791885f, 79, 59, 0, 0,
+ 41.90219116f, 79, 59, 0, 0, 23.30098343f, 78.2413f, 58.7287f;
+ ERMatXf rois_gt(9, 6);
+ // Batch ID
+ rois_gt.block(0, 0, rois_gt.rows(), 1) =
+ ERMatXf::Constant(rois_gt.rows(), 1, 0.0);
+ // rois_gt in [x_ctr, y_ctr, w, h] format
+ rois_gt.block(0, 1, rois_gt.rows(), 4) =
+ boxes_xyxy_to_xywh(rois_gt_xyxy.block(0, 1, rois_gt.rows(), 4));
+ // Angle
+ rois_gt.block(0, 5, rois_gt.rows(), 1) =
+ ERMatXf::Constant(rois_gt.rows(), 1, angle);
+ vector<float> rois_probs_gt{2.66913995e-02f,
+ 5.44218998e-03f,
+ 1.20544003e-03f,
+ 1.19207997e-03f,
+ 6.17993006e-04f,
+ 4.72735002e-04f,
+ 6.09605013e-05f,
+ 1.50015003e-05f,
+ 8.91025957e-06f};
AddInput(vector<TIndex>{img_count, A, H, W}, scores, "scores", &ws);
AddInput(
@@ -450,6 +455,7 @@
def.add_arg()->CopyFrom(MakeArgument("nms_thresh", 0.7f));
def.add_arg()->CopyFrom(MakeArgument("min_size", 16.0f));
def.add_arg()->CopyFrom(MakeArgument("correct_transform_coords", true));
+ def.add_arg()->CopyFrom(MakeArgument("clip_angle_thresh", clip_angle_thresh));
unique_ptr<OperatorBase> op(CreateOperator(def, &ws));
EXPECT_NE(nullptr, op.get());
@@ -484,6 +490,7 @@
float angle = 45.0;
float delta_angle = 0.174533; // 0.174533 radians -> 10 degrees
float expected_angle = 55.0;
+ float clip_angle_thresh = 1.0;
Workspace ws;
OperatorDef def;
@@ -588,6 +595,7 @@
def.add_arg()->CopyFrom(MakeArgument("nms_thresh", 0.7f));
def.add_arg()->CopyFrom(MakeArgument("min_size", 16.0f));
def.add_arg()->CopyFrom(MakeArgument("correct_transform_coords", true));
+ def.add_arg()->CopyFrom(MakeArgument("clip_angle_thresh", clip_angle_thresh));
unique_ptr<OperatorBase> op(CreateOperator(def, &ws));
EXPECT_NE(nullptr, op.get());
diff --git a/caffe2/operators/generate_proposals_op_util_boxes.h b/caffe2/operators/generate_proposals_op_util_boxes.h
index 440d141..0c4c345 100644
--- a/caffe2/operators/generate_proposals_op_util_boxes.h
+++ b/caffe2/operators/generate_proposals_op_util_boxes.h
@@ -192,23 +192,50 @@
}
}
+template <class Derived>
+EArrXXt<typename Derived::Scalar> bbox_xyxy_to_ctrwh(
+ const Eigen::ArrayBase<Derived>& boxes) {
+ CAFFE_ENFORCE_EQ(boxes.cols(), 4);
+
+ const auto& x1 = boxes.col(0);
+ const auto& y1 = boxes.col(1);
+ const auto& x2 = boxes.col(2);
+ const auto& y2 = boxes.col(3);
+
+ EArrXXt<typename Derived::Scalar> ret(boxes.rows(), 4);
+ ret.col(0) = (x1 + x2) / 2.0; // x_ctr
+ ret.col(1) = (y1 + y2) / 2.0; // y_ctr
+ ret.col(2) = x2 - x1 + 1.0; // w
+ ret.col(3) = y2 - y1 + 1.0; // h
+ return ret;
+}
+
+template <class Derived>
+EArrXXt<typename Derived::Scalar> bbox_ctrwh_to_xyxy(
+ const Eigen::ArrayBase<Derived>& boxes) {
+ CAFFE_ENFORCE_EQ(boxes.cols(), 4);
+
+ const auto& x_ctr = boxes.col(0);
+ const auto& y_ctr = boxes.col(1);
+ const auto& w = boxes.col(2);
+ const auto& h = boxes.col(3);
+
+ EArrXXt<typename Derived::Scalar> ret(boxes.rows(), 4);
+ ret.col(0) = x_ctr - (w - 1) / 2.0; // x1
+ ret.col(1) = y_ctr - (h - 1) / 2.0; // y1
+ ret.col(2) = x_ctr + (w - 1) / 2.0; // x2
+ ret.col(3) = y_ctr + (h - 1) / 2.0; // y2
+ return ret;
+}
+
// Clip boxes to image boundaries
// boxes: pixel coordinates of bounding box, size (M * 4)
-//
-// For rotated boxes with angle support (M * 5), we don't clip and just
-// return early. It's tricky to make the entire rectangular box fit within the
-// image and still be able to not leave out pixels of interest.
-// We rely on upstream ops like RoIAlignRotated safely handling such cases.
template <class Derived>
-EArrXXt<typename Derived::Scalar>
-clip_boxes(const Eigen::ArrayBase<Derived>& boxes, int height, int width) {
- CAFFE_ENFORCE(boxes.cols() == 4 || boxes.cols() == 5);
- if (boxes.cols() == 5) {
- // No clipping for rotated boxes.
- // TODO (viswanath): Should this be implemented for backward compatibility
- // with angle=0 case?
- return boxes;
- }
+EArrXXt<typename Derived::Scalar> clip_boxes_upright(
+ const Eigen::ArrayBase<Derived>& boxes,
+ int height,
+ int width) {
+ CAFFE_ENFORCE(boxes.cols() == 4);
EArrXXt<typename Derived::Scalar> ret(boxes.rows(), boxes.cols());
@@ -224,6 +251,69 @@
return ret;
}
+// Similar to clip_boxes_upright but handles rotated boxes with angle info.
+// boxes: size (M, 5), format [ctr_x; ctr_y; width; height; angle (in degrees)]
+//
+// Clipping is only performed for boxes that are almost upright
+// (within a given `angle_thresh` tolerance) to maintain backward compatibility
+// for non-rotated boxes.
+//
+// We don't clip rotated boxes due to a couple of reasons:
+// (1) There are potentially multiple ways to clip a rotated box to make it
+// fit within the image.
+// (2) It's tricky to make the entire rectangular box fit within the image and
+// still be able to not leave out pixels of interest.
+// Therefore, we rely on upstream ops like RoIAlignRotated safely handling this.
+template <class Derived>
+EArrXXt<typename Derived::Scalar> clip_boxes_rotated(
+ const Eigen::ArrayBase<Derived>& boxes,
+ int height,
+ int width,
+ float angle_thresh = 1.0) {
+ CAFFE_ENFORCE(boxes.cols() == 5);
+
+ const auto& angles = boxes.col(4);
+
+ // Filter boxes that are upright (with a tolerance of angle_thresh)
+ EArrXXt<typename Derived::Scalar> upright_boxes;
+ const auto& indices = GetArrayIndices(angles.abs() <= angle_thresh);
+ GetSubArrayRows(boxes, AsEArrXt(indices), &upright_boxes);
+
+ // Convert to [x1, y1, x2, y2] format and clip them
+ const auto& upright_boxes_xyxy =
+ bbox_ctrwh_to_xyxy(upright_boxes.leftCols(4));
+ const auto& clipped_upright_boxes_xyxy =
+ clip_boxes_upright(upright_boxes_xyxy, height, width);
+
+ // Convert back to [x_ctr, y_ctr, w, h, angle] and update upright boxes
+ upright_boxes.block(0, 0, upright_boxes.rows(), 4) =
+ bbox_xyxy_to_ctrwh(clipped_upright_boxes_xyxy);
+
+ EArrXXt<typename Derived::Scalar> ret(boxes.rows(), boxes.cols());
+ ret = boxes;
+ for (int i = 0; i < upright_boxes.rows(); ++i) {
+ ret.row(indices[i]) = upright_boxes.row(i);
+ }
+ return ret;
+}
+
+// Clip boxes to image boundaries.
+template <class Derived>
+EArrXXt<typename Derived::Scalar> clip_boxes(
+ const Eigen::ArrayBase<Derived>& boxes,
+ int height,
+ int width,
+ float angle_thresh = 1.0) {
+ CAFFE_ENFORCE(boxes.cols() == 4 || boxes.cols() == 5);
+ if (boxes.cols() == 4) {
+ // Upright boxes
+ return clip_boxes_upright(boxes, height, width);
+ } else {
+ // Rotated boxes with angle info
+ return clip_boxes_rotated(boxes, height, width, angle_thresh);
+ }
+}
+
// Only keep boxes with both sides >= min_size and center within the image.
// boxes: pixel coordinates of bounding box, size (M * 4)
// im_info: [height, width, img_scale]
diff --git a/caffe2/operators/generate_proposals_op_util_boxes_test.cc b/caffe2/operators/generate_proposals_op_util_boxes_test.cc
index a8d4f4c..f9ff7e9 100644
--- a/caffe2/operators/generate_proposals_op_util_boxes_test.cc
+++ b/caffe2/operators/generate_proposals_op_util_boxes_test.cc
@@ -105,4 +105,33 @@
EXPECT_NEAR((result.matrix() - result_gt).norm(), 0.0, 1e-2);
}
+TEST(UtilsBoxesTest, ClipRotatedBoxes) {
+ // Test utils::clip_boxes_rotated()
+ using EMatXf = Eigen::MatrixXf;
+
+ int height = 800;
+ int width = 600;
+ EMatXf bbox(5, 5);
+ bbox << 20, 20, 200, 150, 0, // Horizontal
+ 20, 20, 200, 150, 0.5, // Almost horizontal
+ 20, 20, 200, 150, 30, // Rotated
+ 300, 300, 200, 150, 30, // Rotated
+ 579, 779, 200, 150, -0.5; // Almost horizontal
+
+ // Test with no clipping
+ float angle_thresh = -1.0;
+ auto result = utils::clip_boxes(bbox.array(), height, width, angle_thresh);
+ EXPECT_NEAR((result.matrix() - bbox).norm(), 0.0, 1e-4);
+
+ EMatXf result_gt(5, 5);
+ result_gt << 59.75, 47.25, 120.5, 95.5, 0, 59.75, 47.25, 120.5, 95.5, 0.5, 20,
+ 20, 200, 150, 30, 300, 300, 200, 150, 30, 539.25, 751.75, 120.5, 95.5,
+ -0.5;
+
+ // Test clipping with tolerance
+ angle_thresh = 1.0;
+ result = utils::clip_boxes(bbox.array(), height, width, angle_thresh);
+ EXPECT_NEAR((result.matrix() - result_gt).norm(), 0.0, 1e-4);
+}
+
} // namespace caffe2
diff --git a/caffe2/python/operator_test/bbox_transform_test.py b/caffe2/python/operator_test/bbox_transform_test.py
index 20008ac..b54a443 100644
--- a/caffe2/python/operator_test/bbox_transform_test.py
+++ b/caffe2/python/operator_test/bbox_transform_test.py
@@ -152,6 +152,42 @@
return pred_boxes
+def clip_tiled_boxes_rotated(boxes, im_shape, angle_thresh=1.0):
+ """
+ Similar to clip_tiled_boxes but for rotated boxes with angle info.
+ Only clips almost horizontal boxes within angle_thresh. The rest are
+ left unchanged.
+ """
+ assert (
+ boxes.shape[1] % 5 == 0
+ ), "boxes.shape[1] is {:d}, but must be divisible by 5.".format(
+ boxes.shape[1]
+ )
+
+ (H, W) = im_shape[:2]
+
+ # Filter boxes that are almost upright within angle_thresh tolerance
+ idx = np.where(np.abs(boxes[:, 4::5]) <= angle_thresh)
+ idx5 = idx[1] * 5
+ # convert to (x1, y1, x2, y2)
+ x1 = boxes[idx[0], idx5] - (boxes[idx[0], idx5 + 2] - 1) / 2.0
+ y1 = boxes[idx[0], idx5 + 1] - (boxes[idx[0], idx5 + 3] - 1) / 2.0
+ x2 = boxes[idx[0], idx5] + (boxes[idx[0], idx5 + 2] - 1) / 2.0
+ y2 = boxes[idx[0], idx5 + 1] + (boxes[idx[0], idx5 + 3] - 1) / 2.0
+ # clip
+ x1 = np.maximum(np.minimum(x1, W - 1), 0)
+ y1 = np.maximum(np.minimum(y1, H - 1), 0)
+ x2 = np.maximum(np.minimum(x2, W - 1), 0)
+ y2 = np.maximum(np.minimum(y2, H - 1), 0)
+ # convert back to (xc, yc, w, h)
+ boxes[idx[0], idx5] = (x1 + x2) / 2.0
+ boxes[idx[0], idx5 + 1] = (y1 + y2) / 2.0
+ boxes[idx[0], idx5 + 2] = x2 - x1 + 1
+ boxes[idx[0], idx5 + 3] = y2 - y1 + 1
+
+ return boxes
+
+
def generate_rois_rotated(roi_counts, im_dims):
rois = generate_rois(roi_counts, im_dims)
# [batch_id, ctr_x, ctr_y, w, h, angle]
@@ -161,7 +197,7 @@
rotated_rois[:, 2] = (rois[:, 2] + rois[:, 4]) / 2. # ctr_y = (y1 + y2) / 2
rotated_rois[:, 3] = rois[:, 3] - rois[:, 1] + 1.0 # w = x2 - x1 + 1
rotated_rois[:, 4] = rois[:, 4] - rois[:, 2] + 1.0 # h = y2 - y1 + 1
- rotated_rois[:, 5] = np.random.uniform(0.0, 360.0) # angle in degrees
+ rotated_rois[:, 5] = np.random.uniform(-90.0, 90.0) # angle in degrees
return rotated_rois
@@ -173,6 +209,7 @@
skip_batch_id=st.booleans(),
rotated=st.booleans(),
angle_bound_on=st.booleans(),
+ clip_angle_thresh=st.sampled_from([-1.0, 1.0]),
**hu.gcs_cpu_only
)
def test_bbox_transform(
@@ -183,6 +220,7 @@
skip_batch_id,
rotated,
angle_bound_on,
+ clip_angle_thresh,
gc,
dc,
):
@@ -202,14 +240,16 @@
def bbox_transform_ref(rois, deltas, im_info):
boxes = rois if rois.shape[1] == box_dim else rois[:, 1:]
+ im_shape = im_info[0, 0:2]
if rotated:
box_out = bbox_transform_rotated(
boxes, deltas, angle_bound_on=angle_bound_on
)
- # No clipping for rotated boxes
+ box_out = clip_tiled_boxes_rotated(
+ box_out, im_shape, angle_thresh=clip_angle_thresh
+ )
else:
box_out = bbox_transform(boxes, deltas)
- im_shape = im_info[0, 0:2]
box_out = clip_tiled_boxes(box_out, im_shape)
return [box_out]
@@ -221,6 +261,7 @@
correct_transform_coords=True,
rotated=rotated,
angle_bound_on=angle_bound_on,
+ clip_angle_thresh=clip_angle_thresh,
)
self.assertReferenceChecks(
@@ -235,10 +276,18 @@
num_classes=st.integers(1, 10),
rotated=st.booleans(),
angle_bound_on=st.booleans(),
+ clip_angle_thresh=st.sampled_from([-1.0, 1.0]),
**hu.gcs_cpu_only
)
def test_bbox_transform_batch(
- self, roi_counts, num_classes, rotated, angle_bound_on, gc, dc
+ self,
+ roi_counts,
+ num_classes,
+ rotated,
+ angle_bound_on,
+ clip_angle_thresh,
+ gc,
+ dc,
):
"""
Test with rois for multiple images in a batch
@@ -266,14 +315,16 @@
continue
cur_boxes = rois[offset : offset + num_rois, 1:]
cur_deltas = deltas[offset : offset + num_rois]
+ im_shape = im_info[i, 0:2]
if rotated:
cur_box_out = bbox_transform_rotated(
cur_boxes, cur_deltas, angle_bound_on=angle_bound_on
)
- # No clipping for rotated boxes
+ cur_box_out = clip_tiled_boxes_rotated(
+ cur_box_out, im_shape, angle_thresh=clip_angle_thresh
+ )
else:
cur_box_out = bbox_transform(cur_boxes, cur_deltas)
- im_shape = im_info[i, 0:2]
cur_box_out = clip_tiled_boxes(cur_box_out, im_shape)
box_out.append(cur_box_out)
offset += num_rois
@@ -292,6 +343,7 @@
correct_transform_coords=True,
rotated=rotated,
angle_bound_on=angle_bound_on,
+ clip_angle_thresh=clip_angle_thresh,
)
self.assertReferenceChecks(