C++ API: torch::nn::BatchNorm1d (#28176)

Summary:
Add torch::nn::BatchNorm1d function/module support for the C++ API.
torch::nn::BatchNorm{2,3}d will be added after this PR is merged.

Related Issue: https://github.com/pytorch/pytorch/issues/25883

Reviewer: yf225

I would like to discuss about below items.

* Necessity of `num_batches_tracked` in `BatchNormImplBase`
  * `num_batches_tracked` is needed to calculate `momentum` when we do not feed `momentum` argument in Python API. But in C++ API, `momentum` argument has a default value.
  * `num_batches_tracked` is only used for counting up `BatchNorm1d::foward()` call. I think it is no necessary for user anymore.
* The design of `BatchNorm{1,2,3}dOptions`
  * We have already `BatchNormOptions` used for deprecated `BatchNorm` module. However, it is hard to use it for `BatchNorm{1,2,3}dOptions` because of the arguments disagreement of each modules.
  * In this PR, I introduce `BatchNormOptionsv2` template class for the `BatchNorm{1,2,3}dOptions`. But I'm not sure this design is good or not.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28176

Differential Revision: D18196843

Pulled By: yf225

fbshipit-source-id: 667e2b5de4150d5776c41b9088c9e6c2ead24cd4
diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp
index bd1a762..c058fbd 100644
--- a/test/cpp/api/functional.cpp
+++ b/test/cpp/api/functional.cpp
@@ -1266,6 +1266,33 @@
   }
 }
 
+TEST_F(FunctionalTest, BatchNorm1d) {
+  int num_features = 5;
+  double eps = 1e-05;
+  double momentum = 0.1;
+
+  auto input = torch::randn({2, 5});
+  auto mean = torch::randn(5);
+  auto variance = torch::rand(5);
+  auto weight = torch::ones({num_features});
+  auto bias = torch::zeros({num_features});
+  auto output = F::batch_norm(
+    input, mean, variance,
+    BatchNormOptions().weight(weight).bias(bias).momentum(momentum).eps(eps),
+    /*training=*/false);
+  auto expected = (input - mean) / torch::sqrt(variance + eps);
+  ASSERT_TRUE(output.allclose(expected));
+}
+
+TEST_F(FunctionalTest, BatchNorm1dDefaultOptions) {
+  auto input = torch::randn({2, 5});
+  auto mean = torch::randn(5);
+  auto variance = torch::rand(5);
+  auto output = F::batch_norm(input, mean, variance);
+  auto expected = (input - mean) / torch::sqrt(variance + 1e-5);
+  ASSERT_TRUE(output.allclose(expected));
+}
+
 TEST_F(FunctionalTest, Interpolate) {
   {
     // 1D interpolation
diff --git a/test/cpp/api/modulelist.cpp b/test/cpp/api/modulelist.cpp
index e4620ea..68ff0f6 100644
--- a/test/cpp/api/modulelist.cpp
+++ b/test/cpp/api/modulelist.cpp
@@ -281,7 +281,7 @@
       "  (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
       "  (1): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n"
       "  (2): torch::nn::Dropout(rate=0.5)\n"
-      "  (3): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n"
+      "  (3): torch::nn::BatchNorm(num_features=5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
       "  (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
       "  (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
       ")");
diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp
index 89b277f..4d1ac87 100644
--- a/test/cpp/api/modules.cpp
+++ b/test/cpp/api/modules.cpp
@@ -1001,7 +1001,7 @@
   BatchNorm bn(5);
 
   // Is stateful by default.
-  ASSERT_TRUE(bn->options.stateful());
+  ASSERT_TRUE(bn->options.track_running_stats());
 
   ASSERT_TRUE(bn->running_mean.defined());
   ASSERT_EQ(bn->running_mean.dim(), 1);
@@ -1023,7 +1023,7 @@
   ASSERT_EQ(bn->bias.size(0), 5);
 }
 TEST_F(ModulesTest, BatchNormStateless) {
-  BatchNorm bn(BatchNormOptions(5).stateful(false).affine(false));
+  BatchNorm bn(BatchNormOptions(5).track_running_stats(false).affine(false));
 
   ASSERT_FALSE(bn->running_mean.defined());
   ASSERT_FALSE(bn->running_var.defined());
@@ -1033,7 +1033,7 @@
   ASSERT_THROWS_WITH(
       bn(torch::ones({2, 5})),
       "Calling BatchNorm::forward is only permitted "
-      "when the 'stateful' option is true (was false). "
+      "when the 'track_running_stats' option is true (was false). "
       "Use BatchNorm::pure_forward instead.");
 }
 
