DNNL: enable conv3d (#35662)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35662
Test Plan: Imported from OSS
Differential Revision: D22102408
Pulled By: VitalyFedyunin
fbshipit-source-id: 1e95cede429f1a950f26bc7052ab33d198857df3
diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp
index 48b14e5..aae3c75 100644
--- a/aten/src/ATen/native/Convolution.cpp
+++ b/aten/src/ATen/native/Convolution.cpp
@@ -236,8 +236,8 @@
(input.options().backend() == at::Backend::CPU &&
input.scalar_type() == kFloat && // only on CPU Float Tensors
!is_dilated() && // doesn't support dilation
- !transposed && // or transposed tensors
- input.ndimension() == 4); // must be in NCHW format
+ !transposed // or transposed tensors
+ );
#endif
return false;
}
diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp
index b6df8de..0479aba 100644
--- a/aten/src/ATen/native/mkldnn/Conv.cpp
+++ b/aten/src/ATen/native/mkldnn/Conv.cpp
@@ -55,7 +55,7 @@
namespace at { namespace native {
-ideep::tensor _mkldnn_conv2d(
+ideep::tensor _mkldnn_convolution(
const ideep::tensor& x,
const ideep::tensor& w,
const c10::optional<ideep::tensor>& b,
@@ -113,7 +113,7 @@
mkldnn_bias = get_mkldnn_tensor(bias);
}
- ideep::tensor mkldnn_output = _mkldnn_conv2d(
+ ideep::tensor mkldnn_output = _mkldnn_convolution(
mkldnn_input,
mkldnn_weight,
mkldnn_bias,
diff --git a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp
index 13250b7..6871026 100644
--- a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp
+++ b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp
@@ -52,7 +52,7 @@
}
ideep::tensor& itensor_from_mkldnn(const MKLDNNTensor& mkldnn_tensor) {
- AT_ASSERTM(mkldnn_tensor.is_mkldnn(),
+ TORCH_CHECK(mkldnn_tensor.is_mkldnn(),
"mkldnn_to_dense expects MKL-DNN tensor input");
TORCH_INTERNAL_ASSERT(at::impl::variable_excluded_from_dispatch());
MKLDNNTensorImpl *mklimpl = static_cast<MKLDNNTensorImpl *>(mkldnn_tensor.unsafeGetTensorImpl());
@@ -60,13 +60,13 @@
}
ideep::tensor itensor_view_from_dense(const Tensor& tensor) {
- AT_ASSERTM(
+ TORCH_CHECK(
tensor.device().type() == DeviceType::CPU,
"itensor_view_from_dense expects CPU tensor input");
- AT_ASSERTM(
+ TORCH_CHECK(
tensor.layout() == Layout::Strided,
"itensor_view_from_dense expects dense tensor input");
- AT_ASSERTM(tensor.scalar_type() == ScalarType::Float,
+ TORCH_CHECK(tensor.scalar_type() == ScalarType::Float,
"itensor_view_from_dense expects float tensor input");
TORCH_INTERNAL_ASSERT(at::impl::variable_excluded_from_dispatch());
return {{{tensor.sizes().cbegin(), tensor.sizes().cend()},
diff --git a/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp b/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
index c6736ce..971fa7a 100644
--- a/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
+++ b/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
@@ -22,13 +22,13 @@
}
Tensor dense_to_mkldnn(const Tensor& cpu_tensor) {
- AT_ASSERTM(cpu_tensor.device().type() == DeviceType::CPU,
+ TORCH_CHECK(cpu_tensor.device().type() == DeviceType::CPU,
"dense_to_mkldnn expects CPU tensor input");
- AT_ASSERTM(cpu_tensor.layout() == Layout::Strided,
+ TORCH_CHECK(cpu_tensor.layout() == Layout::Strided,
"dense_to_mkldnn expects strided tensor input");
- AT_ASSERTM(cpu_tensor.scalar_type() == ScalarType::Float,
+ TORCH_CHECK(cpu_tensor.scalar_type() == ScalarType::Float,
"dense_to_mkldnn expects float tensor input");
- AT_ASSERTM(cpu_tensor.dim() <= 5,
+ TORCH_CHECK(cpu_tensor.dim() <= 5,
"Can't convert cpu tensor with the number of dimensions > 5");
// TODO: consider to convert non-contiguous tensor to `ideep::tensor` directly.
auto cpu_tensor_cont = cpu_tensor.contiguous();
@@ -53,10 +53,6 @@
IntArrayRef dilation,
int64_t groups) {
- auto stride_vec = expand_param_if_needed(stride, "stride", 2);
- auto padding_vec = expand_param_if_needed(padding, "padding", 2);
- auto dilation_vec = expand_param_if_needed(dilation, "dilation", 2);
-
auto w = itensor_from_mkldnn(self);
// Legacy mkldnn conv2d jitted module may contain a 5-d weight with an extra
@@ -73,10 +69,36 @@
ideep::convolution_forward::expected_weights_desc(
w.get_dims(),
w.get_data_type(),
- {stride_vec.cbegin(), stride_vec.cend()},
- {padding_vec.cbegin(), padding_vec.cend()},
- {padding_vec.cbegin(), padding_vec.cend()},
- {dilation_vec.cbegin(), dilation_vec.cend()},
+ {stride.begin(), stride.end()},
+ {padding.begin(), padding.end()},
+ {padding.begin(), padding.end()},
+ {dilation.begin(), dilation.end()},
+ groups,
+ ideep::algorithm::convolution_direct);
+ ideep::tensor result;
+ result.init(desc);
+ result.feed_from(w);
+
+ return new_with_itensor_mkldnn(std::move(result), self.options());
+}
+
+Tensor mkldnn_reorder_conv3d_weight(
+ const Tensor& self,
+ IntArrayRef padding,
+ IntArrayRef stride,
+ IntArrayRef dilation,
+ int64_t groups) {
+
+ auto w = itensor_from_mkldnn(self);
+
+ auto desc =
+ ideep::convolution_forward::expected_weights_desc(
+ w.get_dims(),
+ w.get_data_type(),
+ {stride.begin(), stride.end()},
+ {padding.begin(), padding.end()},
+ {padding.begin(), padding.end()},
+ {dilation.begin(), dilation.end()},
groups,
ideep::algorithm::convolution_direct);
ideep::tensor result;
@@ -89,11 +111,11 @@
#else
Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor) {
- AT_ERROR("MKL-DNN build is disabled");
+ TORCH_CHECK(false, "MKL-DNN build is disabled");
}
Tensor dense_to_mkldnn(const Tensor& cpu_tensor) {
- AT_ERROR("MKL-DNN build is disabled");
+ TORCH_CHECK(false, "MKL-DNN build is disabled");
}
Tensor mkldnn_reorder_conv2d_weight(
@@ -102,7 +124,16 @@
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups) {
- AT_ERROR("mkldnn_reorder_conv2d_weight: MKL-DNN build is disabled");
+ TORCH_CHECK(false, "mkldnn_reorder_conv2d_weight: MKL-DNN build is disabled");
+}
+
+Tensor mkldnn_reorder_conv3d_weight(
+ const Tensor& self,
+ IntArrayRef padding,
+ IntArrayRef stride,
+ IntArrayRef dilation,
+ int64_t groups) {
+ TORCH_CHECK(false, "mkldnn_reorder_conv3d_weight: MKL-DNN build is disabled");
}
#endif // AT_MKLDNN_ENABLED()
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 06aa828..66db2f6 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3558,6 +3558,12 @@
dispatch:
MkldnnCPU: mkldnn_reorder_conv2d_weight
+- func: mkldnn_reorder_conv3d_weight(Tensor self, int[3] padding=0, int[3] stride=1, int[3] dilation=1, int groups=1) -> Tensor
+ variants: function
+ python_module: nn
+ dispatch:
+ MkldnnCPU: mkldnn_reorder_conv3d_weight
+
- func: to_mkldnn_backward(Tensor grad, Tensor input) -> Tensor
use_c10_dispatcher: full
diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py
index 2043095..3bae59c 100644
--- a/test/test_mkldnn.py
+++ b/test/test_mkldnn.py
@@ -177,6 +177,29 @@
conv2d(x),
conv2d_loaded(x.to_mkldnn()).to_dense())
+ def test_conv3d(self):
+ for groups in [1, 4]:
+ N = torch.randint(3, 10, (1,)).item()
+ C = torch.randint(1, 3, (1,)).item() * groups
+ M = torch.randint(1, 3, (1,)).item() * groups
+ x = torch.randn(N, C, 55, 55, 55, dtype=torch.float32)
+ for bias in [True, False]:
+ conv3d = torch.nn.Conv3d(in_channels=C,
+ out_channels=M,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=bias,
+ groups=groups).float()
+ mkldnn_conv3d = mkldnn_utils.to_mkldnn(copy.deepcopy(conv3d))
+ with torch.backends.mkldnn.flags(enabled=False):
+ y_aten = conv3d(x)
+ y_mkldnn = mkldnn_conv3d(x.to_mkldnn()).to_dense()
+ self.assertEqual(y_aten, y_mkldnn)
+
+ self._test_serialization(mkldnn_conv3d, (x.to_mkldnn(),))
+ self._test_tracing(mkldnn_conv3d, (x.to_mkldnn(),))
+
def test_relu(self):
x = torch.randn((4, 5), dtype=torch.float32) * 10
self.assertEqual(torch.relu(x), torch.relu(x.to_mkldnn()).to_dense())
diff --git a/torch/utils/mkldnn.py b/torch/utils/mkldnn.py
index 1a9c595..b084c26 100644
--- a/torch/utils/mkldnn.py
+++ b/torch/utils/mkldnn.py
@@ -105,6 +105,29 @@
self.training = state[2]
+class MkldnnConv3d(_MkldnnConvNd):
+ def __init__(self, dense_module):
+ super(MkldnnConv3d, self).__init__(dense_module)
+
+ self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv3d_weight(
+ dense_module.weight.to_mkldnn(),
+ self.padding,
+ self.stride,
+ self.dilation,
+ self.groups))
+
+ @torch.jit.script_method
+ def __setstate__(self, state):
+ self.weight = torch._C._nn.mkldnn_reorder_conv3d_weight(
+ state[0].to_mkldnn(),
+ self.padding,
+ self.stride,
+ self.dilation,
+ self.groups)
+ self.bias = state[1].to_mkldnn()
+ self.training = state[2]
+
+
class MkldnnBatchNorm2d(torch.jit.ScriptModule):
__constants__ = ['exponential_average_factor', 'eps']
@@ -165,6 +188,8 @@
return MkldnnConv1d(m)
elif isinstance(m, torch.nn.Conv2d):
return MkldnnConv2d(m)
+ elif isinstance(m, torch.nn.Conv3d):
+ return MkldnnConv3d(m)
elif isinstance(m, torch.nn.BatchNorm2d):
return MkldnnBatchNorm2d(m)
else: