Fix Conv and ConvTranspose implementation (#35023)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35023

This PR fixes Conv and ConvTranspose implementation to match the Python API implementation.

**TODO**: cherry-pick this PR into v1.5 release branch.

Test Plan: Imported from OSS

Differential Revision: D20559889

Pulled By: yf225

fbshipit-source-id: 53783a7398ef968ec6d25b6f568fde44907417c5
diff --git a/torch/csrc/api/include/torch/nn/modules/conv.h b/torch/csrc/api/include/torch/nn/modules/conv.h
index 70fd94c..e480675 100644
--- a/torch/csrc/api/include/torch/nn/modules/conv.h
+++ b/torch/csrc/api/include/torch/nn/modules/conv.h
@@ -4,6 +4,7 @@
 #include <torch/nn/cloneable.h>
 #include <torch/nn/init.h>
 #include <torch/nn/modules/common.h>
+#include <torch/nn/modules/utils.h>
 #include <torch/nn/options/conv.h>
 #include <torch/nn/pimpl.h>
 #include <torch/types.h>
@@ -32,6 +33,8 @@
       options.out_channels() % options.groups() == 0,
       "out_channels must be divisible by groups");
 
+    _padding_repeated_twice = torch::nn::modules::utils::_repeat_vector(options.padding(), 2);
+
     if (options.transposed()) {
       std::vector<int64_t> weight_sizes = {
         options.in_channels(),
@@ -106,6 +109,9 @@
 
   /// The learned bias. Only defined if the `bias` option was true.
   Tensor bias;
+
+ protected:
+  std::vector<int64_t> _padding_repeated_twice;
 };
 
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -163,6 +169,9 @@
   }
   explicit Conv2dImpl(Conv2dOptions options_);
   Tensor forward(const Tensor& input);
+
+ protected:
+  Tensor _conv_forward(const Tensor& input, const Tensor& weight);
 };
 
 /// A `ModuleHolder` subclass for `Conv2dImpl`.
diff --git a/torch/csrc/api/include/torch/nn/modules/utils.h b/torch/csrc/api/include/torch/nn/modules/utils.h
index d28cd9a..baa7aa6 100644
--- a/torch/csrc/api/include/torch/nn/modules/utils.h
+++ b/torch/csrc/api/include/torch/nn/modules/utils.h
@@ -1,3 +1,5 @@
+#pragma once
+
 #include <c10/util/ArrayRef.h>
 #include <c10/util/Optional.h>
 
