Add channels_last3d support for mkldnn conv and mkldnn deconv (#95271)

### Motivation

- Add channels_last3d support for mkldnn conv and mkldnn deconv.
- Use `ideep::convolution_transpose_forward::compute_v3` instead of `ideep::convolution_transpose_forward::compute`.  compute_v3 uses `is_channels_last` to notify ideep whether to go CL or not to align with the memory format check of PyTorch.

### Testing
1 socket (28 cores):

- memory format: torch.contiguous_format

module | shape | forward / ms | backward / ms
-- | -- | -- | --
conv3d | input size: (32, 32, 10, 100, 100), weight size: (32, 32, 3, 3, 3) | 64.56885 | 150.1796
conv3d | input size: (32, 16, 10, 200, 200), weight size: (16, 16, 3, 3, 3) | 100.6754 | 231.8883
conv3d | input size: (16, 4, 5, 300, 300), weight size: (4, 4, 3, 3, 3) | 19.31751 | 68.31131

module | shape | forward / ms | backward / ms
-- | -- | -- | --
ConvTranspose3d | input size: (32, 32, 10, 100, 100), weight size: (32, 32, 3, 3, 3) | 122.7646 | 207.5125
ConvTranspose3d | input size: (32, 16, 10, 200, 200), weight size: (16, 16, 3, 3, 3) | 202.4542 | 368.5492
ConvTranspose3d | input size: (16, 4, 5, 300, 300), weight size: (4, 4, 3, 3, 3) | 122.959 | 84.62577

- memory format: torch.channels_last_3d

module | shape | forward / ms | backward / ms
-- | -- | -- | --
conv3d | input size: (32, 32, 10, 100, 100), weight size: (32, 32, 3, 3, 3) | 40.06993 | 114.317
conv3d | input size: (32, 16, 10, 200, 200), weight size: (16, 16, 3, 3, 3 | 49.08249 | 133.4079
conv3d | input size: (16, 4, 5, 300, 300), weight size: (4, 4, 3, 3, 3) | 5.873911 | 17.58647

module | shape | forward / ms | backward / ms
-- | -- | -- | --
ConvTranspose3d | input size: (32, 32, 10, 100, 100), weight size: (32, 32, 3, 3, 3) | 88.4246 | 208.2269
ConvTranspose3d | input size: (32, 16, 10, 200, 200), weight size: (16, 16, 3, 3, 3 | 140.0725 | 270.4172
ConvTranspose3d | input size: (16, 4, 5, 300, 300), weight size: (4, 4, 3, 3, 3) | 23.0223 | 37.16972

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95271
Approved by: https://github.com/jgong5, https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/ConvUtils.h b/aten/src/ATen/native/ConvUtils.h
index 065ed57..fd52628 100644
--- a/aten/src/ATen/native/ConvUtils.h
+++ b/aten/src/ATen/native/ConvUtils.h
@@ -389,8 +389,9 @@
       (input_memory_format  == at::MemoryFormat::ChannelsLast) ||
       (weight_memory_format == at::MemoryFormat::ChannelsLast);
 
-  // TODO: add channels last 3d support
-  bool can_use_mkldnn_channels_last_3d = false;
+  bool can_use_mkldnn_channels_last_3d =
+      (input_memory_format  == at::MemoryFormat::ChannelsLast3d) ||
+      (weight_memory_format == at::MemoryFormat::ChannelsLast3d);
 
   return can_use_mkldnn_channels_last_2d || can_use_mkldnn_channels_last_3d;
 }
diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp
index bbb6727..7af281a 100644
--- a/aten/src/ATen/native/Convolution.cpp
+++ b/aten/src/ATen/native/Convolution.cpp
@@ -508,9 +508,6 @@
     if (transposed && is_output_padding_big()) {
       return false;
     }
-    if (transposed && groups > 1 && at::symint::size<T>(input, 1) == groups) {
-      return false;
-    }
     if (input.device().is_cpu() && input.scalar_type() == kBFloat16 && mkldnn_bf16_device_check()) {
       return true;
     }
diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp
index 72bae7aa..ee1d103 100644
--- a/aten/src/ATen/native/mkldnn/Conv.cpp
+++ b/aten/src/ATen/native/mkldnn/Conv.cpp
@@ -727,7 +727,7 @@
 
   if (bias.defined()) {
     const ideep::tensor b = itensor_from_tensor(bias);
-    ideep::convolution_transpose_forward::compute(
+    ideep::convolution_transpose_forward::compute_v3(
         x,
         w,
         b,
@@ -738,9 +738,10 @@
         padding_r(padding_expanded, output_padding_expanded),
         dilation.vec(),
         groups,
+        use_channels_last,
         op_attr);
   } else {
-    ideep::convolution_transpose_forward::compute(
+    ideep::convolution_transpose_forward::compute_v3(
         x,
         w,
         output_sizes,
@@ -750,6 +751,7 @@
         padding_r(padding_expanded, output_padding_expanded),
         dilation.vec(),
         groups,
+        use_channels_last,
         op_attr);
   }
   if (input.is_mkldnn()) {
@@ -988,7 +990,7 @@
     grad_input.resize_(input_size, memory_format);
     grad_x = itensor_from_tensor(grad_input);
   }
-  ideep::convolution_transpose_backward_data::compute(
+  ideep::convolution_transpose_backward_data::compute_v3(
       grad_y,
       w,
       input_size.vec(),
@@ -997,7 +999,8 @@
       padding.vec(),
       padding_r(padding, output_padding),
       dilation.vec(),
-      groups);
+      groups,
+      is_channels_last);
 
   if (grad_output.is_mkldnn()) {
     return MKLDNNTensor(grad_x, grad_output.options());
@@ -1024,7 +1027,7 @@
 
   ideep::tensor grad_w, grad_b;
   if (bias_defined) {
-    ideep::convolution_transpose_backward_weights::compute(
+    ideep::convolution_transpose_backward_weights::compute_v3(
         x,
         grad_y,
         weight_size.vec(),
@@ -1034,9 +1037,10 @@
         padding.vec(),
         padding_r(padding, output_padding),
         dilation.vec(),
-        groups);
+        groups,
+        is_channels_last);
   } else {
-    ideep::convolution_transpose_backward_weights::compute(
+    ideep::convolution_transpose_backward_weights::compute_v3(
         x,
         grad_y,
         weight_size.vec(),
@@ -1045,7 +1049,8 @@
         padding.vec(),
         padding_r(padding, output_padding),
         dilation.vec(),
-        groups);
+        groups,
+        is_channels_last);
   }
 
   if (!is_channels_last) {
@@ -1061,18 +1066,21 @@
 }
 
 std::tuple<Tensor, Tensor, Tensor> mkldnn_convolution_transpose_backward(
-    const Tensor& input, const Tensor& grad_output_t, const Tensor& weight,
+    const Tensor& input_t, const Tensor& grad_output_t, const Tensor& weight_t,
     IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
     std::array<bool,3> output_mask)
 {
-  bool is_channels_last = mkldnn_conv_use_channels_last(input, weight);
-  auto memory_format = mkldnn_convolution_memory_format(input.ndimension(), is_channels_last);
+  bool is_channels_last = mkldnn_conv_use_channels_last(input_t, weight_t);
+  auto memory_format = mkldnn_convolution_memory_format(input_t.ndimension(), is_channels_last);
   Tensor grad_output = grad_output_t.is_mkldnn() ? grad_output_t : grad_output_t.contiguous(memory_format);
+  auto input = input_t.is_mkldnn() ? input_t : input_t.contiguous(memory_format);
+  auto weight = weight_t.is_mkldnn() ? weight_t : weight_t.contiguous(memory_format);
   int64_t dim = input.ndimension() - 2;
   const auto padding_expanded = expand_param_if_needed(padding, "padding", dim);
   const auto stride_expanded = expand_param_if_needed(stride, "stride", dim);
   const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", dim);
   const auto output_padding_expanded = expand_param_if_needed(output_padding, "output_padding", dim);
+
   Tensor grad_input, grad_weight, grad_bias;
   if (output_mask[0]) {
     grad_input = mkldnn_convolution_transpose_backward_input(
diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py
index 3f87544..bad7192 100644
--- a/test/test_mkldnn.py
+++ b/test/test_mkldnn.py
@@ -321,22 +321,25 @@
     def test_conv3d_bf16(self):
         self._test_conv_bf16_base(dim=3)
 
-    def _test_conv2d_nhwc_base(self, conv_module, weight_memory_format, dtype):
-        input_shapes = (55, 55)
+    def _test_conv_deconv_nhwc_base(self, conv_module, weight_memory_format, dtype, prec=None):
+        input_shapes = {2: (55, 55), 3: (14, 14, 14)}
         options = itertools.product([True, False], [True, False], [1, 2], [1, 4])
+        if conv_module in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]:
+            cl_format = torch.channels_last
+            input_shape = input_shapes[2]
+        elif conv_module in [torch.nn.Conv3d, torch.nn.ConvTranspose3d]:
+            cl_format = torch.channels_last_3d
+            input_shape = input_shapes[3]
+
         for train, bias, dilation, groups in options:
             N = torch.randint(3, 10, (1,)).item()
             M = torch.randint(1, 3, (1,)).item() * groups
             C = torch.randint(1, 3, (1,)).item() * groups
-            x_shape = (N, C) + input_shapes
+            x_shape = (N, C) + input_shape
             x = torch.randn(x_shape, dtype=dtype)
 
-            # TODO: remove this when group depthwise is supported:
-            if conv_module is torch.nn.ConvTranspose2d and groups > 1 and C == groups:
-                continue
-
-            # conv1: mkldnn conv in contiguous memory format (nchw)
-            # conv2: mkldnn conv in channels last memory format (nhwc)
+            # conv1: mkldnn conv/deconv in contiguous memory format (nchw)
+            # conv2: mkldnn conv/deconv in channels last memory format (nhwc)
             conv1 = conv_module(in_channels=C,
                                 out_channels=M,
                                 kernel_size=3,
@@ -347,46 +350,67 @@
                                 groups=groups).to(dtype=dtype)
             conv2 = copy.deepcopy(conv1).to(memory_format=weight_memory_format)
             x1 = x.clone()
-            x2 = x.clone().to(memory_format=torch.channels_last)
+            x2 = x.clone().to(memory_format=cl_format)
             if train:
                 x1.requires_grad_()
                 x2.requires_grad_()
             y1 = conv1(x1)
             y2 = conv2(x2)
-            self.assertEqual(y1, y2)
+            self.assertEqual(y1, y2, atol=prec, rtol=prec)
+
             if train:
                 y1.sum().backward()
                 y2.sum().backward()
-                self.assertTrue(x2.grad.is_contiguous(memory_format=torch.channels_last))
+                self.assertTrue(x2.grad.is_contiguous(memory_format=cl_format))
                 self.assertEqual(conv1.weight.grad,
                                  conv2.weight.grad,
                                  atol=1e-3,
                                  rtol=1e-3)
                 if bias:
-                    self.assertEqual(conv1.bias.grad, conv2.bias.grad)
-                self.assertEqual(x1.grad, x2.grad)
+                    self.assertEqual(conv1.bias.grad, conv2.bias.grad, atol=prec, rtol=prec)
+                self.assertEqual(x1.grad, x2.grad, atol=prec, rtol=prec)
 
-    def test_conv2d_nhwc(self):
-        self._test_conv2d_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.float32)
-        self._test_conv2d_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.float32)
+    def test_conv_nhwc_fp32(self):
+        self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.float32)
+        self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.float32)
+        self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=torch.float32)
+        self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=torch.float32)
 
     @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
-    def test_conv2d_nhwc_bf16(self):
+    def test_conv_nhwc_bf16(self):
         # when torch.ops.mkldnn._is_mkldnn_bf16_supported() returns false, bf16 CPU conv will fall back to thnn impl
         if torch.ops.mkldnn._is_mkldnn_bf16_supported():
-            self._test_conv2d_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.bfloat16)
-            self._test_conv2d_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.bfloat16)
+            self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.bfloat16)
+            self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.bfloat16)
+            self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=torch.bfloat16)
+            self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=torch.bfloat16)
+        # test fall back to thnn impl
+        with torch.backends.mkldnn.flags(enabled=False):
+            self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.bfloat16, prec=1e-2)
+            self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.bfloat16, prec=1e-2)
+            self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=torch.bfloat16, prec=1e-3)
+            self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=torch.bfloat16, prec=1e-3)
 
