DNNL: enable conv3d (#35662)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35662

Test Plan: Imported from OSS

Differential Revision: D22102408

Pulled By: VitalyFedyunin

fbshipit-source-id: 1e95cede429f1a950f26bc7052ab33d198857df3
diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp
index 48b14e5..aae3c75 100644
--- a/aten/src/ATen/native/Convolution.cpp
+++ b/aten/src/ATen/native/Convolution.cpp
@@ -236,8 +236,8 @@
     (input.options().backend() == at::Backend::CPU &&
      input.scalar_type() == kFloat && // only on CPU Float Tensors
      !is_dilated() && // doesn't support dilation
-     !transposed && // or transposed tensors
-     input.ndimension() == 4); // must be in NCHW format
+     !transposed // or transposed tensors
+    );
 #endif
   return false;
 }
diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp
index b6df8de..0479aba 100644
--- a/aten/src/ATen/native/mkldnn/Conv.cpp
+++ b/aten/src/ATen/native/mkldnn/Conv.cpp
@@ -55,7 +55,7 @@
 
 namespace at { namespace native {
 
-ideep::tensor _mkldnn_conv2d(
+ideep::tensor _mkldnn_convolution(
     const ideep::tensor& x,
     const ideep::tensor& w,
     const c10::optional<ideep::tensor>& b,
@@ -113,7 +113,7 @@
     mkldnn_bias = get_mkldnn_tensor(bias);
   }
 
-  ideep::tensor mkldnn_output = _mkldnn_conv2d(
+  ideep::tensor mkldnn_output = _mkldnn_convolution(
       mkldnn_input,
       mkldnn_weight,
       mkldnn_bias,
diff --git a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp
index 13250b7..6871026 100644
--- a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp
+++ b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp
@@ -52,7 +52,7 @@
 }
 
 ideep::tensor& itensor_from_mkldnn(const MKLDNNTensor& mkldnn_tensor) {
-  AT_ASSERTM(mkldnn_tensor.is_mkldnn(),
+  TORCH_CHECK(mkldnn_tensor.is_mkldnn(),
              "mkldnn_to_dense expects MKL-DNN tensor input");
   TORCH_INTERNAL_ASSERT(at::impl::variable_excluded_from_dispatch());
   MKLDNNTensorImpl *mklimpl = static_cast<MKLDNNTensorImpl *>(mkldnn_tensor.unsafeGetTensorImpl());
@@ -60,13 +60,13 @@
 }
 
 ideep::tensor itensor_view_from_dense(const Tensor& tensor) {
-  AT_ASSERTM(
+  TORCH_CHECK(
       tensor.device().type() == DeviceType::CPU,
       "itensor_view_from_dense expects CPU tensor input");
-  AT_ASSERTM(
+  TORCH_CHECK(
       tensor.layout() == Layout::Strided,
       "itensor_view_from_dense expects dense tensor input");
-  AT_ASSERTM(tensor.scalar_type() == ScalarType::Float,
+  TORCH_CHECK(tensor.scalar_type() == ScalarType::Float,
              "itensor_view_from_dense expects float tensor input");
   TORCH_INTERNAL_ASSERT(at::impl::variable_excluded_from_dispatch());
   return {{{tensor.sizes().cbegin(), tensor.sizes().cend()},
diff --git a/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp b/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
index c6736ce..971fa7a 100644
--- a/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
+++ b/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
@@ -22,13 +22,13 @@
 }
 
 Tensor dense_to_mkldnn(const Tensor& cpu_tensor) {
-  AT_ASSERTM(cpu_tensor.device().type() == DeviceType::CPU,
+  TORCH_CHECK(cpu_tensor.device().type() == DeviceType::CPU,
              "dense_to_mkldnn expects CPU tensor input");
-  AT_ASSERTM(cpu_tensor.layout() == Layout::Strided,
+  TORCH_CHECK(cpu_tensor.layout() == Layout::Strided,
              "dense_to_mkldnn expects strided tensor input");
-  AT_ASSERTM(cpu_tensor.scalar_type() == ScalarType::Float,
+  TORCH_CHECK(cpu_tensor.scalar_type() == ScalarType::Float,
              "dense_to_mkldnn expects float tensor input");
-  AT_ASSERTM(cpu_tensor.dim() <= 5,
+  TORCH_CHECK(cpu_tensor.dim() <= 5,
              "Can't convert cpu tensor with the number of dimensions > 5");
   // TODO: consider to convert non-contiguous tensor to `ideep::tensor` directly.
   auto cpu_tensor_cont = cpu_tensor.contiguous();
@@ -53,10 +53,6 @@
     IntArrayRef dilation,
     int64_t groups) {
 
-  auto stride_vec = expand_param_if_needed(stride, "stride", 2);
-  auto padding_vec = expand_param_if_needed(padding, "padding", 2);
-  auto dilation_vec = expand_param_if_needed(dilation, "dilation", 2);
-
   auto w = itensor_from_mkldnn(self);
 
   // Legacy mkldnn conv2d jitted module may contain a 5-d weight with an extra
@@ -73,10 +69,36 @@
       ideep::convolution_forward::expected_weights_desc(
           w.get_dims(),
           w.get_data_type(),
-          {stride_vec.cbegin(), stride_vec.cend()},
-          {padding_vec.cbegin(), padding_vec.cend()},
-          {padding_vec.cbegin(), padding_vec.cend()},
-          {dilation_vec.cbegin(), dilation_vec.cend()},
+          {stride.begin(), stride.end()},
+          {padding.begin(), padding.end()},
+          {padding.begin(), padding.end()},
+          {dilation.begin(), dilation.end()},
+          groups,
+          ideep::algorithm::convolution_direct);
+  ideep::tensor result;
+  result.init(desc);
+  result.feed_from(w);
+
+  return new_with_itensor_mkldnn(std::move(result), self.options());
+}
+
+Tensor mkldnn_reorder_conv3d_weight(
+    const Tensor& self,
+    IntArrayRef padding,
+    IntArrayRef stride,
+    IntArrayRef dilation,
+    int64_t groups) {
+
+  auto w = itensor_from_mkldnn(self);
+
+  auto desc =
+      ideep::convolution_forward::expected_weights_desc(
+          w.get_dims(),
+          w.get_data_type(),
+          {stride.begin(), stride.end()},
+          {padding.begin(), padding.end()},
+          {padding.begin(), padding.end()},
+          {dilation.begin(), dilation.end()},
           groups,
           ideep::algorithm::convolution_direct);
   ideep::tensor result;
@@ -89,11 +111,11 @@
 #else
 
 Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor) {
-  AT_ERROR("MKL-DNN build is disabled");
+  TORCH_CHECK(false, "MKL-DNN build is disabled");
 }
 
 Tensor dense_to_mkldnn(const Tensor& cpu_tensor) {
-  AT_ERROR("MKL-DNN build is disabled");
+  TORCH_CHECK(false, "MKL-DNN build is disabled");
 }
 
 Tensor mkldnn_reorder_conv2d_weight(
@@ -102,7 +124,16 @@
     IntArrayRef stride,
     IntArrayRef dilation,
     int64_t groups) {
-  AT_ERROR("mkldnn_reorder_conv2d_weight: MKL-DNN build is disabled");
+  TORCH_CHECK(false, "mkldnn_reorder_conv2d_weight: MKL-DNN build is disabled");
+}
+
+Tensor mkldnn_reorder_conv3d_weight(
+    const Tensor& self,
+    IntArrayRef padding,
+    IntArrayRef stride,
+    IntArrayRef dilation,
+    int64_t groups) {
+  TORCH_CHECK(false, "mkldnn_reorder_conv3d_weight: MKL-DNN build is disabled");
 }
 
 #endif // AT_MKLDNN_ENABLED()
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 06aa828..66db2f6 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3558,6 +3558,12 @@
   dispatch:
     MkldnnCPU: mkldnn_reorder_conv2d_weight
 
+- func: mkldnn_reorder_conv3d_weight(Tensor self, int[3] padding=0, int[3] stride=1, int[3] dilation=1, int groups=1) -> Tensor
+  variants: function
+  python_module: nn
+  dispatch:
+    MkldnnCPU: mkldnn_reorder_conv3d_weight
+
 - func: to_mkldnn_backward(Tensor grad, Tensor input) -> Tensor
   use_c10_dispatcher: full
 
diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py
index 2043095..3bae59c 100644
--- a/test/test_mkldnn.py
+++ b/test/test_mkldnn.py
@@ -177,6 +177,29 @@
                 conv2d(x),
                 conv2d_loaded(x.to_mkldnn()).to_dense())
 
+    def test_conv3d(self):
+        for groups in [1, 4]:
+            N = torch.randint(3, 10, (1,)).item()
+            C = torch.randint(1, 3, (1,)).item() * groups
+            M = torch.randint(1, 3, (1,)).item() * groups
+            x = torch.randn(N, C, 55, 55, 55, dtype=torch.float32)
+            for bias in [True, False]:
+                conv3d = torch.nn.Conv3d(in_channels=C,
+                                         out_channels=M,
+                                         kernel_size=3,
+                                         stride=2,
+                                         padding=1,
+                                         bias=bias,
+                                         groups=groups).float()
+                mkldnn_conv3d = mkldnn_utils.to_mkldnn(copy.deepcopy(conv3d))
+                with torch.backends.mkldnn.flags(enabled=False):
+                    y_aten = conv3d(x)
+                y_mkldnn = mkldnn_conv3d(x.to_mkldnn()).to_dense()
+                self.assertEqual(y_aten, y_mkldnn)
+
+                self._test_serialization(mkldnn_conv3d, (x.to_mkldnn(),))
+                self._test_tracing(mkldnn_conv3d, (x.to_mkldnn(),))
+
     def test_relu(self):
         x = torch.randn((4, 5), dtype=torch.float32) * 10
         self.assertEqual(torch.relu(x), torch.relu(x.to_mkldnn()).to_dense())
diff --git a/torch/utils/mkldnn.py b/torch/utils/mkldnn.py
index 1a9c595..b084c26 100644
--- a/torch/utils/mkldnn.py
+++ b/torch/utils/mkldnn.py
@@ -105,6 +105,29 @@
         self.training = state[2]
 
 
+class MkldnnConv3d(_MkldnnConvNd):
+    def __init__(self, dense_module):
+        super(MkldnnConv3d, self).__init__(dense_module)
+
+        self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv3d_weight(
+            dense_module.weight.to_mkldnn(),
+            self.padding,
+            self.stride,
+            self.dilation,
+            self.groups))
+
+    @torch.jit.script_method
+    def __setstate__(self, state):
+        self.weight = torch._C._nn.mkldnn_reorder_conv3d_weight(
+            state[0].to_mkldnn(),
+            self.padding,
+            self.stride,
+            self.dilation,
+            self.groups)
+        self.bias = state[1].to_mkldnn()
+        self.training = state[2]
+
+
 class MkldnnBatchNorm2d(torch.jit.ScriptModule):
     __constants__ = ['exponential_average_factor', 'eps']
 
@@ -165,6 +188,8 @@
             return MkldnnConv1d(m)
         elif isinstance(m, torch.nn.Conv2d):
             return MkldnnConv2d(m)
+        elif isinstance(m, torch.nn.Conv3d):
+            return MkldnnConv3d(m)
         elif isinstance(m, torch.nn.BatchNorm2d):
             return MkldnnBatchNorm2d(m)
         else: