Fix pool op custom path issue 2, wrongful routing to global pooling

Summary:
In D5681122 - when routing to global maxpool and average pool, the condition is not correct.
see T24876217 for discussion

Reviewed By: Yangqing

Differential Revision: D6665466

fbshipit-source-id: dcb5b4686249e6ee8e1e976ab66b003ef09b32fd
diff --git a/caffe2/operators/conv_pool_op_base.h b/caffe2/operators/conv_pool_op_base.h
index 4d36f98..ec80b06 100644
--- a/caffe2/operators/conv_pool_op_base.h
+++ b/caffe2/operators/conv_pool_op_base.h
@@ -153,7 +153,7 @@
         CAFFE_ENFORCE(
             pads_[2 * dim] == 0 && pads_[2 * dim + 1] == 0 &&
                 dilation_[dim] == 1 && stride_[dim] == 1,
-            "If global_pooling is set dilation and stride shouldn't be set.");
+            "If global_pooling is set pad, dilation and stride shouldn't be set.");
       }
     }
 
@@ -322,7 +322,6 @@
         output_dims.push_back(dim_size);
       }
     }
-
   }
 
   // ComputePads could be used in backward functions to figure out the padding
diff --git a/caffe2/operators/pool_op_cudnn.cu b/caffe2/operators/pool_op_cudnn.cu
index f7b867f..d4e9a87 100644
--- a/caffe2/operators/pool_op_cudnn.cu
+++ b/caffe2/operators/pool_op_cudnn.cu
@@ -194,8 +194,8 @@
 
     // Fast path for global pooling, as cudnn is slow. But only
     // on float, because fp16 not supported for CUB.