-    def test_conv_transpose2d_nhwc(self):
-        self._test_conv2d_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.float32)
-        self._test_conv2d_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.float32)
+    def test_conv_transpose_nhwc_fp32(self):
+        self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.float32)
+        self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.float32)
+        self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=torch.float32)
+        self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=torch.float32)
 
     @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
-    def test_conv_transpose2d_nhwc_bf16(self):
+    def test_conv_transpose_nhwc_bf16(self):
         # when torch.ops.mkldnn._is_mkldnn_bf16_supported() returns false, bf16 CPU conv will fall back to thnn impl
         if torch.ops.mkldnn._is_mkldnn_bf16_supported():
-            self._test_conv2d_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.bfloat16)
-            self._test_conv2d_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.bfloat16)
+            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.bfloat16)
+            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.bfloat16)
+            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=torch.bfloat16)
+            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=torch.bfloat16)
+        # test fall back to thnn impl
+        with torch.backends.mkldnn.flags(enabled=False):
+            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.bfloat16, prec=2e-2)
+            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.bfloat16, prec=2e-2)
+            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=torch.bfloat16, prec=1e-3)
+            self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=torch.bfloat16, prec=1e-3)
 
     def _test_conv_transpose_base(self, dim):
         conv_module = {
diff --git a/test/test_mkldnn_fusion.py b/test/test_mkldnn_fusion.py
index 4858a27..5cccd85 100644
--- a/test/test_mkldnn_fusion.py
+++ b/test/test_mkldnn_fusion.py
@@ -268,7 +268,7 @@
                 x = self.binary(x, other)
                 return x
 
-        input_shapes = {2: (112, 112), 3: (55, 55, 55)}
+        input_shapes = {2: (112, 112), 3: (22, 22, 22)}
         for pointwise_name, pointwise_fn in self._binary_list().items():
             for dim in [2, 3]:
                 channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d
@@ -301,7 +301,7 @@
                             self.assertEqual(ref, other)
                             self.assertEqual(ref, fused_inplace)
 
-                        self.assertEqual(ref, fused)
+                        self.assertEqual(ref, fused, atol=5e-4, rtol=5e-4)
 
 
     def test_linear_binary_fusion_ops(self):