Add channels_last3d support for mkldnn conv and mkldnn deconv (#95271)
### Motivation
- Add channels_last3d support for mkldnn conv and mkldnn deconv.
- Use `ideep::convolution_transpose_forward::compute_v3` instead of `ideep::convolution_transpose_forward::compute`. compute_v3 uses `is_channels_last` to notify ideep whether to go CL or not to align with the memory format check of PyTorch.
### Testing
1 socket (28 cores):
- memory format: torch.contiguous_format
module | shape | forward / ms | backward / ms
-- | -- | -- | --
conv3d | input size: (32, 32, 10, 100, 100), weight size: (32, 32, 3, 3, 3) | 64.56885 | 150.1796
conv3d | input size: (32, 16, 10, 200, 200), weight size: (16, 16, 3, 3, 3) | 100.6754 | 231.8883
conv3d | input size: (16, 4, 5, 300, 300), weight size: (4, 4, 3, 3, 3) | 19.31751 | 68.31131
module | shape | forward / ms | backward / ms
-- | -- | -- | --
ConvTranspose3d | input size: (32, 32, 10, 100, 100), weight size: (32, 32, 3, 3, 3) | 122.7646 | 207.5125
ConvTranspose3d | input size: (32, 16, 10, 200, 200), weight size: (16, 16, 3, 3, 3) | 202.4542 | 368.5492
ConvTranspose3d | input size: (16, 4, 5, 300, 300), weight size: (4, 4, 3, 3, 3) | 122.959 | 84.62577
- memory format: torch.channels_last_3d
module | shape | forward / ms | backward / ms
-- | -- | -- | --
conv3d | input size: (32, 32, 10, 100, 100), weight size: (32, 32, 3, 3, 3) | 40.06993 | 114.317
conv3d | input size: (32, 16, 10, 200, 200), weight size: (16, 16, 3, 3, 3 | 49.08249 | 133.4079
conv3d | input size: (16, 4, 5, 300, 300), weight size: (4, 4, 3, 3, 3) | 5.873911 | 17.58647
module | shape | forward / ms | backward / ms
-- | -- | -- | --
ConvTranspose3d | input size: (32, 32, 10, 100, 100), weight size: (32, 32, 3, 3, 3) | 88.4246 | 208.2269
ConvTranspose3d | input size: (32, 16, 10, 200, 200), weight size: (16, 16, 3, 3, 3 | 140.0725 | 270.4172
ConvTranspose3d | input size: (16, 4, 5, 300, 300), weight size: (4, 4, 3, 3, 3) | 23.0223 | 37.16972
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95271
Approved by: https://github.com/jgong5, https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/ConvUtils.h b/aten/src/ATen/native/ConvUtils.h
index 065ed57..fd52628 100644
--- a/aten/src/ATen/native/ConvUtils.h
+++ b/aten/src/ATen/native/ConvUtils.h
@@ -389,8 +389,9 @@
(input_memory_format == at::MemoryFormat::ChannelsLast) ||
(weight_memory_format == at::MemoryFormat::ChannelsLast);
- // TODO: add channels last 3d support
- bool can_use_mkldnn_channels_last_3d = false;
+ bool can_use_mkldnn_channels_last_3d =
+ (input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
+ (weight_memory_format == at::MemoryFormat::ChannelsLast3d);
return can_use_mkldnn_channels_last_2d || can_use_mkldnn_channels_last_3d;
}
diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp
index bbb6727..7af281a 100644
--- a/aten/src/ATen/native/Convolution.cpp
+++ b/aten/src/ATen/native/Convolution.cpp
@@ -508,9 +508,6 @@
if (transposed && is_output_padding_big()) {
return false;
}
- if (transposed && groups > 1 && at::symint::size<T>(input, 1) == groups) {
- return false;
- }
if (input.device().is_cpu() && input.scalar_type() == kBFloat16 && mkldnn_bf16_device_check()) {
return true;
}
diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp
index 72bae7aa..ee1d103 100644
--- a/aten/src/ATen/native/mkldnn/Conv.cpp
+++ b/aten/src/ATen/native/mkldnn/Conv.cpp
@@ -727,7 +727,7 @@
if (bias.defined()) {
const ideep::tensor b = itensor_from_tensor(bias);
- ideep::convolution_transpose_forward::compute(
+ ideep::convolution_transpose_forward::compute_v3(
x,
w,
b,
@@ -738,9 +738,10 @@
padding_r(padding_expanded, output_padding_expanded),
dilation.vec(),
groups,
+ use_channels_last,
op_attr);
} else {
- ideep::convolution_transpose_forward::compute(
+ ideep::convolution_transpose_forward::compute_v3(
x,
w,
output_sizes,
@@ -750,6 +751,7 @@
padding_r(padding_expanded, output_padding_expanded),
dilation.vec(),
groups,
+ use_channels_last,
op_attr);
}
if (input.is_mkldnn()) {
@@ -988,7 +990,7 @@
grad_input.resize_(input_size, memory_format);
grad_x = itensor_from_tensor(grad_input);
}
- ideep::convolution_transpose_backward_data::compute(
+ ideep::convolution_transpose_backward_data::compute_v3(
grad_y,
w,
input_size.vec(),
@@ -997,7 +999,8 @@
padding.vec(),
padding_r(padding, output_padding),
dilation.vec(),
- groups);
+ groups,
+ is_channels_last);
if (grad_output.is_mkldnn()) {
return MKLDNNTensor(grad_x, grad_output.options());
@@ -1024,7 +1027,7 @@
ideep::tensor grad_w, grad_b;
if (bias_defined) {
- ideep::convolution_transpose_backward_weights::compute(
+ ideep::convolution_transpose_backward_weights::compute_v3(
x,
grad_y,
weight_size.vec(),
@@ -1034,9 +1037,10 @@
padding.vec(),
padding_r(padding, output_padding),
dilation.vec(),
- groups);
+ groups,
+ is_channels_last);
} else {
- ideep::convolution_transpose_backward_weights::compute(
+ ideep::convolution_transpose_backward_weights::compute_v3(
x,
grad_y,
weight_size.vec(),
@@ -1045,7 +1049,8 @@
padding.vec(),
padding_r(padding, output_padding),
dilation.vec(),
- groups);
+ groups,
+ is_channels_last);
}
if (!is_channels_last) {
@@ -1061,18 +1066,21 @@
}
std::tuple<Tensor, Tensor, Tensor> mkldnn_convolution_transpose_backward(
- const Tensor& input, const Tensor& grad_output_t, const Tensor& weight,
+ const Tensor& input_t, const Tensor& grad_output_t, const Tensor& weight_t,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
std::array<bool,3> output_mask)
{
- bool is_channels_last = mkldnn_conv_use_channels_last(input, weight);
- auto memory_format = mkldnn_convolution_memory_format(input.ndimension(), is_channels_last);
+ bool is_channels_last = mkldnn_conv_use_channels_last(input_t, weight_t);
+ auto memory_format = mkldnn_convolution_memory_format(input_t.ndimension(), is_channels_last);
Tensor grad_output = grad_output_t.is_mkldnn() ? grad_output_t : grad_output_t.contiguous(memory_format);
+ auto input = input_t.is_mkldnn() ? input_t : input_t.contiguous(memory_format);
+ auto weight = weight_t.is_mkldnn() ? weight_t : weight_t.contiguous(memory_format);
int64_t dim = input.ndimension() - 2;
const auto padding_expanded = expand_param_if_needed(padding, "padding", dim);
const auto stride_expanded = expand_param_if_needed(stride, "stride", dim);
const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", dim);
const auto output_padding_expanded = expand_param_if_needed(output_padding, "output_padding", dim);
+
Tensor grad_input, grad_weight, grad_bias;
if (output_mask[0]) {
grad_input = mkldnn_convolution_transpose_backward_input(
diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py
index 3f87544..bad7192 100644
--- a/test/test_mkldnn.py
+++ b/test/test_mkldnn.py
@@ -321,22 +321,25 @@
def test_conv3d_bf16(self):
self._test_conv_bf16_base(dim=3)
- def _test_conv2d_nhwc_base(self, conv_module, weight_memory_format, dtype):
- input_shapes = (55, 55)
+ def _test_conv_deconv_nhwc_base(self, conv_module, weight_memory_format, dtype, prec=None):
+ input_shapes = {2: (55, 55), 3: (14, 14, 14)}
options = itertools.product([True, False], [True, False], [1, 2], [1, 4])
+ if conv_module in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]:
+ cl_format = torch.channels_last
+ input_shape = input_shapes[2]
+ elif conv_module in [torch.nn.Conv3d, torch.nn.ConvTranspose3d]:
+ cl_format = torch.channels_last_3d
+ input_shape = input_shapes[3]
+
for train, bias, dilation, groups in options:
N = torch.randint(3, 10, (1,)).item()
M = torch.randint(1, 3, (1,)).item() * groups
C = torch.randint(1, 3, (1,)).item() * groups
- x_shape = (N, C) + input_shapes
+ x_shape = (N, C) + input_shape
x = torch.randn(x_shape, dtype=dtype)
- # TODO: remove this when group depthwise is supported:
- if conv_module is torch.nn.ConvTranspose2d and groups > 1 and C == groups:
- continue
-
- # conv1: mkldnn conv in contiguous memory format (nchw)
- # conv2: mkldnn conv in channels last memory format (nhwc)
+ # conv1: mkldnn conv/deconv in contiguous memory format (nchw)
+ # conv2: mkldnn conv/deconv in channels last memory format (nhwc)
conv1 = conv_module(in_channels=C,
out_channels=M,
kernel_size=3,
@@ -347,46 +350,67 @@
groups=groups).to(dtype=dtype)
conv2 = copy.deepcopy(conv1).to(memory_format=weight_memory_format)
x1 = x.clone()
- x2 = x.clone().to(memory_format=torch.channels_last)
+ x2 = x.clone().to(memory_format=cl_format)
if train:
x1.requires_grad_()
x2.requires_grad_()
y1 = conv1(x1)
y2 = conv2(x2)
- self.assertEqual(y1, y2)
+ self.assertEqual(y1, y2, atol=prec, rtol=prec)
+
if train:
y1.sum().backward()
y2.sum().backward()
- self.assertTrue(x2.grad.is_contiguous(memory_format=torch.channels_last))
+ self.assertTrue(x2.grad.is_contiguous(memory_format=cl_format))
self.assertEqual(conv1.weight.grad,
conv2.weight.grad,
atol=1e-3,
rtol=1e-3)
if bias:
- self.assertEqual(conv1.bias.grad, conv2.bias.grad)
- self.assertEqual(x1.grad, x2.grad)
+ self.assertEqual(conv1.bias.grad, conv2.bias.grad, atol=prec, rtol=prec)
+ self.assertEqual(x1.grad, x2.grad, atol=prec, rtol=prec)
- def test_conv2d_nhwc(self):
- self._test_conv2d_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.float32)
- self._test_conv2d_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.float32)
+ def test_conv_nhwc_fp32(self):
+ self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.float32)
+ self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.float32)
+ self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=torch.float32)
+ self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=torch.float32)
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
- def test_conv2d_nhwc_bf16(self):
+ def test_conv_nhwc_bf16(self):
# when torch.ops.mkldnn._is_mkldnn_bf16_supported() returns false, bf16 CPU conv will fall back to thnn impl
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
- self._test_conv2d_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.bfloat16)
- self._test_conv2d_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.bfloat16)
+ self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.bfloat16)
+ self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.bfloat16)
+ self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=torch.bfloat16)
+ self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=torch.bfloat16)
+ # test fall back to thnn impl
+ with torch.backends.mkldnn.flags(enabled=False):
+ self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.bfloat16, prec=1e-2)
+ self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.bfloat16, prec=1e-2)
+ self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.contiguous_format, dtype=torch.bfloat16, prec=1e-3)
+ self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=torch.bfloat16, prec=1e-3)
- def test_conv_transpose2d_nhwc(self):
- self._test_conv2d_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.float32)
- self._test_conv2d_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.float32)
+ def test_conv_transpose_nhwc_fp32(self):
+ self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.float32)
+ self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.float32)
+ self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=torch.float32)
+ self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=torch.float32)
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
- def test_conv_transpose2d_nhwc_bf16(self):
+ def test_conv_transpose_nhwc_bf16(self):
# when torch.ops.mkldnn._is_mkldnn_bf16_supported() returns false, bf16 CPU conv will fall back to thnn impl
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
- self._test_conv2d_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.bfloat16)
- self._test_conv2d_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.bfloat16)
+ self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.bfloat16)
+ self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.bfloat16)
+ self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=torch.bfloat16)
+ self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=torch.bfloat16)
+ # test fall back to thnn impl
+ with torch.backends.mkldnn.flags(enabled=False):
+ self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.bfloat16, prec=2e-2)
+ self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.bfloat16, prec=2e-2)
+ self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.contiguous_format, dtype=torch.bfloat16, prec=1e-3)
+ self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose3d, torch.channels_last_3d, dtype=torch.bfloat16, prec=1e-3)
def _test_conv_transpose_base(self, dim):
conv_module = {
diff --git a/test/test_mkldnn_fusion.py b/test/test_mkldnn_fusion.py
index 4858a27..5cccd85 100644
--- a/test/test_mkldnn_fusion.py
+++ b/test/test_mkldnn_fusion.py
@@ -268,7 +268,7 @@
x = self.binary(x, other)
return x
- input_shapes = {2: (112, 112), 3: (55, 55, 55)}
+ input_shapes = {2: (112, 112), 3: (22, 22, 22)}
for pointwise_name, pointwise_fn in self._binary_list().items():
for dim in [2, 3]:
channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d
@@ -301,7 +301,7 @@
self.assertEqual(ref, other)
self.assertEqual(ref, fused_inplace)
- self.assertEqual(ref, fused)
+ self.assertEqual(ref, fused, atol=5e-4, rtol=5e-4)
def test_linear_binary_fusion_ops(self):