add channels last (2d) support for mkldnn_convolution (#55584)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55584
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D27941368
Pulled By: VitalyFedyunin
fbshipit-source-id: 7dd6f02a5787efa1995f31cdbd3244b25653840c
(cherry picked from commit bb555ed0fedafd529cb552807326384e95c90df9)
diff --git a/aten/src/ATen/native/ConvUtils.h b/aten/src/ATen/native/ConvUtils.h
index 54a4b5d..c0a681b 100644
--- a/aten/src/ATen/native/ConvUtils.h
+++ b/aten/src/ATen/native/ConvUtils.h
@@ -289,4 +289,30 @@
return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d;
}
+static inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
+
+ // disable NHWC for float64 input.
+ if (input.scalar_type() == at::kDouble ||
+ weight.scalar_type() == at::kDouble) {
+ return false;
+ }
+
+ // disable NHWC for MkldnnCPU tensor.
+ if (input.is_mkldnn() || weight.is_mkldnn()) {
+ return false;
+ }
+
+ auto input_memory_format = input.suggest_memory_format();
+ auto weight_memory_format = weight.suggest_memory_format();
+
+ bool can_use_mkldnn_channels_last_2d =
+ (input_memory_format == at::MemoryFormat::ChannelsLast) ||
+ (weight_memory_format == at::MemoryFormat::ChannelsLast);
+
+ // TODO: add channels last 3d support
+ bool can_use_mkldnn_channels_last_3d = false;
+
+ return can_use_mkldnn_channels_last_2d || can_use_mkldnn_channels_last_3d;
+}
+
}} // namespace at::native
diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp
index 02b1794..ea040e2 100644
--- a/aten/src/ATen/native/Convolution.cpp
+++ b/aten/src/ATen/native/Convolution.cpp
@@ -629,6 +629,25 @@
check_input_same_type_as_parameters(input, weight, /*bias=*/ Tensor());
}
+static void check_input_same_type_as_parameters(
+ const Tensor& input,
+ const Tensor& weight,
+ const Tensor& bias,
+ const ConvBackend backend) {
+ if (backend == ConvBackend::Mkldnn) {
+ TORCH_CHECK(input.options().type_equal(weight.options())
+ || (input.is_mkldnn() && weight.device().is_cpu() && weight.scalar_type() == kFloat),
+ "Input type (", input.toString(), ") and weight type (", weight.toString(),
+ ") should be the same or input should be a MKLDNN tensor and weight is a dense tensor");
+ TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options()))
+ || (input.is_mkldnn() && bias.device().is_cpu() && bias.scalar_type() == kFloat),
+ "Input type (", input.toString(), ") and bias type (", bias.toString(),
+ ") should be the same or input should be a MKLDNN tensor and bias is a dense tensor");
+ } else {
+ check_input_same_type_as_parameters(input, weight, bias);
+ }
+}
+
static auto view4d(const at::Tensor& tensor) -> at::Tensor {
TORCH_CHECK(tensor.ndimension() == 3,
"expected 3D tensor, got tensor with ", tensor.ndimension(),
@@ -1187,18 +1206,35 @@
static inline at::MemoryFormat determine_backend_memory_format(
const Tensor& input,
- const Tensor& weight) {
+ const Tensor& weight,
+ const ConvBackend backend) {
at::MemoryFormat backend_memory_format = at::MemoryFormat::Contiguous;
auto k = weight.ndimension();
#if !defined(C10_MOBILE)
// See Note [Mobile check segfaults]
- if (detail::getCUDAHooks().compiledWithCuDNN()) {
- backend_memory_format = cudnn_conv_suggest_memory_format(input, weight);
- }
- if (detail::getCUDAHooks().compiledWithMIOpen() && miopen_conv_use_channels_last(input, weight)) {
- TORCH_INTERNAL_ASSERT((k == 4 || k == 5),
- "Expected 4D or 5D input for miopen memory format selection in determine_backend_memory_format()");
- backend_memory_format = (k == 5) ? at::MemoryFormat::Contiguous /*at::MemoryFormat::ChannelsLast3d*/ : at::MemoryFormat::ChannelsLast;
+ switch(backend) {
+ case ConvBackend::Cudnn:
+ case ConvBackend::CudnnTranspose:
+ if (detail::getCUDAHooks().compiledWithCuDNN()) {
+ backend_memory_format = cudnn_conv_suggest_memory_format(input, weight);
+ }
+ break;
+ case ConvBackend::Miopen:
+ case ConvBackend::MiopenDepthwise:
+ case ConvBackend::MiopenTranspose:
+ if (detail::getCUDAHooks().compiledWithMIOpen() && miopen_conv_use_channels_last(input, weight)) {
+ TORCH_INTERNAL_ASSERT((k == 4 || k == 5),
+ "Expected 4D or 5D input for miopen memory format selection in determine_backend_memory_format()");
+ backend_memory_format = (k == 5) ? at::MemoryFormat::Contiguous /*at::MemoryFormat::ChannelsLast3d*/ : at::MemoryFormat::ChannelsLast;
+ }
+ break;
+ case ConvBackend::Mkldnn:
+ if (mkldnn_conv_use_channels_last(input, weight)) {
+ backend_memory_format = (k == 5) ? at::MemoryFormat::Contiguous /*at::MemoryFormat::ChannelsLast3d*/ : at::MemoryFormat::ChannelsLast;
+ }
+ break;
+ default:
+ backend_memory_format = at::MemoryFormat::Contiguous;
}
#endif
return backend_memory_format;
@@ -1251,7 +1287,7 @@
bool need_backward = GradMode::is_enabled() &&
(input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad()));
ConvBackend backend = select_conv_backend(input, weight, bias_sizes_opt, need_backward, params);
- at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight);
+ at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight, backend);
// Call the backend.
Tensor output;
@@ -1312,18 +1348,11 @@
break;
case ConvBackend::Mkldnn:
#if AT_MKLDNN_ENABLED()
- TORCH_CHECK(input.options().type_equal(weight.options())
- || (input.is_mkldnn() && weight.device().is_cpu() && weight.scalar_type() == kFloat),
- "Input type (", input.toString(), ") and weight type (", weight.toString(),
- ") should be the same or input should be a MKLDNN tensor and weight is a dense tensor");
- TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options()))
- || (input.is_mkldnn() && bias.device().is_cpu() && bias.scalar_type() == kFloat),
- "Input type (", input.toString(), ") and bias type (", bias.toString(),
- ") should be the same or input should be a MKLDNN tensor and bias is a dense tensor");
+ check_input_same_type_as_parameters(input, weight, bias, backend);
if (!input.is_mkldnn()) {
// need to ensure contiguous for non-mkldnn tensors
- input = input.contiguous();
- weight = weight.contiguous();
+ input = input.contiguous(backend_memory_format);
+ weight = weight.contiguous(backend_memory_format);
bias = bias.defined() ? bias.contiguous() : bias;
}
output = at::mkldnn_convolution(
@@ -1726,7 +1755,7 @@
// Select appropriate backend to use.
ConvBackend backend = select_conv_backend(input, weight, bias_sizes_opt, /*need_backward=*/ true, params);
- at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight);
+ at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight, backend);
// Call the backend.
Tensor backend_grad_input, backend_grad_weight, backend_grad_bias;
@@ -1834,8 +1863,8 @@
TORCH_CHECK(!weight.is_mkldnn(),
"The MKLDNN backend does not support weight as an MKLDNN tensor during training");
if (!input.is_mkldnn()) {
- input = input.contiguous();
- weight = weight.contiguous();
+ input = input.contiguous(backend_memory_format);
+ weight = weight.contiguous(backend_memory_format);
}
std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
mkldnn_convolution_backward_stub(input.device().type(), input, grad_output, weight, params.padding,
diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp
index 50b366e..a2489e4 100644
--- a/aten/src/ATen/native/mkldnn/Conv.cpp
+++ b/aten/src/ATen/native/mkldnn/Conv.cpp
@@ -43,48 +43,42 @@
namespace at { namespace native {
-ideep::tensor _mkldnn_convolution(
- const ideep::tensor& x,
- const ideep::tensor& w,
- const c10::optional<ideep::tensor>& b,
- IntArrayRef padding,
- IntArrayRef stride,
- IntArrayRef dilation,
- int64_t groups) {
+#define MKLDNNTensor(itensor, options) \
+ new_with_itensor_mkldnn( \
+ std::move(itensor), \
+ optTypeMetaToScalarType(options.dtype_opt()), \
+ options.device_opt())
- auto kernel_size = w.get_dims();
-
- std::vector<int64_t> input_size = x.get_dims();
- std::vector<int64_t> output_sizes =
- conv_output_size(input_size, kernel_size, padding, stride, dilation);
-
- ideep::tensor y;
- if (b.has_value()) {
- ideep::convolution_forward::compute(
- x,
- w,
- b.value(),
- {output_sizes.cbegin(), output_sizes.cend()},
- y,
- {stride.begin(), stride.end()},
- {dilation.begin(), dilation.end()},
- {padding.begin(), padding.end()},
- {padding.begin(), padding.end()},
- groups);
- } else {
- ideep::convolution_forward::compute(
- x,
- w,
- {output_sizes.cbegin(), output_sizes.cend()},
- y,
- {stride.begin(), stride.end()},
- {dilation.begin(), dilation.end()},
- {padding.begin(), padding.end()},
- {padding.begin(), padding.end()},
- groups);
- }
- return y;
-}
+// Note [MKLDNN Convolution Memory Formats]
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// MKLDNN has 3 types of memory formats in convolution:
+//
+// In case memory format passed from PyTorch (aka. user layout)
+// differs from the internal layout which MKLDNN used, a `reorder` is needed;
+// otherwise when user layout is identical to internal layout,
+// MKLDNN uses a memory `view` upon an existing CPU tensor.
+//
+// 1. NCHW (CPU tensor, contiguous)
+// input reorder: NCHW(user) -> Blocked(internal)
+// weight reorder: OIHW(user) -> Blocked(internal)
+// output reorder: Blocked(internal) -> NCHW(user)
+//
+// 2. NHWC: (CPU tensor, channels last)
+// input view: NHWC(user) -> NHWC(internal)
+// weight reorder: OHWI(user) -> Blocked(internal)
+// output view: NHWC(internal) -> NHWC(user)
+//
+// 3. Blocked (MKLDNN tensor):
+// By explicitly converting a tensor to mkldnn, e.g. `x.to_mkldnn()`,
+// blocked format will propagate between layers. Input, output will be in blocked format.
+//
+// For inference case, weight can be prepacked into blocked format by
+// (so as to save weight reoder overhead):
+// model = torch.utils.mkldnn.to_mkldnn(model)
+//
+// For training case, grad_output can be CPU tensor or MKLDNN tensor,
+// but weight/bias and grad_weight/grad_bias are always CPU tensor.
+//
Tensor mkldnn_convolution(
const Tensor& input,
@@ -101,29 +95,53 @@
TORCH_CHECK(mkldnn_bf16_device_check(),
"mkldnn_convolution: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
}
- const ideep::tensor mkldnn_input = itensor_from_tensor(input);
- const ideep::tensor mkldnn_weight = itensor_from_tensor(weight);
- c10::optional<ideep::tensor> mkldnn_bias{c10::nullopt};
+
+ bool is_channels_last = input.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
+
+ auto output_sizes = conv_output_size(input.sizes(), weight.sizes(), padding, stride, dilation);
+ auto output = at::empty({0}, input.options());
+
+ const ideep::tensor x = itensor_from_tensor(input);
+ const ideep::tensor w = itensor_from_tensor(weight);
+
+ ideep::tensor y;
+ if (is_channels_last) {
+ output.resize_(output_sizes, input.suggest_memory_format());
+ y = itensor_from_tensor(output);
+ }
if (bias.defined()) {
- mkldnn_bias = itensor_from_tensor(bias);
+ const ideep::tensor b = itensor_from_tensor(bias);
+ ideep::convolution_forward::compute(
+ x,
+ w,
+ b,
+ {output_sizes.cbegin(), output_sizes.cend()},
+ y,
+ {stride.begin(), stride.end()},
+ {dilation.begin(), dilation.end()},
+ {padding.begin(), padding.end()},
+ {padding.begin(), padding.end()},
+ groups);
+ } else {
+ ideep::convolution_forward::compute(
+ x,
+ w,
+ {output_sizes.cbegin(), output_sizes.cend()},
+ y,
+ {stride.begin(), stride.end()},
+ {dilation.begin(), dilation.end()},
+ {padding.begin(), padding.end()},
+ {padding.begin(), padding.end()},
+ groups);
}
- ideep::tensor mkldnn_output = _mkldnn_convolution(
- mkldnn_input,
- mkldnn_weight,
- mkldnn_bias,
- padding,
- stride,
- dilation,
- groups);
-
if (input.is_mkldnn()) {
- return new_with_itensor_mkldnn(std::move(mkldnn_output), optTypeMetaToScalarType(input.options().dtype_opt()),
- input.options().device_opt());
+ return MKLDNNTensor(y, input.options());
+ } else if (!is_channels_last) {
+ return mkldnn_to_dense(MKLDNNTensor(y, input.options()));
} else {
- return mkldnn_to_dense(
- new_with_itensor_mkldnn(std::move(mkldnn_output), optTypeMetaToScalarType(input.options().dtype_opt()),
- input.options().device_opt()));
+ TORCH_INTERNAL_ASSERT(y.get_desc().is_nhwc());
+ return output;
}
}
@@ -131,17 +149,22 @@
IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined)
{
- // for training case, grad_output can be cpu tensor or MKLDNN tensor,
- // but weight and bias always cpu tensor.
- auto mkldnn_grad_output = itensor_from_tensor(grad_output);
- auto mkldnn_weight = itensor_view_from_dense(weight);
+ bool is_channels_last = grad_output.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
+ auto grad_input = at::empty({0}, grad_output.options());
- ideep::tensor mkldnn_grad_input;
+ auto grad_y = itensor_from_tensor(grad_output);
+ auto w = itensor_view_from_dense(weight);
+
+ ideep::tensor grad_x;
+ if (is_channels_last) {
+ grad_input.resize_(input_size, grad_output.suggest_memory_format());
+ grad_x = itensor_from_tensor(grad_input);
+ }
ideep::convolution_backward_data::compute(
- mkldnn_grad_output,
- mkldnn_weight,
+ grad_y,
+ w,
input_size.vec(),
- mkldnn_grad_input,
+ grad_x,
stride.vec(),
dilation.vec(),
padding.vec(),
@@ -149,14 +172,12 @@
groups);
if (grad_output.is_mkldnn()) {
- return new_with_itensor_mkldnn(std::move(mkldnn_grad_input),
- optTypeMetaToScalarType(grad_output.options().dtype_opt()),
- grad_output.options().device_opt());
-
+ return MKLDNNTensor(grad_x, grad_output.options());
+ } else if (!is_channels_last){
+ return mkldnn_to_dense(MKLDNNTensor(grad_x, grad_output.options()));
} else {
- return mkldnn_to_dense(new_with_itensor_mkldnn(std::move(mkldnn_grad_input),
- optTypeMetaToScalarType(grad_output.options().dtype_opt()),
- grad_output.options().device_opt()));
+ TORCH_INTERNAL_ASSERT(grad_x.get_desc().is_nhwc());
+ return grad_input;
}
}
@@ -164,19 +185,19 @@
IntArrayRef weight_size, const Tensor& grad_output, const Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined)
{
- // for training case, grad_output and input can be cpu tensor or MKLDNN tensor,
- // but weight and bias are always cpu tensor.
- const ideep::tensor mkldnn_grad_output = itensor_from_tensor(grad_output);
- const ideep::tensor mkldnn_input = itensor_from_tensor(input);
+ bool is_channels_last = grad_output.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
- ideep::tensor mkldnn_grad_weight, mkldnn_grad_bias;
+ const ideep::tensor grad_y = itensor_from_tensor(grad_output);
+ const ideep::tensor x = itensor_from_tensor(input);
+
+ ideep::tensor grad_w, grad_b;
if (bias_defined) {
ideep::convolution_backward_weights::compute(
- mkldnn_input,
- mkldnn_grad_output,
+ x,
+ grad_y,
weight_size.vec(),
- mkldnn_grad_weight,
- mkldnn_grad_bias,
+ grad_w,
+ grad_b,
stride.vec(),
dilation.vec(),
padding.vec(),
@@ -184,10 +205,10 @@
groups);
} else {
ideep::convolution_backward_weights::compute(
- mkldnn_input,
- mkldnn_grad_output,
+ x,
+ grad_y,
weight_size.vec(),
- mkldnn_grad_weight,
+ grad_w,
stride.vec(),
dilation.vec(),
padding.vec(),
@@ -195,20 +216,23 @@
groups);
}
- return std::make_tuple(
- mkldnn_to_dense(new_with_itensor_mkldnn(std::move(mkldnn_grad_weight),
- optTypeMetaToScalarType(grad_output.options().dtype_opt()),
- grad_output.options().device_opt())),
- bias_defined ? mkldnn_to_dense(new_with_itensor_mkldnn(std::move(mkldnn_grad_bias),
- optTypeMetaToScalarType(grad_output.options().dtype_opt()),
- grad_output.options().device_opt())) : Tensor());
+ if (!is_channels_last) {
+ return std::make_tuple(
+ mkldnn_to_dense(MKLDNNTensor(grad_w, grad_output.options())),
+ bias_defined ? mkldnn_to_dense(MKLDNNTensor(grad_b, grad_output.options())) : Tensor());
+ } else {
+ return std::make_tuple(
+ mkldnn_to_dense(MKLDNNTensor(grad_w, grad_output.options())).to(at::MemoryFormat::ChannelsLast),
+ bias_defined ? mkldnn_to_dense(MKLDNNTensor(grad_b, grad_output.options())) : Tensor());
+ }
}
std::tuple<Tensor, Tensor, Tensor> mkldnn_convolution_backward(
const Tensor& input, const Tensor& grad_output_t, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, std::array<bool,3> output_mask)
{
- Tensor grad_output = grad_output_t.is_mkldnn() ? grad_output_t : grad_output_t.contiguous();
+ auto memory_format = input.suggest_memory_format();
+ Tensor grad_output = grad_output_t.is_mkldnn() ? grad_output_t : grad_output_t.contiguous(memory_format);
Tensor grad_input, grad_weight, grad_bias;
if (output_mask[0]) {
diff --git a/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp b/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
index cfbbf5c..fbfb329 100644
--- a/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
+++ b/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
@@ -30,7 +30,7 @@
: stensor.to_public(cpu_tensor.template data_ptr<BFloat16>(),
ideep::tensor::data_type::bf16);
cpu_tensor.as_strided_(dims, pub_tensor.get_strides());
- return cpu_tensor;
+ return cpu_tensor.contiguous();
}
Tensor dense_to_mkldnn(const Tensor& cpu_tensor, c10::optional<ScalarType> dtype) {
@@ -43,7 +43,7 @@
"dense_to_mkldnn expects float or bfloat16 tensor input");
TORCH_CHECK(cpu_tensor.dim() <= 5,
"Can't convert cpu tensor with the number of dimensions > 5");
- // TODO: consider to convert non-contiguous tensor to `ideep::tensor` directly.
+ // NOTE: forbid direct convert from non-contiguous (or channels last) to `ideep::tensor`.
auto cpu_tensor_cont = cpu_tensor.contiguous();
auto data_type = dtype.has_value() ? dtype.value() : cpu_tensor.scalar_type();
TORCH_CHECK(data_type == ScalarType::Float || data_type == ScalarType::BFloat16,
diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py
index bfaca50..5294816 100644
--- a/test/test_mkldnn.py
+++ b/test/test_mkldnn.py
@@ -241,6 +241,47 @@
def test_conv3d(self):
self._test_conv_base(dim=3)
+ def test_conv2d_nhwc(self):
+ conv_module = torch.nn.Conv2d
+ input_shapes = (224, 224)
+ options = itertools.product([True, False], [True, False], [1, 2], [1, 4])
+ for train, bias, dilation, groups in options:
+ N = torch.randint(3, 10, (1,)).item()
+ M = torch.randint(1, 3, (1,)).item() * groups
+ C = torch.randint(1, 3, (1,)).item() * groups
+ x_shape = (N, C) + input_shapes
+ x = torch.randn(x_shape, dtype=torch.float32)
+ # conv1: mkldnn conv2d in contiguous memory format (nchw)
+ # conv2: mkldnn conv2d in channels last memory format (nhwc)
+ conv1 = conv_module(in_channels=C,
+ out_channels=M,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ dilation=dilation,
+ bias=bias,
+ groups=groups).float()
+ conv2 = copy.deepcopy(conv1).to(memory_format=torch.channels_last)
+ x1 = x.clone()
+ x2 = x.clone().to(memory_format=torch.channels_last)
+ if train:
+ x1.requires_grad_()
+ x2.requires_grad_()
+ y1 = conv1(x1)
+ y2 = conv2(x2)
+ self.assertEqual(y1, y2)
+ if train:
+ y1.sum().backward()
+ y2.sum().backward()
+ self.assertTrue(x2.grad.is_contiguous(memory_format=torch.channels_last))
+ self.assertEqual(conv1.weight.grad,
+ conv2.weight.grad,
+ atol=1e-3,
+ rtol=1e-3)
+ if bias:
+ self.assertEqual(conv1.bias.grad, conv2.bias.grad)
+ self.assertEqual(x1.grad, x2.grad)
+
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
def _test_conv_bf16_base(self, dim):
conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py
index 44554ba..42b1ea2 100644
--- a/torch/testing/_internal/common_quantization.py
+++ b/torch/testing/_internal/common_quantization.py
@@ -1825,7 +1825,7 @@
x = self.sub1(x)
x = self.dequant(x)
x = self.sub2(x)
- x = x.view(-1, 36).contiguous()
+ x = x.reshape(-1, 36).contiguous()
x = self.fc(x)
y = self.conv2(y)
y = self.relu2(y)