Fix Conv and ConvTranspose implementation (#35023)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35023
This PR fixes Conv and ConvTranspose implementation to match the Python API implementation.
**TODO**: cherry-pick this PR into v1.5 release branch.
Test Plan: Imported from OSS
Differential Revision: D20559889
Pulled By: yf225
fbshipit-source-id: 53783a7398ef968ec6d25b6f568fde44907417c5
diff --git a/torch/csrc/api/include/torch/nn/modules/conv.h b/torch/csrc/api/include/torch/nn/modules/conv.h
index 70fd94c..e480675 100644
--- a/torch/csrc/api/include/torch/nn/modules/conv.h
+++ b/torch/csrc/api/include/torch/nn/modules/conv.h
@@ -4,6 +4,7 @@
#include <torch/nn/cloneable.h>
#include <torch/nn/init.h>
#include <torch/nn/modules/common.h>
+#include <torch/nn/modules/utils.h>
#include <torch/nn/options/conv.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>
@@ -32,6 +33,8 @@
options.out_channels() % options.groups() == 0,
"out_channels must be divisible by groups");
+ _padding_repeated_twice = torch::nn::modules::utils::_repeat_vector(options.padding(), 2);
+
if (options.transposed()) {
std::vector<int64_t> weight_sizes = {
options.in_channels(),
@@ -106,6 +109,9 @@
/// The learned bias. Only defined if the `bias` option was true.
Tensor bias;
+
+ protected:
+ std::vector<int64_t> _padding_repeated_twice;
};
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -163,6 +169,9 @@
}
explicit Conv2dImpl(Conv2dOptions options_);
Tensor forward(const Tensor& input);
+
+ protected:
+ Tensor _conv_forward(const Tensor& input, const Tensor& weight);
};
/// A `ModuleHolder` subclass for `Conv2dImpl`.
diff --git a/torch/csrc/api/include/torch/nn/modules/utils.h b/torch/csrc/api/include/torch/nn/modules/utils.h
index d28cd9a..baa7aa6 100644
--- a/torch/csrc/api/include/torch/nn/modules/utils.h
+++ b/torch/csrc/api/include/torch/nn/modules/utils.h
@@ -1,3 +1,5 @@
+#pragma once
+
#include <c10/util/ArrayRef.h>
#include <c10/util/Optional.h>
@@ -8,6 +10,21 @@
namespace modules {
namespace utils {
+// Repeat each element of `t` for `n` times.
+// This can be used to translate padding arg used by Conv and Pooling modules
+// to the ones used by `F::pad`.
+//
+// This mirrors `_repeat_tuple` in `torch/nn/modules/utils.py`.
+inline std::vector<int64_t> _repeat_vector(at::ArrayRef<int64_t> t, int64_t n) {
+ std::vector<int64_t> ret;
+ for (int64_t elem : t) {
+ for (int64_t i = 0; i < n; i++) {
+ ret.emplace_back(elem);
+ }
+ }
+ return ret;
+}
+
inline std::vector<int64_t> _list_with_default(
torch::ArrayRef<c10::optional<int64_t>> out_size, torch::IntArrayRef defaults) {
TORCH_CHECK(
diff --git a/torch/csrc/api/include/torch/nn/options/conv.h b/torch/csrc/api/include/torch/nn/options/conv.h
index c849b47..20e45c5 100644
--- a/torch/csrc/api/include/torch/nn/options/conv.h
+++ b/torch/csrc/api/include/torch/nn/options/conv.h
@@ -11,7 +11,12 @@
namespace detail {
-typedef c10::variant<enumtype::kZeros, enumtype::kCircular> conv_padding_mode_t;
+typedef c10::variant<
+ enumtype::kZeros,
+ enumtype::kReflect,
+ enumtype::kReplicate,
+ enumtype::kCircular
+> conv_padding_mode_t;
/// Options for a `D`-dimensional convolution or convolution transpose module.
template <size_t D>
@@ -75,7 +80,7 @@
/// Changing this parameter after construction __has no effect__.
TORCH_ARG(bool, bias) = true;
- /// Accepted values `zeros` and `circular` Default: `zeros`
+ /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or `torch::kCircular`. Default: `torch::kZeros`
TORCH_ARG(conv_padding_mode_t, padding_mode) = torch::kZeros;
};
@@ -136,7 +141,7 @@
/// Changing this parameter after construction __has no effect__.
TORCH_ARG(bool, bias) = true;
- /// Accepted values `zeros` and `circular` Default: `zeros`
+ /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or `torch::kCircular`. Default: `torch::kZeros`
TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros;
};
@@ -283,7 +288,7 @@
/// This parameter __can__ be changed after construction.
TORCH_ARG(ExpandingArray<D>, dilation) = 1;
- /// Accepted values `zeros` and `circular` Default: `zeros`
+ /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or `torch::kCircular`. Default: `torch::kZeros`
TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros;
};
diff --git a/torch/csrc/api/src/nn/modules/conv.cpp b/torch/csrc/api/src/nn/modules/conv.cpp
index 99fb947..274751a 100644
--- a/torch/csrc/api/src/nn/modules/conv.cpp
+++ b/torch/csrc/api/src/nn/modules/conv.cpp
@@ -2,6 +2,7 @@
#include <torch/nn/functional/padding.h>
#include <torch/nn/modules/conv.h>
+#include <torch/enum.h>
#include <torch/expanding_array.h>
#include <torch/nn/init.h>
#include <torch/types.h>
@@ -15,6 +16,20 @@
namespace F = torch::nn::functional;
+F::PadFuncOptions::mode_t _get_pad_mode_from_conv_padding_mode(torch::nn::detail::conv_padding_mode_t conv_padding_mode) {
+ F::PadFuncOptions::mode_t pad_mode;
+ if (c10::get_if<torch::enumtype::kReflect>(&conv_padding_mode)) {
+ pad_mode = torch::kReflect;
+ } else if (c10::get_if<torch::enumtype::kReplicate>(&conv_padding_mode)) {
+ pad_mode = torch::kReplicate;
+ } else if (c10::get_if<torch::enumtype::kCircular>(&conv_padding_mode)) {
+ pad_mode = torch::kCircular;
+ } else {
+ TORCH_CHECK(false, "Unsupported conv padding mode: ", torch::enumtype::get_enum_name(conv_padding_mode));
+ }
+ return pad_mode;
+}
+
namespace torch {
namespace nn {
Conv1dImpl::Conv1dImpl(
@@ -34,10 +49,9 @@
.padding_mode(options_.padding_mode())) {}
Tensor Conv1dImpl::forward(const Tensor& input) {
- if (c10::get_if<enumtype::kCircular>(&options.padding_mode())) {
- std::vector<int64_t> expanded_padding = {((*options.padding())[0] + 1) / 2, (*options.padding())[0] / 2};
+ if (!c10::get_if<enumtype::kZeros>(&options.padding_mode())) {
return F::detail::conv1d(
- F::detail::pad(input, expanded_padding, torch::kCircular, 0),
+ F::pad(input, F::PadFuncOptions(_padding_repeated_twice).mode(_get_pad_mode_from_conv_padding_mode(options.padding_mode()))),
weight, bias,
options.stride(),
/*padding=*/0,
@@ -70,13 +84,10 @@
.bias(options_.bias())
.padding_mode(options_.padding_mode())) {}
-Tensor Conv2dImpl::forward(const Tensor& input) {
- if (c10::get_if<enumtype::kCircular>(&options.padding_mode())) {
- std::vector<int64_t> expanded_padding = {
- ((*options.padding())[1] + 1) / 2, (*options.padding())[1] / 2,
- ((*options.padding())[0] + 1) / 2, (*options.padding())[0] / 2};
+Tensor Conv2dImpl::_conv_forward(const Tensor& input, const Tensor& weight) {
+ if (!c10::get_if<enumtype::kZeros>(&options.padding_mode())) {
return F::detail::conv2d(
- F::detail::pad(input, expanded_padding, torch::kCircular, 0),
+ F::pad(input, F::PadFuncOptions(_padding_repeated_twice).mode(_get_pad_mode_from_conv_padding_mode(options.padding_mode()))),
weight, bias,
options.stride(),
/*padding=*/0,
@@ -93,6 +104,10 @@
options.groups());
}
+Tensor Conv2dImpl::forward(const Tensor& input) {
+ return _conv_forward(input, weight);
+}
+
Conv3dImpl::Conv3dImpl(
Conv3dOptions options_)
: ConvNdImpl(
@@ -110,13 +125,9 @@
.padding_mode(options_.padding_mode())) {}
Tensor Conv3dImpl::forward(const Tensor& input) {
- if (c10::get_if<enumtype::kCircular>(&options.padding_mode())) {
- std::vector<int64_t> expanded_padding = {
- ((*options.padding())[2] + 1) / 2, (*options.padding())[2] / 2,
- ((*options.padding())[1] + 1) / 2, (*options.padding())[1] / 2,
- ((*options.padding())[0] + 1) / 2, (*options.padding())[0] / 2};
+ if (!c10::get_if<enumtype::kZeros>(&options.padding_mode())) {
return F::detail::conv3d(
- F::detail::pad(input, expanded_padding, torch::kCircular, 0),
+ F::pad(input, F::PadFuncOptions(_padding_repeated_twice).mode(_get_pad_mode_from_conv_padding_mode(options.padding_mode()))),
weight, bias,
options.stride(),
/*padding=*/0,