C++ API parity: AdaptiveAvgPool1d
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26808
Test Plan: Imported from OSS
Differential Revision: D17627827
Pulled By: pbelevich
fbshipit-source-id: 13ad1d0414e7b62f4fc2f6573332bb2c07b16b53
diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp
index eed61d1..e47298f 100644
--- a/test/cpp/api/functional.cpp
+++ b/test/cpp/api/functional.cpp
@@ -106,3 +106,12 @@
ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3, 3})));
ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 3, 3, 3}));
}
+
+TEST_F(FunctionalTest, AdaptiveAvgPool1d) {
+ auto x = torch::ones({1, 1, 5});
+ auto y = F::adaptive_avg_pool1d(x, AdaptiveAvgPool1dOptions(3));
+
+ ASSERT_EQ(y.ndimension(), 3);
+ ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 3})));
+ ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 3}));
+}
diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp
index c863299..3b8e8c0 100644
--- a/test/cpp/api/modules.cpp
+++ b/test/cpp/api/modules.cpp
@@ -467,6 +467,20 @@
ASSERT_EQ(indices.sizes(), torch::IntArrayRef({1, 3, 3, 3}));
}
+TEST_F(ModulesTest, AdaptiveAvgPool1d) {
+ AdaptiveAvgPool1d model(3);
+ auto x = torch::tensor({{{1, 2, 3, 4, 5}}}, torch::requires_grad());
+ auto y = model(x);
+ torch::Tensor s = y.sum();
+
+ s.backward();
+ ASSERT_EQ(s.ndimension(), 0);
+
+ ASSERT_EQ(y.ndimension(), 3);
+ ASSERT_TRUE(torch::allclose(y, torch::tensor({{{1.5, 3.0, 4.5}}}, torch::kFloat)));
+ ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 3}));
+}
+
TEST_F(ModulesTest, Linear) {
Linear model(5, 2);
auto x = torch::randn({10, 5}, torch::requires_grad());
@@ -843,6 +857,12 @@
"torch::nn::AdaptiveMaxPool3d(output_size=[5, 6, 7])");
}
+TEST_F(ModulesTest, PrettyPrintAdaptiveAvgPool) {
+ ASSERT_EQ(
+ c10::str(AdaptiveAvgPool1d(5)),
+ "torch::nn::AdaptiveAvgPool1d(output_size=5)");
+}
+
TEST_F(ModulesTest, PrettyPrintDropout) {
ASSERT_EQ(c10::str(Dropout(0.5)), "torch::nn::Dropout(rate=0.5)");
ASSERT_EQ(
diff --git a/torch/csrc/api/include/torch/nn/functional/pooling.h b/torch/csrc/api/include/torch/nn/functional/pooling.h
index e706f89..4a286ad 100644
--- a/torch/csrc/api/include/torch/nn/functional/pooling.h
+++ b/torch/csrc/api/include/torch/nn/functional/pooling.h
@@ -132,6 +132,13 @@
return torch::adaptive_max_pool3d(input, options.output_size());
}
+// ============================================================================
+
+inline Tensor adaptive_avg_pool1d(const Tensor& input,
+ const AdaptiveAvgPool1dOptions& options) {
+ return torch::adaptive_avg_pool1d(input, options.output_size());
+}
+
} // namespace functional
} // namespace nn
} // namespace torch
diff --git a/torch/csrc/api/include/torch/nn/modules/pooling.h b/torch/csrc/api/include/torch/nn/modules/pooling.h
index c1a17f5..a9b0bf6 100644
--- a/torch/csrc/api/include/torch/nn/modules/pooling.h
+++ b/torch/csrc/api/include/torch/nn/modules/pooling.h
@@ -248,5 +248,43 @@
/// module storage semantics.
TORCH_MODULE(AdaptiveMaxPool3d);
+// ============================================================================
+
+/// Base class for all (dimension-specialized) adaptive avgpool modules.
+template <size_t D, typename Derived>
+class TORCH_API AdaptiveAvgPoolImpl : public torch::nn::Cloneable<Derived> {
+ public:
+ AdaptiveAvgPoolImpl(ExpandingArray<D> output_size)
+ : AdaptiveAvgPoolImpl(AdaptiveAvgPoolOptions<D>(output_size)) {}
+ explicit AdaptiveAvgPoolImpl(const AdaptiveAvgPoolOptions<D>& options_);
+
+ void reset() override;
+
+ /// Pretty prints the `AdaptiveAvgPool{1,2,3}d` module into the given `stream`.
+ void pretty_print(std::ostream& stream) const override;
+
+ /// The options with which this `Module` was constructed.
+ AdaptiveAvgPoolOptions<D> options;
+};
+
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveAvgPool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+/// Applies adaptive avgpool over a 1-D input.
+/// See https://pytorch.org/docs/master/nn.html#torch.nn.AdaptiveAvgPool1d
+/// to learn about the exact behavior of this module.
+class TORCH_API AdaptiveAvgPool1dImpl :
+ public AdaptiveAvgPoolImpl<1, AdaptiveAvgPool1dImpl> {
+ public:
+ using AdaptiveAvgPoolImpl<1, AdaptiveAvgPool1dImpl>::AdaptiveAvgPoolImpl;
+
+ Tensor forward(const Tensor& input);
+};
+
+/// A `ModuleHolder` subclass for `AdaptiveAvgPool1dImpl`.
+/// See the documentation for `AdaptiveAvgPool1dImpl` class to learn what methods it
+/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
+/// module storage semantics.
+TORCH_MODULE(AdaptiveAvgPool1d);
+
} // namespace nn
} // namespace torch
diff --git a/torch/csrc/api/include/torch/nn/options/pooling.h b/torch/csrc/api/include/torch/nn/options/pooling.h
index 5fece7f..480a96a 100644
--- a/torch/csrc/api/include/torch/nn/options/pooling.h
+++ b/torch/csrc/api/include/torch/nn/options/pooling.h
@@ -96,5 +96,20 @@
/// `AdaptiveMaxPoolOptions` specialized for 3-D adaptive maxpool.
using AdaptiveMaxPool3dOptions = AdaptiveMaxPoolOptions<3>;
+// ============================================================================
+
+/// Options for a `D`-dimensional adaptive avgpool functional and module.
+template <size_t D>
+struct AdaptiveAvgPoolOptions {
+ AdaptiveAvgPoolOptions(ExpandingArray<D> output_size)
+ : output_size_(output_size) {}
+
+ /// the target output size
+ TORCH_ARG(ExpandingArray<D>, output_size);
+};
+
+/// `AdaptiveAvgPoolOptions` specialized for 1-D adaptive avgpool.
+using AdaptiveAvgPool1dOptions = AdaptiveAvgPoolOptions<1>;
+
} // namespace nn
} // namespace torch
diff --git a/torch/csrc/api/src/nn/modules/pooling.cpp b/torch/csrc/api/src/nn/modules/pooling.cpp
index ae06b27..42b4dee 100644
--- a/torch/csrc/api/src/nn/modules/pooling.cpp
+++ b/torch/csrc/api/src/nn/modules/pooling.cpp
@@ -129,5 +129,26 @@
template class AdaptiveMaxPoolImpl<2, AdaptiveMaxPool2dImpl>;
template class AdaptiveMaxPoolImpl<3, AdaptiveMaxPool3dImpl>;
+// ============================================================================
+
+template <size_t D, typename Derived>
+AdaptiveAvgPoolImpl<D, Derived>::AdaptiveAvgPoolImpl(
+ const AdaptiveAvgPoolOptions<D>& options_) : options(options_) {}
+
+template <size_t D, typename Derived>
+void AdaptiveAvgPoolImpl<D, Derived>::reset() {}
+
+template <size_t D, typename Derived>
+void AdaptiveAvgPoolImpl<D, Derived>::pretty_print(std::ostream& stream) const {
+ stream << "torch::nn::AdaptiveAvgPool" << D << "d"
+ << "(output_size=" << options.output_size() << ")";
+}
+
+Tensor AdaptiveAvgPool1dImpl::forward(const Tensor& input) {
+ return F::adaptive_avg_pool1d(input, options);
+}
+
+template class AdaptiveAvgPoolImpl<1, AdaptiveAvgPool1dImpl>;
+
} // namespace nn
} // namespace torch
diff --git a/torch/csrc/api/src/nn/options/pooling.cpp b/torch/csrc/api/src/nn/options/pooling.cpp
index f4e89cb..887998e 100644
--- a/torch/csrc/api/src/nn/options/pooling.cpp
+++ b/torch/csrc/api/src/nn/options/pooling.cpp
@@ -15,5 +15,7 @@
template struct AdaptiveMaxPoolOptions<2>;
template struct AdaptiveMaxPoolOptions<3>;
+template struct AdaptiveAvgPoolOptions<1>;
+
} // namespace nn
} // namespace torch