@@ -1051,6 +1051,71 @@
   ASSERT_TRUE(output.allclose(expected));
 }
 
+TEST_F(ModulesTest, BatchNormLegacyWarning) {
+  std::stringstream buffer;
+  torch::test::CerrRedirect cerr_redirect(buffer.rdbuf());
+
+  BatchNorm bn(5);
+
+  ASSERT_EQ(
+    count_substr_occurrences(
+      buffer.str(),
+      "torch::nn::BatchNorm module is deprecated"
+    ),
+  1);
+}
+
+TEST_F(ModulesTest, BatchNorm1dStateful) {
+  BatchNorm1d bn(BatchNorm1dOptions(5));
+
+  ASSERT_TRUE(bn->options.track_running_stats());
+
+  ASSERT_TRUE(bn->running_mean.defined());
+  ASSERT_EQ(bn->running_mean.dim(), 1);
+  ASSERT_EQ(bn->running_mean.size(0), 5);
+
+  ASSERT_TRUE(bn->running_var.defined());
+  ASSERT_EQ(bn->running_var.dim(), 1);
+  ASSERT_EQ(bn->running_var.size(0), 5);
+
+  ASSERT_TRUE(bn->num_batches_tracked.defined());
+  ASSERT_EQ(bn->num_batches_tracked.dim(), 1);
+  ASSERT_EQ(bn->num_batches_tracked.size(0), 1);
+
+  ASSERT_TRUE(bn->options.affine());
+
+  ASSERT_TRUE(bn->weight.defined());
+  ASSERT_EQ(bn->weight.dim(), 1);
+  ASSERT_EQ(bn->weight.size(0), 5);
+
+  ASSERT_TRUE(bn->bias.defined());
+  ASSERT_EQ(bn->bias.dim(), 1);
+  ASSERT_EQ(bn->bias.size(0), 5);
+}
+
+TEST_F(ModulesTest, BatchNorm1dStateless) {
+  BatchNorm1d bn(BatchNorm1dOptions(5).track_running_stats(false).affine(false));
+
+  ASSERT_FALSE(bn->running_mean.defined());
+  ASSERT_FALSE(bn->running_var.defined());
+  ASSERT_FALSE(bn->num_batches_tracked.defined());
+  ASSERT_FALSE(bn->weight.defined());
+  ASSERT_FALSE(bn->bias.defined());
+}
+
+TEST_F(ModulesTest, BatchNorm1d) {
+  BatchNorm1d bn(BatchNorm1dOptions(5));
+  bn->eval();
+
+  auto input = torch::randn({2, 5}, torch::requires_grad());
+  auto output = bn->forward(input);
+  auto s = output.sum();
+  s.backward();
+  
+  ASSERT_EQ(input.sizes(), input.grad().sizes());
+  ASSERT_TRUE(input.grad().allclose(torch::ones({2, 5})));
+}
+
 TEST_F(ModulesTest, Linear_CUDA) {
   Linear model(5, 2);
   model->to(torch::kCUDA);
@@ -2303,9 +2368,17 @@
 TEST_F(ModulesTest, PrettyPrintBatchNorm) {
   ASSERT_EQ(
       c10::str(BatchNorm(
-          BatchNormOptions(4).eps(0.5).momentum(0.1).affine(false).stateful(
+          BatchNormOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(
               true))),
-      "torch::nn::BatchNorm(features=4, eps=0.5, momentum=0.1, affine=false, stateful=true)");
+      "torch::nn::BatchNorm(num_features=4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
+}
+
+TEST_F(ModulesTest, PrettyPrintBatchNorm1d) {
+  ASSERT_EQ(
+      c10::str(BatchNorm1d(
+          BatchNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false)
+          .track_running_stats(true))),
+      "torch::nn::BatchNorm1d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
 }
 
 TEST_F(ModulesTest, PrettyPrintLayerNorm) {
diff --git a/test/cpp/api/sequential.cpp b/test/cpp/api/sequential.cpp
index 543df66..c3393218 100644
--- a/test/cpp/api/sequential.cpp
+++ b/test/cpp/api/sequential.cpp
@@ -412,7 +412,7 @@
       "  (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
       "  (1): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n"
       "  (2): torch::nn::Dropout(rate=0.5)\n"
-      "  (3): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n"
+      "  (3): torch::nn::BatchNorm(num_features=5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
       "  (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
       "  (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
       ")");
@@ -431,7 +431,7 @@
       "  (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
       "  (conv2d): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n"
       "  (dropout): torch::nn::Dropout(rate=0.5)\n"
-      "  (batchnorm): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n"
+      "  (batchnorm): torch::nn::BatchNorm(num_features=5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
       "  (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
       "  (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
       ")");
diff --git a/test/cpp_api_parity/parity-tracker.md b/test/cpp_api_parity/parity-tracker.md
index 8a5d5f4..5ad56fe 100644
--- a/test/cpp_api_parity/parity-tracker.md
+++ b/test/cpp_api_parity/parity-tracker.md
@@ -69,7 +69,7 @@
 torch.nn.Softmax2d|Yes|No
 torch.nn.LogSoftmax|Yes|No
 torch.nn.AdaptiveLogSoftmaxWithLoss|No|No
-torch.nn.BatchNorm1d|No|No
+torch.nn.BatchNorm1d|Yes|No
 torch.nn.BatchNorm2d|No|No
 torch.nn.BatchNorm3d|No|No
 torch.nn.GroupNorm|No|No
diff --git a/torch/csrc/api/include/torch/nn/functional.h b/torch/csrc/api/include/torch/nn/functional.h
index 90ce395..6531b7f 100644
--- a/torch/csrc/api/include/torch/nn/functional.h
+++ b/torch/csrc/api/include/torch/nn/functional.h
@@ -1,5 +1,6 @@
 #pragma once
 
+#include <torch/nn/functional/batchnorm.h>
 #include <torch/nn/functional/distance.h>
 #include <torch/nn/functional/embedding.h>
 #include <torch/nn/functional/fold.h>
diff --git a/torch/csrc/api/include/torch/nn/functional/batchnorm.h b/torch/csrc/api/include/torch/nn/functional/batchnorm.h
new file mode 100644
index 0000000..a180f3f
--- /dev/null
+++ b/torch/csrc/api/include/torch/nn/functional/batchnorm.h
@@ -0,0 +1,36 @@
+#pragma once
+
+#include <torch/nn/options/batchnorm.h>
+#include <torch/types.h>
+
+namespace torch {
+namespace nn {
+namespace functional {
+
+inline Tensor batch_norm(const Tensor& input, const Tensor& running_mean,
+                         const Tensor& running_var, const BatchNormOptions& options = {}, bool training = false) {
+  if (training) {
+    auto size = input.sizes();
+    int64_t size_prods = size[0];
+    for (size_t i = 0; i < size.size() - 2; i++) {
+      size_prods *= size[i + 2];
+    }
+    TORCH_CHECK(size_prods != 1,
+                "Expected more than 1 value per channel when training, got input size ", size);
+  }
+
+  return torch::batch_norm(
+    input,
+    options.weight(),
+    options.bias(),
+    running_mean,
+    running_var,
+    training,
+    options.momentum().value(),
+    options.eps(),
+    at::globalContext().userEnabledCuDNN());
+}
+
+} // namespace functional
+} // namespace nn
+} // namespace torch
diff --git a/torch/csrc/api/include/torch/nn/modules/batchnorm.h b/torch/csrc/api/include/torch/nn/modules/batchnorm.h
index 210effd..eb3e267 100644
--- a/torch/csrc/api/include/torch/nn/modules/batchnorm.h
+++ b/torch/csrc/api/include/torch/nn/modules/batchnorm.h
@@ -25,8 +25,8 @@
 /// \endrst
 class TORCH_API BatchNormImpl : public torch::nn::Cloneable<BatchNormImpl> {
  public:
-  explicit BatchNormImpl(int64_t features)
-      : BatchNormImpl(BatchNormOptions(features)) {}
+  explicit BatchNormImpl(int64_t num_features)
+      : BatchNormImpl(BatchNormOptions(num_features)) {}
   explicit BatchNormImpl(const BatchNormOptions& options_);
 
   void reset() override;
@@ -37,7 +37,7 @@
   /// Applies batch normalization on the `input` using the stored mean and
   /// variance.
   ///
-  /// The module must be constructed with `stateful = true` when calling this
+  /// The module must be constructed with `track_running_stats = true` when calling this
   /// method, as the module will otherwise not store running statistics. If you
   /// want to supply the mean and variance yourself, use `pure_forward`.
   Tensor forward(const Tensor& input);
@@ -61,11 +61,11 @@
   Tensor bias;
 
   /// The running mean.
-  /// Only defined if the `stateful` option was `true` upon construction.
+  /// Only defined if the `track_running_stats` option was `true` upon construction.
   Tensor running_mean;
 
   /// The running variance.
-  /// Only defined if the `stateful` option was `true` upon construction.
+  /// Only defined if the `track_running_stats` option was `true` upon construction.
   Tensor running_var;
 };
 
@@ -75,5 +75,62 @@
 /// module storage semantics.
 TORCH_MODULE(BatchNorm);
 
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+/// Base class for all (dimension-specialized) batchnorm modules.
+template <size_t D, typename Derived>
+class TORCH_API BatchNormImplBase : public torch::nn::Cloneable<Derived> {
+ protected:
+  virtual void _check_input_dim(const Tensor& input) = 0;
+
+ public:
+  explicit BatchNormImplBase(const BatchNormOptions& options_);
+
+  Tensor forward(const Tensor& input);
+
+  void reset_running_stats();
+
+  void reset() override;
+
+  /// Pretty prints the `BatchNorm{1,2,3}d` module into the given `stream`.
+  void pretty_print(std::ostream& stream) const override;
+
+  /// The options with which this module was constructed.
+  BatchNormOptions options;
+
+  /// The learned weight.
+  /// Only defined if the `affine` option was `true` upon construction.
+  Tensor weight;
+
+  /// The learned bias.
+  /// Only defined if the `affine` option was `true` upon construction.
+  Tensor bias;
+
+  /// The running mean.
+  /// Only defined if the `track_running_stats` option was `true` upon construction.
+  Tensor running_mean;
+
+  /// The running variance.
+  /// Only defined if the `track_running_stats` option was `true` upon construction.
+  Tensor running_var;
+
+  /// The number of the forward call.
+  /// Only defined if the `track_running_stats` option was `true` upon construction.
+  Tensor num_batches_tracked;
+};
+
+/// Applies the BatchNorm1d function.
+/// See https://pytorch.org/docs/master/nn.html#torch.nn.BatchNorm1d to learn
+/// about the exact behavior of this module.
+class TORCH_API BatchNorm1dImpl : public BatchNormImplBase<1, BatchNorm1dImpl> {
+ protected:
+  virtual void _check_input_dim(const Tensor& input) override;
+
+ public:
+  using BatchNormImplBase<1, BatchNorm1dImpl>::BatchNormImplBase;
+};
+
+TORCH_MODULE(BatchNorm1d);
+
 } // namespace nn
 } // namespace torch