@@ -8,6 +10,21 @@
 namespace modules {
 namespace utils {
 
+// Repeat each element of `t` for `n` times.
+// This can be used to translate padding arg used by Conv and Pooling modules
+// to the ones used by `F::pad`.
+//
+// This mirrors `_repeat_tuple` in `torch/nn/modules/utils.py`.
+inline std::vector<int64_t> _repeat_vector(at::ArrayRef<int64_t> t, int64_t n) {
+  std::vector<int64_t> ret;
+  for (int64_t elem : t) {
+    for (int64_t i = 0; i < n; i++) {
+      ret.emplace_back(elem);
+    }
+  }
+  return ret;
+}
+
 inline std::vector<int64_t> _list_with_default(
   torch::ArrayRef<c10::optional<int64_t>> out_size, torch::IntArrayRef defaults) {
   TORCH_CHECK(
diff --git a/torch/csrc/api/include/torch/nn/options/conv.h b/torch/csrc/api/include/torch/nn/options/conv.h
index c849b47..20e45c5 100644
--- a/torch/csrc/api/include/torch/nn/options/conv.h
+++ b/torch/csrc/api/include/torch/nn/options/conv.h
@@ -11,7 +11,12 @@
 
 namespace detail {
 
-typedef c10::variant<enumtype::kZeros, enumtype::kCircular> conv_padding_mode_t;
+typedef c10::variant<
+  enumtype::kZeros,
+  enumtype::kReflect,
+  enumtype::kReplicate,
+  enumtype::kCircular
+> conv_padding_mode_t;
 
 /// Options for a `D`-dimensional convolution or convolution transpose module.
 template <size_t D>
@@ -75,7 +80,7 @@
   /// Changing this parameter after construction __has no effect__.
   TORCH_ARG(bool, bias) = true;
 
-  /// Accepted values `zeros` and `circular` Default: `zeros`
+  /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or `torch::kCircular`. Default: `torch::kZeros`
   TORCH_ARG(conv_padding_mode_t, padding_mode) = torch::kZeros;
 };
 
@@ -136,7 +141,7 @@
   /// Changing this parameter after construction __has no effect__.
   TORCH_ARG(bool, bias) = true;
 
-  /// Accepted values `zeros` and `circular` Default: `zeros`
+  /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or `torch::kCircular`. Default: `torch::kZeros`
   TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros;
 };
 
@@ -283,7 +288,7 @@
   /// This parameter __can__ be changed after construction.
   TORCH_ARG(ExpandingArray<D>, dilation) = 1;
 
-  /// Accepted values `zeros` and `circular` Default: `zeros`
+  /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or `torch::kCircular`. Default: `torch::kZeros`
   TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros;
 };
 
diff --git a/torch/csrc/api/src/nn/modules/conv.cpp b/torch/csrc/api/src/nn/modules/conv.cpp
index 99fb947..274751a 100644
--- a/torch/csrc/api/src/nn/modules/conv.cpp
+++ b/torch/csrc/api/src/nn/modules/conv.cpp
@@ -2,6 +2,7 @@
 #include <torch/nn/functional/padding.h>
 #include <torch/nn/modules/conv.h>
 
+#include <torch/enum.h>
 #include <torch/expanding_array.h>
 #include <torch/nn/init.h>
 #include <torch/types.h>
@@ -15,6 +16,20 @@
 
 namespace F = torch::nn::functional;
 
+F::PadFuncOptions::mode_t _get_pad_mode_from_conv_padding_mode(torch::nn::detail::conv_padding_mode_t conv_padding_mode) {
+  F::PadFuncOptions::mode_t pad_mode;
+  if (c10::get_if<torch::enumtype::kReflect>(&conv_padding_mode)) {
+    pad_mode = torch::kReflect;
+  } else if (c10::get_if<torch::enumtype::kReplicate>(&conv_padding_mode)) {
+    pad_mode = torch::kReplicate;
+  } else if (c10::get_if<torch::enumtype::kCircular>(&conv_padding_mode)) {
+    pad_mode = torch::kCircular;
+  } else {
+    TORCH_CHECK(false, "Unsupported conv padding mode: ", torch::enumtype::get_enum_name(conv_padding_mode));
+  }
+  return pad_mode;
+}
+
 namespace torch {
 namespace nn {
 Conv1dImpl::Conv1dImpl(
@@ -34,10 +49,9 @@
           .padding_mode(options_.padding_mode())) {}
 
 Tensor Conv1dImpl::forward(const Tensor& input) {
-  if (c10::get_if<enumtype::kCircular>(&options.padding_mode())) {
-    std::vector<int64_t> expanded_padding = {((*options.padding())[0] + 1) / 2, (*options.padding())[0] / 2};
+  if (!c10::get_if<enumtype::kZeros>(&options.padding_mode())) {
     return F::detail::conv1d(
-      F::detail::pad(input, expanded_padding, torch::kCircular, 0),
+      F::pad(input, F::PadFuncOptions(_padding_repeated_twice).mode(_get_pad_mode_from_conv_padding_mode(options.padding_mode()))),
       weight, bias,
       options.stride(),
       /*padding=*/0,
@@ -70,13 +84,10 @@
           .bias(options_.bias())
           .padding_mode(options_.padding_mode())) {}
 
-Tensor Conv2dImpl::forward(const Tensor& input) {
-  if (c10::get_if<enumtype::kCircular>(&options.padding_mode())) {
-    std::vector<int64_t> expanded_padding = {
-      ((*options.padding())[1] + 1) / 2, (*options.padding())[1] / 2,
-      ((*options.padding())[0] + 1) / 2, (*options.padding())[0] / 2};
+Tensor Conv2dImpl::_conv_forward(const Tensor& input, const Tensor& weight) {
+  if (!c10::get_if<enumtype::kZeros>(&options.padding_mode())) {
     return F::detail::conv2d(
-      F::detail::pad(input, expanded_padding, torch::kCircular, 0),
+      F::pad(input, F::PadFuncOptions(_padding_repeated_twice).mode(_get_pad_mode_from_conv_padding_mode(options.padding_mode()))),
       weight, bias,
       options.stride(),
       /*padding=*/0,
@@ -93,6 +104,10 @@
     options.groups());
 }
 
+Tensor Conv2dImpl::forward(const Tensor& input) {
+  return _conv_forward(input, weight);
+}
+
 Conv3dImpl::Conv3dImpl(
     Conv3dOptions options_)
     : ConvNdImpl(
@@ -110,13 +125,9 @@
           .padding_mode(options_.padding_mode())) {}
 
 Tensor Conv3dImpl::forward(const Tensor& input) {
-  if (c10::get_if<enumtype::kCircular>(&options.padding_mode())) {
-    std::vector<int64_t> expanded_padding = {
-      ((*options.padding())[2] + 1) / 2, (*options.padding())[2] / 2,
-      ((*options.padding())[1] + 1) / 2, (*options.padding())[1] / 2,
-      ((*options.padding())[0] + 1) / 2, (*options.padding())[0] / 2};
+  if (!c10::get_if<enumtype::kZeros>(&options.padding_mode())) {
     return F::detail::conv3d(
-      F::detail::pad(input, expanded_padding, torch::kCircular, 0),
+      F::pad(input, F::PadFuncOptions(_padding_repeated_twice).mode(_get_pad_mode_from_conv_padding_mode(options.padding_mode()))),
       weight, bias,
       options.stride(),
       /*padding=*/0,