-    if (sizeof(T) == 4) {
-      if (order_ == StorageOrder::NCHW && Y->size() == N * C) {
+    if (std::is_same<T, float>::value) {
+      if (order_ == StorageOrder::NCHW && global_pooling_) {
         if (mode_ == CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING) {
           global_avgpool_kernel_NCHW<float>
               <<<std::min(N * C, CAFFE_MAXIMUM_NUM_BLOCKS),
@@ -275,9 +275,9 @@
     auto* Y = Output(0);
 
     if (X.IsType<float>()) {
-      return DoRunWithType<float,float>();
+      return DoRunWithType<float, float>();
     } else if (X.IsType<float16>()) {
-      return DoRunWithType<float16,float>();
+      return DoRunWithType<float16, float>();
     } else {
       LOG(FATAL) << "Unsupported input types";
     }
@@ -292,6 +292,7 @@
   cudnnTensorDescriptor_t top_desc_;
   cudnnPoolingDescriptor_t pooling_desc_;
   cudnnPoolingMode_t mode_;
+
  private:
 };
 
@@ -347,34 +348,34 @@
     int N = 0, C = 0, H = 0, W = 0, D = 0;
     int H_out = 0, W_out = 0, D_out = 0;
     switch (order_) {
-    case StorageOrder::NHWC:
-      N = X.dim32(0);
-      H = X.dim32(1);
-      W = X.ndim() > 3 ? X.dim32(2) : 1;
-      D = X.ndim() > 4 ? X.dim32(3) : 1;
-      C = X.dim32(X.ndim() - 1);
-      H_out = Y.dim32(1);
-      W_out = Y.ndim() > 3 ? Y.dim32(2) : 1;
-      D_out = Y.ndim() > 4 ? Y.dim32(3) : 1;
-      break;
-    case StorageOrder::NCHW:
-      N = X.dim32(0);
-      C = X.dim32(1);
-      H = X.dim32(2);
-      W = X.ndim() > 3 ? X.dim32(3) : 1;
-      D = X.ndim() > 4 ? X.dim32(4) : 1;
-      H_out = Y.dim32(2);
-      W_out = Y.ndim() > 3 ? Y.dim32(3) : 1;
-      D_out = Y.ndim() > 4 ? Y.dim32(4) : 1;
-      break;
-    default:
-      LOG(FATAL) << "Unknown storage order: " << order_;
+      case StorageOrder::NHWC:
+        N = X.dim32(0);
+        H = X.dim32(1);
+        W = X.ndim() > 3 ? X.dim32(2) : 1;
+        D = X.ndim() > 4 ? X.dim32(3) : 1;
+        C = X.dim32(X.ndim() - 1);
+        H_out = Y.dim32(1);
+        W_out = Y.ndim() > 3 ? Y.dim32(2) : 1;
+        D_out = Y.ndim() > 4 ? Y.dim32(3) : 1;
+        break;
+      case StorageOrder::NCHW:
+        N = X.dim32(0);
+        C = X.dim32(1);
+        H = X.dim32(2);
+        W = X.ndim() > 3 ? X.dim32(3) : 1;
+        D = X.ndim() > 4 ? X.dim32(4) : 1;
+        H_out = Y.dim32(2);
+        W_out = Y.ndim() > 3 ? Y.dim32(3) : 1;
+        D_out = Y.ndim() > 4 ? Y.dim32(4) : 1;
+        break;
+      default:
+        LOG(FATAL) << "Unknown storage order: " << order_;
     }
 
     // Fast path for global pooling, as cudnn is slow. But only
     // on float, because fp16 not supported for CUB.
-    if (sizeof(T) == 4) {
-      if (order_ == StorageOrder::NCHW && dY.size() == N * C) {
+    if (std::is_same<T, float>::value) {
+      if (order_ == StorageOrder::NCHW && global_pooling_) {
         if (mode_ == CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING) {
           global_avgpool_backward_NCHW<float>
               <<<CAFFE_GET_BLOCKS(dX->size()),
@@ -487,9 +488,9 @@
     dX->ResizeLike(X);
 
     if (X.IsType<float>()) {
-      return DoRunWithType<float,float>();
+      return DoRunWithType<float, float>();
     } else if (X.IsType<float16>()) {
-      return DoRunWithType<float16,float>();
+      return DoRunWithType<float16, float>();
     } else {
       LOG(FATAL) << "Unsupported input types";
     }
@@ -530,5 +531,5 @@
 
 REGISTER_CUDNN_OPERATOR(MaxPool3D, CuDNNPoolOp);
 REGISTER_CUDNN_OPERATOR(MaxPool3DGradient, CuDNNPoolGradientOp);
-}  // namespace
-}  // namespace caffe2
+} // namespace
+} // namespace caffe2
diff --git a/caffe2/python/models/resnet.py b/caffe2/python/models/resnet.py
index 822b185..8106af9 100644
--- a/caffe2/python/models/resnet.py
+++ b/caffe2/python/models/resnet.py
@@ -292,6 +292,7 @@
         'final_avg',
         kernel=final_avg_kernel,
         stride=1,
+        global_pooling=True,
     )
 
     # Final dimension of the "image" is reduced to 7x7
diff --git a/caffe2/python/operator_test/pooling_test.py b/caffe2/python/operator_test/pooling_test.py
index b2b96c9..6203c78 100644
--- a/caffe2/python/operator_test/pooling_test.py
+++ b/caffe2/python/operator_test/pooling_test.py
@@ -128,7 +128,7 @@
 
     @given(stride=st.integers(1, 3),
            pad=st.integers(0, 2),
-           kernel=st.integers(1, 3),
+           kernel=st.integers(1, 6),
            size=st.integers(3, 5),
            input_channels=st.integers(1, 3),
            batch_size=st.integers(1, 3),
@@ -140,6 +140,10 @@
     def test_pooling_3d(self, stride, pad, kernel, size, input_channels,
                         batch_size, order, op_type, engine, gc, dc):
         assume(pad < kernel)
+        assume(size + pad + pad >= kernel)
+        # some case here could be calculated with global pooling, but instead
+        # calculated with general implementation, slower but should still
+        # be corect.
         op = core.CreateOperator(
             op_type,
             ["X"],
@@ -155,9 +159,82 @@
         if order == "NCHW":
             X = X.transpose((0, 4, 1, 2, 3))
 
-        self.assertDeviceChecks(dc, op, [X], [0])
+        self.assertDeviceChecks(dc, op, [X], [0], threshold=0.001)
         if 'MaxPool' not in op_type:
-            self.assertGradientChecks(gc, op, [X], 0, [0])
+            self.assertDeviceChecks(dc, op, [X], [0], threshold=0.001)
+
+    @given(stride=st.integers(1, 3),
+           pad=st.integers(0, 2),
+           kernel=st.integers(1, 6),
+           size=st.integers(3, 5),
+           input_channels=st.integers(1, 3),
+           batch_size=st.integers(1, 3),
+           order=st.sampled_from(["NCHW", "NHWC"]),
+           op_type=st.sampled_from(["MaxPool", "AveragePool",
+                                    "MaxPool3D", "AveragePool3D"]),
+           engine=st.sampled_from(["", "CUDNN"]),
+           **hu.gcs)
+    def test_global_pooling_3d(self, stride, pad, kernel, size, input_channels,
+                               batch_size, order, op_type, engine, gc, dc):
+        assume(pad < kernel)
+        assume(size + pad + pad >= kernel)
+        # Used to determine if we can use global pooling for average or max pooling
+        # the assumptions here are:
+        # 1. kernel can be greater than input dim, but always smaller than dim + pads
+        #    on both sides, ie.
+        #         dim.H + pad_t + pad_b >= kernel.H
+        #         dim.W + pad_l + pad_r >= kernel.W
+        #         dim.D + pad_f + pad_e >= kernel.D         (f = front e = end)
+        # 2. padding applied to both sides of the input dim
+        # 3. pooling are applied by first align kernel with one side of padding, then
+        #    shifting kernel for a stride distance towards the other side of padding
+        # 4. kernel continue shifts by stride distance until when one more stride is
+        #    applied, kernel will go beyond input dim plus padding.
+        # So it is possible if stride value is large, some input dim elements will
+        # not be covered. consider these cases:
+        #
+        # case 1:
+        # kernel = 4, dim = 3, pad_l = 2, pad_r = 2, stride = 4
+        # when kernel is applied for the first time, pad_l and dim upto 2
+        # is covered then we have 1 unit left of dim and pad_r not covered, but
+        # because stride is 4, shift kernel by 4 will go beyond pad_r, we should not
+        # apply another kernel, the out_size will be 1, and some element (last of
+        # dim) is ignored, therefore we can not use global pooling
+        #
+        # case 2:
+        # k = 4, dim = 3, pad_l = 1, pad_r = 2, stride = 1
+        # after kernel applied first time, pad_l and dim and 1st pad_r element all
+        # covered, shift kernel by stride move it to the end of pad_r, covering dim +
+        # pad_r, not beyond pad_r, so we should apply the kernel for a second time.
+        # out_size = 2 and we should not use global pooling either because dim is
+        # covered twice.
+        #
+        # case 3:
+        # k = 4, dim = 3, pad_l = 1, pad_r = 1, stride = 2
+        # first kernel apply cover all dim, but can not shift by stride because
+        # kernel go beyond pad_r so kernel is only applied once and cover entire dim
+        # this is the only case we can use global pooling.
+        #
+        # Summary: use global pooling when all dim is covered and only covered once
+        assume(kernel >= size)
+        assume(kernel + stride > size + pad + pad)
+        op = core.CreateOperator(
+            op_type,
+            ["X"],
+            ["Y"],
+            kernels=[kernel] * 3,
+            order=order,
+            global_pooling=True,
+            engine=engine,
+        )
+        X = np.random.rand(
+            batch_size, size, size, size, input_channels).astype(np.float32)
+        if order == "NCHW":
+            X = X.transpose((0, 4, 1, 2, 3))
+
+        self.assertDeviceChecks(dc, op, [X], [0], threshold=0.001)
+        if 'MaxPool' not in op_type:
+            self.assertDeviceChecks(dc, op, [X], [0], threshold=0.001)
 
     @unittest.skipIf(not workspace.has_gpu_support, "No GPU support")
     @given(stride=st.integers(1, 3),