Add aten mkldnn ops: relu, max_pool2d and avg_pool2d

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19205

Reviewed By: dzhulgakov

Differential Revision: D14850598

fbshipit-source-id: 5bbd5909c06df9c980de680ffb81bf772766c0ba
diff --git a/aten/src/ATen/native/Pooling.cpp b/aten/src/ATen/native/Pooling.cpp
index 428bab4..9e267ff 100644
--- a/aten/src/ATen/native/Pooling.cpp
+++ b/aten/src/ATen/native/Pooling.cpp
@@ -114,6 +114,10 @@
     IntArrayRef padding,
     IntArrayRef dilation,
     bool ceil_mode) {
+  if (self.is_mkldnn()) {
+    return at::mkldnn_max_pool2d(
+        self, kernel_size, stride, padding, dilation, ceil_mode);
+  }
   auto output_and_indices = at::max_pool2d_with_indices(
       self, kernel_size, stride, padding, dilation, ceil_mode);
   return std::get<0>(output_and_indices);
diff --git a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp
index 17d336c..1b83f8a 100644
--- a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp
+++ b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp
@@ -43,7 +43,6 @@
 Tensor new_with_itensor_mkldnn(ideep::tensor&& it, const TensorOptions& options) {
   // NOTE: int32_t dims from ideep::tensor but sizes needs int64_t
   // TODO: support int64_t dims in ideep::tensor to avoid extra conversion
-  AT_ASSERT(!it.has_extra());
   auto dims = it.get_dims();
   IDeepTensorWrapperPtr handle = c10::make_intrusive<IDeepTensorWrapper>(std::move(it));
   return detail::make_tensor<MKLDNNTensorImpl>(
diff --git a/aten/src/ATen/native/mkldnn/MKLDNNCommon.h b/aten/src/ATen/native/mkldnn/MKLDNNCommon.h
index 9fafeac..bee406c 100644
--- a/aten/src/ATen/native/mkldnn/MKLDNNCommon.h
+++ b/aten/src/ATen/native/mkldnn/MKLDNNCommon.h
@@ -23,10 +23,10 @@
   }
 };
 
-// Construct MKL-DNN tensor given an ideep tensor
+// Construct aten MKL-DNN tensor given an ideep tensor
 Tensor new_with_itensor_mkldnn(ideep::tensor&& it, const TensorOptions& options);
 
-// Construct MKL-DNN tensor given `sizes` for allocation
+// Construct aten MKL-DNN tensor given `sizes` for allocation
 Tensor new_with_sizes_mkldnn(IntArrayRef sizes, const TensorOptions& options);
 
 // Retrieve `ideep::tensor` from MKL-DNN tensor
diff --git a/aten/src/ATen/native/mkldnn/Pooling.cpp b/aten/src/ATen/native/mkldnn/Pooling.cpp
new file mode 100644
index 0000000..f6ddaac
--- /dev/null
+++ b/aten/src/ATen/native/mkldnn/Pooling.cpp
@@ -0,0 +1,142 @@
+#include <ATen/ATen.h>
+#include <ATen/Config.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/native/utils/ParamUtils.h>
+#include <tuple>
+
+
+#if !AT_MKLDNN_ENABLED()
+
+namespace at {
+namespace native {
+
+Tensor mkldnn_max_pool2d(
+    const Tensor& self,
+    IntArrayRef kernel_size,
+    IntArrayRef stride,
+    IntArrayRef padding,
+    IntArrayRef dilation,
+    bool ceil_mode) {
+  AT_ERROR(
+      "mkldnn_max_pool2d: ATen not compiled with MKLDNN support");
+}
+
+Tensor mkldnn_avg_pool2d(
+    const Tensor& self,
+    IntArrayRef kernel_size,
+    IntArrayRef stride,
+    IntArrayRef padding,
+    bool ceil_mode,
+    bool count_include_pad) {
+  AT_ERROR("mkldnn_avg_pool2d: ATen not compiled with MKLDNN support");
+}
+
+Tensor& mkldnn_avg_pool2d_out(
+    Tensor& output,
+    const Tensor& self,
+    IntArrayRef kernel_size,
+    IntArrayRef stride,
+    IntArrayRef padding,
+    bool ceil_mode,
+    bool count_include_pad) {
+  AT_ERROR("mkldnn_avg_pool2d_out: ATen not compiled with MKLDNN support");
+}
+} // namespace native
+} // namespace at
+
+#else // AT_MKLDNN_ENABLED
+
+#include <ATen/native/mkldnn/MKLDNNCommon.h>
+#include <ATen/native/mkldnn/Utils.h>
+
+namespace at {
+namespace native {
+
+static Tensor _mkldnn_pool2d(
+    const Tensor& input,
+    IntArrayRef kernel_size,
+    IntArrayRef stride,
+    IntArrayRef padding,
+    IntArrayRef dilation,
+    bool ceil_mode,
+    ideep::algorithm algo) {
+  AT_CHECK(!ceil_mode, "Currently Mkldnn Pooling operators do not support ceil_mode.");
+  auto kernel_size_vec = expand_param_if_needed(kernel_size, "kernel_size", 2);
+  auto stride_vec = expand_param_if_needed(stride, "stride", 2);
+  auto padding_vec = expand_param_if_needed(padding, "padding", 2);
+  auto dilation_vec = expand_param_if_needed(dilation, "dilation", 2);
+
+  const ideep::tensor& x = itensor_from_mkldnn(input);
+  const std::vector<int64_t> output_sizes = pool_output_sizes(
+      input.sizes(),
+      kernel_size_vec,
+      stride_vec,
+      padding_vec,
+      dilation_vec,
+      ceil_mode);
+  ideep::tensor y;
+  ideep::pooling_forward::compute<AllocForMKLDNN>(
+      x,
+      {output_sizes.cbegin(), output_sizes.cend()},
+      y,
+      {stride_vec.cbegin(), stride_vec.cend()},
+      {kernel_size_vec.cbegin(), kernel_size_vec.cend()},
+      {padding_vec.cbegin(), padding_vec.cend()},
+      {padding_vec.cbegin(), padding_vec.cend()},
+      algo,
+      ideep::prop_kind::forward);
+
+  return new_with_itensor_mkldnn(std::move(y), input.options());
+}
+
+Tensor mkldnn_max_pool2d(
+    const Tensor& input,
+    IntArrayRef kernel_size,
+    IntArrayRef stride,
+    IntArrayRef padding,
+    IntArrayRef dilation,
+    bool ceil_mode) {
+  return _mkldnn_pool2d(
+      input,
+      kernel_size,
+      stride,
+      padding,
+      dilation,
+      ceil_mode,
+      ideep::algorithm::pooling_max);
+}
+
+Tensor mkldnn_avg_pool2d(
+    const Tensor& input,
+    IntArrayRef kernel_size,
+    IntArrayRef stride,
+    IntArrayRef padding,
+    bool ceil_mode,
+    bool count_include_pad) {
+  return _mkldnn_pool2d(
+      input,
+      kernel_size,
+      stride,
+      padding,
+      std::vector<int64_t>{1, 1},
+      ceil_mode,
+      count_include_pad ? ideep::algorithm::pooling_avg_include_padding
+                        : ideep::algorithm::pooling_avg_exclude_padding);
+}
+
+Tensor& mkldnn_avg_pool2d_out(
+    Tensor& output,
+    const Tensor& input,
+    IntArrayRef kernel_size,
+    IntArrayRef stride,
+    IntArrayRef padding,
+    bool ceil_mode,
+    bool count_include_pad) {
+  AT_ERROR(
+      "mkldnn_avg_pool2d_out: in-place mkldnn operations are not supported yet");
+}
+
+} // namespace native
+} // namespace at
+
+#endif // AT_MKLDNN_ENABLED
diff --git a/aten/src/ATen/native/mkldnn/Relu.cpp b/aten/src/ATen/native/mkldnn/Relu.cpp
new file mode 100644
index 0000000..27c6ba1
--- /dev/null
+++ b/aten/src/ATen/native/mkldnn/Relu.cpp
@@ -0,0 +1,43 @@
+#include <ATen/ATen.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/Config.h>
+
+
+#if !AT_MKLDNN_ENABLED()
+
+namespace at { namespace native {
+
+Tensor mkldnn_relu(const Tensor& input) {
+  AT_ERROR("mkldnn_relu: ATen not compiled with MKLDNN support");
+}
+
+Tensor& mkldnn_relu_(Tensor& input) {
+  AT_ERROR("mkldnn_relu_: ATen not compiled with MKLDNN support");
+}
+
+}}
+
+#else // AT_MKLDNN_EBABLED
+
+#include <ATen/native/mkldnn/MKLDNNCommon.h>
+
+namespace at { namespace native {
+
+Tensor mkldnn_relu(const Tensor& input) {
+  const ideep::tensor& x = itensor_from_mkldnn(input);
+  ideep::tensor y;
+  ideep::eltwise_forward::compute<AllocForMKLDNN>(
+      x, y, ideep::algorithm::eltwise_relu, ideep::prop_kind::forward_training, /*alpha*/ 0.0);
+  return new_with_itensor_mkldnn(std::move(y), input.options());
+}
+
+Tensor& mkldnn_relu_(Tensor& input) {
+  ideep::tensor& x = itensor_from_mkldnn(input);
+  ideep::eltwise_forward::compute<AllocForMKLDNN>(
+      x, x, ideep::algorithm::eltwise_relu, ideep::prop_kind::forward_training, /*alpha*/ 0.0);
+  return input;
+}
+
+}}
+
+#endif // AT_MKLDNN_EBABLED
diff --git a/aten/src/ATen/native/mkldnn/Utils.cpp b/aten/src/ATen/native/mkldnn/Utils.cpp
index 1442cf9..0d66ec6 100644
--- a/aten/src/ATen/native/mkldnn/Utils.cpp
+++ b/aten/src/ATen/native/mkldnn/Utils.cpp
@@ -1,4 +1,5 @@
 #include <ATen/native/mkldnn/Utils.h>
+#include <THNN/generic/pooling_shape.h>
 
 namespace at { namespace native {
 
@@ -20,4 +21,30 @@
   return output_size;
 }
 
+std::vector<int64_t> pool_output_sizes(
+    IntArrayRef input_size,
+    IntArrayRef kernel_size,
+    IntArrayRef stride,
+    IntArrayRef padding,
+    IntArrayRef dilation,
+    bool ceil_mode) {
+  std::vector<int64_t> output_size(input_size.size());
+  // copy N and C
+  output_size[0] = input_size[0];
+  output_size[1] = input_size[1];
+
+  for (int i = 2; i < input_size.size(); ++i) {
+    output_size[i] = pooling_output_shape<int64_t>(
+      input_size[i],
+      kernel_size[i - 2],
+      padding[i - 2],
+      stride[i - 2],
+      dilation[i - 2],
+      ceil_mode
+    );
+  }
+
+   return output_size;
+}
+
 }}
