Local response norm (#28759)
Summary:
Implemented LocalResponseNorm and some initial tests for modules and functional. Reference https://github.com/pytorch/pytorch/issues/25883
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28759
Differential Revision: D18219745
Pulled By: yf225
fbshipit-source-id: e6aad568a8b1e81f54752decaefd4f9044029da9
diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp
index c058fbd..02f882f 100644
--- a/test/cpp/api/functional.cpp
+++ b/test/cpp/api/functional.cpp
@@ -873,6 +873,26 @@
ASSERT_TRUE(torch::allclose(y, y_exp));
}
+TEST_F(FunctionalTest, LocalResponseNorm) {
+ const auto x = torch::arange(100, 118).resize_({3, 3, 2});
+ const auto y = F::local_response_norm(x, LocalResponseNormOptions(2));
+ ASSERT_EQ(y.ndimension(), 3);
+ ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 2}));
+ const auto y_exp = torch::tensor(
+ {{{73.7788, 74.1462},
+ {60.1942, 60.3302},
+ {60.4609, 60.5865}},
+ {{75.8729, 76.2011},
+ {60.9331, 61.0390},
+ {61.1403, 61.2370}},
+ {{77.7387, 78.0303},
+ {61.5011, 61.5807},
+ {61.6563, 61.7279}}},
+ torch::kFloat
+ );
+ ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
+}
+
TEST_F(FunctionalTest, Linear) {
{
const auto x = torch::arange(100, 118).resize_({3, 3, 2});
diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp
index e32d1eb..9b71f37 100644
--- a/test/cpp/api/modules.cpp
+++ b/test/cpp/api/modules.cpp
@@ -813,6 +813,47 @@
}
}
+TEST_F(ModulesTest, LocalResponseNorm) {
+ {
+ LocalResponseNorm model(LocalResponseNormOptions(2));
+ const auto x = torch::arange(100, 136, torch::requires_grad()).reshape({2, 3, 3, 2});
+ auto y = model(x);
+ const auto y_exp = torch::tensor(
+ {{{{73.7788, 74.1462},
+ {74.5031, 74.8572},
+ {75.2010, 75.5420}},
+
+ {{61.6057, 61.7227},
+ {61.8347, 61.9418},
+ {62.0441, 62.1418}},
+
+ {{62.2349, 62.3235},
+ {62.4077, 62.4877},
+ {62.5635, 62.6353}}},
+
+ {{{79.3915, 79.6491},
+ {79.8978, 80.1446},
+ {80.3827, 80.6190}},
+
+ {{63.0317, 63.0742},
+ {63.1135, 63.1496},
+ {63.1826, 63.2126}},
+
+ {{63.2396, 63.2637},
+ {63.2850, 63.3036},
+ {63.3195, 63.3328}}}},
+ torch::kFloat
+ );
+ torch::Tensor s = y.sum();
+
+ s.backward();
+ ASSERT_EQ(y.ndimension(), 4);
+ ASSERT_EQ(s.ndimension(), 0);
+ ASSERT_EQ(y.sizes(), x.sizes());
+ ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
+ }
+}
+
TEST_F(ModulesTest, LayerNorm) {
LayerNorm model(LayerNormOptions({2, 2}).eps(2e-5));
auto x = torch::randn({2, 2}, torch::requires_grad());
@@ -2425,6 +2466,15 @@
"torch::nn::LayerNorm([2, 2], eps=2e-05, elementwise_affine=false)");
}
+TEST_F(ModulesTest, PrettyPrintLocalResponseNorm) {
+ ASSERT_EQ(
+ c10::str(LocalResponseNorm(LocalResponseNormOptions(2))),
+ "torch::nn::LocalResponseNorm(2, alpha=0.0001, beta=0.75, k=1)");
+ ASSERT_EQ(
+ c10::str(LocalResponseNorm(LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.))),
+ "torch::nn::LocalResponseNorm(2, alpha=0.0002, beta=0.85, k=2)");
+}
+
TEST_F(ModulesTest, PrettyPrintEmbedding) {
ASSERT_EQ(
c10::str(Embedding(EmbeddingOptions(10, 2))),
diff --git a/test/cpp_api_parity/parity-tracker.md b/test/cpp_api_parity/parity-tracker.md
index 0fbda1e..e3dd585 100644
--- a/test/cpp_api_parity/parity-tracker.md
+++ b/test/cpp_api_parity/parity-tracker.md
@@ -78,7 +78,7 @@
torch.nn.InstanceNorm2d|No|No
torch.nn.InstanceNorm3d|No|No
torch.nn.LayerNorm|Yes|No
-torch.nn.LocalResponseNorm|No|No
+torch.nn.LocalResponseNorm|Yes|No
torch.nn.RNN|No|No
torch.nn.LSTM|No|No
torch.nn.GRU|No|No
diff --git a/torch/csrc/api/include/torch/nn/functional/normalization.h b/torch/csrc/api/include/torch/nn/functional/normalization.h
index 70b5759..c4d0367 100644
--- a/torch/csrc/api/include/torch/nn/functional/normalization.h
+++ b/torch/csrc/api/include/torch/nn/functional/normalization.h
@@ -1,6 +1,9 @@
#pragma once
#include <torch/nn/options/normalization.h>
+#include <torch/nn/functional/padding.h>
+#include <torch/nn/functional/pooling.h>
+#include <torch/types.h>
namespace torch {
namespace nn {
@@ -10,7 +13,6 @@
const Tensor& input,
const NormalizeOptions& options = {},
c10::optional<Tensor> out = c10::nullopt) {
-
if (out == c10::nullopt) {
auto denom = input.norm(options.p(), options.dim(), true).clamp_min(options.eps()).expand_as(input);
return input / denom;
@@ -28,6 +30,26 @@
return torch::layer_norm(input, options.normalized_shape(), weight, bias, options.eps());
}
+inline Tensor local_response_norm(
+ const Tensor& input,
+ const LocalResponseNormOptions& options) {
+ auto dim = input.dim();
+ TORCH_CHECK(dim >=3, "Expected 3D or higher dimensionality input (got ", dim, " dimensions)");
+ auto div = input.mul(input).unsqueeze(1);
+ if (dim == 3) {
+ div = pad(div, PadOptions({0, 0, options.size() / 2, (options.size() - 1) / 2}));
+ div = avg_pool2d(div, AvgPool2dOptions({options.size(), 1}).stride(1)).squeeze(1);
+ } else {
+ auto sizes = input.sizes();
+ div = div.view({sizes[0], 1, sizes[1], sizes[2], -1});
+ div = pad(div, PadOptions({0, 0, 0, 0, options.size() / 2, (options.size() - 1) / 2}));
+ div = avg_pool3d(div, AvgPool3dOptions({options.size(), 1, 1}).stride(1)).squeeze(1);
+ div = div.view(sizes);
+ }
+ div = div.mul(options.alpha()).add(options.k()).pow(options.beta());
+ return input / div;
+}
+
} // namespace functional
} // namespace nn
} // namespace torch
diff --git a/torch/csrc/api/include/torch/nn/modules/normalization.h b/torch/csrc/api/include/torch/nn/modules/normalization.h
index ace5ecd..3a65104 100644
--- a/torch/csrc/api/include/torch/nn/modules/normalization.h
+++ b/torch/csrc/api/include/torch/nn/modules/normalization.h
@@ -1,7 +1,6 @@
#pragma once
#include <torch/nn/cloneable.h>
-#include <torch/nn/cloneable.h>
#include <torch/nn/functional/normalization.h>
#include <torch/nn/options/normalization.h>
#include <torch/nn/pimpl.h>
@@ -52,5 +51,29 @@
/// module storage semantics.
TORCH_MODULE(LayerNorm);
+/// Applies local response normalization over an input signal composed
+/// of several input planes, where channels occupy the second dimension.
+/// Applies normalization across channels
+/// See https://pytorch.org/docs/master/nn.html#torch.nn.LocalResponseNorm to learn
+/// about the exact behavior of this module.
+class TORCH_API LocalResponseNormImpl : public Cloneable<LocalResponseNormImpl> {
+ public:
+ LocalResponseNormImpl(int64_t size)
+ : LocalResponseNormImpl(LocalResponseNormOptions(size)) {}
+ explicit LocalResponseNormImpl(const LocalResponseNormOptions& options_);
+
+ Tensor forward(const Tensor& input);
+
+ void reset() override;
+
+ /// Pretty prints the `LocalResponseNormImpl` module into the given `stream`.
+ void pretty_print(std::ostream& stream) const override;
+
+ /// The options with which this `Module` was constructed.
+ LocalResponseNormOptions options;
+};
+
+TORCH_MODULE(LocalResponseNorm);
+
} // namespace nn
} // namespace torch
diff --git a/torch/csrc/api/include/torch/nn/options/normalization.h b/torch/csrc/api/include/torch/nn/options/normalization.h
index febe515..bc8482b 100644
--- a/torch/csrc/api/include/torch/nn/options/normalization.h
+++ b/torch/csrc/api/include/torch/nn/options/normalization.h
@@ -31,5 +31,23 @@
TORCH_ARG(bool, elementwise_affine) = true;
};
+// ============================================================================
+
+/// Options for LocalResponseNorm functional and module.
+struct TORCH_API LocalResponseNormOptions {
+ /* implicit */ LocalResponseNormOptions(int64_t size) : size_(size) {}
+ /// amount of neighbouring channels used for normalization
+ TORCH_ARG(int64_t, size);
+
+ /// multiplicative factor. Default: 1e-4
+ TORCH_ARG(double, alpha) = 1e-4;
+
+ /// exponent. Default: 0.75
+ TORCH_ARG(double, beta) = 0.75;
+
+ /// additive factor. Default: 1
+ TORCH_ARG(double, k) = 1.;
+};
+
} // namespace nn
} // namespace torch
diff --git a/torch/csrc/api/src/nn/modules/normalization.cpp b/torch/csrc/api/src/nn/modules/normalization.cpp
index 463f15b..895fcc9 100644
--- a/torch/csrc/api/src/nn/modules/normalization.cpp
+++ b/torch/csrc/api/src/nn/modules/normalization.cpp
@@ -41,5 +41,25 @@
torch::Tensor LayerNormImpl::forward(const Tensor& input) {
return F::layer_norm(input, options, weight, bias);
}
+
+// ============================================================================
+
+LocalResponseNormImpl::LocalResponseNormImpl(const LocalResponseNormOptions& options_)
+ : options(options_) {}
+
+Tensor LocalResponseNormImpl::forward(const Tensor& input) {
+ return F::local_response_norm(input, options);
+}
+
+void LocalResponseNormImpl::reset() {}
+
+void LocalResponseNormImpl::pretty_print(std::ostream& stream) const {
+ stream << std::boolalpha
+ << "torch::nn::LocalResponseNorm(" << options.size()
+ << ", alpha=" << options.alpha()
+ << ", beta=" << options.beta()
+ << ", k=" << options.k()
+ << ")";
+}
} // namespace nn
} // namespace torch