diff --git a/torch/csrc/api/include/torch/nn/options/batchnorm.h b/torch/csrc/api/include/torch/nn/options/batchnorm.h
index ca6a952..2aa175c 100644
--- a/torch/csrc/api/include/torch/nn/options/batchnorm.h
+++ b/torch/csrc/api/include/torch/nn/options/batchnorm.h
@@ -9,26 +9,40 @@
 
 /// Options for the `BatchNorm` module.
 struct TORCH_API BatchNormOptions {
-  /* implicit */ BatchNormOptions(int64_t features);
+  BatchNormOptions() {}
+
+  /* implicit */ BatchNormOptions(int64_t num_features);
+
   /// The number of features of the input tensor.
   /// Changing this parameter after construction __has no effect__.
-  TORCH_ARG(int64_t, features);
+  TORCH_ARG(int64_t, num_features);
+
+  /// The epsilon value added for numerical stability.
+  /// Changing this parameter after construction __is effective__.
+  TORCH_ARG(double, eps) = 1e-5;
+
+  /// A momentum multiplier for the mean and variance.
+  /// Changing this parameter after construction __is effective__.
+  TORCH_ARG(c10::optional<double>, momentum) = 0.1;
+
   /// Whether to learn a scale and bias that are applied in an affine
   /// transformation on the input.
   /// Changing this parameter after construction __has no effect__.
   TORCH_ARG(bool, affine) = true;
+
   /// Whether to store and update batch statistics (mean and variance) in the
-  /// module. If `false`, you should call `pure_forward` and supply those batch
-  /// statistics yourself.
+  /// module.
   /// Changing this parameter after construction __has no effect__.
-  TORCH_ARG(bool, stateful) = true;
-  /// The epsilon value added for numerical stability.
-  /// Changing this parameter after construction __is effective__.
-  TORCH_ARG(double, eps) = 1e-5;
-  /// A momentum multiplier for the mean and variance.
-  /// Changing this parameter after construction __is effective__.
-  TORCH_ARG(double, momentum) = 0.1;
+  TORCH_ARG(bool, track_running_stats) = true;
+
+  /// This parameter is only used in `F::batch_norm`.
+  TORCH_ARG(Tensor, weight) = Tensor();
+
+  /// This parameter is only used in `F::batch_norm`.
+  TORCH_ARG(Tensor, bias) = Tensor();
 };
 