diff --git a/aten/src/ATen/native/mkldnn/Utils.h b/aten/src/ATen/native/mkldnn/Utils.h
index 3b70478..cf8b774 100644
--- a/aten/src/ATen/native/mkldnn/Utils.h
+++ b/aten/src/ATen/native/mkldnn/Utils.h
@@ -11,4 +11,12 @@
     IntArrayRef padding,
     IntArrayRef stride,
     IntArrayRef dilation);
+
+std::vector<int64_t> pool_output_sizes(
+    IntArrayRef input_size,
+    IntArrayRef kernel_size,
+    IntArrayRef stride,
+    IntArrayRef padding,
+    IntArrayRef dilation,
+    bool ceil_mode);
 }}
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 8a4a54c..041c549 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1141,6 +1141,11 @@
 
 - func: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
 
+- func: mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
+  requires_tensor: True
+  dispatch:
+    MkldnnCPU: mkldnn_max_pool2d
+
 - func: max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor
 
 # FIXME: These could be combined as optional<ScalarType> but for https://github.com/pytorch/pytorch/issues/6593.
@@ -1522,9 +1527,17 @@
 
 - func: relu(Tensor self) -> Tensor
   variants: function, method
+  dispatch:
+    CPU: relu
+    CUDA: relu
+    MkldnnCPU: mkldnn_relu
 
 - func: relu_(Tensor(a!) self) -> Tensor(a!)
   variants: function, method
