C++ API: torch::nn::BatchNorm1d (#28176)
Summary:
Add torch::nn::BatchNorm1d function/module support for the C++ API.
torch::nn::BatchNorm{2,3}d will be added after this PR is merged.
Related Issue: https://github.com/pytorch/pytorch/issues/25883
Reviewer: yf225
I would like to discuss about below items.
* Necessity of `num_batches_tracked` in `BatchNormImplBase`
* `num_batches_tracked` is needed to calculate `momentum` when we do not feed `momentum` argument in Python API. But in C++ API, `momentum` argument has a default value.
* `num_batches_tracked` is only used for counting up `BatchNorm1d::foward()` call. I think it is no necessary for user anymore.
* The design of `BatchNorm{1,2,3}dOptions`
* We have already `BatchNormOptions` used for deprecated `BatchNorm` module. However, it is hard to use it for `BatchNorm{1,2,3}dOptions` because of the arguments disagreement of each modules.
* In this PR, I introduce `BatchNormOptionsv2` template class for the `BatchNorm{1,2,3}dOptions`. But I'm not sure this design is good or not.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28176
Differential Revision: D18196843
Pulled By: yf225
fbshipit-source-id: 667e2b5de4150d5776c41b9088c9e6c2ead24cd4
diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp
index bd1a762..c058fbd 100644
--- a/test/cpp/api/functional.cpp
+++ b/test/cpp/api/functional.cpp
@@ -1266,6 +1266,33 @@
}
}
+TEST_F(FunctionalTest, BatchNorm1d) {
+ int num_features = 5;
+ double eps = 1e-05;
+ double momentum = 0.1;
+
+ auto input = torch::randn({2, 5});
+ auto mean = torch::randn(5);
+ auto variance = torch::rand(5);
+ auto weight = torch::ones({num_features});
+ auto bias = torch::zeros({num_features});
+ auto output = F::batch_norm(
+ input, mean, variance,
+ BatchNormOptions().weight(weight).bias(bias).momentum(momentum).eps(eps),
+ /*training=*/false);
+ auto expected = (input - mean) / torch::sqrt(variance + eps);
+ ASSERT_TRUE(output.allclose(expected));
+}
+
+TEST_F(FunctionalTest, BatchNorm1dDefaultOptions) {
+ auto input = torch::randn({2, 5});
+ auto mean = torch::randn(5);
+ auto variance = torch::rand(5);
+ auto output = F::batch_norm(input, mean, variance);
+ auto expected = (input - mean) / torch::sqrt(variance + 1e-5);
+ ASSERT_TRUE(output.allclose(expected));
+}
+
TEST_F(FunctionalTest, Interpolate) {
{
// 1D interpolation
diff --git a/test/cpp/api/modulelist.cpp b/test/cpp/api/modulelist.cpp
index e4620ea..68ff0f6 100644
--- a/test/cpp/api/modulelist.cpp
+++ b/test/cpp/api/modulelist.cpp
@@ -281,7 +281,7 @@
" (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
" (1): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n"
" (2): torch::nn::Dropout(rate=0.5)\n"
- " (3): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n"
+ " (3): torch::nn::BatchNorm(num_features=5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
" (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
")");
diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp
index 89b277f..4d1ac87 100644
--- a/test/cpp/api/modules.cpp
+++ b/test/cpp/api/modules.cpp
@@ -1001,7 +1001,7 @@
BatchNorm bn(5);
// Is stateful by default.
- ASSERT_TRUE(bn->options.stateful());
+ ASSERT_TRUE(bn->options.track_running_stats());
ASSERT_TRUE(bn->running_mean.defined());
ASSERT_EQ(bn->running_mean.dim(), 1);
@@ -1023,7 +1023,7 @@
ASSERT_EQ(bn->bias.size(0), 5);
}
TEST_F(ModulesTest, BatchNormStateless) {
- BatchNorm bn(BatchNormOptions(5).stateful(false).affine(false));
+ BatchNorm bn(BatchNormOptions(5).track_running_stats(false).affine(false));
ASSERT_FALSE(bn->running_mean.defined());
ASSERT_FALSE(bn->running_var.defined());
@@ -1033,7 +1033,7 @@
ASSERT_THROWS_WITH(
bn(torch::ones({2, 5})),
"Calling BatchNorm::forward is only permitted "
- "when the 'stateful' option is true (was false). "
+ "when the 'track_running_stats' option is true (was false). "
"Use BatchNorm::pure_forward instead.");
}
@@ -1051,6 +1051,71 @@
ASSERT_TRUE(output.allclose(expected));
}
+TEST_F(ModulesTest, BatchNormLegacyWarning) {
+ std::stringstream buffer;
+ torch::test::CerrRedirect cerr_redirect(buffer.rdbuf());
+
+ BatchNorm bn(5);
+
+ ASSERT_EQ(
+ count_substr_occurrences(
+ buffer.str(),
+ "torch::nn::BatchNorm module is deprecated"
+ ),
+ 1);
+}
+
+TEST_F(ModulesTest, BatchNorm1dStateful) {
+ BatchNorm1d bn(BatchNorm1dOptions(5));
+
+ ASSERT_TRUE(bn->options.track_running_stats());
+
+ ASSERT_TRUE(bn->running_mean.defined());
+ ASSERT_EQ(bn->running_mean.dim(), 1);
+ ASSERT_EQ(bn->running_mean.size(0), 5);
+
+ ASSERT_TRUE(bn->running_var.defined());
+ ASSERT_EQ(bn->running_var.dim(), 1);
+ ASSERT_EQ(bn->running_var.size(0), 5);
+
+ ASSERT_TRUE(bn->num_batches_tracked.defined());
+ ASSERT_EQ(bn->num_batches_tracked.dim(), 1);
+ ASSERT_EQ(bn->num_batches_tracked.size(0), 1);
+
+ ASSERT_TRUE(bn->options.affine());
+
+ ASSERT_TRUE(bn->weight.defined());
+ ASSERT_EQ(bn->weight.dim(), 1);
+ ASSERT_EQ(bn->weight.size(0), 5);
+
+ ASSERT_TRUE(bn->bias.defined());
+ ASSERT_EQ(bn->bias.dim(), 1);
+ ASSERT_EQ(bn->bias.size(0), 5);
+}
+
+TEST_F(ModulesTest, BatchNorm1dStateless) {
+ BatchNorm1d bn(BatchNorm1dOptions(5).track_running_stats(false).affine(false));
+
+ ASSERT_FALSE(bn->running_mean.defined());
+ ASSERT_FALSE(bn->running_var.defined());
+ ASSERT_FALSE(bn->num_batches_tracked.defined());
+ ASSERT_FALSE(bn->weight.defined());
+ ASSERT_FALSE(bn->bias.defined());
+}
+
+TEST_F(ModulesTest, BatchNorm1d) {
+ BatchNorm1d bn(BatchNorm1dOptions(5));
+ bn->eval();
+
+ auto input = torch::randn({2, 5}, torch::requires_grad());
+ auto output = bn->forward(input);
+ auto s = output.sum();
+ s.backward();
+
+ ASSERT_EQ(input.sizes(), input.grad().sizes());
+ ASSERT_TRUE(input.grad().allclose(torch::ones({2, 5})));
+}
+
TEST_F(ModulesTest, Linear_CUDA) {
Linear model(5, 2);
model->to(torch::kCUDA);
@@ -2303,9 +2368,17 @@
TEST_F(ModulesTest, PrettyPrintBatchNorm) {
ASSERT_EQ(
c10::str(BatchNorm(
- BatchNormOptions(4).eps(0.5).momentum(0.1).affine(false).stateful(
+ BatchNormOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(
true))),
- "torch::nn::BatchNorm(features=4, eps=0.5, momentum=0.1, affine=false, stateful=true)");
+ "torch::nn::BatchNorm(num_features=4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
+}
+
+TEST_F(ModulesTest, PrettyPrintBatchNorm1d) {
+ ASSERT_EQ(
+ c10::str(BatchNorm1d(
+ BatchNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false)
+ .track_running_stats(true))),
+ "torch::nn::BatchNorm1d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
}
TEST_F(ModulesTest, PrettyPrintLayerNorm) {
diff --git a/test/cpp/api/sequential.cpp b/test/cpp/api/sequential.cpp
index 543df66..c3393218 100644
--- a/test/cpp/api/sequential.cpp
+++ b/test/cpp/api/sequential.cpp
@@ -412,7 +412,7 @@
" (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
" (1): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n"
" (2): torch::nn::Dropout(rate=0.5)\n"
- " (3): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n"
+ " (3): torch::nn::BatchNorm(num_features=5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
" (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
")");
@@ -431,7 +431,7 @@
" (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
" (conv2d): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n"
" (dropout): torch::nn::Dropout(rate=0.5)\n"
- " (batchnorm): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n"
+ " (batchnorm): torch::nn::BatchNorm(num_features=5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
" (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
" (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
")");
diff --git a/test/cpp_api_parity/parity-tracker.md b/test/cpp_api_parity/parity-tracker.md
index 8a5d5f4..5ad56fe 100644
--- a/test/cpp_api_parity/parity-tracker.md
+++ b/test/cpp_api_parity/parity-tracker.md
@@ -69,7 +69,7 @@
torch.nn.Softmax2d|Yes|No
torch.nn.LogSoftmax|Yes|No
torch.nn.AdaptiveLogSoftmaxWithLoss|No|No
-torch.nn.BatchNorm1d|No|No
+torch.nn.BatchNorm1d|Yes|No
torch.nn.BatchNorm2d|No|No
torch.nn.BatchNorm3d|No|No
torch.nn.GroupNorm|No|No
diff --git a/torch/csrc/api/include/torch/nn/functional.h b/torch/csrc/api/include/torch/nn/functional.h
index 90ce395..6531b7f 100644
--- a/torch/csrc/api/include/torch/nn/functional.h
+++ b/torch/csrc/api/include/torch/nn/functional.h
@@ -1,5 +1,6 @@
#pragma once
+#include <torch/nn/functional/batchnorm.h>
#include <torch/nn/functional/distance.h>
#include <torch/nn/functional/embedding.h>
#include <torch/nn/functional/fold.h>
diff --git a/torch/csrc/api/include/torch/nn/functional/batchnorm.h b/torch/csrc/api/include/torch/nn/functional/batchnorm.h
new file mode 100644
index 0000000..a180f3f
--- /dev/null
+++ b/torch/csrc/api/include/torch/nn/functional/batchnorm.h
@@ -0,0 +1,36 @@
+#pragma once
+
+#include <torch/nn/options/batchnorm.h>
+#include <torch/types.h>
+
+namespace torch {
+namespace nn {
+namespace functional {
+
+inline Tensor batch_norm(const Tensor& input, const Tensor& running_mean,
+ const Tensor& running_var, const BatchNormOptions& options = {}, bool training = false) {
+ if (training) {
+ auto size = input.sizes();
+ int64_t size_prods = size[0];
+ for (size_t i = 0; i < size.size() - 2; i++) {
+ size_prods *= size[i + 2];
+ }
+ TORCH_CHECK(size_prods != 1,
+ "Expected more than 1 value per channel when training, got input size ", size);
+ }
+
+ return torch::batch_norm(
+ input,
+ options.weight(),
+ options.bias(),
+ running_mean,
+ running_var,
+ training,
+ options.momentum().value(),
+ options.eps(),
+ at::globalContext().userEnabledCuDNN());
+}
+
+} // namespace functional
+} // namespace nn
+} // namespace torch
diff --git a/torch/csrc/api/include/torch/nn/modules/batchnorm.h b/torch/csrc/api/include/torch/nn/modules/batchnorm.h
index 210effd..eb3e267 100644
--- a/torch/csrc/api/include/torch/nn/modules/batchnorm.h
+++ b/torch/csrc/api/include/torch/nn/modules/batchnorm.h
@@ -25,8 +25,8 @@
/// \endrst
class TORCH_API BatchNormImpl : public torch::nn::Cloneable<BatchNormImpl> {
public:
- explicit BatchNormImpl(int64_t features)
- : BatchNormImpl(BatchNormOptions(features)) {}
+ explicit BatchNormImpl(int64_t num_features)
+ : BatchNormImpl(BatchNormOptions(num_features)) {}
explicit BatchNormImpl(const BatchNormOptions& options_);
void reset() override;
@@ -37,7 +37,7 @@
/// Applies batch normalization on the `input` using the stored mean and
/// variance.
///
- /// The module must be constructed with `stateful = true` when calling this
+ /// The module must be constructed with `track_running_stats = true` when calling this
/// method, as the module will otherwise not store running statistics. If you
/// want to supply the mean and variance yourself, use `pure_forward`.
Tensor forward(const Tensor& input);
@@ -61,11 +61,11 @@
Tensor bias;
/// The running mean.
- /// Only defined if the `stateful` option was `true` upon construction.
+ /// Only defined if the `track_running_stats` option was `true` upon construction.
Tensor running_mean;
/// The running variance.
- /// Only defined if the `stateful` option was `true` upon construction.
+ /// Only defined if the `track_running_stats` option was `true` upon construction.
Tensor running_var;
};
@@ -75,5 +75,62 @@
/// module storage semantics.
TORCH_MODULE(BatchNorm);
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+/// Base class for all (dimension-specialized) batchnorm modules.
+template <size_t D, typename Derived>
+class TORCH_API BatchNormImplBase : public torch::nn::Cloneable<Derived> {
+ protected:
+ virtual void _check_input_dim(const Tensor& input) = 0;
+
+ public:
+ explicit BatchNormImplBase(const BatchNormOptions& options_);
+
+ Tensor forward(const Tensor& input);
+
+ void reset_running_stats();
+
+ void reset() override;
+
+ /// Pretty prints the `BatchNorm{1,2,3}d` module into the given `stream`.
+ void pretty_print(std::ostream& stream) const override;
+
+ /// The options with which this module was constructed.
+ BatchNormOptions options;
+
+ /// The learned weight.
+ /// Only defined if the `affine` option was `true` upon construction.
+ Tensor weight;
+
+ /// The learned bias.
+ /// Only defined if the `affine` option was `true` upon construction.
+ Tensor bias;
+
+ /// The running mean.
+ /// Only defined if the `track_running_stats` option was `true` upon construction.
+ Tensor running_mean;
+
+ /// The running variance.
+ /// Only defined if the `track_running_stats` option was `true` upon construction.
+ Tensor running_var;
+
+ /// The number of the forward call.
+ /// Only defined if the `track_running_stats` option was `true` upon construction.
+ Tensor num_batches_tracked;
+};
+
+/// Applies the BatchNorm1d function.
+/// See https://pytorch.org/docs/master/nn.html#torch.nn.BatchNorm1d to learn
+/// about the exact behavior of this module.
+class TORCH_API BatchNorm1dImpl : public BatchNormImplBase<1, BatchNorm1dImpl> {
+ protected:
+ virtual void _check_input_dim(const Tensor& input) override;
+
+ public:
+ using BatchNormImplBase<1, BatchNorm1dImpl>::BatchNormImplBase;
+};
+
+TORCH_MODULE(BatchNorm1d);
+
} // namespace nn
} // namespace torch
diff --git a/torch/csrc/api/include/torch/nn/options/batchnorm.h b/torch/csrc/api/include/torch/nn/options/batchnorm.h
index ca6a952..2aa175c 100644
--- a/torch/csrc/api/include/torch/nn/options/batchnorm.h
+++ b/torch/csrc/api/include/torch/nn/options/batchnorm.h
@@ -9,26 +9,40 @@
/// Options for the `BatchNorm` module.
struct TORCH_API BatchNormOptions {
- /* implicit */ BatchNormOptions(int64_t features);
+ BatchNormOptions() {}
+
+ /* implicit */ BatchNormOptions(int64_t num_features);
+
/// The number of features of the input tensor.
/// Changing this parameter after construction __has no effect__.
- TORCH_ARG(int64_t, features);
+ TORCH_ARG(int64_t, num_features);
+
+ /// The epsilon value added for numerical stability.
+ /// Changing this parameter after construction __is effective__.
+ TORCH_ARG(double, eps) = 1e-5;
+
+ /// A momentum multiplier for the mean and variance.
+ /// Changing this parameter after construction __is effective__.
+ TORCH_ARG(c10::optional<double>, momentum) = 0.1;
+
/// Whether to learn a scale and bias that are applied in an affine
/// transformation on the input.
/// Changing this parameter after construction __has no effect__.
TORCH_ARG(bool, affine) = true;
+
/// Whether to store and update batch statistics (mean and variance) in the
- /// module. If `false`, you should call `pure_forward` and supply those batch
- /// statistics yourself.
+ /// module.
/// Changing this parameter after construction __has no effect__.
- TORCH_ARG(bool, stateful) = true;
- /// The epsilon value added for numerical stability.
- /// Changing this parameter after construction __is effective__.
- TORCH_ARG(double, eps) = 1e-5;
- /// A momentum multiplier for the mean and variance.
- /// Changing this parameter after construction __is effective__.
- TORCH_ARG(double, momentum) = 0.1;
+ TORCH_ARG(bool, track_running_stats) = true;
+
+ /// This parameter is only used in `F::batch_norm`.
+ TORCH_ARG(Tensor, weight) = Tensor();
+
+ /// This parameter is only used in `F::batch_norm`.
+ TORCH_ARG(Tensor, bias) = Tensor();
};
+using BatchNorm1dOptions = BatchNormOptions;
+
} // namespace nn
} // namespace torch
diff --git a/torch/csrc/api/src/nn/modules/batchnorm.cpp b/torch/csrc/api/src/nn/modules/batchnorm.cpp
index 806d77b..816d07d 100644
--- a/torch/csrc/api/src/nn/modules/batchnorm.cpp
+++ b/torch/csrc/api/src/nn/modules/batchnorm.cpp
@@ -1,7 +1,9 @@
+#include <torch/nn/functional/batchnorm.h>
#include <torch/nn/modules/batchnorm.h>
#include <torch/cuda.h>
#include <torch/types.h>
+#include <torch/nn/init.h>
#include <c10/util/Exception.h>
@@ -10,41 +12,45 @@
#include <utility>
#include <vector>
+namespace F = torch::nn::functional;
+
namespace torch {
namespace nn {
BatchNormImpl::BatchNormImpl(const BatchNormOptions& options_) : options(options_) {
+ TORCH_WARN("torch::nn::BatchNorm module is deprecated."
+ "Use BatchNorm{1,2,3}d instead.");
reset();
}
void BatchNormImpl::reset() {
if (options.affine()) {
weight = register_parameter(
- "weight", torch::empty({options.features()}).uniform_());
- bias = register_parameter("bias", torch::zeros({options.features()}));
+ "weight", torch::empty({options.num_features()}).uniform_());
+ bias = register_parameter("bias", torch::zeros({options.num_features()}));
}
- if (options.stateful()) {
+ if (options.track_running_stats()) {
running_mean =
- register_buffer("running_mean", torch::zeros({options.features()}));
+ register_buffer("running_mean", torch::zeros({options.num_features()}));
running_var =
- register_buffer("running_var", torch::ones({options.features()}));
+ register_buffer("running_var", torch::ones({options.num_features()}));
}
}
void BatchNormImpl::pretty_print(std::ostream& stream) const {
stream << std::boolalpha
- << "torch::nn::BatchNorm(features=" << options.features()
- << ", eps=" << options.eps() << ", momentum=" << options.momentum()
- << ", affine=" << options.affine() << ", stateful=" << options.stateful()
+ << "torch::nn::BatchNorm(num_features=" << options.num_features()
+ << ", eps=" << options.eps() << ", momentum=" << options.momentum().value()
+ << ", affine=" << options.affine() << ", track_running_stats=" << options.track_running_stats()
<< ")";
}
Tensor BatchNormImpl::forward(const Tensor& input) {
TORCH_CHECK(
- options.stateful(),
+ options.track_running_stats(),
"Calling BatchNorm::forward is only permitted when "
- "the 'stateful' option is true (was false). "
+ "the 'track_running_stats' option is true (was false). "
"Use BatchNorm::pure_forward instead.");
return pure_forward(input, running_mean, running_var);
}
@@ -67,10 +73,100 @@
mean,
variance,
is_training(),
- options.momentum(),
+ options.momentum().value(),
options.eps(),
torch::cuda::cudnn_is_available());
}
+template <size_t D, typename Derived>
+BatchNormImplBase<D, Derived>::BatchNormImplBase(const BatchNormOptions& options_)
+ : options(options_) {
+ reset();
+}
+
+template <size_t D, typename Derived>
+void BatchNormImplBase<D, Derived>::reset_running_stats() {
+ if (options.track_running_stats()) {
+ running_mean.zero_();
+ running_var.fill_(1);
+ num_batches_tracked.zero_();
+ }
+}
+
+template <size_t D, typename Derived>
+void BatchNormImplBase<D, Derived>::reset() {
+ if (options.affine()) {
+ weight = this->register_parameter("weight", torch::empty({options.num_features()}));
+ bias = this->register_parameter("bias", torch::empty({options.num_features()}));
+ } else {
+ weight = this->register_parameter("weight", Tensor());
+ bias = this->register_parameter("bias", Tensor());
+ }
+ if (options.track_running_stats()) {
+ running_mean = this->register_buffer("running_mean", torch::zeros({options.num_features()}));
+ running_var = this->register_buffer("running_var", torch::ones({options.num_features()}));
+ num_batches_tracked = this->register_buffer("num_batches_tracked", torch::tensor(0, torch::dtype(torch::kLong)));
+ } else {
+ running_mean = this->register_buffer("running_mean", Tensor());
+ running_var = this->register_buffer("running_var", Tensor());
+ num_batches_tracked = this->register_buffer("num_batches_tracked", Tensor());
+ }
+
+ reset_running_stats();
+ if (options.affine()) {
+ torch::nn::init::ones_(weight);
+ torch::nn::init::zeros_(bias);
+ }
+}
+
+template <size_t D, typename Derived>
+void BatchNormImplBase<D, Derived>::pretty_print(std::ostream& stream) const {
+ stream << std::boolalpha
+ << "torch::nn::BatchNorm" << D << "d("
+ << options.num_features() << ", "
+ << "eps=" << options.eps() << ", "
+ << "momentum=" << options.momentum().value() << ", "
+ << "affine=" << options.affine() << ", "
+ << "track_running_stats=" << options.track_running_stats() << ")";
+}
+
+template <size_t D, typename Derived>
+Tensor BatchNormImplBase<D, Derived>::forward(const Tensor& input) {
+ _check_input_dim(input);
+
+ double exponential_average_factor;
+ if (options.momentum() == c10::nullopt) {
+ exponential_average_factor = 0.0;
+ } else {
+ exponential_average_factor = options.momentum().value();
+ }
+
+ if (this->is_training() && options.track_running_stats()) {
+ if (num_batches_tracked.defined()) {
+ num_batches_tracked += 1;
+ if (options.momentum() == c10::nullopt) { // use cumulative moving average
+ exponential_average_factor = 1.0 / num_batches_tracked.item<double>();
+ } else { // use exponential moving average
+ exponential_average_factor = options.momentum().value();
+ }
+ }
+ }
+
+ return F::batch_norm(
+ input,
+ running_mean,
+ running_var,
+ BatchNormOptions().weight(weight).bias(bias).momentum(exponential_average_factor).eps(options.eps()),
+ this->is_training() || !options.track_running_stats());
+}
+
+void BatchNorm1dImpl::_check_input_dim(const Tensor& input) {
+ TORCH_CHECK(
+ input.dim() == 2 || input.dim() == 3,
+ "expected 2D or 3D input (got ", input.dim(), "D input)");
+}
+
+template class BatchNormImplBase<1, BatchNorm1dImpl>;
+
} // namespace nn
} // namespace torch
diff --git a/torch/csrc/api/src/nn/options/batchnorm.cpp b/torch/csrc/api/src/nn/options/batchnorm.cpp
index 60c363b..2144443 100644
--- a/torch/csrc/api/src/nn/options/batchnorm.cpp
+++ b/torch/csrc/api/src/nn/options/batchnorm.cpp
@@ -3,7 +3,7 @@
namespace torch {
namespace nn {
-BatchNormOptions::BatchNormOptions(int64_t features) : features_(features) {}
+BatchNormOptions::BatchNormOptions(int64_t num_features) : num_features_(num_features) {}
} // namespace nn
} // namespace torch