+using BatchNorm1dOptions = BatchNormOptions;
+
 } // namespace nn
 } // namespace torch
diff --git a/torch/csrc/api/src/nn/modules/batchnorm.cpp b/torch/csrc/api/src/nn/modules/batchnorm.cpp
index 806d77b..816d07d 100644
--- a/torch/csrc/api/src/nn/modules/batchnorm.cpp
+++ b/torch/csrc/api/src/nn/modules/batchnorm.cpp
@@ -1,7 +1,9 @@
+#include <torch/nn/functional/batchnorm.h>
 #include <torch/nn/modules/batchnorm.h>
 
 #include <torch/cuda.h>
 #include <torch/types.h>
+#include <torch/nn/init.h>
 
 #include <c10/util/Exception.h>
 
@@ -10,41 +12,45 @@
 #include <utility>
 #include <vector>
 
+namespace F = torch::nn::functional;
+
 namespace torch {
 namespace nn {
 
 BatchNormImpl::BatchNormImpl(const BatchNormOptions& options_) : options(options_) {
+  TORCH_WARN("torch::nn::BatchNorm module is deprecated."
+             "Use BatchNorm{1,2,3}d instead.");
   reset();
 }
 
 void BatchNormImpl::reset() {
   if (options.affine()) {
     weight = register_parameter(
-        "weight", torch::empty({options.features()}).uniform_());
-    bias = register_parameter("bias", torch::zeros({options.features()}));
+        "weight", torch::empty({options.num_features()}).uniform_());
+    bias = register_parameter("bias", torch::zeros({options.num_features()}));
   }
 
-  if (options.stateful()) {
+  if (options.track_running_stats()) {
     running_mean =
-        register_buffer("running_mean", torch::zeros({options.features()}));
+        register_buffer("running_mean", torch::zeros({options.num_features()}));
     running_var =
-        register_buffer("running_var", torch::ones({options.features()}));
+        register_buffer("running_var", torch::ones({options.num_features()}));
   }
 }
 
 void BatchNormImpl::pretty_print(std::ostream& stream) const {
   stream << std::boolalpha
-         << "torch::nn::BatchNorm(features=" << options.features()
-         << ", eps=" << options.eps() << ", momentum=" << options.momentum()
-         << ", affine=" << options.affine() << ", stateful=" << options.stateful()
+         << "torch::nn::BatchNorm(num_features=" << options.num_features()
+         << ", eps=" << options.eps() << ", momentum=" << options.momentum().value()
+         << ", affine=" << options.affine() << ", track_running_stats=" << options.track_running_stats()
          << ")";
 }
 
 Tensor BatchNormImpl::forward(const Tensor& input) {
   TORCH_CHECK(
-      options.stateful(),
+      options.track_running_stats(),
       "Calling BatchNorm::forward is only permitted when "
-      "the 'stateful' option is true (was false). "
+      "the 'track_running_stats' option is true (was false). "
       "Use BatchNorm::pure_forward instead.");
   return pure_forward(input, running_mean, running_var);
 }
@@ -67,10 +73,100 @@
       mean,
       variance,
       is_training(),
-      options.momentum(),
+      options.momentum().value(),
       options.eps(),
       torch::cuda::cudnn_is_available());
 }
 
