blob: 7750097eb22188e292213ebd0cc7b2144db8ecff [file] [log] [blame]
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Config.h>
#if !AT_MKLDNN_ENABLED()
namespace at { namespace native {
at::Tensor mkldnn_convolution(
const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) {
AT_ERROR("mkldnn_convolution_forward: ATen not compiled with MKLDNN support");
}
at::Tensor mkldnn_convolution_backward_input(
IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
AT_ERROR("mkldnn_convolution_backward_input: ATen not compiled with MKLDNN support");
}
std::tuple<at::Tensor,at::Tensor> mkldnn_convolution_backward_weights(
IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
AT_ERROR("mkldnn_convolution_backward_weights: ATen not compiled with MKLDNN support");
}
std::tuple<at::Tensor,at::Tensor,at::Tensor> mkldnn_convolution_backward(
const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, std::array<bool,3> output_mask) {
AT_ERROR("mkldnn_convolution_backward: ATen not compiled with MKLDNN support");
}
}}
#else // AT_MKLDNN_EBABLED
#include <ATen/mkldnn/Runtime.h>
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ATen/native/mkldnn/Utils.h>
#include <ATen/native/ConvUtils.h>
using namespace mkldnn;
namespace {
// Helper function for getting an ideep tensor out of an aten Tensor.
// Note in case the aten Tensor is a dense tensor, the returned ideep
// tensor is just a view of the storage of the aten dense tensor, so
// caller needs to make sure the aten dense tensor's lifetime is
// longer than the ideep tensor.
inline ideep::tensor get_mkldnn_tensor(const at::Tensor& tensor) {
if (tensor.is_mkldnn()) {
return at::native::itensor_from_mkldnn(tensor);
} else {
return at::native::itensor_view_from_dense(tensor);
}
}
}
namespace at { namespace native {
ideep::tensor _mkldnn_conv2d(
const ideep::tensor& x,
const ideep::tensor& w,
const c10::optional<ideep::tensor>& b,
at::IntArrayRef padding,
at::IntArrayRef stride,
at::IntArrayRef dilation,
int64_t groups) {
std::vector<int64_t> kernel_size(x.ndims());
// mkldnn conv2d weights could have been re-ordered to 5d by
// mkldnn_reorder_conv2d_weight
if (w.ndims() == x.ndims() + 1) {
AT_ASSERTM(
groups > 1,
"Only group _mkldnn_conv2d weights could have been reordered to 5d");
kernel_size[0] = w.get_dim(0) * w.get_dim(1);
std::copy_n(
w.get_dims().cbegin() + 2, x.ndims() - 1, kernel_size.begin() + 1);
} else {
std::copy_n(w.get_dims().cbegin(), x.ndims(), kernel_size.begin());
}
const ideep::param::dims x_dims = x.get_dims();
std::vector<int64_t> input_size{x_dims.cbegin(), x_dims.cend()};
std::vector<int64_t> output_sizes =
conv_output_size(input_size, kernel_size, padding, stride, dilation);
ideep::tensor y;
if (b.has_value()) {
ideep::convolution_forward::compute<AllocForMKLDNN>(
x,
w,
b.value(),
{output_sizes.cbegin(), output_sizes.cend()},
y,
{stride.begin(), stride.end()},
{dilation.begin(), dilation.end()},
{padding.begin(), padding.end()},
{padding.begin(), padding.end()},
groups,
ideep::descriptor_group::attr_t{},
ideep::algorithm::convolution_direct,
ideep::prop_kind::forward);
} else {
ideep::convolution_forward::compute<AllocForMKLDNN>(
x,
w,
{output_sizes.cbegin(), output_sizes.cend()},
y,
{stride.begin(), stride.end()},
{dilation.begin(), dilation.end()},
{padding.begin(), padding.end()},
{padding.begin(), padding.end()},
groups,
ideep::descriptor_group::attr_t{},
ideep::algorithm::convolution_direct,
ideep::prop_kind::forward);
}
return y;
}
at::Tensor mkldnn_convolution(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& bias,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups) {
const ideep::tensor mkldnn_input = get_mkldnn_tensor(input);
const ideep::tensor mkldnn_weight = get_mkldnn_tensor(weight);
c10::optional<ideep::tensor> mkldnn_bias{c10::nullopt};
if (bias.defined()) {
mkldnn_bias = get_mkldnn_tensor(bias);
}
ideep::tensor mkldnn_output = _mkldnn_conv2d(
mkldnn_input,
mkldnn_weight,
mkldnn_bias,
padding,
stride,
dilation,
groups);
if (input.is_mkldnn()) {
return new_with_itensor_mkldnn(std::move(mkldnn_output), input.options());
} else {
return mkldnn_to_dense(
new_with_itensor_mkldnn(std::move(mkldnn_output), input.options()));
}
}
Tensor mkldnn_convolution_backward_input(
IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined)
{
auto grad_input = at::empty(input_size, grad_output.options());
auto cpu_engine = CpuEngine::Instance().get_engine();
int32_t g = groups;
int32_t n = grad_input.size(0);
int32_t ic = grad_input.size(1);
int32_t ih = grad_input.size(2);
int32_t iw = grad_input.size(3);
int32_t oc = grad_output.size(1);
int32_t oh = grad_output.size(2);
int32_t ow = grad_output.size(3);
int32_t kh = weight.size(2);
int32_t kw = weight.size(3);
int32_t sh = stride[0];
int32_t sw = stride[1];
int32_t ph = padding[0];
int32_t pw = padding[1];
auto data_t = memory::data_type::f32;
auto format_any = memory::format::any;
auto format_nchw = memory::format::nchw;
auto format_weight = (g!= 1) ? memory::format::goihw : memory::format::oihw;
memory::dims input_tz = {n, ic, ih, iw};
memory::dims weight_tz = (g!= 1) ? memory::dims{g, oc/g, ic/g, kh, kw} : memory::dims{oc, ic, kh, kw};
memory::dims bias_tz = {oc};
memory::dims output_tz = {n, oc, oh, ow};
memory::dims _stride = {sh, sw};
memory::dims _padding = {ph, pw};
auto input_md = memory::desc({input_tz}, data_t, format_any);
auto weight_md = memory::desc({weight_tz}, data_t, format_any);
auto bias_md = memory::desc({bias_tz}, data_t, format_any);
auto output_md = memory::desc({output_tz}, data_t, format_any);
// need to re-create conv_forward_pd to feed conv_backward_data_pd
std::shared_ptr<convolution_forward::desc> conv_forward_desc;
if (bias_defined) {
conv_forward_desc.reset(new convolution_forward::desc(prop_kind::forward,
convolution_direct, input_md, weight_md, bias_md, output_md,
_stride, _padding, _padding, padding_kind::zero));
} else {
conv_forward_desc.reset(new convolution_forward::desc(prop_kind::forward,
convolution_direct, input_md, weight_md, output_md,
_stride, _padding, _padding, padding_kind::zero));
}
std::shared_ptr<convolution_forward::primitive_desc> conv_forward_pd;
conv_forward_pd.reset(new convolution_forward::primitive_desc(
*conv_forward_desc, cpu_engine));
std::shared_ptr<convolution_backward_data::desc> conv_backward_data_desc;
conv_backward_data_desc.reset(new convolution_backward_data::desc(
convolution_direct, input_md, weight_md, output_md,
_stride, _padding, _padding, padding_kind::zero));
std::shared_ptr<convolution_backward_data::primitive_desc> conv_backward_data_pd;
conv_backward_data_pd.reset(new convolution_backward_data::primitive_desc(
*conv_backward_data_desc, cpu_engine, *conv_forward_pd));
auto grad_output_usr_memory = memory({{{output_tz}, data_t, format_nchw}, cpu_engine},
grad_output.data_ptr());
auto weight_usr_memory = memory({{{weight_tz}, data_t, format_weight}, cpu_engine},
weight.data_ptr());
auto grad_input_usr_memory = memory({{{input_tz}, data_t, format_nchw}, cpu_engine},
grad_input.data_ptr());
std::vector<primitive> net;
auto grad_output_pd = conv_backward_data_pd->diff_dst_primitive_desc();
auto grad_output_memory = grad_output_usr_memory;
if (grad_output_usr_memory.get_primitive_desc() != memory::primitive_desc(grad_output_pd)) {
grad_output_memory = memory(grad_output_pd);
net.push_back(reorder(grad_output_usr_memory, grad_output_memory));
}
auto weight_pd = conv_backward_data_pd->weights_primitive_desc();
auto weight_memory = weight_usr_memory;
if (weight_usr_memory.get_primitive_desc() != memory::primitive_desc(weight_pd)) {
weight_memory = memory(weight_pd);
net.push_back(reorder(weight_usr_memory, weight_memory));
}
auto grad_input_pd = conv_backward_data_pd->diff_src_primitive_desc();
auto grad_input_memory = grad_input_usr_memory;
if (grad_input_memory.get_primitive_desc() != memory::primitive_desc(grad_input_pd)) {
grad_input_memory = memory(grad_input_pd);
}
std::shared_ptr<convolution_backward_data> conv_backward_data;
conv_backward_data.reset(new convolution_backward_data(*conv_backward_data_pd,
grad_output_memory, weight_memory, grad_input_memory));
net.push_back(*conv_backward_data);
if (grad_input_memory != grad_input_usr_memory) {
net.push_back(reorder(grad_input_memory, grad_input_usr_memory));
}
Stream::Instance().get_stream().submit(net);
return grad_input;
}
std::tuple<at::Tensor, at::Tensor> mkldnn_convolution_backward_weights(
IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined)
{
auto grad_weight = at::empty(weight_size, grad_output.options());
Tensor grad_bias;
if (bias_defined) {
grad_bias = at::empty({grad_output.size(1)}, grad_output.options());
}
auto cpu_engine = CpuEngine::Instance().get_engine();
int32_t g = groups;
int32_t n = input.size(0);
int32_t ic = input.size(1);
int32_t ih = input.size(2);
int32_t iw = input.size(3);
int32_t oc = grad_output.size(1);
int32_t oh = grad_output.size(2);
int32_t ow = grad_output.size(3);
int32_t kh = grad_weight.size(2);
int32_t kw = grad_weight.size(3);
int32_t sh = stride[0];
int32_t sw = stride[1];
int32_t ph = padding[0];
int32_t pw = padding[1];
auto data_t = memory::data_type::f32;
auto format_any = memory::format::any;
auto format_nchw = memory::format::nchw;
auto format_weight = (g!= 1) ? memory::format::goihw : memory::format::oihw;
auto format_x = memory::format::x;
memory::dims input_tz = {n, ic, ih, iw};
memory::dims weight_tz = (g!= 1) ? memory::dims{g, oc/g, ic/g, kh, kw} : memory::dims{oc, ic, kh, kw};
memory::dims bias_tz = {oc};
memory::dims output_tz = {n, oc, oh, ow};
memory::dims _stride = {sh, sw};
memory::dims _padding = {ph, pw};
memory::desc input_md({input_tz}, data_t, format_any);
memory::desc weight_md({weight_tz}, data_t, format_any);
memory::desc bias_md({bias_tz}, data_t, format_any);
memory::desc output_md({output_tz}, data_t, format_any);
// need to re-create conv_forward_pd to feed conv_backward_weight_pd
std::shared_ptr<convolution_forward::desc> conv_forward_desc;
if (bias_defined) {
conv_forward_desc.reset(new convolution_forward::desc(prop_kind::forward,
convolution_direct, input_md, weight_md, bias_md, output_md,
_stride, _padding, _padding, padding_kind::zero));
} else {
conv_forward_desc.reset(new convolution_forward::desc(prop_kind::forward,
convolution_direct, input_md, weight_md, output_md,
_stride, _padding, _padding, padding_kind::zero));
}
std::shared_ptr<convolution_forward::primitive_desc> conv_forward_pd;
conv_forward_pd.reset(new convolution_forward::primitive_desc(
*conv_forward_desc, cpu_engine));
std::shared_ptr<convolution_backward_weights::desc> conv_backward_weight_desc;
if (bias_defined) {
conv_backward_weight_desc.reset(new convolution_backward_weights::desc(
convolution_direct, input_md, weight_md, bias_md, output_md,
_stride, _padding, _padding, padding_kind::zero));
} else {
conv_backward_weight_desc.reset(new convolution_backward_weights::desc(
convolution_direct, input_md, weight_md, output_md,
_stride, _padding, _padding, padding_kind::zero));
}
std::shared_ptr<convolution_backward_weights::primitive_desc> conv_backward_weight_pd;
conv_backward_weight_pd.reset(new convolution_backward_weights::primitive_desc(
*conv_backward_weight_desc, cpu_engine, *conv_forward_pd));
auto input_usr_memory = memory({{{input_tz}, data_t, format_nchw}, cpu_engine},
input.data_ptr());
auto grad_output_usr_memory = memory({{{output_tz}, data_t, format_nchw}, cpu_engine},
grad_output.data_ptr());
auto grad_weight_usr_memory = memory({{{weight_tz}, data_t, format_weight}, cpu_engine},
grad_weight.data_ptr());
std::shared_ptr<memory> grad_bias_memory;
std::vector<primitive> net;
auto input_pd = conv_backward_weight_pd->src_primitive_desc();
auto input_memory = input_usr_memory;
if (input_usr_memory.get_primitive_desc() != memory::primitive_desc(input_pd)) {
input_memory = memory(input_pd);
net.push_back(reorder(input_usr_memory, input_memory));
}
auto grad_output_pd = conv_backward_weight_pd->diff_dst_primitive_desc();
auto grad_output_memory = grad_output_usr_memory;
if (grad_output_usr_memory.get_primitive_desc() != memory::primitive_desc(grad_output_pd)) {
grad_output_memory = memory(grad_output_pd);
net.push_back(reorder(grad_output_usr_memory, grad_output_memory));
}
auto grad_weight_pd = conv_backward_weight_pd->diff_weights_primitive_desc();
auto grad_weight_memory = grad_weight_usr_memory;
if (grad_weight_usr_memory.get_primitive_desc() != memory::primitive_desc(grad_weight_pd)) {
grad_weight_memory = memory(grad_weight_pd);
}
std::shared_ptr<convolution_backward_weights> conv_backward_weight;
if (bias_defined) {
grad_bias_memory.reset(new memory({{{bias_tz}, data_t, format_x}, cpu_engine},
grad_bias.data_ptr()));
conv_backward_weight.reset(new convolution_backward_weights(*conv_backward_weight_pd,
input_memory, grad_output_memory, grad_weight_memory, *grad_bias_memory));
} else {
conv_backward_weight.reset(new convolution_backward_weights(*conv_backward_weight_pd,
input_memory, grad_output_memory, grad_weight_memory));
}
net.push_back(*conv_backward_weight);
if (grad_weight_memory != grad_weight_usr_memory) {
net.push_back(reorder(grad_weight_memory, grad_weight_usr_memory));
}
Stream::Instance().get_stream().submit(net);
return std::tuple<at::Tensor, at::Tensor>{grad_weight, grad_bias};
}
std::tuple<at::Tensor,at::Tensor,at::Tensor> mkldnn_convolution_backward(
const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, std::array<bool,3> output_mask)
{
Tensor grad_output = grad_output_t.contiguous();
Tensor grad_input, grad_weight, grad_bias;
if (output_mask[0]) {
grad_input = at::mkldnn_convolution_backward_input(
input.sizes(), grad_output, weight, padding, stride, dilation, groups, output_mask[2]);
}
if (output_mask[1] || output_mask[2]) {
std::tie(grad_weight, grad_bias) = at::mkldnn_convolution_backward_weights(
weight.sizes(), grad_output, input, padding, stride, dilation, groups, output_mask[2]);
}
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
}
}} // namespace at::native
#endif