+  dispatch:
+    CPU: relu_
+    CUDA: relu_
+    MkldnnCPU: mkldnn_relu_
 
 - func: prelu(Tensor self, Tensor weight) -> Tensor
   variants: function, method
@@ -3670,9 +3683,17 @@
 
 - func: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, *, Tensor(a!) out) -> Tensor(a!)
   python_module: nn
+  dispatch:
+    CPU: avg_pool2d_out
+    CUDA: avg_pool2d_out
+    MkldnnCPU: mkldnn_avg_pool2d_out
 
 - func: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor
   python_module: nn
+  dispatch:
+    CPU: avg_pool2d
+    CUDA: avg_pool2d
+    MkldnnCPU: mkldnn_avg_pool2d
 
 - func: avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, *, Tensor(a!) grad_input) -> Tensor(a!)
   python_module: nn
diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py
index b3c7e7b..15eb661 100644
--- a/test/test_mkldnn.py
+++ b/test/test_mkldnn.py
@@ -109,6 +109,45 @@
                     conv2d(x),
                     mkldnn_conv2d(x.to_mkldnn()).to_dense())
 
+    def test_relu(self):
+        x = torch.randn((4, 5), dtype=torch.float32) * 10
+        self.assertEqual(torch.relu(x), torch.relu(x.to_mkldnn()).to_dense())
+
+    def test_relu_(self):
+        x1 = torch.randn((4, 5), dtype=torch.float32) * 10
+        x2 = x1.clone().to_mkldnn()
+        self.assertEqual(torch.relu_(x1), torch.relu_(x2).to_dense())
+
+    def test_max_pool2d(self):
+        N = torch.randint(3, 10, (1,)).item()
+        C = torch.randint(3, 10, (1,)).item()
+        x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10
+
+        max_pool2d = torch.nn.MaxPool2d(
+            kernel_size=3,
+            stride=2,
+            padding=1)
+
+        self.assertEqual(
+            max_pool2d(x),
+            max_pool2d(x.to_mkldnn()).to_dense())
+
+    def test_avg_pool2d(self):
+        N = torch.randint(3, 10, (1,)).item()
+        C = torch.randint(3, 10, (1,)).item()
+        x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10
+
+        for count_include_pad in [True, False]:
+            avg_pool2d = torch.nn.AvgPool2d(
+                kernel_size=3,
+                stride=2,
+                padding=1,
+                count_include_pad=count_include_pad)
+
+            self.assertEqual(
+                avg_pool2d(x),
+                avg_pool2d(x.to_mkldnn()).to_dense())
+
 
 if __name__ == '__main__':
     run_tests()
diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py
index e6ce9b8..14fa69d 100644
--- a/torch/nn/modules/activation.py
+++ b/torch/nn/modules/activation.py
@@ -60,7 +60,7 @@
 
 
 @weak_module
-class ReLU(Threshold):
+class ReLU(Module):
     r"""Applies the rectified linear unit function element-wise:
 
     :math:`\text{ReLU}(x)= \max(0, x)`
@@ -88,9 +88,15 @@
         >>> input = torch.randn(2).unsqueeze(0)
         >>> output = torch.cat((m(input),m(-input)))
     """
+    __constants__ = ['inplace']
 
     def __init__(self, inplace=False):
-        super(ReLU, self).__init__(0., 0., inplace)
+        super(ReLU, self).__init__()
+        self.inplace = inplace
+
+    @weak_script_method
+    def forward(self, input):
+        return F.relu(input, inplace=self.inplace)
 
     def extra_repr(self):
         inplace_str = 'inplace' if self.inplace else ''