Local response norm (#28759)

Summary:
Implemented LocalResponseNorm and some initial tests for modules and functional. Reference https://github.com/pytorch/pytorch/issues/25883
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28759

Differential Revision: D18219745

Pulled By: yf225

fbshipit-source-id: e6aad568a8b1e81f54752decaefd4f9044029da9
diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp
index c058fbd..02f882f 100644
--- a/test/cpp/api/functional.cpp
+++ b/test/cpp/api/functional.cpp
@@ -873,6 +873,26 @@
   ASSERT_TRUE(torch::allclose(y, y_exp));
 }
 
+TEST_F(FunctionalTest, LocalResponseNorm) {
+  const auto x = torch::arange(100, 118).resize_({3, 3, 2});
+  const auto y = F::local_response_norm(x, LocalResponseNormOptions(2));
+  ASSERT_EQ(y.ndimension(), 3);
+  ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 2}));
+  const auto y_exp = torch::tensor(
+    {{{73.7788, 74.1462},
+      {60.1942, 60.3302},
+      {60.4609, 60.5865}},
+    {{75.8729, 76.2011},
+      {60.9331, 61.0390},
+      {61.1403, 61.2370}},
+    {{77.7387, 78.0303},
+      {61.5011, 61.5807},
+      {61.6563, 61.7279}}},
+    torch::kFloat
+  );
+  ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
+}
+
 TEST_F(FunctionalTest, Linear) {
   {
     const auto x = torch::arange(100, 118).resize_({3, 3, 2});
diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp
index e32d1eb..9b71f37 100644
--- a/test/cpp/api/modules.cpp
+++ b/test/cpp/api/modules.cpp
@@ -813,6 +813,47 @@
   }
 }
 
+TEST_F(ModulesTest, LocalResponseNorm) {
+  {
+    LocalResponseNorm model(LocalResponseNormOptions(2));
+    const auto x = torch::arange(100, 136, torch::requires_grad()).reshape({2, 3, 3, 2});
+    auto y = model(x);
+    const auto y_exp = torch::tensor(
+      {{{{73.7788, 74.1462},
+          {74.5031, 74.8572},
+          {75.2010, 75.5420}},
+
+         {{61.6057, 61.7227},
+          {61.8347, 61.9418},
+          {62.0441, 62.1418}},
+
+         {{62.2349, 62.3235},
+          {62.4077, 62.4877},
+          {62.5635, 62.6353}}},
+
+        {{{79.3915, 79.6491},
+          {79.8978, 80.1446},
+          {80.3827, 80.6190}},
+
+         {{63.0317, 63.0742},
+          {63.1135, 63.1496},
+          {63.1826, 63.2126}},
+
+         {{63.2396, 63.2637},
+          {63.2850, 63.3036},
+          {63.3195, 63.3328}}}},
+      torch::kFloat
+    );
+    torch::Tensor s = y.sum();
+
+    s.backward();
+    ASSERT_EQ(y.ndimension(), 4);
+    ASSERT_EQ(s.ndimension(), 0);
+    ASSERT_EQ(y.sizes(), x.sizes());
+    ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
+  }
+}
+
 TEST_F(ModulesTest, LayerNorm) {
   LayerNorm model(LayerNormOptions({2, 2}).eps(2e-5));
   auto x = torch::randn({2, 2}, torch::requires_grad());
@@ -2425,6 +2466,15 @@
           "torch::nn::LayerNorm([2, 2], eps=2e-05, elementwise_affine=false)");
 }
 
