blob: dd7ad986897a40a32cf4161cc432a541e3aae7d4 [file] [log] [blame]
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Config.h>
#include <ATen/Parallel.h>
#include <ATen/TensorOperators.h>
#include <ATen/native/ConvolutionMM3d.h>
#include <ATen/native/ConvUtils.h>
#include <ATen/native/Pool.h>
#include <ATen/native/cpu/DepthwiseConvKernel.h>
#include <ATen/native/utils/ParamUtils.h>
#include <ATen/native/xnnpack/Engine.h>
#include <c10/util/accumulate.h>
#include <c10/util/irange.h>
#include <c10/macros/Macros.h>
#include <limits>
#include <utility>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/permute.h>
#endif
#if AT_NNPACK_ENABLED()
#include <nnpack.h>
#endif
#if AT_MKLDNN_ENABLED()
#include <ATen/native/mkldnn/Utils.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_conv_depthwise2d.h>
#include <ATen/ops/_convolution.h>
#include <ATen/ops/_convolution_double_backward_native.h>
#include <ATen/ops/_convolution_mode.h>
#include <ATen/ops/_convolution_mode_native.h>
#include <ATen/ops/_convolution_native.h>
#include <ATen/ops/_mps_convolution.h>
#include <ATen/ops/_mps_convolution_transpose.h>
#include <ATen/ops/_nnpack_available.h>
#include <ATen/ops/_nnpack_spatial_convolution.h>
#include <ATen/ops/_slow_conv2d_backward.h>
#include <ATen/ops/_unsafe_view.h>
#include <ATen/ops/cat.h>
#include <ATen/ops/constant_pad_nd.h>
#include <ATen/ops/conv1d_native.h>
#include <ATen/ops/conv2d_native.h>
#include <ATen/ops/conv3d_native.h>
#include <ATen/ops/conv_depthwise3d.h>
#include <ATen/ops/conv_transpose1d_native.h>
#include <ATen/ops/conv_transpose2d_native.h>
#include <ATen/ops/conv_transpose3d_native.h>
#include <ATen/ops/convolution.h>
#include <ATen/ops/convolution_backward_native.h>
#include <ATen/ops/convolution_backward_overrideable.h>
#include <ATen/ops/convolution_backward_overrideable_native.h>
#include <ATen/ops/convolution_native.h>
#include <ATen/ops/convolution_overrideable.h>
#include <ATen/ops/convolution_overrideable_native.h>
#include <ATen/ops/cudnn_convolution.h>
#include <ATen/ops/cudnn_convolution_transpose.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/empty_native.h>
#include <ATen/ops/miopen_convolution.h>
#include <ATen/ops/miopen_convolution_transpose.h>
#include <ATen/ops/miopen_depthwise_convolution.h>
#include <ATen/ops/mkldnn_convolution.h>
#include <ATen/ops/mps_convolution_backward.h>
#include <ATen/ops/mps_convolution_transpose_backward.h>
#include <ATen/ops/slow_conv3d.h>
#include <ATen/ops/slow_conv_dilated2d.h>
#include <ATen/ops/slow_conv_dilated3d.h>
#include <ATen/ops/slow_conv_transpose2d.h>
#include <ATen/ops/slow_conv_transpose3d.h>
#include <ATen/ops/thnn_conv2d.h>
#include <ATen/ops/view_as_real.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/zeros_like.h>
#endif
constexpr int MIOPEN_DIM_MAX = 5;
namespace at { namespace native {
// Check workload to activate fast depthwise FP16 cudnn conv kernels
template <typename T>
bool check_cudnn_depthwise_workload(const at::Tensor& input, int stride) {
auto w = at::symint::size<T>(input, 3); // same as h
auto ch = at::symint::size<T>(input, 1);
auto bs = at::symint::size<T>(input, 0);
if (stride==1) {
if (w >= 7) {
// All batch sizes and nb_channels
if (w >= 112) {
return true;
}
// large nb_channels
if (ch >= 1024) {
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if (w >= 56) {
return true;
} else if (bs >= 32) {
return true;
}
}
// batch_size specific
if (bs >= 128) {
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if (ch >= 512) {
return true;
} else if (ch >= 64) {
if (w >= 14) {
return true;
}
} else if ((ch >= 32) && (w >=28)) {
return true;
}
} else if (bs >= 64) {
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if ((ch >= 256) && (w >= 14)) {
return true;
} else if ((ch >= 32) && (w >= 28)) {
return true;
}
} else if (bs >= 32) {
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if ((ch >= 256) && (w >= 14)) {
return true;
} else if ((ch >= 128) && (w >= 28)) {
return true;
} else if ((ch >= 32) && (w >= 56)) {
return true;
}
} else if (bs >= 16) {
if ((ch >= 1024) && (w >= 14)) {
return true;
}
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if ((ch >= 256) && (w >= 28)) {
return true;
} else if ((ch >= 32) && (w >= 56)) {
return true;
}
} else if (bs >= 8) {
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if ((ch >= 512) && (w >= 28)) {
return true;
} else if ((ch >= 64) && (w >= 56)) {
return true;
}
}
}
} else if (stride==2) {
if (ch < 256) {
return false;
}
if (w >= 7) {
if (bs >= 128) {
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if (ch >= 1024) {
return true;
} else if ((ch >= 512) && (w >= 14)) {
return true;
} else if (w >= 28) {
return true;
}
} else if (bs >= 64) {
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if ((ch >= 512) && (w >= 14)) {
return true;
} else if (w >= 28) {
return true;
}
} else if (bs >= 32) {
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if ((ch >= 1024) && (w >= 14)) {
return true;
} else if (w >= 28) {
return true;
}
} else if (bs >= 16) {
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if ((ch >= 512) && (w >= 28)) {
return true;
} else if (w >= 56) {
return true;
}
} else if (bs >= 8) {
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if ((ch >= 1024) && (w >= 28)) {
return true;
} else if (w >= 56) {
return true;
}
} else if (bs >= 1) {
if ((ch >= 512) && (w >=112)) {
return true;
}
}
}
}
return false;
}
// simplified version for cudnn 8.2 and above
template <typename T>
bool check_cudnn_depthwise_workload_with_filter(const at::Tensor& input, int stride, const at::Tensor& weight) {
// 1D conv
if(at::symint::size<T>(input, 2) == 1 && stride == 1){
return true;
}
// 2d conv
// only square filters
if (at::symint::size<T>(weight, 2) != at::symint::size<T>(weight, 3)) return false;
auto filter = at::symint::size<T>(weight, 3);
// only 1/3/5 filter
if (filter != 1 && filter != 3 && filter != 5) return false;
// we don't enforce square input but only check width to reduce heuristic space
if (at::symint::size<T>(input, 3) < 7) return false; // min width 7
auto w = at::symint::size<T>(input, 3);
// only 1/2 stride, use cudnn for all stride 1
if (stride == 1) return true;
if (stride != 2) return false;
auto ch = at::symint::size<T>(input, 1);
auto bs = at::symint::size<T>(input, 0);
// special case since bs1 show good perf in lots of cases
if (bs == 1) {
if (filter == 1 && w <= 28) return true;
if (filter == 3 || filter == 5) return true;
} else {
if (filter == 1 && bs <= 16 && ch >= 128 && w <= 7) return true;
if (filter == 3 || filter == 5) {
if ((ch >= 512) || (ch >= 256 && w >= 28)) return true;
}
}
return false;
}
#if defined(C10_MOBILE)
static bool xnnpack_use_convolution2d(
const Tensor& input,
const Tensor& weight,
const at::OptionalIntArrayRef bias_sizes_opt,
const IntArrayRef padding,
const IntArrayRef stride,
const IntArrayRef dilation,
const int64_t groups,
const bool transposed) {
return xnnpack::use_convolution2d(input, weight, bias_sizes_opt, padding, stride, dilation, groups, transposed);
}
static bool xnnpack_use_convolution2d(
const Tensor& input,
const Tensor& weight,
const at::OptionalSymIntArrayRef bias_sizes_opt,
const SymIntArrayRef padding,
const IntArrayRef stride,
const IntArrayRef dilation,
const int64_t groups,
const bool transposed) {
// Never use xnnpack for symbolic tracing
return false;
}
#endif
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
// This struct is templated so that we can run backend selection in a dynamic
// shapes context; all of the real kernel selection in eager mode runs with
// int64_t
template <typename T>
struct ConvParams {
std::vector<int64_t> stride;
std::vector<T> padding;
std::vector<int64_t> dilation;
bool transposed;
std::vector<T> output_padding;
int groups;
bool benchmark;
bool deterministic;
bool cudnn_enabled;
bool allow_tf32;
bool is_strided() const {
bool is_strided = false;
for (auto s : stride) {
is_strided |= (s != 1);
}
return is_strided;
}
bool is_dilated() const {
bool is_dilated = false;
for (auto d : dilation) {
is_dilated |= (d != 1);
}
return is_dilated;
}
bool is_padded() const {
bool is_padded = false;
for (auto p : padding) {
is_padded |= (p != 0);
}
return is_padded;
}
bool is_output_padding_neg() const {
bool is_non_neg = false;
for (const auto& p : output_padding) {
is_non_neg |= (p < 0);
}
return is_non_neg;
}
bool is_output_padding_big() const {
bool is_big = false;
for (auto i: c10::irange(output_padding.size())) {
is_big |= (output_padding[i] >= stride[i]);
}
return is_big;
}
bool is_padding_neg() const {
bool is_non_neg = false;
for (const auto& p : padding) {
is_non_neg |= (p < 0);
}
return is_non_neg;
}
bool is_stride_nonpos() const {
bool is_nonpos = false;
for (auto s : stride) {
is_nonpos |= (s <= 0);
}
return is_nonpos;
}
void view1d_as_2d() {
if (stride.size() == 1) {
stride.insert(stride.begin(), 1);
padding.insert(padding.begin(), 0);
dilation.insert(dilation.begin(), 1);
output_padding.insert(output_padding.begin(), 0);
}
}
bool use_cpu_depthwise3x3_winograd(const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias) const {
#if defined(__ARM_NEON__)
// Currently only 3x3 depthwise convolutions on tensors of float are supported.
return (input.ndimension() == 4) &&
(at::symint::size<T>(input, 1) == groups) &&
(weight.ndimension() == 4 ) &&
(at::symint::size<T>(weight, 0) % at::symint::size<T>(input, 1) == 0) &&
(at::symint::size<T>(weight, 1) == 1) &&
(at::symint::size<T>(weight, 2) == 3) &&
(at::symint::size<T>(weight, 3) == 3) &&
(input.device().is_cpu()) &&
(input.scalar_type() == at::kFloat) &&
input.is_contiguous() &&
(weight.device().is_cpu()) &&
(weight.scalar_type() == at::kFloat) &&
weight.is_contiguous() &&
(!bias.has_value() || bias->is_contiguous()) &&
!is_strided() &&
!is_dilated() &&
!transposed;
#else
return false;
#endif
}
bool needs_64bit_indexing_no_split(const at::Tensor& input, const at::Tensor& weight) const {
constexpr int64_t int_max = std::numeric_limits<int>::max();
auto numel_input = at::symint::numel<T>(input);
// empty input
if (numel_input == 0) {
return false;
}
// input size can not be reduced to the range of int by splitting the batch dim
auto n = at::symint::size<T>(input, 0);
if (numel_input / n > int_max) {
return true;
}
// output size can not be reduced to the range of int by splitting the batch dim
T outsize = 1;
if (transposed) {
auto o = conv_input_size(at::symint::sizes<T>(input), at::symint::sizes<T>(weight), padding, output_padding, stride, dilation, groups);
outsize = c10::multiply_integers(o.begin() + 1, o.end());
} else {
auto o = conv_output_size(at::symint::sizes<T>(input), at::symint::sizes<T>(weight), padding, stride, dilation);
outsize = c10::multiply_integers(o.begin() + 1, o.end());
}
return outsize > int_max;
}
bool use_cudnn(const at::Tensor& input, const at::Tensor& weight) const {
// Note [Mobile check segfaults]
// cudnn and miopen are guaranteed not to be on mobile, and T102591915 / T110194934 suggest
// that maybe the compiledWithCuDNN() check sometimes segfaults (though I can't imagine how)
#if !defined(C10_MOBILE)
if (needs_64bit_indexing_no_split(input, weight)) {
return false;
}
if (!detail::getCUDAHooks().compiledWithCuDNN()) {
return false;
}
if (!input.is_cuda() || !cudnn_enabled) {
return false;
}
if (input.scalar_type() == at::kBFloat16 || weight.scalar_type() == at::kBFloat16) {
if (!(detail::getCUDAHooks().supportsBFloat16ConvolutionWithCuDNNv8() && at::native::cudnnv8_enabled_check_debug())) {
return false;
}
}
if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous) {
// bypass dilation checks for channels_last convolution
if (deterministic && is_dilated()) {
// cudnn doesn't support deterministic dilated convolution fully yet
return false;
}
if (is_dilated()) {
return detail::getCUDAHooks().supportsDilatedConvolutionWithCuDNN() && !is_output_padding_big();
}
}
return !is_output_padding_big();
#else
return false;
#endif
}
// Use cudnn for FP16 depthwise convolutions
bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const {
if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous && use_cudnn(input, weight)) {
// always use cudnn_depthwise for channels_last format
return true;
}
if (detail::getCUDAHooks().supportsDepthwiseConvolutionWithCuDNN()) {
long cudnn_version = detail::getCUDAHooks().versionCuDNN();
if (cudnn_version >= 8200) {
bool kernel_cond = (use_cudnn(input, weight) &&
input.scalar_type() == kHalf && // only for FP16
weight.scalar_type() == kHalf &&
is_depthwise(input, weight) &&
input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks
!is_dilated() && // no dilation supported
(stride[0] == stride[1] || at::symint::size<T>(input, 2) == 1) && // square or 1d
at::symint::size<T>(input, 1) >= 32); // min 32 channels supported)
if (kernel_cond) {
return check_cudnn_depthwise_workload_with_filter<T>(input, stride[1], weight);
}
}
// keep (7600 <= cudnn < 8200) code unchanged
bool kernel_cond = (cudnn_version >= 7600 &&
use_cudnn(input, weight) &&
input.scalar_type() == kHalf && // only for FP16
weight.scalar_type() == kHalf &&
is_depthwise(input, weight) &&
input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks
at::symint::size<T>(weight, 2) == at::symint::size<T>(weight, 3) && // only square kernels
at::symint::size<T>(input, 2) >= 7 && // min width/height 7
!is_dilated() && // no dilation supported
stride[0] == stride[1] && // equal strides
((at::symint::size<T>(weight, 3) == 3) || (at::symint::size<T>(weight, 3) == 1)) &&
at::symint::size<T>(input, 1) >= 32); // min 32 channels supported)
if (kernel_cond) {
return check_cudnn_depthwise_workload<T>(input, stride[0]);
} else {
return false;
}
} else {
return false;
}
}
bool use_miopen(const at::Tensor& input, const at::Tensor& weight, bool bias_defined) const {
if (needs_64bit_indexing_no_split(input, weight)) {
return false;
}
return ((input.scalar_type() == at::kFloat) || (input.scalar_type() == at::kHalf) || (input.scalar_type() == at::kBFloat16))
&& detail::getCUDAHooks().compiledWithMIOpen()
&& input.is_cuda()
&& input.dim() <= MIOPEN_DIM_MAX
&& !(groups > 1 && is_dilated()) // MIOpen currently does not support dilation with groups of size > 1
&& !(input.scalar_type() == at::kBFloat16 && bias_defined) // MIOpen currently doesn't support bias with bfloat16
&& cudnn_enabled
;
}
bool use_mkldnn(const at::Tensor& input, const at::Tensor& weight) const {
#if AT_MKLDNN_ENABLED()
if (!at::globalContext().userEnabledMkldnn()) {
return false;
}
if (transposed && is_output_padding_big()) {
return false;
}
if (transposed && groups > 1 && at::symint::size<T>(input, 1) == groups) {
return false;
}
if (input.device().is_cpu() && input.scalar_type() == kBFloat16 && mkldnn_bf16_device_check()) {
return true;
}
return (input.is_mkldnn()) || // input is mkldnn Tensor
(input.device().is_cpu() &&
input.scalar_type() == kFloat && // only on CPU Float Tensors
// For 1x1 filters, MKLDNN is faster than THNN when multi-threaded,
// but THNN is faster when single-threaded.
(is_strided() || is_dilated() || at::symint::size<T>(input, 0) >= 16 ||
at::symint::size<T>(weight, -1) != 1 || at::symint::size<T>(weight, -2) != 1 || at::get_num_threads() > 1) &&
(groups > 1
|| (at::symint::size<T>(weight, -1) > 3 && at::symint::size<T>(weight, -2) > 3)
|| at::symint::size<T>(input, 0) > 1
|| at::symint::size<T>(input, 0)*at::symint::size<T>(input, 1)*at::symint::size<T>(input, 2)*at::symint::size<T>(input, 3) > 20480) // for some case, native is faster
);
#endif
return false;
}
bool use_nnpack(const at::Tensor& input, const at::Tensor& weight) const {
#if AT_NNPACK_ENABLED()
return at::_nnpack_available() &&
input.device().is_cpu() &&
input.scalar_type() == kFloat && // only on CPU Float Tensors
!is_dilated() && // or dilation
!transposed && // or transposed tensors
input.ndimension() == 4 && // must be in NCHW format
weight.ndimension() == 4 &&
(at::symint::size<T>(weight, 2) < 17) && (at::symint::size<T>(weight, 3) < 17) // NNPACK only supports kernels up to 16x16
#if !defined(C10_MOBILE)
&& at::symint::size<T>(input, 0) >= 16 // ensure large enough batch size to ensure perf, tuneable
#endif
;
#endif
return false;
}
bool use_xnnpack(const at::Tensor& input, const at::Tensor& weight,
const at::OptionalArrayRef<T> bias_sizes_opt) const {
#if defined(C10_MOBILE)
if (!transposed) {
// NB: for the call here, it MATTERS that we are templated. If you
// untemplate this to always use SymInt, the function
// xnnpack_use_convolution2d will always return false
return (at::symint::size<T>(input, 1) == groups) &&
xnnpack_use_convolution2d(
input,
weight,
bias_sizes_opt,
padding,
stride,
dilation,
groups,
transposed);
}
#endif
return false;
}
bool use_mps(const at::Tensor& input, const at::Tensor& weight) const {
// These checks need to be expanded. Currently we have very limited set of
// checks for MPS.
#ifdef USE_MPS
if (needs_64bit_indexing_no_split(input, weight)) {
return false;
}
if (!input.is_mps()) {
return false;
}
return true;
#else
return false;
#endif
}
// We currently only have depthwise support for the case where groups ==
// nInputPlane and nInputPlane == nOutputPlane (the latter due to the lack of
// a depthwise multiplier)
bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const {
return input.is_cuda() &&
!transposed &&
(input.ndimension() == 4 || input.ndimension() == 5) &&
at::symint::size<T>(input, 1) == groups &&
groups > 1 && // no point if there is only a single group
at::symint::size<T>(weight, 0) % at::symint::size<T>(input, 1) == 0; // output channels must be a multiple of input channels
}
};
DEFINE_DISPATCH(conv_depthwise2d_backward_stub);
DEFINE_DISPATCH(conv_depthwise3d_backward_stub);
DEFINE_DISPATCH(cudnn_convolution_backward_stub);
DEFINE_DISPATCH(cudnn_convolution_transpose_backward_stub);
DEFINE_DISPATCH(slow_conv_transpose3d_backward_stub);
DEFINE_DISPATCH(convolution_depthwise3x3_winograd_stub);
DEFINE_DISPATCH(miopen_convolution_backward_stub);
DEFINE_DISPATCH(miopen_convolution_transpose_backward_stub);
DEFINE_DISPATCH(miopen_depthwise_convolution_backward_stub);
DEFINE_DISPATCH(mkldnn_convolution_backward_stub);
DEFINE_DISPATCH(mkldnn_convolution_transpose_stub);
DEFINE_DISPATCH(mkldnn_convolution_transpose_backward_stub);
DEFINE_DISPATCH(slow_conv_dilated2d_backward_stub);
DEFINE_DISPATCH(slow_conv_dilated3d_backward_stub);
DEFINE_DISPATCH(slow_conv_transpose2d_backward_stub);
REGISTER_NO_CPU_DISPATCH(conv_depthwise2d_backward_stub);
REGISTER_NO_CPU_DISPATCH(conv_depthwise3d_backward_stub);
REGISTER_NO_CPU_DISPATCH(cudnn_convolution_backward_stub);
REGISTER_NO_CPU_DISPATCH(cudnn_convolution_transpose_backward_stub);
REGISTER_NO_CPU_DISPATCH(miopen_convolution_backward_stub);
REGISTER_NO_CPU_DISPATCH(miopen_convolution_transpose_backward_stub);
REGISTER_NO_CPU_DISPATCH(miopen_depthwise_convolution_backward_stub);
template <typename T>
std::ostream& operator<<(std::ostream & out, const ConvParams<T>& params) {
out << "ConvParams {"
<< " stride = " << IntArrayRef{params.stride}
<< " padding = " << ArrayRef<T>{params.padding}
<< " dilation = " << IntArrayRef{params.dilation}
<< " transposed = " << params.transposed
<< " output_padding = " << ArrayRef<T>{params.output_padding}
<< " groups = " << params.groups
<< " benchmark = " << params.benchmark
<< " deterministic = " << params.deterministic
<< " cudnn_enabled = " << params.cudnn_enabled
<< " allow_tf32 = " << params.allow_tf32
<< "}";
return out;
}
template <typename T>
static void check_shape_forward(const at::Tensor& input,
const c10::ArrayRef<T>& weight_sizes, const at::Tensor& bias,
const ConvParams<T>& params) {
int64_t k = input.ndimension();
int64_t weight_dim = weight_sizes.size();
int64_t groups = params.groups;
const auto& padding = params.padding;
const auto& dilation = params.dilation;
bool transposed = params.transposed;
TORCH_CHECK(!params.is_padding_neg(), "negative padding is not supported");
TORCH_CHECK(!params.is_output_padding_neg(), "negative output_padding is not supported");
TORCH_CHECK(!params.is_stride_nonpos(), "non-positive stride is not supported");
TORCH_CHECK(weight_dim == k,
"Expected ", weight_dim, "-dimensional input for ", weight_dim,
"-dimensional weight ", weight_sizes, ", but got ", k, "-dimensional input of size ",
at::symint::sizes<T>(input), " instead");
TORCH_CHECK(weight_sizes[0] >= groups,
"Given groups=", groups, ", expected weight to be at least ", groups,
" at dimension 0, but got weight of size ", weight_sizes, " instead");
TORCH_CHECK(weight_sizes[0] % groups == 0,
"Given groups=", groups, ", expected weight to be divisible by ",
groups, " at dimension 0, but got weight of size [", weight_sizes,
"] instead");
if (!transposed) {
std::vector<T> input_shape;
std::vector<T> kernel_shape;
bool kernel_size_correct = true;
TORCH_CHECK(at::symint::size<T>(input, 1) == (weight_sizes[1] * groups),
"Given groups=", groups, ", weight of size ", weight_sizes,
", expected input", at::symint::sizes<T>(input), " to have ",
(weight_sizes[1] * groups), " channels, but got ", at::symint::size<T>(input, 1),
" channels instead");
TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && at::symint::size<T>(bias, 0) == weight_sizes[0]),
"Given weight of size ", weight_sizes,
", expected bias to be 1-dimensional with ", weight_sizes[0], " elements",
", but got bias of size ", at::symint::sizes<T>(bias), " instead");
for (const auto i : c10::irange(2, k)) {
input_shape.push_back(at::symint::size<T>(input, i) + 2 * padding[i-2]);
// log new kernel size considering dilation
kernel_shape.push_back(dilation[i-2] * (weight_sizes[i]-1) + 1);
if (input_shape.back() < kernel_shape.back()) {
kernel_size_correct = false;
}
}
TORCH_CHECK(input_shape.size() == kernel_shape.size(), "Inconsistent shape between Input and Kernel");
if (!kernel_size_correct) {
// If kernel size is incorrect
std::ostringstream input_ss;
std::ostringstream kernel_ss;
std::string separator = "";
for (int i = 0, len = input_shape.size(); i < len; ++i) {
input_ss << separator << input_shape[i];
kernel_ss << separator << kernel_shape[i];
separator = " x ";
}
AT_ERROR("Calculated padded input size per channel: (", input_ss.str(), "). "
"Kernel size: (", kernel_ss.str(), "). Kernel size can't be greater than actual input size");
}
} else { // transposed
TORCH_CHECK(at::symint::size<T>(input, 1) == weight_sizes[0],
"Given transposed=", transposed, ", weight of size ", weight_sizes,
", expected input", at::symint::sizes<T>(input), " to have ", weight_sizes[0],
" channels, but got ", at::symint::size<T>(input, 1), " channels instead");
TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && at::symint::size<T>(bias, 0) == weight_sizes[1] * groups),
"Given transposed=", transposed, ", weight of size ", weight_sizes,
", expected bias to be 1-dimensional with ", weight_sizes[1] * groups, " elements",
", but got bias of size ", at::symint::sizes<T>(bias), " instead");
}
}
template <typename T>
static void check_shape_backward(
const at::Tensor& input,
const c10::ArrayRef<T>& weight_sizes,
const ConvParams<T>& params) {
check_shape_forward<T>(input, weight_sizes, /*bias=*/ Tensor(), params);
}
// Given an input tensor and an expected number of spatial dimensions, checks that the
// input is a valid shape and returns the batched form of the input.
//
// Args:
// input (Tensor): Input tensor
// num_spatial_dims (int): Number of spatial dimensions expected for the input
// func_name (string): Function name to produce a nice error message for invalid input
//
// Returns a std::tuple containing:
// batched_input (Tensor): Input with a batch dimension
// is_batched (bool): Indicates whether the original input was already batched
static std::tuple<Tensor, bool> batchify(
const Tensor& input,
const int64_t num_spatial_dims,
const std::string& func_name) {
const auto dim_count_no_batch = num_spatial_dims + 1;
const auto dim_count_batch = dim_count_no_batch + 1;
const auto is_batched = (input.dim() == dim_count_batch);
TORCH_CHECK(input.dim() == dim_count_no_batch || is_batched,
"Expected ", dim_count_no_batch, "D (unbatched) or ", dim_count_batch,
"D (batched) input to ", func_name, ", but got input of size: ", input.sizes());
return std::make_tuple(is_batched ? input : input.unsqueeze(0), is_batched);
}
static void check_input_same_type_as_parameters(
const Tensor& input,
const Tensor& weight,
const Tensor& bias) {
TORCH_CHECK(input.options().type_equal(weight.options()),
"Input type (", input.toString(), ") and weight type (", weight.toString(),
") should be the same");
TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())),
"Input type (", input.toString(), ") and bias type (", bias.toString(),
") should be the same");
}
static void check_input_same_type_as_parameters(
const Tensor& input,
const Tensor& weight) {
check_input_same_type_as_parameters(input, weight, /*bias=*/ Tensor());
}
static void check_input_same_type_as_parameters(
const Tensor& input,
const Tensor& weight,
const Tensor& bias,
const ConvBackend backend) {
if (backend == ConvBackend::Mkldnn || backend == ConvBackend::MkldnnTranspose) {
TORCH_CHECK(input.options().type_equal(weight.options())
|| (input.is_mkldnn() && weight.device().is_cpu() && weight.scalar_type() == kFloat),
"Input type (", input.toString(), ") and weight type (", weight.toString(),
") should be the same or input should be a MKLDNN tensor and weight is a dense tensor");
TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options()))
|| (input.is_mkldnn() && bias.device().is_cpu() && bias.scalar_type() == kFloat),
"Input type (", input.toString(), ") and bias type (", bias.toString(),
") should be the same or input should be a MKLDNN tensor and bias is a dense tensor");
} else {
check_input_same_type_as_parameters(input, weight, bias);
}
}
static auto view4d(const at::Tensor& tensor) -> at::Tensor {
TORCH_CHECK(tensor.ndimension() == 3,
"expected 3D tensor, got tensor with ", tensor.ndimension(),
" dimensions instead");
return tensor.unsqueeze(2);
}
static auto view3d(const at::Tensor& tensor) -> at::Tensor {
TORCH_CHECK(tensor.ndimension() == 4,
"expected 4D tensor, got tensor with ", tensor.ndimension(),
" dimensions instead");
return tensor.squeeze(2);
}
static at::Tensor subtensor(at::Tensor& tensor, int dim, int groups, int g) {
if (!tensor.defined()) {
return at::Tensor();
}
const auto memory_format = tensor.suggest_memory_format();
int64_t n = tensor.sizes()[dim] / groups;
return tensor.narrow(dim, n * g, n).contiguous(memory_format);
}
namespace {
std::pair<Tensor, Tensor> complex_to_real(const Tensor& inp) {
auto inp_view_as_complex = at::view_as_real(inp);
auto dim_i = inp_view_as_complex.dim() - 1;
auto i_r = inp_view_as_complex.select(dim_i, 0);
auto i_i = inp_view_as_complex.select(dim_i, 1);
return std::make_pair(i_r, i_i);
}
at::Tensor complex_convolution(
const Tensor& input,
const Tensor& weight,
const Tensor& bias,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool transposed,
IntArrayRef output_padding,
int64_t groups) {
check_input_same_type_as_parameters(input, weight, bias);
Tensor i_r, i_i, w_r, w_i;
std::tie(i_r, i_i) = complex_to_real(input.resolve_conj());
std::tie(w_r, w_i) = complex_to_real(weight.resolve_conj());
// [NOTE] Complex Convolution
// conv(W, x, b) = conv(Wr, xr, br) - conv(Wi, xi, 0) + i(conv(Wi, xr, bi) + conv(Wr, xi, 0))
// where W, x and b are all complex inputs.
// With Gauss Trick:
// a = conv(Wr, xr, br),
// b = conv(Wi, xi, 0),
// c = conv(Wr + Wi, xr + xi, bi + br)
// conv(W, x, b) = a - b + i(c - a - b)
Tensor a, b, c;
if (!bias.defined()) {
a = at::convolution(i_r, w_r, bias, stride, padding, dilation, transposed, output_padding, groups);
b = at::convolution(i_i, w_i, bias, stride, padding, dilation, transposed, output_padding, groups);
c = at::convolution(i_r + i_i, w_r + w_i, bias, stride, padding, dilation, transposed, output_padding, groups);
} else {
Tensor b_r, b_i;
std::tie(b_r, b_i) = complex_to_real(bias.resolve_conj());
a = at::convolution(i_r, w_r, b_r, stride, padding, dilation, transposed, output_padding, groups);
b = at::convolution(i_i, w_i, Tensor(), stride, padding, dilation, transposed, output_padding, groups);
c = at::convolution(i_r + i_i, w_r + w_i, b_r + b_i, stride, padding, dilation, transposed, output_padding, groups);
}
auto i = c10::Scalar(c10::complex<double>(0, 1));
return a - b + i * (c - a - b);
}
at::Tensor complex_convolution_mode(
const at::Tensor& input,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_opt,
at::IntArrayRef stride,
c10::string_view padding,
at::IntArrayRef dilation,
int64_t groups) {
auto bias = bias_opt.value_or(Tensor());
check_input_same_type_as_parameters(input, weight, bias);
Tensor i_r, i_i, w_r, w_i;
std::tie(i_r, i_i) = complex_to_real(input.resolve_conj());
std::tie(w_r, w_i) = complex_to_real(weight.resolve_conj());
// See [NOTE] Complex Convolution
Tensor a, b, c;
if (!bias.defined()) {
a = at::_convolution_mode(i_r, w_r, bias, stride, padding, dilation, groups);
b = at::_convolution_mode(i_i, w_i, bias, stride, padding, dilation, groups);
c = at::_convolution_mode(i_r + i_i, w_r + w_i, bias, stride, padding, dilation, groups);
} else {
Tensor b_r, b_i;
std::tie(b_r, b_i) = complex_to_real(bias.resolve_conj());
a = at::_convolution_mode(i_r, w_r, b_r, stride, padding, dilation, groups);
b = at::_convolution_mode(i_i, w_i, Tensor(), stride, padding, dilation, groups);
c = at::_convolution_mode(i_r + i_i, w_r + w_i, b_r + b_i, stride, padding, dilation, groups);
}
auto i = c10::Scalar(c10::complex<double>(0, 1));
return a - b + i * (c - a - b);
}
} // namespace
at::Tensor conv1d(
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d");
Tensor output;
if (at::isComplexType(input_.scalar_type())) {
output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {0}, groups);
} else {
output = at::convolution(input, weight, bias, stride, padding, dilation, false, {0}, groups);
}
return is_batched ? std::move(output) : output.squeeze(0);
}
at::Tensor conv2d(
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
TORCH_CHECK(
!bias.defined() || bias.dtype() == input_.dtype(),
"Input type (",
input_.dtype().name(),
") and bias type (",
bias.dtype().name(),
") should be the same");
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 2, "conv2d");
Tensor output;
if (at::isComplexType(input_.scalar_type())) {
output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0}}, groups);
} else {
output = at::convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0}}, groups);
}
return is_batched ? std::move(output) : output.squeeze(0);
}
at::Tensor conv3d(
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 3, "conv3d");
Tensor output;
if (at::isComplexType(input_.scalar_type())) {
output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0, 0}}, groups);
} else {
output = at::convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0, 0}}, groups);
}
return is_batched ? std::move(output) : output.squeeze(0);
}
static Tensor convolution_same(
const Tensor &input, const Tensor &weight, const Tensor &bias,
IntArrayRef stride, IntArrayRef dilation, int64_t groups) {
auto k = weight.dim();
TORCH_CHECK(k > 2, "weight should have at least three dimensions");
auto dim = static_cast<size_t>(k - 2);
auto weight_sizes = weight.sym_sizes();
auto input_sizes = input.sym_sizes();
TORCH_CHECK(k == input.dim(),
"Expected ", k, "-dimensional input for ",
k, "-dimensional weight", weight_sizes, ", but got ",
input.dim(), "-dimensional input of size ",
input.sizes(), " instead");
TORCH_CHECK(stride.size() == dim || stride.size() == 1U,
"stride cannot broadcast to ", dim, " dimensions");
TORCH_CHECK(dilation.size() == dim || dilation.size() == 1U,
"dilation cannot broadcast to ", dim, " dimensions");
for (auto i: c10::irange(stride.size())) {
TORCH_CHECK(stride[i] == 1, "padding='same' is not supported for strided convolutions");
}
// Calculate the correct padding
SymDimVector padding_l, padding_r;
bool symmetric_padding = true;
for (auto i: c10::irange(dim)) {
auto s = stride.size() == 1 ? stride[0] : stride[i];
auto d = dilation.size() == 1 ? dilation[0] : dilation[i];
auto pad = pooling_same_mode_padding_lr(
input_sizes[i + 2], weight_sizes[i + 2], s, d);
padding_l.push_back(pad.first);
padding_r.push_back(pad.second);
if (pad.first != pad.second) {
symmetric_padding = false;
}
}
if (symmetric_padding) {
// All backends handle symmetric padding natively
SymDimVector output_padding(static_cast<size_t>(dim));
return at::convolution_symint(input, weight, bias, stride, padding_l, dilation,
false, output_padding, groups);
}
TORCH_WARN_ONCE("Using padding='same' with even kernel lengths and odd dilation may"
" require a zero-padded copy of the input be created");
SmallVector<c10::SymInt, kDimVectorStaticSize * 2> pad_nd(static_cast<size_t>(2 * dim));
for (auto i: c10::irange(dim)) {
// Apply padding by the difference, leaving only a symmetric padding
auto delta_pad = padding_r[i] - padding_l[i];
auto pad_idx = 2 * (dim - 1 - i); // F.pad goes from last dim to first
if (delta_pad > 0) {
pad_nd[pad_idx + 1] = delta_pad;
} else {
pad_nd[pad_idx] = delta_pad;
padding_l[i] = padding_r[i];
}
}
auto padded_input = at::constant_pad_nd_symint(input, pad_nd, 0);
SymDimVector output_padding(static_cast<size_t>(dim));
return at::convolution_symint(padded_input, weight, bias, stride, padding_l,
dilation, false, output_padding, groups);
}
Tensor _convolution_mode(
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
IntArrayRef stride, c10::string_view padding, IntArrayRef dilation,
int64_t groups) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
if (padding == "same") {
return at::native::convolution_same(
input, weight, bias, stride, dilation, groups);
} else if (padding == "valid") {
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t padding_[] = {0};
return at::convolution(
input, weight, bias, stride, padding_, dilation, false, padding_, groups);
}
TORCH_CHECK(false, "Invalid padding string: '", padding, "'");
}
at::Tensor conv1d(
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias,
IntArrayRef stride, c10::string_view padding, IntArrayRef dilation,
int64_t groups) {
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d");
Tensor output;
if (at::isComplexType(input_.scalar_type())) {
output = complex_convolution_mode(input, weight, bias, stride, std::move(padding), dilation, groups);
} else {
output = at::_convolution_mode(input, weight, bias, stride, std::move(padding), dilation, groups);
}
return is_batched ? std::move(output) : output.squeeze(0);
}
at::Tensor conv2d(
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias,
IntArrayRef stride, c10::string_view padding, IntArrayRef dilation,
int64_t groups) {
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 2, "conv2d");
Tensor output;
if (at::isComplexType(input_.scalar_type())) {
output = complex_convolution_mode(input, weight, bias, stride, std::move(padding), dilation, groups);
} else {
output = at::_convolution_mode(input, weight, bias, stride, std::move(padding), dilation, groups);
}
return is_batched ? std::move(output) : output.squeeze(0);
}
at::Tensor conv3d(
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias,
IntArrayRef stride, c10::string_view padding, IntArrayRef dilation,
int64_t groups) {
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 3, "conv3d");
Tensor output;
if (at::isComplexType(input_.scalar_type())) {
output = complex_convolution_mode(input, weight, bias, stride, std::move(padding), dilation, groups);
} else {
output = at::_convolution_mode(input, weight, bias, stride, std::move(padding), dilation, groups);
}
return is_batched ? std::move(output) : output.squeeze(0);
}
at::Tensor conv_transpose1d(
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
IntArrayRef stride, IntArrayRef padding, IntArrayRef output_padding, int64_t groups, IntArrayRef dilation) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 1, "conv_transpose1d");
Tensor output;
if (at::isComplexType(input_.scalar_type())) {
output = complex_convolution(
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
} else {
output = at::convolution(
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
}
return is_batched ? std::move(output) : output.squeeze(0);
}
at::Tensor conv_transpose2d(
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
IntArrayRef stride, IntArrayRef padding, IntArrayRef output_padding, int64_t groups, IntArrayRef dilation) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 2, "conv_transpose2d");
Tensor output;
if (at::isComplexType(input_.scalar_type())) {
output = complex_convolution(
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
} else {
output = at::convolution(
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
}
return is_batched ? std::move(output) : output.squeeze(0);
}
at::Tensor conv_transpose3d(
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
IntArrayRef stride, IntArrayRef padding, IntArrayRef output_padding, int64_t groups, IntArrayRef dilation) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 3, "conv_transpose3d");
Tensor output;
if (at::isComplexType(input_.scalar_type())) {
output = complex_convolution(
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
} else {
output = at::convolution(
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
}
return is_batched ? std::move(output) : output.squeeze(0);
}
at::Tensor convolution(
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
bool transposed, IntArrayRef output_padding, int64_t groups) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
auto& ctx = at::globalContext();
// See Note [Enabling Deterministic Operations]
bool deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
return at::_convolution(input, weight, bias, stride, padding, dilation,
transposed, output_padding, groups,
ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN());
}
at::Tensor convolution_overrideable(
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
bool transposed, IntArrayRef output_padding, int64_t groups) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "convolution_overrideable not implemented. You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function ");
}
// Function to select the convolution backend based on the inputs and params.
// This overload is used within the convolution internals but not exposed to python.
// NB: The forward pass provides a bias tensor while the backward pass provides
// a bool indicating whether the bias is defined. This is done to save memory by
// avoiding saving the full bias tensor for backward.
template <typename T>
ConvBackend _select_conv_backend(
const Tensor& input,
const Tensor& weight,
const c10::optional<Tensor>& bias,
const at::OptionalArrayRef<T> bias_sizes_opt,
const bool need_backward,
const ConvParams<T>& params) {
// don't send empty inputs through backends
if (at::symint::size<T>(input, 0) == 0 || at::symint::size<T>(input, 1) == 0) {
return input.is_mkldnn() ? ConvBackend::MkldnnEmpty : ConvBackend::Empty;
} else if (at::symint::numel<T>(input) == 0) {
TORCH_CHECK(false, "Only zero batch or zero channel inputs are supported, but got input shape: ", at::symint::sizes<T>(input));
}
if (params.is_depthwise(input, weight)) {
if (params.use_cudnn_depthwise(input, weight)) {
return ConvBackend::Cudnn;
} else if (params.use_miopen(input, weight, bias_sizes_opt.has_value())) {
return ConvBackend::MiopenDepthwise;
} else {
if (input.ndimension() == 4) {
return ConvBackend::CudaDepthwise2d;
} else if (input.ndimension() == 5) {
return ConvBackend::CudaDepthwise3d;
} else {
// unsupported
}
}
} else if (params.use_cudnn(input, weight)) {
if (params.transposed) {
return ConvBackend::CudnnTranspose;
} else {
return ConvBackend::Cudnn;
}
} else if (params.use_miopen(input, weight, bias_sizes_opt.has_value())) {
if (params.transposed) {
return ConvBackend::MiopenTranspose;
} else {
return ConvBackend::Miopen;
}
} else if (params.use_mkldnn(input, weight)) {
if (params.transposed) {
return ConvBackend::MkldnnTranspose;
} else {
return ConvBackend::Mkldnn;
}
} else if (!need_backward && params.use_xnnpack(input, weight, bias_sizes_opt)) {
// Using prepacked conv is preferred, but XNNPACK is still the fastest
// option for NHWC.
return ConvBackend::Xnnpack2d;
// 3x3 depthwith convolutions implementation is inference only
} else if (!need_backward && params.use_cpu_depthwise3x3_winograd(input, weight, bias)) {
return ConvBackend::Winograd3x3Depthwise;
} else if (
!params.transposed && (input.ndimension() == 5) &&
(input.device().is_cpu()) &&
!params.is_dilated()) {
// fast path for grouped conv3d
return ConvBackend::Slow3d;
} else if (input.device().is_cpu() || input.is_cuda()) {
// backends without support for groups
if (params.transposed) {
if (input.ndimension() == 4) {
return ConvBackend::SlowTranspose2d;
} else if (input.ndimension() == 5) {
return ConvBackend::SlowTranspose3d;
} else {
// unsupported
}
} else { /* Not transposed */
if (input.ndimension() == 4) {
if (params.is_dilated()) {
return ConvBackend::SlowDilated2d;
} else { /* dim == 4, non-dilated */
if (params.use_nnpack(input, weight)) {
return ConvBackend::NnpackSpatial;
} else {
/* CPU implementation has specialized MM kernels
for non-dilated case here */
return ConvBackend::Slow2d;
}
}
} else if (input.ndimension() == 5 && (input.is_cuda() || params.is_dilated())) {
return ConvBackend::SlowDilated3d;
} else if (input.ndimension() == 5) { /* dim == 5, CPU, non-dilated */
/* CPU implementation has specialized MM kernels
for non-dilated case here */
return ConvBackend::Slow3d;
} else {
// unsupported
}
}
} else if (params.use_mps(input, weight)) {
if (params.transposed) {
return ConvBackend::MpsTranspose;
} else {
return ConvBackend::Mps;
}
} else {
// Only reach here when input is backend with out-of-source implementation.
return ConvBackend::Overrideable;
}
// Error out if no suitable backend was found.
AT_ERROR("unsupported ConvNd parameters");
}
// Selects a backend for convolution based on the inputs and params.
ConvBackend select_conv_backend(
const Tensor& input_r, const Tensor& weight_r, const c10::optional<Tensor>& bias_opt,
IntArrayRef stride_, SymIntArrayRef padding_, IntArrayRef dilation_,
bool transposed_, SymIntArrayRef output_padding_, int64_t groups_, const at::OptionalSymIntArrayRef bias_sizes_opt) {
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
auto& ctx = at::globalContext();
auto k = weight_r.ndimension();
int64_t dim = k - 2;
ConvParams<c10::SymInt> params;
params.stride = expand_param_if_needed(stride_, "stride", dim);
params.padding = expand_param_if_needed(padding_, "padding", dim);
params.dilation = expand_param_if_needed(dilation_, "dilation", dim);
params.transposed = transposed_;
params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim);
params.groups = groups_;
params.benchmark = ctx.benchmarkCuDNN();
params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
params.cudnn_enabled = ctx.userEnabledCuDNN();
params.allow_tf32 = ctx.allowTF32CuDNN();
auto input = input_r;
auto weight = weight_r;
check_shape_forward(input, weight.sym_sizes(), bias, params);
// Expand 1d -> 2d.
// This is only done for backends that don't natively support 1d spatial input.
if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
// avoid accidentally going through NHWC for permuted 3d input.
input = input.contiguous();
params.view1d_as_2d();
input = view4d(input);
weight = view4d(weight);
}
auto bias_sizes = bias.defined() ? c10::optional<SymIntArrayRef>(bias.sym_sizes()) : bias_sizes_opt;
bool need_backward = GradMode::is_enabled() &&
(input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad()));
return _select_conv_backend(input, weight, bias, bias_sizes, need_backward, params);
}
// For BC reasons, have a copy that does not require bias_opt
static ConvBackend select_conv_backend(
const Tensor& input,
const Tensor& weight,
const at::OptionalIntArrayRef bias_sizes_opt,
const bool need_backward,
const ConvParams<int64_t>& params) {
return _select_conv_backend(input, weight, {}, bias_sizes_opt, need_backward, params);
}
static at::Tensor _convolution_nogroup_backend(
const Tensor& input,
const Tensor& weight,
const Tensor& bias,
const ConvBackend backend,
const ConvParams<int64_t>& params) {
auto kernel_size = weight.sizes().slice(2);
switch(backend) {
case ConvBackend::NnpackSpatial:
#if AT_NNPACK_ENABLED()
return at::_nnpack_spatial_convolution(input, weight, bias, params.padding, params.stride);
#else
TORCH_INTERNAL_ASSERT(false, "NnpackSpatial backend was selected in PyTorch compiled without nnpack support");
#endif
case ConvBackend::Slow2d:
return at::thnn_conv2d(input, weight, kernel_size, bias, params.stride, params.padding);
case ConvBackend::SlowDilated2d:
return at::slow_conv_dilated2d(
input, weight, kernel_size, bias, params.stride, params.padding, params.dilation);
case ConvBackend::SlowDilated3d:
return at::slow_conv_dilated3d(
input, weight, kernel_size, bias, params.stride, params.padding, params.dilation);
case ConvBackend::SlowTranspose2d:
return at::slow_conv_transpose2d(
input, weight, kernel_size, bias, params.stride, params.padding, params.output_padding, params.dilation);
case ConvBackend::SlowTranspose3d:
return at::slow_conv_transpose3d(
input, weight, kernel_size, bias, params.stride, params.padding, params.output_padding, params.dilation);
default:
TORCH_CHECK(false, "Unsupported conv nogroup backend encountered");
}
}
static inline std::vector<int64_t> calc_output_size(
const Tensor& input,
const Tensor& weight,
const ConvParams<int64_t>& params) {
std::vector<int64_t> output_size = params.transposed ?
conv_input_size(input.sizes(), weight.sizes(), params.padding, params.output_padding,
params.stride, params.dilation, params.groups) :
conv_output_size(input.sizes(), weight.sizes(), params.padding, params.stride, params.dilation);
// Handle empty # of channels.
if (input.size(1) == 0) {
output_size[input_channels_dim] = 0;
}
return output_size;
}
static inline at::MemoryFormat determine_backend_memory_format(
const Tensor& input,
const Tensor& weight,
const ConvBackend backend) {
at::MemoryFormat backend_memory_format = at::MemoryFormat::Contiguous;
auto k = weight.ndimension();
#if !defined(C10_MOBILE)
// See Note [Mobile check segfaults]
switch(backend) {
case ConvBackend::Cudnn:
case ConvBackend::CudnnTranspose:
if (detail::getCUDAHooks().compiledWithCuDNN()) {
backend_memory_format = cudnn_conv_suggest_memory_format(input, weight);
}
break;
case ConvBackend::Miopen:
case ConvBackend::MiopenDepthwise:
case ConvBackend::MiopenTranspose:
if (detail::getCUDAHooks().compiledWithMIOpen() && miopen_conv_use_channels_last(input, weight)) {
TORCH_INTERNAL_ASSERT((k == 4 || k == 5),
"Expected 4D or 5D input for miopen memory format selection in determine_backend_memory_format()");
backend_memory_format = (k == 5) ? at::MemoryFormat::Contiguous /*at::MemoryFormat::ChannelsLast3d*/ : at::MemoryFormat::ChannelsLast;
}
break;
case ConvBackend::Mkldnn:
case ConvBackend::MkldnnTranspose:
if (mkldnn_conv_use_channels_last(input, weight)) {
backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
}
break;
case ConvBackend::Slow2d:
case ConvBackend::SlowDilated2d:
case ConvBackend::SlowTranspose2d:
if (thnn_conv_use_channels_last(input, weight)) {
backend_memory_format = at::MemoryFormat::ChannelsLast;
}
break;
default:
backend_memory_format = at::MemoryFormat::Contiguous;
}
#endif
return backend_memory_format;
}
at::MemoryFormat _determine_backend_memory_format(
const Tensor& input,
const Tensor& weight,
const ConvBackend backend) {
return determine_backend_memory_format(input, weight, backend);
}
at::Tensor _convolution(
const Tensor& input_r, const Tensor& weight_r, const c10::optional<Tensor>& bias_r_opt,
IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
bool transposed_, IntArrayRef output_padding_, int64_t groups_,
bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt);
const Tensor& bias_r = *bias_r_maybe_owned;
auto input = input_r;
auto weight = weight_r;
auto bias = bias_r;
auto k = weight.ndimension();
c10::IntArrayRef weight_sizes = weight.sizes();
int64_t dim = k - 2;
TORCH_CHECK(dim > 0, "weight should have at least three dimensions");
TORCH_CHECK(groups_ > 0, "non-positive groups is not supported");
ConvParams<int64_t> params;
params.stride = expand_param_if_needed(stride_, "stride", dim);
params.padding = expand_param_if_needed(padding_, "padding", dim);
params.dilation = expand_param_if_needed(dilation_, "dilation", dim);
params.transposed = transposed_;
params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim);
params.groups = groups_;
params.benchmark = benchmark;
params.deterministic = deterministic;
params.cudnn_enabled = cudnn_enabled;
params.allow_tf32 = allow_tf32;
check_shape_forward(input, weight_sizes, bias, params);
// Expand 1d -> 2d.
// This is only done for backends that don't natively support 1d spatial input.
if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
// avoid accidentally going through NHWC for permuted 3d input.
input = input.contiguous();
params.view1d_as_2d();
input = view4d(input);
weight = view4d(weight);
}
// Select appropriate backend to use.
auto bias_sizes_opt = bias.defined() ? c10::optional<IntArrayRef>(bias.sizes()) : c10::nullopt;
bool need_backward = GradMode::is_enabled() &&
(input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad()));
ConvBackend backend = _select_conv_backend(input, weight, bias, c10::OptionalIntArrayRef(bias_sizes_opt), need_backward, params);
at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight, backend);
// Call the backend.
Tensor output;
auto kernel_size = weight.sizes().slice(2);
switch (backend) {
case ConvBackend::CudaDepthwise2d:
output = at::_conv_depthwise2d(input.contiguous(), weight, kernel_size, bias,
params.stride, params.padding, params.dilation);
break;
case ConvBackend::CudaDepthwise3d:
output = at::conv_depthwise3d(input.contiguous(), weight, kernel_size, bias,
params.stride, params.padding, params.dilation);
break;
case ConvBackend::Cudnn:
check_input_same_type_as_parameters(input, weight, bias);
output = at::cudnn_convolution(
input.contiguous(backend_memory_format), weight, params.padding, params.stride,
params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
if (bias.defined()) {
output.add_(reshape_bias(input.dim(), bias));
}
break;
case ConvBackend::CudnnTranspose:
check_input_same_type_as_parameters(input, weight, bias);
output = at::cudnn_convolution_transpose(
input.contiguous(backend_memory_format), weight, params.padding, params.output_padding,
params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
if (bias.defined()) {
output.add_(reshape_bias(input.dim(), bias));
}
break;
case ConvBackend::Empty:
{
Tensor weight_view;
// Use permute and clone to avoid at::_unsafe_view(weight, -1) failure for non-contiguous cases where
// view size is not compatible with input tensor's size and stride.
if(weight.is_contiguous()) {
weight_view = at::_unsafe_view(weight, -1);
} else if (weight.is_contiguous(at::MemoryFormat::ChannelsLast)) {
weight_view = at::_unsafe_view(at::permute(weight, {0, 2, 3, 1}), -1);
} else if (weight.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
weight_view = at::_unsafe_view(at::permute(weight, {0, 2, 3, 4, 1}), -1);
} else {
weight_view = at::_unsafe_view(weight.clone(at::MemoryFormat::Contiguous), -1);
}
output = (input.size(1) == 0) ? (input.view(-1) * weight_view) : (input * weight_view[0]);
if (bias.defined()) {
output.add_(bias[0]);
}
output = output.view(calc_output_size(input, weight, params));
break;
}
case ConvBackend::Miopen:
check_input_same_type_as_parameters(input, weight, bias);
output = at::miopen_convolution(
input.contiguous(backend_memory_format), weight, bias, params.padding, params.stride,
params.dilation, params.groups, params.benchmark, params.deterministic);
break;
case ConvBackend::MiopenDepthwise:
output = at::miopen_depthwise_convolution(
input.contiguous(backend_memory_format), weight, bias, params.padding, params.stride,
params.dilation, params.groups, params.benchmark, params.deterministic);
break;
case ConvBackend::MiopenTranspose:
check_input_same_type_as_parameters(input, weight, bias);
output = at::miopen_convolution_transpose(
input.contiguous(backend_memory_format), weight, bias, params.padding, params.output_padding,
params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
break;
case ConvBackend::Mkldnn:
#if AT_MKLDNN_ENABLED()
check_input_same_type_as_parameters(input, weight, bias, backend);
if (!input.is_mkldnn()) {
// need to ensure contiguous for non-mkldnn tensors
input = input.contiguous(backend_memory_format);
weight = weight.contiguous(backend_memory_format);
bias = bias.defined() ? bias.contiguous() : bias;
}
output = at::mkldnn_convolution(
input, weight, bias, params.padding, params.stride, params.dilation, params.groups);
#else
TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support");
#endif
break;
case ConvBackend::MkldnnTranspose:
#if AT_MKLDNN_ENABLED()
check_input_same_type_as_parameters(input, weight, bias, backend);
if (!input.is_mkldnn()) {
// need to ensure contiguous for non-mkldnn tensors
input = input.contiguous(backend_memory_format);
weight = weight.contiguous(backend_memory_format);
bias = bias.defined() ? bias.contiguous() : bias;
}
output = mkldnn_convolution_transpose_stub(input.device().type(),
input, weight, bias, params.padding, params.output_padding, params.stride, params.dilation, params.groups);
#else
TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support");
#endif
break;
case ConvBackend::MkldnnEmpty:
#if AT_MKLDNN_ENABLED()
output = empty_mkldnn(
calc_output_size(input, weight, params), optTypeMetaToScalarType(input.options().dtype_opt()),
input.options().layout_opt(), input.options().device_opt(), input.options().pinned_memory_opt());
#else
TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support");
#endif
break;
case ConvBackend::Overrideable:
output = at::convolution_overrideable(
input, weight, bias, params.stride, params.padding, params.dilation, params.transposed,
params.output_padding, params.groups);
break;
case ConvBackend::Slow3d:
output = at::slow_conv3d(input, weight, kernel_size, bias, params.stride, params.padding);
break;
case ConvBackend::Winograd3x3Depthwise:
output = convolution_depthwise3x3_winograd_stub(
input.device().type(), input, weight, bias, params.stride, params.padding, params.groups);
break;
case ConvBackend::Xnnpack2d:
output = xnnpack::convolution2d(
input, weight, bias, params.padding, params.stride, params.dilation, params.groups);
break;
// Handle backends that don't natively support groups > 1.
case ConvBackend::NnpackSpatial:
case ConvBackend::Slow2d:
case ConvBackend::SlowDilated2d:
case ConvBackend::SlowDilated3d:
case ConvBackend::SlowTranspose2d:
case ConvBackend::SlowTranspose3d:
input = input.contiguous(backend_memory_format);
weight = weight.contiguous(backend_memory_format);
if (params.groups == 1) {
output = _convolution_nogroup_backend(input, weight, bias, backend, params);
} else {
std::vector<Tensor> outputs(params.groups);
for (const auto g : c10::irange(params.groups)) {
auto input_g = subtensor(input, 1, params.groups, g);
auto weight_g = subtensor(weight, 0, params.groups, g);
auto bias_g = subtensor(bias, 0, params.groups, g);
outputs[g] = _convolution_nogroup_backend(input_g, weight_g, bias_g, backend, params);
}
output = at::cat(outputs, 1);
}
break;
case ConvBackend::Mps:
#ifdef USE_MPS
TORCH_CHECK(input.options().type_equal(weight.options()),
"Input type (", input.toString(), ") and weight type (", weight.toString(),
") should be the same");
TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())),
"Input type (", input.toString(), ") and bias type (", bias.toString(),
") should be the same");
output = at::_mps_convolution(input.contiguous(), weight, bias.defined() ? bias.contiguous() : bias,
params.padding, params.stride, params.dilation,
params.groups);
#else
TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support");
#endif
break;
case ConvBackend::MpsTranspose:
#ifdef USE_MPS
TORCH_CHECK(input.options().type_equal(weight.options()),
"Input type (", input.toString(), ") and weight type (", weight.toString(),
") should be the same");
TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())),
"Input type (", input.toString(), ") and bias type (", bias.toString(),
") should be the same");
output = at::_mps_convolution_transpose(
input.contiguous(backend_memory_format), weight,
params.padding, params.output_padding,
params.stride, params.dilation, params.groups);
if (bias.defined()) {
output.add_(reshape_bias(input.dim(), bias));
}
#else
TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support");
#endif
break;
}
if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
output = view3d(output);
}
return output;
}
at::Tensor _convolution(
const Tensor& input_r, const Tensor& weight_r, const c10::optional<Tensor>& bias_r_opt,
IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
bool transposed_, IntArrayRef output_padding_, int64_t groups_,
bool benchmark, bool deterministic, bool cudnn_enabled)
{
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt);
const Tensor& bias_r = *bias_r_maybe_owned;
return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN());
}
std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
bool transposed, IntArrayRef output_padding, int64_t groups, std::array<bool, 3> output_mask) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "convolution_backward_overrideable: You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function ");
return std::tuple<Tensor, Tensor, Tensor>(
at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT),
at::empty_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT),
at::empty({}));
}
static Tensor subvariable(const Tensor& var, int dim, int groups, int g) {
int64_t n = var.sizes()[dim] / groups;
auto result = var.narrow(dim, n * g, n);
return result;
}
std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( const c10::optional<Tensor>& ggI_opt, const c10::optional<Tensor>& ggW_r_opt, const c10::optional<Tensor>& ggb_opt,
const Tensor& gO_r, const Tensor& weight_r, const Tensor& input,
IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
bool transposed_, IntArrayRef output_padding_, int64_t groups_,
std::array<bool, 3> output_mask) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> ggI_maybe_owned = at::borrow_from_optional_tensor(ggI_opt);
const Tensor& ggI = *ggI_maybe_owned;
const Tensor& ggW_r = c10::value_or_else(ggW_r_opt, [] {return Tensor();});
const Tensor& ggb = c10::value_or_else(ggb_opt, [] {return Tensor();});
auto ggW = ggW_r;
auto gO = gO_r;
auto weight = weight_r;
int64_t dim = weight.ndimension() - 2;
ConvParams<int64_t> params;
params.stride = expand_param_if_needed(stride_, "stride", dim);
params.padding = expand_param_if_needed(padding_, "padding", dim);
params.dilation = expand_param_if_needed(dilation_, "dilation", dim);
params.transposed = transposed_;
params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim);
// TODO: hacky way of inferring the groups number for grouped Conv3D
// See: https://github.com/pytorch/pytorch/pull/36355
if (!params.transposed && input.dim() > 4) {
// Avoid undefined behavior when num channels == 0; params are unused for that case.
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
params.groups = (weight.size(1) > 0) ? input.size(1) / weight.size(1) : -1;
} else {
params.groups = groups_;
}
// Compute ggO = conv(ggI, w) + conv(i, ggW) + ggb
Tensor ggO;
if (input.numel() != 0) {
if (ggI.defined()) {
if (weight.is_cuda()) {
weight = weight.contiguous();
}
ggO = at::convolution(ggI, weight, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups);
}
if (ggW.defined()) {
if (ggW.is_cuda()) {
ggW = ggW.contiguous();
}
auto ggW_term = at::convolution(input, ggW, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups);
if (ggO.defined()) {
ggO = ggO + ggW_term;
} else {
ggO = ggW_term;
}
}
}
if (ggb.defined()) {
// View as (1, ggb.size(0), 1, 1...)
// Expand
std::vector<int64_t> new_size(gO.ndimension(), 1);
new_size[1] = ggb.sizes()[0];
auto ggb_contiguous = ggb.contiguous();
auto ggb_view = ggb_contiguous.view(new_size);
// Expand
auto ggb_expanded = ggb_view.expand(gO.sizes());
if (ggO.defined()) {
ggO = ggO + ggb_expanded;
} else {
ggO = ggb_expanded;
}
}
// Compute gW = conv(ggI, gO)
Tensor gW;
if (ggI.defined()) {
// Modified params with correct padding
ConvParams<int64_t> gw_conv_params(params);
// Disable groups as they are handled separately
auto groups = gw_conv_params.groups;
gw_conv_params.groups = 1;
std::swap(gw_conv_params.dilation, gw_conv_params.stride);
// Transpose gO and ggI to accumulate over batch
auto gOt = gO.transpose(0, 1);
auto ggIt = ggI.transpose(0, 1);
Tensor gWt;
// Compute conv
if (input.numel() != 0) {
if (groups == 1) {
if (gOt.is_cuda()) {
gOt = gOt.contiguous();
}
// Compute conv
if (params.transposed) {
gw_conv_params.transposed = false;
gWt = at::convolution(gOt, ggIt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
} else {
gWt = at::convolution(ggIt, gOt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
}
} else {
std::vector<Tensor> gWt_list(groups);
for (const auto g : c10::irange(groups)) {
auto ggIt_g = subvariable(ggIt, 0, groups, g);
auto gOt_g = subvariable(gOt, 0, groups, g);
if (gOt_g.is_cuda()) {
gOt_g = gOt_g.contiguous();
}
// Compute conv
if (params.transposed) {
gw_conv_params.transposed = false;
gWt_list[g] = at::convolution(gOt_g, ggIt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
} else {
gWt_list[g] = at::convolution(ggIt_g, gOt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
}
}
gWt = at::cat(gWt_list, 1);
}
// Transpose gW to match chan_in and chan_out
gW = gWt.transpose(0, 1);
// narrow gW to only relevant portion
// we do it this way instead of narrowing the input itself because
// the ConvForward kernels don't support asymmetric padding.
auto gW_size = gW.sizes();
auto w_size = weight.sizes();
for (const auto i : c10::irange(2, gW_size.size())) {
if (gW_size[i] > w_size[i]) {
gW = gW.narrow(i, 0, w_size[i]);
gW_size = gW.sizes();
}
}
}
}
// Compute gI = convT(gO, ggW) if !transposed
// gI = conv(gO, ggw) if transposed
Tensor gI;
if (input.numel() != 0) {
if (ggW.defined()) {
ConvParams<int64_t> gi_conv_params(params);
gi_conv_params.transposed = !params.transposed;
if (params.transposed) {
if (gO.is_cuda()) {
gO = gO.contiguous();
}
gI = at::convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups);
// narrow gI to only relevant portion
// we do it this way because negative output_padding is not supported
// TODO: figure out if we can narrow gO and save some compute,
// rather than narrowing the computed gI
auto gI_size = gI.sizes();
auto i_size = input.sizes();
for (const auto i : c10::irange(2, gI_size.size())) {
if (gI_size[i] > i_size[i]) {
gI = gI.narrow(i, 0, i_size[i]);
gI_size = gI.sizes();
}
}
} else {
// calculate output_padding
// TODO: figure out why this needs to be computed...
auto kernel_size = weight.sizes().slice(2);
auto input_shape = input.sizes().slice(2);
auto grad_output_shape = gO.sizes().slice(2);
for (const auto i : c10::irange(kernel_size.size())) {
// Check if whole input has been used or not
auto expected_input_shape = (kernel_size[i] - 1) * gi_conv_params.dilation[i]
- 2 * gi_conv_params.padding[i]
+ (gi_conv_params.stride[i] * (grad_output_shape[i] - 1) + 1);
if (expected_input_shape != input_shape[i]) {
gi_conv_params.output_padding[i] = input_shape[i] - expected_input_shape;
}
}
if (gO.is_cuda()) {
gO = gO.contiguous();
}
gI = at::convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups);
}
}
}
return std::tuple<Tensor,Tensor,Tensor>{ggO, gI, gW};
}
static std::tuple<at::Tensor, at::Tensor, at::Tensor> _convolution_backward_nogroup_backend(
const Tensor& grad_output,
const Tensor& input,
const Tensor& weight,
const std::array<bool, 3> output_mask,
const ConvBackend backend,
const ConvParams<int64_t>& params) {
auto kernel_size = weight.sizes().slice(2);
switch(backend) {
case ConvBackend::Slow2d:
return at::_slow_conv2d_backward(
grad_output, input, weight, kernel_size, params.stride, params.padding, output_mask);
// NB: nnpack backward does not support strided convolutions; use slow impl instead
case ConvBackend::NnpackSpatial:
case ConvBackend::SlowDilated2d:
return slow_conv_dilated2d_backward_stub(
input.device().type(),
grad_output, input, weight, kernel_size, params.stride, params.padding, params.dilation, output_mask);
case ConvBackend::SlowDilated3d:
return slow_conv_dilated3d_backward_stub(
input.device().type(),
grad_output, input, weight, kernel_size, params.stride, params.padding, params.dilation, output_mask);
case ConvBackend::SlowTranspose2d:
return slow_conv_transpose2d_backward_stub(
input.device().type(), grad_output, input, weight, kernel_size, params.stride, params.padding,
params.output_padding, params.dilation, output_mask);
case ConvBackend::SlowTranspose3d:
return slow_conv_transpose3d_backward_stub(
input.device().type(), grad_output, input, weight, kernel_size, params.stride, params.padding,
params.output_padding, params.dilation, output_mask);
default:
TORCH_CHECK(false, "Unsupported conv nogroup backend encountered");
}
}
// Backward pass for convolution. Computes gradients for input, weight, and bias depending on the
// output_mask setting. This function supports 1D, 2D, or 3D spatial convolution and currently requires
// a single batch dimension to be present.
//
// Args:
// grad_output_: tensor of shape (N, C_out, L_out), (N, C_out, H_out, W_out), or (N, C_out, D_out, H_out, W_out)
// input_: tensor of shape (N, C_in, L_in), (N, C_in, H_in, W_in), or (N, C_in, D_in, H_in, W_in)
// weight_: tensor of shape (C_out, C_in // groups, *kernel_size); dimension of kernel_size must match the number
// of input spatial dimensions
// bias_sizes_opt: if specified, indicates that a bias was used in the forward pass and contains the shape
// of the bias. While the bias shape can be computed from other inputs, it is provided to this function for
// ease of use. The bias shape is (weight.shape[0]) for normal convolution and (weight.shape[1] * groups)
// for transposed convolution.
// stride: single value or an array with dimension matching the number of input spatial dimensions
// padding: single value or an array with dimension matching the number of input spatial dimensions
// dilation: single value or an array with dimension matching the number of input spatial dimensions
// transposed: boolean indicating whether the convolution is transposed
// output_padding: single value or dimension == number of input spatial dimensions; only supported when
// transposed is true
// groups: number of groups for grouped convolution
// output_mask: 3-dim boolean array specifying which gradients to compute in input, weight, bias order
std::tuple<Tensor, Tensor, Tensor> convolution_backward(
const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_,
const at::OptionalIntArrayRef bias_sizes_opt,
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, IntArrayRef output_padding,
int64_t groups, std::array<bool, 3> output_mask) {
auto grad_output = grad_output_;
auto input = input_;
auto weight = weight_;
auto k = weight.ndimension();
int64_t dim = k - 2;
TORCH_CHECK(dim > 0, "weight should have at least three dimensions");
auto& ctx = at::globalContext();
ConvParams<int64_t> params;
params.stride = expand_param_if_needed(stride, "stride", dim);
params.padding = expand_param_if_needed(padding, "padding", dim);
params.dilation = expand_param_if_needed(dilation, "dilation", dim);
params.transposed = transposed;
params.output_padding = expand_param_if_needed(output_padding, "output_padding", dim);
params.groups = groups;
params.benchmark = ctx.benchmarkCuDNN();
params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
params.cudnn_enabled = ctx.userEnabledCuDNN();
params.allow_tf32 = ctx.allowTF32CuDNN();
// Validate inputs.
check_shape_backward(input, weight.sizes(), params);
TORCH_CHECK(input.dim() == grad_output.dim(),
"Expected input and grad_output to have the same number of dimensions, but got: ",
input.dim(), " and ", grad_output.dim());
// output_padding is only supported for transposed convolutions
if (!params.transposed) {
for (auto pad : params.output_padding) {
TORCH_CHECK(pad == 0, "output_padding is not supported for non-transposed convolutions; got: ",
params.output_padding);
}
}
// Expand 1d -> 2d.
// This is only done for backends that don't natively support 1d spatial input.
if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
// avoid accidentally going through NHWC for permuted 3d input.
input = input.contiguous();
params.view1d_as_2d();
grad_output = view4d(grad_output);
input = view4d(input);
weight = view4d(weight);
}
// Select appropriate backend to use.
ConvBackend backend = select_conv_backend(input, weight, bias_sizes_opt, /*need_backward=*/ true, params);
at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight, backend);
// Call the backend.
Tensor backend_grad_input, backend_grad_weight, backend_grad_bias;
auto kernel_size = weight.sizes().slice(2);
switch(backend) {
case ConvBackend::CudaDepthwise2d:
{
std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]};
std::tie(backend_grad_input, backend_grad_weight) =
conv_depthwise2d_backward_stub(input.device().type(), grad_output, input,
weight, kernel_size, params.stride, params.padding, params.dilation, input_weight_output_mask);
break;
}
case ConvBackend::CudaDepthwise3d:
TORCH_CHECK(input.ndimension() == 5);
std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
conv_depthwise3d_backward_stub(
input.device().type(), grad_output, input, weight, kernel_size, params.stride,
params.padding, params.dilation, output_mask);
break;
case ConvBackend::Cudnn:
{
check_input_same_type_as_parameters(input, weight);
std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]};
std::tie(backend_grad_input, backend_grad_weight) = cudnn_convolution_backward_stub(
input.device().type(),
// Only make input contiguous when it is necessary for the backwards computation
output_mask[1] ? input.contiguous(backend_memory_format) : input,
grad_output, weight, params.padding, params.stride,
params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32,
input_weight_output_mask);
break;
}
case ConvBackend::Mps:
{
#ifdef USE_MPS
check_input_same_type_as_parameters(input, weight);
std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
at::mps_convolution_backward(input, grad_output, weight, params.padding,
params.stride, params.dilation, params.groups, output_mask);
#else
TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support");
#endif
break;
}
case ConvBackend::MpsTranspose:
{
#ifdef USE_MPS
check_input_same_type_as_parameters(input, weight);
std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]};
std::tie(backend_grad_input, backend_grad_weight) = at::mps_convolution_transpose_backward(
// Only make input contiguous when it is necessary for the backwards computation
output_mask[1] ? input.contiguous(backend_memory_format) : input,
grad_output, weight, params.padding, params.output_padding,
params.stride, params.dilation, params.groups, input_weight_output_mask);
#else
TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support");
#endif
break;
}
case ConvBackend::CudnnTranspose:
{
check_input_same_type_as_parameters(input, weight);
std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]};
std::tie(backend_grad_input, backend_grad_weight) = cudnn_convolution_transpose_backward_stub(
input.device().type(),
// Only make input contiguous when it is necessary for the backwards computation
output_mask[1] ? input.contiguous(backend_memory_format) : input,
grad_output, weight, params.padding, params.output_padding,
params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32,
input_weight_output_mask);
break;
}
case ConvBackend::Empty:
if (output_mask[0]) {
backend_grad_input = at::zeros_like(input);
}
if (output_mask[1]) {
backend_grad_weight = at::zeros_like(weight);
}
if (output_mask[2]) {
backend_grad_bias = at::zeros(*bias_sizes_opt, weight.options());
}
break;
case ConvBackend::MkldnnEmpty:
#if AT_MKLDNN_ENABLED()
if (output_mask[0]) {
if (input.is_mkldnn()) {
backend_grad_input = empty_mkldnn(input.sizes(), optTypeMetaToScalarType(input.options().dtype_opt()),
input.options().layout_opt(), input.options().device_opt(), input.options().pinned_memory_opt());
backend_grad_input.zero_();
} else {
backend_grad_input = at::zeros_like(input);
}
}
if (output_mask[1]) {
// mkldnn weight is not supported during training by the mkldnn backend
backend_grad_weight = at::zeros_like(weight);
}
if (output_mask[2]) {
// mkldnn bias is not supported during training by the mkldnn backend
backend_grad_bias = at::zeros(*bias_sizes_opt, weight.options());
}
#else
TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support");
#endif
break;
case ConvBackend::Miopen:
check_input_same_type_as_parameters(input, weight);
std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
miopen_convolution_backward_stub(
input.device().type(),
input.contiguous(backend_memory_format), grad_output, weight, params.padding, params.stride,
params.dilation, params.groups, params.benchmark, params.deterministic, output_mask);
break;
case ConvBackend::MiopenDepthwise:
std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
miopen_depthwise_convolution_backward_stub(
input.device().type(),
input.contiguous(backend_memory_format), grad_output, weight, params.padding, params.stride,
params.dilation, params.groups, params.benchmark, params.deterministic, output_mask);
break;
case ConvBackend::MiopenTranspose:
check_input_same_type_as_parameters(input, weight);
std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
miopen_convolution_transpose_backward_stub(
input.device().type(),
input.contiguous(backend_memory_format), grad_output, weight, params.padding, params.output_padding,
params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, output_mask);
break;
case ConvBackend::Mkldnn:
TORCH_CHECK(!weight.is_mkldnn(),
"The MKLDNN backend does not support weight as an MKLDNN tensor during training");
if (!input.is_mkldnn()) {
input = input.contiguous(backend_memory_format);
weight = weight.contiguous(backend_memory_format);
}
std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
mkldnn_convolution_backward_stub(input.device().type(), input, grad_output, weight, params.padding,
params.stride, params.dilation, params.groups, output_mask);
break;
case ConvBackend::MkldnnTranspose:
TORCH_CHECK(!weight.is_mkldnn(),
"The MKLDNN backend does not support weight as an MKLDNN tensor during training");
if (!input.is_mkldnn()) {
input = input.contiguous(backend_memory_format);
weight = weight.contiguous(backend_memory_format);
}
std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
mkldnn_convolution_transpose_backward_stub(input.device().type(), input, grad_output, weight, params.padding,
params.output_padding, params.stride, params.dilation, params.groups, output_mask);
break;
case ConvBackend::Overrideable:
// Only reach here when input is backend with out-of-source implementation.
std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
at::convolution_backward_overrideable(grad_output, input, weight, params.stride, params.padding,
params.dilation, params.transposed, params.output_padding, params.groups, output_mask);
break;
case ConvBackend::Slow3d:
// Note that no CUDA implementation of this kernel exists currently.
std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
slow_conv3d_backward_cpu(
grad_output, input, weight, kernel_size,
params.stride, params.padding, output_mask);
break;
// Handle backends that don't natively support groups > 1.
case ConvBackend::NnpackSpatial:
case ConvBackend::Slow2d:
case ConvBackend::SlowDilated2d:
case ConvBackend::SlowDilated3d:
case ConvBackend::SlowTranspose2d:
case ConvBackend::SlowTranspose3d:
{
input = input.contiguous(backend_memory_format);
weight = weight.contiguous(backend_memory_format);
if (params.groups == 1) {
std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
_convolution_backward_nogroup_backend(
grad_output, input, weight, output_mask, backend, params);
} else {
std::vector<Tensor> backend_grad_inputs(params.groups);
std::vector<Tensor> backend_grad_weights(params.groups);
std::vector<Tensor> backend_grad_biases(params.groups);
for (int g = 0; g < params.groups; ++g) {
auto grad_output_g = subtensor(grad_output, 1, params.groups, g);
auto input_g = subtensor(input, 1, params.groups, g);
auto weight_g = subtensor(weight, 0, params.groups, g);
std::tie(backend_grad_inputs[g], backend_grad_weights[g], backend_grad_biases[g]) =
_convolution_backward_nogroup_backend(
grad_output_g, input_g, weight_g, output_mask, backend, params);
}
if (output_mask[0]) {
backend_grad_input = at::cat(backend_grad_inputs, 1);
}
if (output_mask[1]) {
backend_grad_weight = at::cat(backend_grad_weights, 0);
}
if (output_mask[2]) {
backend_grad_bias = at::cat(backend_grad_biases, 0);
}
}
break;
}
// Backward is not supported for these backends.
case ConvBackend::Winograd3x3Depthwise:
TORCH_CHECK(false, "Backward is not supported for depthwise 3x3 winograd");
break;
case ConvBackend::Xnnpack2d:
TORCH_CHECK(false, "Backward is not supported for xnnpack");
break;
}
// Convert 2D inputs back to 1D for backends that don't natively support 1D
// spatial inputs.
if (output_mask[0]) {
if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
backend_grad_input = view3d(backend_grad_input);
}
}
if (output_mask[1]) {
if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
backend_grad_weight = view3d(backend_grad_weight);
}
}
if (output_mask[2]) {
if (!backend_grad_bias.defined()) {
// Calculate bias gradients outside of the backend for those that don't support it.
backend_grad_bias = grad_output.sum((dim == 3) ? IntArrayRef{0, 2, 3, 4} : IntArrayRef{0, 2, 3});
}
}
return std::make_tuple(backend_grad_input, backend_grad_weight, backend_grad_bias);
}
}} // at::native