Update caffe2 from facebook (#2178)

* [C2] Don't crash kernel in case of invalid shapes for ConcatOp

Enforce correctness of the shapes for input tensors so we won't access invalid index.

* [Caffe2] Add analytical performance counters to Dynolog

Initial diff for counting analytical flops and memory writes for C2 operators.

* BBoxTransform op: Handle RoIs from multiple images per batch

BBoxTransform op used during typical Faster-RCNN inference operates only on
RoIs from a single image (no batching). Adding support to handle that with an
optional output blob containing the batch splits (i.e., the number of RoIs
belonging to each item in the batch). The code is perfectly backward compatible
and shouldn't break any existing models..

* [mkl] Make MKL-DNN cooperate with memongered nets

C2's MKL-DNN implementation caches input dims and reuses intermediate and
output buffers across net runs, which prevents memonger from being used. This
may not always be useful since input dims may vary widely in many cases and
we'll end up reallocating anyway. Added an option to force reallocation when
memonger is used.

* [oncall] fix batch gather ops for empty input

still need to bisect for the breaking change, but this shall fix the case for empty input.

the error logging is like: https://interncache-ftw.fbcdn.net/t49.3276-7/23938497_293562711176943_6500112636590424064_n.txt?_nc_log=1

@[557759185:raychen] can you help to subscribe oncall from ads side. this may affect the Sigrid online trainer.

* optimize BatchOneHotOp

We want to iterate in row-major as opposed to column-major for better
locality.

* Supported exporting model with int blobs.

Supported exporting model with int blobs. Needed by condensenet.

* BoxWithNMSLimit op: Handle boxes from mutiple images per batch

Similar to D7135360. Added support for multiple images per batch in the op.
Takes an optional additional input "batch_splits" as output by BBoxTransform
op, and returns new batch_splits after applying NMS and filtering. Otherwise,
backward compatibility is maintained.
diff --git a/caffe2/mkl/mkl_operator.cc b/caffe2/mkl/mkl_operator.cc
index 4ac0a85..47c4967 100644
--- a/caffe2/mkl/mkl_operator.cc
+++ b/caffe2/mkl/mkl_operator.cc
@@ -17,6 +17,12 @@
 #include "caffe2/core/operator.h"
 #include "caffe2/proto/caffe2.pb.h"
 
+CAFFE2_DEFINE_bool(
+    caffe2_mkl_memonger_in_use,
+    false,
+    "Turn on if memonger is used to force reallocate intermediate "
+    "and output buffers within each op");
+
 namespace caffe2 {
 
 CAFFE_DEFINE_REGISTRY(
diff --git a/caffe2/mkl/operators/concat_op.cc b/caffe2/mkl/operators/concat_op.cc
index bac4dda..9fe1674 100644
--- a/caffe2/mkl/operators/concat_op.cc
+++ b/caffe2/mkl/operators/concat_op.cc
@@ -64,7 +64,7 @@
       dims_changed = (input_size_cache_[i] != Input(i).dims());
     }
 
-    if (dims_changed) {
+    if (dims_changed || FLAGS_caffe2_mkl_memonger_in_use) {
       input_size_cache_.resize(nInputs);
       int output_channels = 0;
       int canonical_axis = canonical_axis_index_(axis_, nDims);
@@ -96,7 +96,7 @@
       Y->Reset(cached_output_dims_, primitive_, dnnResourceDst);
       buffer_.Reset(cached_output_dims_, primitive_, dnnResourceDst, true);
     }
-    buffer_.ShareFrom(*Y);
+    bool shared = buffer_.ShareFrom(*Y);
 
     for (int i = 0; i < nInputs; ++i) {
       resources_[dnnResourceMultipleSrc + i] = Input(i).buffer();
@@ -104,6 +104,9 @@
     resources_[dnnResourceDst] = buffer_.buffer();
     ExecutePrimitive();
     buffer_.CopyTo(Y, primitive_, dnnResourceDst);
+    if (FLAGS_caffe2_mkl_memonger_in_use && !shared) {
+      buffer_.Reset();
+    }
     return true;
   }
 
diff --git a/caffe2/mkl/operators/conv_op.cc b/caffe2/mkl/operators/conv_op.cc
index ffe3c1f..e4ae7e5 100644
--- a/caffe2/mkl/operators/conv_op.cc
+++ b/caffe2/mkl/operators/conv_op.cc
@@ -67,7 +67,7 @@
 
     bool dims_changed;
     CHECK_INPUT_FILTER_DIMS(X, filter, dims_changed);
-    if (dims_changed) {
+    if (dims_changed || FLAGS_caffe2_mkl_memonger_in_use) {
       CAFFE_ENFORCE(
           C == filter.dim32(1) * group_,
           "Convolution op: input channels does not match: # of input channels ",
@@ -136,7 +136,7 @@
     // Try to share from the output: this allows us to avoid unnecessary copy
     // operations, if the output is already allocated and is having the same
     // layout as the buffer has.
-    buffer_.ShareFrom(*Y);
+    bool shared = buffer_.ShareFrom(*Y);
 
     std::shared_ptr<void> X_view = X.View(
         input_layout_, primitive_, dnnResourceSrc);
@@ -168,6 +168,11 @@
 
     MKLDNN_SAFE_CALL(mkl::dnnExecute<T>(primitive_, resources_));
     buffer_.CopyTo(Y, primitive_, dnnResourceDst);
+    if (FLAGS_caffe2_mkl_memonger_in_use && !shared) {
+      // buffer_ is not shared with Y. Free memory since it'll
+      // be re-allocated in the next run anyway due to memonger in use.
+      buffer_.Reset();
+    }
     return true;
   }
 
diff --git a/caffe2/mkl/operators/elementwise_sum_op.cc b/caffe2/mkl/operators/elementwise_sum_op.cc
index 2918495..5d3d377 100644
--- a/caffe2/mkl/operators/elementwise_sum_op.cc
+++ b/caffe2/mkl/operators/elementwise_sum_op.cc
@@ -42,7 +42,7 @@
     MKLMemory<T>* Y = Output(0);
     bool dims_changed;
     CHECK_INPUT_DIMS(X0, dims_changed);
-    if (dims_changed) {
+    if (dims_changed || FLAGS_caffe2_mkl_memonger_in_use) {
       primitive_.Reset(
           dnnSumCreate<T>,
           nullptr,
@@ -63,14 +63,18 @@
       input_views_[i] = Xi.View(X0.layout());
       resources_[dnnResourceMultipleSrc + i] = input_views_[i].get();
     }
+    bool shared = false;
     if (Y != &X0) {
       // TODO: MKLDNN seems broken in the in-place case, so when we specify
       // in-place we will need to use buffer differnt from X0/Y.
-      buffer_.ShareFrom(*Y);
+      shared = buffer_.ShareFrom(*Y);
     }
     resources_[dnnResourceDst] = buffer_.buffer();
     MKLDNN_SAFE_CALL(mkl::dnnExecute<T>(primitive_, resources_));
     buffer_.CopyTo(Y, primitive_, dnnResourceDst);
+    if (FLAGS_caffe2_mkl_memonger_in_use && !shared) {
+      buffer_.Reset();
+    }
     return true;
   }
 
diff --git a/caffe2/mkl/operators/fully_connected_op.cc b/caffe2/mkl/operators/fully_connected_op.cc
index 081e7f0..c788f07 100644
--- a/caffe2/mkl/operators/fully_connected_op.cc
+++ b/caffe2/mkl/operators/fully_connected_op.cc
@@ -42,7 +42,7 @@
 
     bool dims_changed;
     CHECK_INPUT_FILTER_DIMS(X, filter, dims_changed);
-    if (dims_changed) {
+    if (dims_changed || FLAGS_caffe2_mkl_memonger_in_use) {
       const int N = filter.dim32(0);
       CAFFE_ENFORCE(N == bias.dim32(0));
 
@@ -81,7 +81,7 @@
     // Try to share from the output: this allows us to avoid unnecessary copy
     // operations, if the output is already allocated and is having the same
     // layout as the buffer has.
-    buffer_.ShareFrom(*Y);
+    bool shared = buffer_.ShareFrom(*Y);
 
     std::shared_ptr<void> X_view =
         X.View(input_layout_, primitive_, dnnResourceSrc);
@@ -96,6 +96,9 @@
 
     MKLDNN_SAFE_CALL(mkl::dnnExecute<T>(primitive_, resources_));
     buffer_.CopyTo(Y, primitive_, dnnResourceDst);
+    if (FLAGS_caffe2_mkl_memonger_in_use && !shared) {
+      buffer_.Reset();
+    }
     return true;
   }
 
diff --git a/caffe2/mkl/operators/local_response_normalization_op.cc b/caffe2/mkl/operators/local_response_normalization_op.cc
index 168a886..fd67bb4 100644
--- a/caffe2/mkl/operators/local_response_normalization_op.cc
+++ b/caffe2/mkl/operators/local_response_normalization_op.cc
@@ -50,7 +50,7 @@
 
   bool dims_changed;
   CHECK_INPUT_DIMS(X, dims_changed);
-  if (dims_changed) {
+  if (dims_changed || FLAGS_caffe2_mkl_memonger_in_use) {
     size_t dim = X.ndim();
     CAFFE_ENFORCE(4 == dim);
 
@@ -75,12 +75,15 @@
   // Try to share from the output: this allows us to avoid unnecessary copy
   // operations, if the output is already allocated and is having the same
   // layout as the buffer has.
-  buffer_.ShareFrom(*Y);
+  bool shared = buffer_.ShareFrom(*Y);
   resources_[dnnResourceSrc] = X.buffer();
   resources_[dnnResourceDst] = buffer_.buffer();
   resources_[dnnResourceWorkspace] = workspace_buffer_->buffer();
   MKLDNN_SAFE_CALL(mkl::dnnExecute<float>(primitive_, resources_));
   buffer_.CopyTo(Y, primitive_, dnnResourceDst);
+  if (FLAGS_caffe2_mkl_memonger_in_use && !shared) {
+    buffer_.Reset();
+  }
   return true;
 }
 
diff --git a/caffe2/mkl/operators/pool_op.cc b/caffe2/mkl/operators/pool_op.cc
index 62a19cb..115d1f6 100644
--- a/caffe2/mkl/operators/pool_op.cc
+++ b/caffe2/mkl/operators/pool_op.cc
@@ -74,7 +74,7 @@
 
   bool dims_changed;
   CHECK_INPUT_DIMS(X, dims_changed);
-  if (dims_changed) {
+  if (dims_changed || FLAGS_caffe2_mkl_memonger_in_use) {
     // We will utilize the SetOutputSize() function in the base class
     // with dummy TensorCPU input and output to calculate the sizes.
     TensorCPU dummy_input(X.dims());
@@ -111,12 +111,15 @@
   // Try to share from the output: this allows us to avoid unnecessary copy
   // operations, if the output is already allocated and is having the same
   // layout as the buffer has.
-  buffer_.ShareFrom(*Y);
+  bool shared = buffer_.ShareFrom(*Y);
   resources_[dnnResourceSrc] = X.buffer();
   resources_[dnnResourceDst] = buffer_.buffer();
   resources_[dnnResourceWorkspace] = workspace_buffer_->buffer();
   MKLDNN_SAFE_CALL(mkl::dnnExecute<float>(primitive_, resources_));
   buffer_.CopyTo(Y, primitive_, dnnResourceDst);
+  if (FLAGS_caffe2_mkl_memonger_in_use && !shared) {
+    buffer_.Reset();
+  }
   return true;
 }
 
diff --git a/caffe2/mkl/operators/relu_op.cc b/caffe2/mkl/operators/relu_op.cc
index 4644943..e16fc41 100644
--- a/caffe2/mkl/operators/relu_op.cc
+++ b/caffe2/mkl/operators/relu_op.cc
@@ -35,7 +35,7 @@
 
     bool dims_changed;
     CHECK_INPUT_DIMS(X, dims_changed);
-    if (dims_changed) {
+    if (dims_changed || FLAGS_caffe2_mkl_memonger_in_use) {
       // First run or changed input size, will need to recreate environment
       primitive_.Reset(dnnReLUCreateForward<T>, nullptr, X.layout(), 0.f);
       if (&X != Y) {
@@ -46,12 +46,15 @@
     // Try to share from the output: this allows us to avoid unnecessary copy
     // operations, if the output is already allocated and is having the same
     // layout as the buffer has.
-    buffer_.ShareFrom(*Y);
+    bool shared = buffer_.ShareFrom(*Y);
     CAFFE_ENFORCE(dnnLayoutCompare_F32(X.layout(), buffer_.layout()));
     resources_[dnnResourceSrc] = X.buffer();
     resources_[dnnResourceDst] = buffer_.buffer();
     ExecutePrimitive();
     buffer_.CopyTo(Y, primitive_, dnnResourceDst);
+    if (FLAGS_caffe2_mkl_memonger_in_use && !shared) {
+      buffer_.Reset();
+    }
     return true;
   }
 
diff --git a/caffe2/mkl/operators/spatial_batch_norm_op.cc b/caffe2/mkl/operators/spatial_batch_norm_op.cc
index eefc318..eab10ee 100644
--- a/caffe2/mkl/operators/spatial_batch_norm_op.cc
+++ b/caffe2/mkl/operators/spatial_batch_norm_op.cc
@@ -66,7 +66,7 @@
 
     bool dims_changed;
     CHECK_INPUT_DIMS(X, dims_changed);
-    if (dims_changed) {
+    if (dims_changed || FLAGS_caffe2_mkl_memonger_in_use) {
       // Create main primitive.
       if (is_test_) {
         primitive_.Reset(
@@ -111,7 +111,7 @@
     // Try to share from the output: this allows us to avoid unnecessary copy
     // operations, if the output is already allocated and is having the same
     // layout as the buffer has.
-    buffer_.ShareFrom(*Y);
+    bool shared = buffer_.ShareFrom(*Y);
     resources_[dnnResourceSrc] = X.buffer();
     resources_[dnnResourceDst] = buffer_.buffer();
     resources_[dnnResourceScaleShift] = scale_bias_buffer_->buffer();
@@ -143,6 +143,9 @@
       }
     }
     buffer_.CopyTo(Y, primitive_, dnnResourceDst);
+    if (FLAGS_caffe2_mkl_memonger_in_use && !shared) {
+      buffer_.Reset();
+    }
     return true;
   }
 
diff --git a/caffe2/mkl/operators/squeeze_op.cc b/caffe2/mkl/operators/squeeze_op.cc
index f431a1f..a41b372 100644
--- a/caffe2/mkl/operators/squeeze_op.cc
+++ b/caffe2/mkl/operators/squeeze_op.cc
@@ -55,7 +55,7 @@
 
     bool dims_changed;
     CHECK_INPUT_DIMS(X, dims_changed);
-    if (dims_changed) {
+    if (dims_changed || FLAGS_caffe2_mkl_memonger_in_use) {
       // Temp buffer mainly to convert the input to plain layout before
       // Reshape() if the input has a custom layout.
       buffer_.Reset(X.dims());
diff --git a/caffe2/mkl/utils/mkl_memory.h b/caffe2/mkl/utils/mkl_memory.h
index ee7bf14..43b6f11 100644
--- a/caffe2/mkl/utils/mkl_memory.h
+++ b/caffe2/mkl/utils/mkl_memory.h
@@ -61,6 +61,13 @@
     creator(&primitive_, args...);
   }
 
+  void Reset() {
+    if (primitive_) {
+      MKLDNN_SAFE_CALL(dnnDelete<T>(primitive_));
+      primitive_ = nullptr;
+    }
+  }
+
   operator dnnPrimitive_t() const {
     return primitive_;
   }
@@ -134,6 +141,13 @@
     MKLDNN_SAFE_CALL(dnnLayoutCreate<T>(&layout_, dimension, size, strides));
   }
 
+  void Reset() {
+    if (layout_) {
+      MKLDNN_CHECK(dnnLayoutDelete<T>(layout_));
+      layout_ = nullptr;
+    }
+  }
+
   operator dnnLayout_t() const {
     return layout_;
   }
@@ -251,6 +265,16 @@
     }
   }
 
+  void Reset() {
+    buffer_.reset();
+    dims_.clear();
+    size_ = 0;
+    user_layout_.Reset();
+    layout_.Reset();
+    convert_in_.Reset();
+    convert_out_.Reset();
+  }
+
   /**
    * Resizes the tensor without touching underlying storage.
    * This requires the total size of the tensor to remains constant.
diff --git a/caffe2/mkl/utils/mkl_operator.h b/caffe2/mkl/utils/mkl_operator.h
index 0fd37e9..35bc280 100644
--- a/caffe2/mkl/utils/mkl_operator.h
+++ b/caffe2/mkl/utils/mkl_operator.h
@@ -22,6 +22,8 @@
 #include "caffe2/mkl/utils/mkl_memory.h"
 #include "caffe2/proto/caffe2.pb.h"
 
+CAFFE2_DECLARE_bool(caffe2_mkl_memonger_in_use);
+
 namespace caffe2 {
 
 CAFFE_DECLARE_REGISTRY(
diff --git a/caffe2/operators/batch_gather_ops.h b/caffe2/operators/batch_gather_ops.h
index c282c78..4e29009 100644
--- a/caffe2/operators/batch_gather_ops.h
+++ b/caffe2/operators/batch_gather_ops.h
@@ -111,6 +111,10 @@
 
     output->ResizeLike(data);
     TData* out_data = output->template mutable_data<TData>();
+    if (data.size() <= 0) {
+      return true;
+    }
+
     memset(out_data, 0, output->nbytes());
 
     const TData* grad_data = grad.template data<TData>();
diff --git a/caffe2/operators/bbox_transform_op.cc b/caffe2/operators/bbox_transform_op.cc
index b63b418..369db13 100644
--- a/caffe2/operators/bbox_transform_op.cc
+++ b/caffe2/operators/bbox_transform_op.cc
@@ -19,7 +19,7 @@
 // Input: box, delta Output: box
 OPERATOR_SCHEMA(BBoxTransform)
     .NumInputs(3)
-    .NumOutputs(1)
+    .NumOutputs(1, 2)
     .SetDoc(R"DOC(
 Transform proposal bounding boxes to target bounding box using bounding box
     regression deltas.
@@ -42,7 +42,9 @@
         "rois",
         "Bounding box proposals in pixel coordinates, "
         "Size (M, 4), format [x1, y1, x2, y2], or"
-        "Size (M, 5), format [img_index_IGNORED, x1, y1, x2, y2]")
+        "Size (M, 5), format [batch_index, x1, y1, x2, y2]. "
+        "If proposals from multiple images in a batch are present, they "
+        "should be grouped sequentially and in incremental order.")
     .Input(
         1,
         "deltas",
@@ -51,13 +53,18 @@
     .Input(
         2,
         "im_info",
-        "Image dimensions, size (1, 3), "
-        "format [img_height, img_width, img_scale_IGNORED]")
+        "Image dimensions, size (batch_size, 3), "
+        "format [img_height, img_width, img_scale]")
     .Output(
         0,
         "box_out",
         "Pixel coordinates of the transformed bounding boxes,"
-        "Size (M, 4*K), format [x1, y1, x2, y2]");
+        "Size (M, 4*K), format [x1, y1, x2, y2]")
+    .Output(
+        1,
+        "roi_batch_splits",
+        "Tensor of shape (batch_size) with each element denoting the number "
+        "of RoIs belonging to the corresponding image in batch");
 
 SHOULD_NOT_DO_GRADIENT(BBoxTransform);
 } // namespace
@@ -71,43 +78,82 @@
 
   const int N = roi_in.dim32(0);
   CAFFE_ENFORCE_EQ(roi_in.ndim(), 2);
-  CAFFE_ENFORCE_GE(roi_in.dim32(1), 4);
+  CAFFE_ENFORCE(roi_in.dim32(1) == 4 || roi_in.dim32(1) == 5);
 
-  CAFFE_ENFORCE_EQ(roi_in.ndim(), 2);
+  CAFFE_ENFORCE_EQ(delta_in.ndim(), 2);
   CAFFE_ENFORCE_EQ(delta_in.dim32(0), N);
   CAFFE_ENFORCE_EQ(delta_in.dim32(1) % 4, 0);
+  const int num_classes = delta_in.dim32(1) / 4;
+
+  CAFFE_ENFORCE_EQ(iminfo_in.ndim(), 2);
+  CAFFE_ENFORCE_EQ(iminfo_in.dim32(1), 3);
+  const int batch_size = iminfo_in.dim32(0);
 
   DCHECK_EQ(weights_.size(), 4);
 
-  CAFFE_ENFORCE_EQ(iminfo_in.size(), 3);
-  ConstEigenVectorArrayMap<float> iminfo(iminfo_in.data<float>(), 3);
-  const float scale_before = iminfo(2);
-  const float scale_after = apply_scale_ ? iminfo(2) : 1.0;
-  int img_h = int(iminfo(0) / scale_before + 0.5);
-  int img_w = int(iminfo(1) / scale_before + 0.5);
-
   Eigen::Map<const ERArrXXf> boxes0(
       roi_in.data<float>(), roi_in.dim32(0), roi_in.dim32(1));
-  auto boxes = boxes0.rightCols(4) / scale_before;
-
   Eigen::Map<const ERArrXXf> deltas0(
       delta_in.data<float>(), delta_in.dim32(0), delta_in.dim32(1));
 
+  // Count the number of RoIs per batch
+  vector<int> num_rois_per_batch(batch_size, 0);
+  if (roi_in.dim32(1) == 4) {
+    CAFFE_ENFORCE_EQ(batch_size, 1);
+    num_rois_per_batch[0] = N;
+  } else {
+    const auto& roi_batch_ids = boxes0.col(0);
+    for (int i = 0; i < roi_batch_ids.size(); ++i) {
+      const int roi_batch_id = roi_batch_ids(i);
+      CAFFE_ENFORCE_LT(roi_batch_id, batch_size);
+      num_rois_per_batch[roi_batch_id]++;
+    }
+  }
+
+  CAFFE_ENFORCE_EQ(iminfo_in.dims(), (vector<TIndex>{batch_size, 3}));
+  Eigen::Map<const ERArrXXf> iminfo(
+      iminfo_in.data<float>(), iminfo_in.dim(0), iminfo_in.dim(1));
+
   box_out->ResizeLike(delta_in);
   Eigen::Map<ERArrXXf> new_boxes(
       box_out->mutable_data<float>(), box_out->dim32(0), box_out->dim32(1));
 
-  int num_classes = deltas0.cols() / 4;
-  for (int k = 0; k < num_classes; k++) {
-    auto deltas = deltas0.block(0, k * 4, N, 4);
-    auto trans_boxes = utils::bbox_transform(
-        boxes,
-        deltas,
-        weights_,
-        utils::BBOX_XFORM_CLIP_DEFAULT,
-        correct_transform_coords_);
-    auto clip_boxes = utils::clip_boxes(trans_boxes, img_h, img_w);
-    new_boxes.block(0, k * 4, N, 4) = clip_boxes * scale_after;
+  // We assume roi_in and delta_in over multiple batches are grouped
+  // together in increasing order as generated by GenerateProposalsOp
+  int offset = 0;
+  for (int i = 0; i < batch_size; ++i) {
+    const int num_rois = num_rois_per_batch[i];
+    const auto& cur_iminfo = iminfo.row(i);
+    const float scale_before = cur_iminfo(2);
+    const float scale_after = apply_scale_ ? cur_iminfo(2) : 1.0;
+    int img_h = int(cur_iminfo(0) / scale_before + 0.5);
+    int img_w = int(cur_iminfo(1) / scale_before + 0.5);
+
+    const auto& cur_boxes =
+        boxes0.rightCols(4).block(offset, 0, num_rois, 4) / scale_before;
+    for (int k = 0; k < num_classes; k++) {
+      const auto& cur_deltas = deltas0.block(offset, k * 4, num_rois, 4);
+      const auto& trans_boxes = utils::bbox_transform(
+          cur_boxes,
+          cur_deltas,
+          weights_,
+          utils::BBOX_XFORM_CLIP_DEFAULT,
+          correct_transform_coords_);
+      const auto& clip_boxes = utils::clip_boxes(trans_boxes, img_h, img_w);
+      new_boxes.block(offset, k * 4, num_rois, 4) = clip_boxes * scale_after;
+    }
+
+    offset += num_rois;
+  }
+
+  if (OutputSize() > 1) {
+    auto* roi_batch_splits = Output(1);
+    roi_batch_splits->Resize(batch_size);
+    Eigen::Map<EArrXf> roi_batch_splits_map(
+        roi_batch_splits->mutable_data<float>(), batch_size);
+    roi_batch_splits_map =
+        Eigen::Map<const EArrXi>(num_rois_per_batch.data(), batch_size)
+            .cast<float>();
   }
 
   return true;
diff --git a/caffe2/operators/box_with_nms_limit_op.cc b/caffe2/operators/box_with_nms_limit_op.cc
index ad06025..9caadb3 100644
--- a/caffe2/operators/box_with_nms_limit_op.cc
+++ b/caffe2/operators/box_with_nms_limit_op.cc
@@ -51,142 +51,196 @@
   }
   CAFFE_ENFORCE(tboxes.template IsType<float>(), tboxes.meta().name());
 
+  int N = tscores.dim(0);
   int num_classes = tscores.dim(1);
 
-  CAFFE_ENFORCE_EQ(tscores.dim(0), tboxes.dim(0));
+  CAFFE_ENFORCE_EQ(N, tboxes.dim(0));
   CAFFE_ENFORCE_EQ(num_classes * 4, tboxes.dim(1));
 
-  Eigen::Map<const ERArrXXf> scores(
-      tscores.data<float>(), tscores.dim(0), tscores.dim(1));
-  Eigen::Map<const ERArrXXf> boxes(
-      tboxes.data<float>(), tboxes.dim(0), tboxes.dim(1));
+  int batch_size = 1;
+  vector<float> batch_splits_default(1, tscores.dim(0));
+  const float* batch_splits_data = batch_splits_default.data();
+  if (InputSize() > 2) {
+    // tscores and tboxes have items from multiple images in a batch. Get the
+    // corresponding batch splits from input.
+    const auto& tbatch_splits = Input(2);
+    CAFFE_ENFORCE_EQ(tbatch_splits.ndim(), 1);
+    batch_size = tbatch_splits.dim(0);
+    batch_splits_data = tbatch_splits.data<float>();
+  }
+  Eigen::Map<const EArrXf> batch_splits(batch_splits_data, batch_size);
+  CAFFE_ENFORCE_EQ(batch_splits.sum(), N);
 
-  // To store updated scores if SoftNMS is used
-  ERArrXXf soft_nms_scores(tscores.dim(0), tscores.dim(1));
+  out_scores->Resize(0);
+  out_boxes->Resize(0, 4);
+  out_classes->Resize(0);
 
-  vector<vector<int>> keeps(num_classes);
-
-  // Perform nms to each class
-  // skip j = 0, because it's the background class
-  int total_keep_count = 0;
-  for (int j = 1; j < num_classes; j++) {
-    auto cur_scores = scores.col(j);
-    auto inds = utils::GetArrayIndices(cur_scores > score_thres_);
-    auto cur_boxes = boxes.block(0, j * 4, boxes.rows(), 4);
-
-    if (soft_nms_enabled_) {
-      auto out_scores = soft_nms_scores.col(j);
-      keeps[j] = utils::soft_nms_cpu(
-          &out_scores,
-          cur_boxes,
-          cur_scores,
-          inds,
-          soft_nms_sigma_,
-          nms_thres_,
-          soft_nms_min_score_thres_,
-          soft_nms_method_);
-    } else {
-      std::sort(
-          inds.data(),
-          inds.data() + inds.size(),
-          [&cur_scores](int lhs, int rhs) {
-            return cur_scores(lhs) > cur_scores(rhs);
-          });
-      keeps[j] = utils::nms_cpu(cur_boxes, cur_scores, inds, nms_thres_);
-    }
-    total_keep_count += keeps[j].size();
+  TensorCPU* out_keeps = nullptr;
+  TensorCPU* out_keeps_size = nullptr;
+  if (OutputSize() > 4) {
+    out_keeps = Output(4);
+    out_keeps_size = Output(5);
+    out_keeps->Resize(0);
+    out_keeps_size->Resize(batch_size, num_classes);
   }
 
-  if (soft_nms_enabled_) {
-    // Re-map scores to the updated SoftNMS scores
-    new (&scores) Eigen::Map<const ERArrXXf>(
-        soft_nms_scores.data(), soft_nms_scores.rows(), soft_nms_scores.cols());
-  }
+  vector<int> total_keep_per_batch(batch_size);
+  int offset = 0;
+  for (int b = 0; b < batch_splits.size(); ++b) {
+    int num_boxes = batch_splits(b);
+    Eigen::Map<const ERArrXXf> scores(
+        tscores.data<float>() + offset * tscores.dim(1),
+        num_boxes,
+        tscores.dim(1));
+    Eigen::Map<const ERArrXXf> boxes(
+        tboxes.data<float>() + offset * tboxes.dim(1),
+        num_boxes,
+        tboxes.dim(1));
 
-  // Limit to max_per_image detections *over all classes*
-  if (detections_per_im_ > 0 && total_keep_count > detections_per_im_) {
-    // merge all scores together and sort
-    auto get_all_scores_sorted = [&scores, &keeps, total_keep_count]() {
-      EArrXf ret(total_keep_count);
+    // To store updated scores if SoftNMS is used
+    ERArrXXf soft_nms_scores(num_boxes, tscores.dim(1));
+    vector<vector<int>> keeps(num_classes);
 
-      int ret_idx = 0;
-      for (int i = 1; i < keeps.size(); i++) {
-        auto& cur_keep = keeps[i];
-        auto cur_scores = scores.col(i);
-        auto cur_ret = ret.segment(ret_idx, cur_keep.size());
-        utils::GetSubArray(cur_scores, utils::AsEArrXt(keeps[i]), &cur_ret);
-        ret_idx += cur_keep.size();
-      }
-
-      std::sort(ret.data(), ret.data() + ret.size());
-
-      return ret;
-    };
-
-    // Compute image thres based on all classes
-    auto all_scores_sorted = get_all_scores_sorted();
-    DCHECK_GT(all_scores_sorted.size(), detections_per_im_);
-    auto image_thresh =
-        all_scores_sorted[all_scores_sorted.size() - detections_per_im_];
-
-    total_keep_count = 0;
-    // filter results with image_thresh
+    // Perform nms to each class
+    // skip j = 0, because it's the background class
+    int total_keep_count = 0;
     for (int j = 1; j < num_classes; j++) {
-      auto& cur_keep = keeps[j];
       auto cur_scores = scores.col(j);
-      keeps[j] =
-          filter_with_indices(cur_scores, cur_keep, [&image_thresh](float sc) {
-            return sc >= image_thresh;
-          });
+      auto inds = utils::GetArrayIndices(cur_scores > score_thres_);
+      auto cur_boxes = boxes.block(0, j * 4, boxes.rows(), 4);
+
+      if (soft_nms_enabled_) {
+        auto cur_soft_nms_scores = soft_nms_scores.col(j);
+        keeps[j] = utils::soft_nms_cpu(
+            &cur_soft_nms_scores,
+            cur_boxes,
+            cur_scores,
+            inds,
+            soft_nms_sigma_,
+            nms_thres_,
+            soft_nms_min_score_thres_,
+            soft_nms_method_);
+      } else {
+        std::sort(
+            inds.data(),
+            inds.data() + inds.size(),
+            [&cur_scores](int lhs, int rhs) {
+              return cur_scores(lhs) > cur_scores(rhs);
+            });
+        keeps[j] = utils::nms_cpu(cur_boxes, cur_scores, inds, nms_thres_);
+      }
       total_keep_count += keeps[j].size();
     }
-  }
 
-  // Write results
-  out_scores->Resize(total_keep_count);
-  out_boxes->Resize(total_keep_count, 4);
-  out_classes->Resize(total_keep_count);
-  int cur_out_idx = 0;
-  for (int j = 1; j < num_classes; j++) {
-    auto cur_scores = scores.col(j);
-    auto cur_boxes = boxes.block(0, j * 4, boxes.rows(), 4);
-    auto& cur_keep = keeps[j];
-    Eigen::Map<EArrXf> cur_out_scores(
-        out_scores->mutable_data<float>() + cur_out_idx, cur_keep.size());
-    Eigen::Map<ERArrXXf> cur_out_boxes(
-        out_boxes->mutable_data<float>() + cur_out_idx * 4, cur_keep.size(), 4);
-    Eigen::Map<EArrXf> cur_out_classes(
-        out_classes->mutable_data<float>() + cur_out_idx, cur_keep.size());
-
-    utils::GetSubArray(cur_scores, utils::AsEArrXt(cur_keep), &cur_out_scores);
-    utils::GetSubArrayRows(
-        cur_boxes, utils::AsEArrXt(cur_keep), &cur_out_boxes);
-    for (int k = 0; k < cur_keep.size(); k++) {
-      cur_out_classes[k] = static_cast<float>(j);
+    if (soft_nms_enabled_) {
+      // Re-map scores to the updated SoftNMS scores
+      new (&scores) Eigen::Map<const ERArrXXf>(
+          soft_nms_scores.data(),
+          soft_nms_scores.rows(),
+          soft_nms_scores.cols());
     }
 
-    cur_out_idx += cur_keep.size();
+    // Limit to max_per_image detections *over all classes*
+    if (detections_per_im_ > 0 && total_keep_count > detections_per_im_) {
+      // merge all scores together and sort
+      auto get_all_scores_sorted = [&scores, &keeps, total_keep_count]() {
+        EArrXf ret(total_keep_count);
+
+        int ret_idx = 0;
+        for (int i = 1; i < keeps.size(); i++) {
+          auto& cur_keep = keeps[i];
+          auto cur_scores = scores.col(i);
+          auto cur_ret = ret.segment(ret_idx, cur_keep.size());
+          utils::GetSubArray(cur_scores, utils::AsEArrXt(keeps[i]), &cur_ret);
+          ret_idx += cur_keep.size();
+        }
+
+        std::sort(ret.data(), ret.data() + ret.size());
+
+        return ret;
+      };
+
+      // Compute image thres based on all classes
+      auto all_scores_sorted = get_all_scores_sorted();
+      DCHECK_GT(all_scores_sorted.size(), detections_per_im_);
+      auto image_thresh =
+          all_scores_sorted[all_scores_sorted.size() - detections_per_im_];
+
+      total_keep_count = 0;
+      // filter results with image_thresh
+      for (int j = 1; j < num_classes; j++) {
+        auto& cur_keep = keeps[j];
+        auto cur_scores = scores.col(j);
+        keeps[j] = filter_with_indices(
+            cur_scores, cur_keep, [&image_thresh](float sc) {
+              return sc >= image_thresh;
+            });
+        total_keep_count += keeps[j].size();
+      }
+    }
+    total_keep_per_batch[b] = total_keep_count;
+
+    // Write results
+    int cur_start_idx = out_scores->dim(0);
+    out_scores->Extend(total_keep_count, 50, &context_);
+    out_boxes->Extend(total_keep_count, 50, &context_);
+    out_classes->Extend(total_keep_count, 50, &context_);
+
+    int cur_out_idx = 0;
+    for (int j = 1; j < num_classes; j++) {
+      auto cur_scores = scores.col(j);
+      auto cur_boxes = boxes.block(0, j * 4, boxes.rows(), 4);
+      auto& cur_keep = keeps[j];
+      Eigen::Map<EArrXf> cur_out_scores(
+          out_scores->mutable_data<float>() + cur_start_idx + cur_out_idx,
+          cur_keep.size());
+      Eigen::Map<ERArrXXf> cur_out_boxes(
+          out_boxes->mutable_data<float>() + (cur_start_idx + cur_out_idx) * 4,
+          cur_keep.size(),
+          4);
+      Eigen::Map<EArrXf> cur_out_classes(
+          out_classes->mutable_data<float>() + cur_start_idx + cur_out_idx,
+          cur_keep.size());
+
+      utils::GetSubArray(
+          cur_scores, utils::AsEArrXt(cur_keep), &cur_out_scores);
+      utils::GetSubArrayRows(
+          cur_boxes, utils::AsEArrXt(cur_keep), &cur_out_boxes);
+      for (int k = 0; k < cur_keep.size(); k++) {
+        cur_out_classes[k] = static_cast<float>(j);
+      }
+
+      cur_out_idx += cur_keep.size();
+    }
+
+    if (out_keeps) {
+      out_keeps->Extend(total_keep_count, 50, &context_);
+
+      Eigen::Map<EArrXi> out_keeps_arr(
+          out_keeps->mutable_data<int>() + cur_start_idx, total_keep_count);
+      Eigen::Map<EArrXi> cur_out_keeps_size(
+          out_keeps_size->mutable_data<int>() + b * num_classes, num_classes);
+
+      cur_out_idx = 0;
+      for (int j = 0; j < num_classes; j++) {
+        out_keeps_arr.segment(cur_out_idx, keeps[j].size()) =
+            utils::AsEArrXt(keeps[j]);
+        cur_out_keeps_size[j] = keeps[j].size();
+        cur_out_idx += keeps[j].size();
+      }
+    }
+
+    offset += num_boxes;
   }
 
   if (OutputSize() > 3) {
-    auto* out_keeps = Output(3);
-    auto* out_keeps_size = Output(4);
-    out_keeps->Resize(total_keep_count);
-    out_keeps_size->Resize(num_classes);
-
-    Eigen::Map<EArrXi> cur_out_keeps_size(
-        out_keeps_size->mutable_data<int>(), num_classes);
-
-    cur_out_idx = 0;
-    Eigen::Map<EArrXi> out_keeps_arr(
-        out_keeps->mutable_data<int>(), total_keep_count);
-    for (int j = 0; j < num_classes; j++) {
-      out_keeps_arr.segment(cur_out_idx, keeps[j].size()) =
-          utils::AsEArrXt(keeps[j]);
-
-      cur_out_keeps_size[j] = keeps[j].size();
-      cur_out_idx += keeps[j].size();
-    }
+    auto* batch_splits_out = Output(3);
+    batch_splits_out->Resize(batch_size);
+    Eigen::Map<EArrXf> batch_splits_out_map(
+        batch_splits_out->mutable_data<float>(), batch_size);
+    batch_splits_out_map =
+        Eigen::Map<const EArrXi>(total_keep_per_batch.data(), batch_size)
+            .cast<float>();
   }
 
   return true;
@@ -203,8 +257,8 @@
 #endif // CAFFE2_HAS_MKL_DNN
 
 OPERATOR_SCHEMA(BoxWithNMSLimit)
-    .NumInputs(2)
-    .NumOutputs(3, 5)
+    .NumInputs(2, 3)
+    .NumOutputs(3, 6)
     .SetDoc(R"DOC(
 Apply NMS to each class (except background) and limit the number of
 returned boxes.
@@ -223,12 +277,22 @@
         1,
         "boxes",
         "Bounding box for each class, size (count, num_classes * 4)")
+    .Input(
+        2,
+        "batch_splits",
+        "Tensor of shape (batch_size) with each element denoting the number "
+        "of RoIs/boxes belonging to the corresponding image in batch. "
+        "Sum should add up to total count of scores/boxes.")
     .Output(0, "scores", "Filtered scores, size (n)")
     .Output(1, "boxes", "Filtered boxes, size (n, 4)")
     .Output(2, "classes", "Class id for each filtered score/box, size (n)")
-    .Output(3, "keeps", "Optional filtered indices, size (n)")
     .Output(
-        4,
+        3,
+        "batch_splits",
+        "Output batch splits for scores/boxes after applying NMS")
+    .Output(4, "keeps", "Optional filtered indices, size (n)")
+    .Output(
+        5,
         "keeps_size",
         "Optional number of filtered indices per class, size (num_classes)");
 
diff --git a/caffe2/operators/concat_split_op.cc b/caffe2/operators/concat_split_op.cc
index 05c4dff..7c11e03 100644
--- a/caffe2/operators/concat_split_op.cc
+++ b/caffe2/operators/concat_split_op.cc
@@ -60,7 +60,7 @@
   }
 
   struct OpSchema::Cost cost;
-  cost.flops = size;
+  cost.flops = 0;
   cost.bytes_moved = size * sizeof(float);
   cost.params_bytes = 0;
   return cost;
@@ -89,9 +89,53 @@
       vector<int> split_shape(1, in.size());
       vector<int> out_shape(in[0].dims().begin(), in[0].dims().end());
       if (add_axis) {
+        for (int i = 1; i < in.size(); ++i) {
+          CAFFE_ENFORCE_EQ(
+              in[0].dims().size(),
+              in[i].dims().size(),
+              "All inputs of Concat should have same dims when add_axis = 1. "
+              "Got different sizes for inputs 0 and ",
+              i);
+          for (int j = 0; j < in[0].dims().size(); ++j) {
+            CAFFE_ENFORCE_EQ(
+                in[0].dims(j),
+                in[i].dims(j),
+                "All inputs of Concat should have same dims when add_axis = 1. "
+                "Got different dims for inputs 0 and ",
+                i,
+                ". At dim: ",
+                j);
+          }
+        }
         out_shape.insert(out_shape.begin() + canonical_axis, in.size());
       } else {
         for (int i = 1; i < in.size(); ++i) {
+          CAFFE_ENFORCE_EQ(
+              in[0].dims().size(),
+              in[i].dims().size(),
+              "All inputs of Concat should have same dims except "
+              "canonical_axis dim that is equal to ",
+              canonical_axis,
+              "Got different sizes for inputs 0 and ",
+              i);
+          for (int j = 0; j < in[0].dims().size(); ++j) {
+            if (j == canonical_axis) {
+              continue;
+            }
+            CAFFE_ENFORCE_EQ(
+                in[0].dims(j),
+                in[i].dims(j),
+                "All inputs of Concat should have same dims except "
+                "canonical_axis dim that is equal to ",
+                canonical_axis,
+                "Got different dims for inputs 0 and ",
+                i,
+                ". At dim: ",
+                j);
+          }
+        }
+
+        for (int i = 1; i < in.size(); ++i) {
           out_shape[canonical_axis] += in[i].dims(canonical_axis);
         }
       }
diff --git a/caffe2/operators/one_hot_ops.cc b/caffe2/operators/one_hot_ops.cc
index 614a954..aa7161c 100644
--- a/caffe2/operators/one_hot_ops.cc
+++ b/caffe2/operators/one_hot_ops.cc
@@ -34,10 +34,14 @@
 
   const auto* lens_data = lens.template data<int32_t>();
   TIndex output_dim = 0;
+  valsOffsets_.resize(D + 1);
   for (TIndex i = 0; i < D; i++) {
     CAFFE_ENFORCE_GE(lens_data[i], 0);
+    valsOffsets_[i] = output_dim;
     output_dim += lens_data[i];
   }
+  valsOffsets_[D] = output_dim;
+
   CAFFE_ENFORCE_EQ(vals.size(), output_dim);
   auto* output = Output(ONE_HOT);
   output->Resize(N, output_dim);
@@ -45,21 +49,17 @@
   const auto* input_data = input.template data<T>();
   const auto* vals_data = vals.template data<T>();
   auto* output_data = output->template mutable_data<T>();
-  // eigen is column-major
-  auto input_m = ConstEigenMatrixMap<T>(input_data, D, N);
-  auto output_m = EigenMatrixMap<T>(output_data, output_dim, N);
 
-  // `p` is the column position in output_data, that points to the next
-  // column to be filled.
-  TIndex p = 0;
-  // one-hot encoding for each example.
-  for (TIndex j = 0; j < D; j++) {
-    for (TIndex t = 0; t < lens_data[j]; t++) {
-      output_m.row(p) =
-          input_m.row(j).cwiseEqual(vals_data[p]).template cast<T>();
-      p++;
+  for (TIndex i = 0; i < N; ++i) {
+    for (TIndex j = 0; j < D; j++) {
+      const auto input_val = input_data[i * D + j];
+      for (TIndex k = valsOffsets_[j]; k < valsOffsets_[j + 1]; ++k) {
+        output_data[k] = vals_data[k] == input_val;
+      }
     }
+    output_data += output_dim;
   }
+
   return true;
 }
 
diff --git a/caffe2/operators/one_hot_ops.h b/caffe2/operators/one_hot_ops.h
index 42d70dc..84bd976 100644
--- a/caffe2/operators/one_hot_ops.h
+++ b/caffe2/operators/one_hot_ops.h
@@ -84,6 +84,10 @@
  protected:
   INPUT_TAGS(X, LENS, VALS);
   OUTPUT_TAGS(ONE_HOT);
+
+ private:
+  // allows for fast random access to a given dict and is re-used across runs
+  std::vector<TIndex> valsOffsets_;
 };
 
 template <class Context>
diff --git a/caffe2/operators/relu_op.cc b/caffe2/operators/relu_op.cc
index 759cfdf..0206b6f 100644
--- a/caffe2/operators/relu_op.cc
+++ b/caffe2/operators/relu_op.cc
@@ -71,7 +71,7 @@
 OpSchema::Cost CostInferenceForRelu(
     const OperatorDef& def,
     const vector<TensorShape>& in) {
-  struct OpSchema::Cost cost = PointwiseCostInference<2>(def, in);
+  struct OpSchema::Cost cost = PointwiseCostInference<0>(def, in);
   if (def.input(0) == def.output(0)) {
     cost.bytes_moved = 0;
   }
diff --git a/caffe2/operators/roi_align_op.cc b/caffe2/operators/roi_align_op.cc
index 8f5d12e..29f00ee 100644
--- a/caffe2/operators/roi_align_op.cc
+++ b/caffe2/operators/roi_align_op.cc
@@ -364,9 +364,11 @@
     .Input(
         1,
         "RoIs",
-        "2D input of shape (R, 5) specifying R RoIs with five columns "
+        "2D input of shape (R, 4 or 5) specifying R RoIs "
         "representing: batch index in [0, N - 1], x1, y1, x2, y2. The RoI "
-        "coordinates are in the coordinate system of the input image.")
+        "coordinates are in the coordinate system of the input image. For "
+        "inputs corresponding to a single image, batch index can be excluded "
+        "to have just 4 columns.")
     .Output(
         0,
         "Y",
diff --git a/caffe2/operators/spatial_batch_norm_op.cc b/caffe2/operators/spatial_batch_norm_op.cc
index bcb1aa8..dcd25cd 100644
--- a/caffe2/operators/spatial_batch_norm_op.cc
+++ b/caffe2/operators/spatial_batch_norm_op.cc
@@ -179,7 +179,7 @@
 OpSchema::Cost CostInferenceForSpatialBN(
     const OperatorDef& def,
     const vector<TensorShape>& in) {
-  struct OpSchema::Cost cost = PointwiseCostInference<2>(def, in);
+  struct OpSchema::Cost cost = PointwiseCostInference<4>(def, in);
   ArgumentHelper helper(def);
   auto order =
       StringToStorageOrder(helper.GetSingleArgument<string>("order", "NCHW"));
diff --git a/caffe2/python/operator_test/bbox_transform_test.py b/caffe2/python/operator_test/bbox_transform_test.py
new file mode 100644
index 0000000..a930cb4
--- /dev/null
+++ b/caffe2/python/operator_test/bbox_transform_test.py
@@ -0,0 +1,209 @@
+# Copyright (c) 2016-present, Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+##############################################################################
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+from caffe2.python import core
+from hypothesis import given
+import caffe2.python.hypothesis_test_util as hu
+import hypothesis.strategies as st
+import numpy as np
+
+
+# Reference implementation from detectron/lib/utils/boxes.py
+def bbox_transform(boxes, deltas, weights=(1.0, 1.0, 1.0, 1.0)):
+    """Forward transform that maps proposal boxes to predicted ground-truth
+    boxes using bounding-box regression deltas. See bbox_transform_inv for a
+    description of the weights argument.
+    """
+    if boxes.shape[0] == 0:
+        return np.zeros((0, deltas.shape[1]), dtype=deltas.dtype)
+
+    boxes = boxes.astype(deltas.dtype, copy=False)
+
+    widths = boxes[:, 2] - boxes[:, 0] + 1.0
+    heights = boxes[:, 3] - boxes[:, 1] + 1.0
+    ctr_x = boxes[:, 0] + 0.5 * widths
+    ctr_y = boxes[:, 1] + 0.5 * heights
+
+    wx, wy, ww, wh = weights
+    dx = deltas[:, 0::4] / wx
+    dy = deltas[:, 1::4] / wy
+    dw = deltas[:, 2::4] / ww
+    dh = deltas[:, 3::4] / wh
+
+    # Prevent sending too large values into np.exp()
+    BBOX_XFORM_CLIP = np.log(1000. / 16.)
+    dw = np.minimum(dw, BBOX_XFORM_CLIP)
+    dh = np.minimum(dh, BBOX_XFORM_CLIP)
+
+    pred_ctr_x = dx * widths[:, np.newaxis] + ctr_x[:, np.newaxis]
+    pred_ctr_y = dy * heights[:, np.newaxis] + ctr_y[:, np.newaxis]
+    pred_w = np.exp(dw) * widths[:, np.newaxis]
+    pred_h = np.exp(dh) * heights[:, np.newaxis]
+
+    pred_boxes = np.zeros(deltas.shape, dtype=deltas.dtype)
+    # x1
+    pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w
+    # y1
+    pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h
+    # x2 (note: "- 1" is correct; don't be fooled by the asymmetry)
+    pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w - 1
+    # y2 (note: "- 1" is correct; don't be fooled by the asymmetry)
+    pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h - 1
+
+    return pred_boxes
+
+
+# Reference implementation from detectron/lib/utils/boxes.py
+def clip_tiled_boxes(boxes, im_shape):
+    """Clip boxes to image boundaries. im_shape is [height, width] and boxes
+    has shape (N, 4 * num_tiled_boxes)."""
+    assert boxes.shape[1] % 4 == 0, \
+        'boxes.shape[1] is {:d}, but must be divisible by 4.'.format(
+        boxes.shape[1]
+    )
+    # x1 >= 0
+    boxes[:, 0::4] = np.maximum(np.minimum(boxes[:, 0::4], im_shape[1] - 1), 0)
+    # y1 >= 0
+    boxes[:, 1::4] = np.maximum(np.minimum(boxes[:, 1::4], im_shape[0] - 1), 0)
+    # x2 < im_shape[1]
+    boxes[:, 2::4] = np.maximum(np.minimum(boxes[:, 2::4], im_shape[1] - 1), 0)
+    # y2 < im_shape[0]
+    boxes[:, 3::4] = np.maximum(np.minimum(boxes[:, 3::4], im_shape[0] - 1), 0)
+    return boxes
+
+
+def generate_rois(roi_counts, im_dims):
+    assert len(roi_counts) == len(im_dims)
+    all_rois = []
+    for i, num_rois in enumerate(roi_counts):
+        if num_rois == 0:
+            continue
+        # [batch_idx, x1, y1, x2, y2]
+        rois = np.random.uniform(
+            0, im_dims[i], size=(roi_counts[i], 5)
+        ).astype(np.float32)
+        rois[:, 0] = i  # batch_idx
+        # Swap (x1, x2) if x1 > x2
+        rois[:, 1], rois[:, 3] = np.minimum(rois[:, 1], rois[:, 3]), \
+                np.maximum(rois[:, 1], rois[:, 3])
+        # Swap (y1, y2) if y1 > y2
+        rois[:, 2], rois[:, 4] = np.minimum(rois[:, 2], rois[:, 4]), \
+                np.maximum(rois[:, 2], rois[:, 4])
+        all_rois.append(rois)
+    if len(all_rois) > 0:
+        return np.vstack(all_rois)
+    return np.empty((0, 5)).astype(np.float32)
+
+
+class TestBBoxTransformOp(hu.HypothesisTestCase):
+    @given(
+        num_rois=st.integers(1, 10),
+        num_classes=st.integers(1, 10),
+        im_dim=st.integers(100, 600),
+        skip_batch_id=st.booleans(),
+        **hu.gcs_cpu_only
+    )
+    def test_bbox_transform(
+        self, num_rois, num_classes, im_dim, skip_batch_id, gc, dc
+    ):
+        """
+        Test with all rois belonging to a single image per run.
+        """
+        rois = generate_rois([num_rois], [im_dim])
+        if skip_batch_id:
+            rois = rois[:, 1:5]
+        deltas = np.random.randn(num_rois, 4 * num_classes).astype(np.float32)
+        im_info = np.array([im_dim, im_dim,
+                            1.0]).astype(np.float32).reshape(1, 3)
+
+        def bbox_transform_ref(rois, deltas, im_info):
+            boxes = rois if rois.shape[1] == 4 else rois[:, 1:5]
+            box_out = bbox_transform(boxes, deltas)
+            im_shape = im_info[0, 0:2]
+            box_out = clip_tiled_boxes(box_out, im_shape)
+            return [box_out]
+
+        op = core.CreateOperator(
+            "BBoxTransform",
+            ["rois", "deltas", "im_info"],
+            ["box_out"],
+            apply_scale=False,
+            correct_transform_coords=True,
+        )
+
+        self.assertReferenceChecks(
+            device_option=gc,
+            op=op,
+            inputs=[rois, deltas, im_info],
+            reference=bbox_transform_ref,
+        )
+
+    @given(
+        roi_counts=st.lists(st.integers(0, 5), min_size=1, max_size=10),
+        num_classes=st.integers(1, 10),
+        **hu.gcs_cpu_only
+    )
+    def test_bbox_transform_batch(self, roi_counts, num_classes, gc, dc):
+        """
+        Test with rois for multiple images in a batch
+        """
+        batch_size = len(roi_counts)
+        total_rois = sum(roi_counts)
+        im_dims = np.random.randint(100, 600, batch_size)
+        rois = generate_rois(roi_counts, im_dims)
+        deltas = np.random.randn(total_rois, 4 * num_classes).astype(np.float32)
+        im_info = np.zeros((batch_size, 3)).astype(np.float32)
+        im_info[:, 0] = im_dims
+        im_info[:, 1] = im_dims
+        im_info[:, 2] = 1.0
+
+        def bbox_transform_ref(rois, deltas, im_info):
+            box_out = []
+            offset = 0
+            for i, num_rois in enumerate(roi_counts):
+                if num_rois == 0:
+                    continue
+                cur_boxes = rois[offset:offset + num_rois, 1:5]
+                cur_deltas = deltas[offset:offset + num_rois]
+                cur_box_out = bbox_transform(cur_boxes, cur_deltas)
+                im_shape = im_info[i, 0:2]
+                cur_box_out = clip_tiled_boxes(cur_box_out, im_shape)
+                box_out.append(cur_box_out)
+                offset += num_rois
+
+            if len(box_out) > 0:
+                box_out = np.vstack(box_out)
+            else:
+                box_out = np.empty(deltas.shape).astype(np.float32)
+            return [box_out, roi_counts]
+
+        op = core.CreateOperator(
+            "BBoxTransform",
+            ["rois", "deltas", "im_info"],
+            ["box_out", "roi_batch_splits"],
+            apply_scale=False,
+            correct_transform_coords=True,
+        )
+
+        self.assertReferenceChecks(
+            device_option=gc,
+            op=op,
+            inputs=[rois, deltas, im_info],
+            reference=bbox_transform_ref,
+        )
diff --git a/caffe2/python/predictor/mobile_exporter.py b/caffe2/python/predictor/mobile_exporter.py
index ee93b50..a5c8211 100644
--- a/caffe2/python/predictor/mobile_exporter.py
+++ b/caffe2/python/predictor/mobile_exporter.py
@@ -22,6 +22,38 @@
 from __future__ import unicode_literals
 from caffe2.python import core, utils
 from caffe2.proto import caffe2_pb2
+import numpy as np
+
+
+def add_tensor(net, name, blob):
+    ''' Create an operator to store the tensor 'blob',
+        run the operator to put the blob to workspace.
+        uint8 is stored as an array of string with one element.
+    '''
+    kTypeNameMapper = {
+        np.dtype('float32'): "GivenTensorFill",
+        np.dtype('int32'): "GivenTensorIntFill",
+        np.dtype('int64'): "GivenTensorInt64Fill",
+        np.dtype('uint8'): "GivenTensorStringFill",
+    }
+
+    shape = blob.shape
+    values = blob
+    # pass array of uint8 as a string to save storage
+    # storing uint8_t has a large overhead for now
+    if blob.dtype == np.dtype('uint8'):
+        shape = [1]
+        values = [str(blob.data)]
+
+    op = core.CreateOperator(
+        kTypeNameMapper[blob.dtype],
+        [], [name],
+        arg=[
+            utils.MakeArgument("shape", shape),
+            utils.MakeArgument("values", values),
+        ]
+    )
+    net.op.extend([op])
 
 
 def Export(workspace, net, params):
@@ -49,17 +81,7 @@
     for blob_ref in params:
         blob_name = str(blob_ref)
         blob = workspace.FetchBlob(blob_name)
-        init_net.op.extend(
-            [
-                core.CreateOperator(
-                    "GivenTensorFill", [], [blob_name],
-                    arg=[
-                        utils.MakeArgument("shape", blob.shape),
-                        utils.MakeArgument("values", blob)
-                    ]
-                )
-            ]
-        )
+        add_tensor(init_net, blob_name, blob)
     # We have to make sure the blob exists in the namespace
     # and we can do so with fake data. (Which is immediately overwritten
     # by any typical usage)
diff --git a/caffe2/python/predictor/mobile_exporter_test.py b/caffe2/python/predictor/mobile_exporter_test.py
index 1707431..0ae91da 100644
--- a/caffe2/python/predictor/mobile_exporter_test.py
+++ b/caffe2/python/predictor/mobile_exporter_test.py
@@ -83,3 +83,52 @@
         np.testing.assert_allclose(
             ref_out, predictor_out, atol=1e-10, rtol=1e-10
         )
+
+    def test_mobile_exporter_datatypes(self):
+        model = ModelHelper(name="mobile_exporter_test_model")
+        model.Copy("data_int", "out")
+        model.params.append("data_int")
+
+        # Create our mobile exportable networks
+        workspace.RunNetOnce(model.param_init_net)
+        np_data_int = np.random.randint(100, size=(1, 1, 28, 28), dtype=np.int32)
+        workspace.FeedBlob("data_int", np_data_int)
+
+        init_net, predict_net = mobile_exporter.Export(
+            workspace, model.net, model.params
+        )
+
+        workspace.CreateNet(model.net)
+        workspace.RunNet(model.net)
+        ref_out = workspace.FetchBlob("out")
+
+        # Clear the workspace
+        workspace.ResetWorkspace()
+
+        # Populate the workspace with data
+        workspace.RunNetOnce(init_net)
+
+        # Overwrite the old net
+        workspace.CreateNet(predict_net, True)
+        workspace.RunNet(predict_net.name)
+        manual_run_out = workspace.FetchBlob("out")
+        np.testing.assert_allclose(
+            ref_out, manual_run_out, atol=1e-10, rtol=1e-10
+        )
+
+        # Clear the workspace
+        workspace.ResetWorkspace()
+
+        # Predictor interface test (simulates writing to disk)
+        predictor = workspace.Predictor(
+            init_net.SerializeToString(), predict_net.SerializeToString()
+        )
+
+        # Output is a vector of outputs but we only care about the first and only result
+        predictor_out = predictor.run([])
+        assert len(predictor_out) == 1
+        predictor_out = predictor_out[0]
+
+        np.testing.assert_allclose(
+            ref_out, predictor_out, atol=1e-10, rtol=1e-10
+        )