+TEST_F(ModulesTest, PrettyPrintLocalResponseNorm) {
+  ASSERT_EQ(
+    c10::str(LocalResponseNorm(LocalResponseNormOptions(2))),
+      "torch::nn::LocalResponseNorm(2, alpha=0.0001, beta=0.75, k=1)");
+  ASSERT_EQ(
+    c10::str(LocalResponseNorm(LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.))),
+      "torch::nn::LocalResponseNorm(2, alpha=0.0002, beta=0.85, k=2)");
+}
+
 TEST_F(ModulesTest, PrettyPrintEmbedding) {
   ASSERT_EQ(
       c10::str(Embedding(EmbeddingOptions(10, 2))),
diff --git a/test/cpp_api_parity/parity-tracker.md b/test/cpp_api_parity/parity-tracker.md
index 0fbda1e..e3dd585 100644
--- a/test/cpp_api_parity/parity-tracker.md
+++ b/test/cpp_api_parity/parity-tracker.md
@@ -78,7 +78,7 @@
 torch.nn.InstanceNorm2d|No|No
 torch.nn.InstanceNorm3d|No|No
 torch.nn.LayerNorm|Yes|No
-torch.nn.LocalResponseNorm|No|No
+torch.nn.LocalResponseNorm|Yes|No
 torch.nn.RNN|No|No
 torch.nn.LSTM|No|No
 torch.nn.GRU|No|No
diff --git a/torch/csrc/api/include/torch/nn/functional/normalization.h b/torch/csrc/api/include/torch/nn/functional/normalization.h
index 70b5759..c4d0367 100644
--- a/torch/csrc/api/include/torch/nn/functional/normalization.h
+++ b/torch/csrc/api/include/torch/nn/functional/normalization.h
@@ -1,6 +1,9 @@
 #pragma once
 
 #include <torch/nn/options/normalization.h>
+#include <torch/nn/functional/padding.h>
+#include <torch/nn/functional/pooling.h>
+#include <torch/types.h>
 
 namespace torch {
 namespace nn {
@@ -10,7 +13,6 @@
     const Tensor& input,
     const NormalizeOptions& options = {},
     c10::optional<Tensor> out = c10::nullopt) {
-
     if (out == c10::nullopt) {
       auto denom = input.norm(options.p(), options.dim(), true).clamp_min(options.eps()).expand_as(input);
       return input / denom;
@@ -28,6 +30,26 @@
     return torch::layer_norm(input, options.normalized_shape(), weight, bias, options.eps());
 }
 
+inline Tensor local_response_norm(
+    const Tensor& input,
+    const LocalResponseNormOptions& options) {
+    auto dim = input.dim();
+    TORCH_CHECK(dim >=3, "Expected 3D or higher dimensionality input (got ", dim, " dimensions)");
+    auto div = input.mul(input).unsqueeze(1);
+    if (dim == 3) {
+      div = pad(div, PadOptions({0, 0, options.size() / 2, (options.size() - 1) / 2}));
+      div = avg_pool2d(div, AvgPool2dOptions({options.size(), 1}).stride(1)).squeeze(1);
+    } else {
+      auto sizes = input.sizes();
+      div = div.view({sizes[0], 1, sizes[1], sizes[2], -1});
+      div = pad(div, PadOptions({0, 0, 0, 0, options.size() / 2, (options.size() - 1) / 2}));
+      div = avg_pool3d(div, AvgPool3dOptions({options.size(), 1, 1}).stride(1)).squeeze(1);
+      div = div.view(sizes);
+    }
+    div = div.mul(options.alpha()).add(options.k()).pow(options.beta());
+    return input / div;
+}
+
 } // namespace functional
 } // namespace nn
 } // namespace torch
diff --git a/torch/csrc/api/include/torch/nn/modules/normalization.h b/torch/csrc/api/include/torch/nn/modules/normalization.h
index ace5ecd..3a65104 100644
--- a/torch/csrc/api/include/torch/nn/modules/normalization.h
+++ b/torch/csrc/api/include/torch/nn/modules/normalization.h
@@ -1,7 +1,6 @@
 #pragma once
 
 #include <torch/nn/cloneable.h>
-#include <torch/nn/cloneable.h>
 #include <torch/nn/functional/normalization.h>
 #include <torch/nn/options/normalization.h>
 #include <torch/nn/pimpl.h>
@@ -52,5 +51,29 @@
 /// module storage semantics.
 TORCH_MODULE(LayerNorm);
 
+/// Applies local response normalization over an input signal composed
+/// of several input planes, where channels occupy the second dimension.
+/// Applies normalization across channels
+/// See https://pytorch.org/docs/master/nn.html#torch.nn.LocalResponseNorm to learn
+/// about the exact behavior of this module.
+class TORCH_API LocalResponseNormImpl : public Cloneable<LocalResponseNormImpl> {
+ public:
+  LocalResponseNormImpl(int64_t size)
+      : LocalResponseNormImpl(LocalResponseNormOptions(size)) {}
+  explicit LocalResponseNormImpl(const LocalResponseNormOptions& options_);
+
+  Tensor forward(const Tensor& input);
+
+  void reset() override;
+
+  /// Pretty prints the `LocalResponseNormImpl` module into the given `stream`.
+  void pretty_print(std::ostream& stream) const override;
+
+  /// The options with which this `Module` was constructed.
+  LocalResponseNormOptions options;
+};
+
+TORCH_MODULE(LocalResponseNorm);
+
 } // namespace nn
 } // namespace torch
diff --git a/torch/csrc/api/include/torch/nn/options/normalization.h b/torch/csrc/api/include/torch/nn/options/normalization.h
index febe515..bc8482b 100644
--- a/torch/csrc/api/include/torch/nn/options/normalization.h
+++ b/torch/csrc/api/include/torch/nn/options/normalization.h
@@ -31,5 +31,23 @@
   TORCH_ARG(bool, elementwise_affine) = true;
 };
 
+// ============================================================================
+
+/// Options for LocalResponseNorm functional and module.
+struct TORCH_API LocalResponseNormOptions {
+  /* implicit */ LocalResponseNormOptions(int64_t size) : size_(size) {}
+  /// amount of neighbouring channels used for normalization
+  TORCH_ARG(int64_t, size);
+  
+  /// multiplicative factor. Default: 1e-4
+  TORCH_ARG(double, alpha) = 1e-4;
+
+  /// exponent. Default: 0.75
+  TORCH_ARG(double, beta) = 0.75;
+
+  /// additive factor. Default: 1
+  TORCH_ARG(double, k) = 1.;
+};
+
 } // namespace nn
 } // namespace torch
diff --git a/torch/csrc/api/src/nn/modules/normalization.cpp b/torch/csrc/api/src/nn/modules/normalization.cpp
index 463f15b..895fcc9 100644
--- a/torch/csrc/api/src/nn/modules/normalization.cpp
+++ b/torch/csrc/api/src/nn/modules/normalization.cpp
@@ -41,5 +41,25 @@
 torch::Tensor LayerNormImpl::forward(const Tensor& input) {
   return F::layer_norm(input, options, weight, bias);
 }
+
+// ============================================================================
+
+LocalResponseNormImpl::LocalResponseNormImpl(const LocalResponseNormOptions& options_)
+    : options(options_) {}
+
+Tensor LocalResponseNormImpl::forward(const Tensor& input) {
+  return F::local_response_norm(input, options);
+}
+
+void LocalResponseNormImpl::reset() {}
+
+void LocalResponseNormImpl::pretty_print(std::ostream& stream) const {
+  stream << std::boolalpha
+         << "torch::nn::LocalResponseNorm(" <<  options.size()
+         << ", alpha=" << options.alpha()
+         << ", beta=" << options.beta()
+         << ", k=" << options.k()
+         << ")";
+}
 } // namespace nn
 } // namespace torch