+template <size_t D, typename Derived>
+BatchNormImplBase<D, Derived>::BatchNormImplBase(const BatchNormOptions& options_)
+    : options(options_) {
+  reset();
+}
+
+template <size_t D, typename Derived>
+void BatchNormImplBase<D, Derived>::reset_running_stats() {
+  if (options.track_running_stats()) {
+    running_mean.zero_();
+    running_var.fill_(1);
+    num_batches_tracked.zero_();
+  }
+}
+
+template <size_t D, typename Derived>
+void BatchNormImplBase<D, Derived>::reset() {
+  if (options.affine()) {
+    weight = this->register_parameter("weight", torch::empty({options.num_features()}));
+    bias = this->register_parameter("bias", torch::empty({options.num_features()}));
+  } else {
+    weight = this->register_parameter("weight", Tensor());
+    bias = this->register_parameter("bias", Tensor());
+  }
+  if (options.track_running_stats()) {
+    running_mean = this->register_buffer("running_mean", torch::zeros({options.num_features()}));
+    running_var = this->register_buffer("running_var", torch::ones({options.num_features()}));
+    num_batches_tracked = this->register_buffer("num_batches_tracked", torch::tensor(0, torch::dtype(torch::kLong)));
+  } else {
+    running_mean = this->register_buffer("running_mean", Tensor());
+    running_var = this->register_buffer("running_var", Tensor());
+    num_batches_tracked = this->register_buffer("num_batches_tracked", Tensor());
+  }
+
+  reset_running_stats();
+  if (options.affine()) {
+    torch::nn::init::ones_(weight);
+    torch::nn::init::zeros_(bias);
+  }
+}
+
+template <size_t D, typename Derived>
+void BatchNormImplBase<D, Derived>::pretty_print(std::ostream& stream) const {
+  stream << std::boolalpha
+         << "torch::nn::BatchNorm" << D << "d("
+         << options.num_features() << ", "
+         << "eps=" << options.eps() << ", "
+         << "momentum=" << options.momentum().value() << ", "
+         << "affine=" << options.affine() << ", "
+         << "track_running_stats=" << options.track_running_stats() << ")";
+}
+
+template <size_t D, typename Derived>
+Tensor BatchNormImplBase<D, Derived>::forward(const Tensor& input) {
+  _check_input_dim(input);
+
+  double exponential_average_factor;
+  if (options.momentum() == c10::nullopt) {
+    exponential_average_factor = 0.0;
+  } else {
+    exponential_average_factor = options.momentum().value();
+  }
+
+  if (this->is_training() && options.track_running_stats()) {
+    if (num_batches_tracked.defined()) {
+      num_batches_tracked += 1;
+      if (options.momentum() == c10::nullopt) {  // use cumulative moving average
+        exponential_average_factor = 1.0 / num_batches_tracked.item<double>();
+      } else {  // use exponential moving average
+        exponential_average_factor = options.momentum().value();
+      }
+    }
+  }
+
+  return F::batch_norm(
+      input,
+      running_mean,
+      running_var,
+      BatchNormOptions().weight(weight).bias(bias).momentum(exponential_average_factor).eps(options.eps()),
+      this->is_training() || !options.track_running_stats());
+}
+
+void BatchNorm1dImpl::_check_input_dim(const Tensor& input) {
+  TORCH_CHECK(
+      input.dim() == 2 || input.dim() == 3,
+      "expected 2D or 3D input (got ", input.dim(), "D input)");
+}
+
+template class BatchNormImplBase<1, BatchNorm1dImpl>;
+
 } // namespace nn
 } // namespace torch
diff --git a/torch/csrc/api/src/nn/options/batchnorm.cpp b/torch/csrc/api/src/nn/options/batchnorm.cpp
index 60c363b..2144443 100644
--- a/torch/csrc/api/src/nn/options/batchnorm.cpp
+++ b/torch/csrc/api/src/nn/options/batchnorm.cpp
@@ -3,7 +3,7 @@
 namespace torch {
 namespace nn {
 
-BatchNormOptions::BatchNormOptions(int64_t features) : features_(features) {}
+BatchNormOptions::BatchNormOptions(int64_t num_features) : num_features_(num_features) {}
 
 } // namespace nn
 } // namespace torch